mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 22:48:03 +00:00
Extensions MCP refactor (#12413)
This commit is contained in:
@@ -113,7 +113,7 @@ export class Task {
|
||||
// state managed within the @gemini-cli/core module.
|
||||
async getMetadata(): Promise<TaskMetadata> {
|
||||
const toolRegistry = await this.config.getToolRegistry();
|
||||
const mcpServers = this.config.getMcpServers() || {};
|
||||
const mcpServers = this.config.getMcpClientManager()?.getMcpServers() || {};
|
||||
const serverStatuses = getAllMCPServerStatuses();
|
||||
const servers = Object.keys(mcpServers).map((serverName) => ({
|
||||
name: serverName,
|
||||
|
||||
@@ -20,9 +20,7 @@ import {
|
||||
GEMINI_DIR,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
type GeminiCLIExtension,
|
||||
type ExtensionLoader,
|
||||
debugLogger,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
import { logger } from '../utils/logger.js';
|
||||
@@ -34,7 +32,6 @@ export async function loadConfig(
|
||||
extensionLoader: ExtensionLoader,
|
||||
taskId: string,
|
||||
): Promise<Config> {
|
||||
const mcpServers = mergeMcpServers(settings, extensionLoader.getExtensions());
|
||||
const workspaceDir = process.cwd();
|
||||
const adcFilePath = process.env['GOOGLE_APPLICATION_CREDENTIALS'];
|
||||
|
||||
@@ -54,7 +51,7 @@ export async function loadConfig(
|
||||
process.env['GEMINI_YOLO_MODE'] === 'true'
|
||||
? ApprovalMode.YOLO
|
||||
: ApprovalMode.DEFAULT,
|
||||
mcpServers,
|
||||
mcpServers: settings.mcpServers,
|
||||
cwd: workspaceDir,
|
||||
telemetry: {
|
||||
enabled: settings.telemetry?.enabled,
|
||||
@@ -120,25 +117,6 @@ export async function loadConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
export function mergeMcpServers(
|
||||
settings: Settings,
|
||||
extensions: GeminiCLIExtension[],
|
||||
) {
|
||||
const mcpServers = { ...(settings.mcpServers || {}) };
|
||||
for (const extension of extensions) {
|
||||
Object.entries(extension.mcpServers || {}).forEach(([key, server]) => {
|
||||
if (mcpServers[key]) {
|
||||
debugLogger.warn(
|
||||
`Skipping extension MCP config for server with key "${key}" as it already exists.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
mcpServers[key] = server;
|
||||
});
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
export function setTargetDir(agentSettings: AgentSettings | undefined): string {
|
||||
const originalCWD = process.cwd();
|
||||
const targetDir =
|
||||
|
||||
@@ -21,27 +21,33 @@ vi.mock('../../config/extensions/storage.js', () => ({
|
||||
},
|
||||
}));
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('@google/gemini-cli-core', () => ({
|
||||
createTransport: vi.fn(),
|
||||
MCPServerStatus: {
|
||||
CONNECTED: 'CONNECTED',
|
||||
CONNECTING: 'CONNECTING',
|
||||
DISCONNECTED: 'DISCONNECTED',
|
||||
},
|
||||
Storage: vi.fn().mockImplementation((_cwd: string) => ({
|
||||
getGlobalSettingsPath: () => '/tmp/gemini/settings.json',
|
||||
getWorkspaceSettingsPath: () => '/tmp/gemini/workspace-settings.json',
|
||||
getProjectTempDir: () => '/test/home/.gemini/tmp/mocked_hash',
|
||||
})),
|
||||
GEMINI_DIR: '.gemini',
|
||||
getErrorMessage: (e: unknown) => (e instanceof Error ? e.message : String(e)),
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
createTransport: vi.fn(),
|
||||
MCPServerStatus: {
|
||||
CONNECTED: 'CONNECTED',
|
||||
CONNECTING: 'CONNECTING',
|
||||
DISCONNECTED: 'DISCONNECTED',
|
||||
},
|
||||
Storage: vi.fn().mockImplementation((_cwd: string) => ({
|
||||
getGlobalSettingsPath: () => '/tmp/gemini/settings.json',
|
||||
getWorkspaceSettingsPath: () => '/tmp/gemini/workspace-settings.json',
|
||||
getProjectTempDir: () => '/test/home/.gemini/tmp/mocked_hash',
|
||||
})),
|
||||
GEMINI_DIR: '.gemini',
|
||||
getErrorMessage: (e: unknown) =>
|
||||
e instanceof Error ? e.message : String(e),
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
|
||||
const mockedGetUserExtensionsDir =
|
||||
|
||||
@@ -1146,9 +1146,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(baseSettings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server1: { url: 'http://localhost:8080' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1']);
|
||||
});
|
||||
|
||||
it('should allow multiple specified MCP servers', async () => {
|
||||
@@ -1162,10 +1160,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(baseSettings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server1: { url: 'http://localhost:8080' },
|
||||
server3: { url: 'http://localhost:8082' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server3']);
|
||||
});
|
||||
|
||||
it('should handle server names that do not exist', async () => {
|
||||
@@ -1179,16 +1174,14 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(baseSettings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server1: { url: 'http://localhost:8080' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server4']);
|
||||
});
|
||||
|
||||
it('should allow no MCP servers if the flag is provided but empty', async () => {
|
||||
process.argv = ['node', 'script.js', '--allowed-mcp-server-names', ''];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(baseSettings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['']);
|
||||
});
|
||||
|
||||
it('should read allowMCPServers from settings', async () => {
|
||||
@@ -1199,10 +1192,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
mcp: { allowed: ['server1', 'server2'] },
|
||||
};
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server1: { url: 'http://localhost:8080' },
|
||||
server2: { url: 'http://localhost:8081' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should read excludeMCPServers from settings', async () => {
|
||||
@@ -1213,9 +1203,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
mcp: { excluded: ['server1', 'server2'] },
|
||||
};
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server3: { url: 'http://localhost:8082' },
|
||||
});
|
||||
expect(config.getBlockedMcpServers()).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should override allowMCPServers with excludeMCPServers if overlapping', async () => {
|
||||
@@ -1229,9 +1217,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
},
|
||||
};
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server2: { url: 'http://localhost:8081' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']);
|
||||
expect(config.getBlockedMcpServers()).toEqual(['server1']);
|
||||
});
|
||||
|
||||
it('should prioritize mcp server flag if set', async () => {
|
||||
@@ -1250,9 +1237,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
},
|
||||
};
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server1: { url: 'http://localhost:8080' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server1']);
|
||||
});
|
||||
|
||||
it('should prioritize CLI flag over both allowed and excluded settings', async () => {
|
||||
@@ -1273,10 +1258,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
|
||||
},
|
||||
};
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getMcpServers()).toEqual({
|
||||
server2: { url: 'http://localhost:8081' },
|
||||
server3: { url: 'http://localhost:8082' },
|
||||
});
|
||||
expect(config.getAllowedMcpServers()).toEqual(['server2', 'server3']);
|
||||
expect(config.getBlockedMcpServers()).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -13,9 +13,7 @@ import process from 'node:process';
|
||||
import { mcpCommand } from '../commands/mcp.js';
|
||||
import type {
|
||||
FileFilteringOptions,
|
||||
MCPServerConfig,
|
||||
OutputFormat,
|
||||
GeminiCLIExtension,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { extensionsCommand } from '../commands/extensions.js';
|
||||
import {
|
||||
@@ -49,9 +47,13 @@ import { appEvents } from '../utils/events.js';
|
||||
import { isWorkspaceTrusted } from './trustedFolders.js';
|
||||
import { createPolicyEngineConfig } from './policy.js';
|
||||
import { ExtensionManager } from './extension-manager.js';
|
||||
import type { ExtensionLoader } from '@google/gemini-cli-core/src/utils/extensionLoader.js';
|
||||
import type {
|
||||
ExtensionEvents,
|
||||
ExtensionLoader,
|
||||
} from '@google/gemini-cli-core/src/utils/extensionLoader.js';
|
||||
import { requestConsentNonInteractive } from './extensions/consent.js';
|
||||
import { promptForSetting } from './extensions/extensionSettings.js';
|
||||
import type { EventEmitter } from 'node:stream';
|
||||
|
||||
export interface CliArgs {
|
||||
query: string | undefined;
|
||||
@@ -429,6 +431,7 @@ export async function loadCliConfig(
|
||||
requestSetting: promptForSetting,
|
||||
workspaceDir: cwd,
|
||||
enabledExtensionOverrides: argv.extensions,
|
||||
eventEmitter: appEvents as EventEmitter<ExtensionEvents>,
|
||||
});
|
||||
await extensionManager.loadExtensions();
|
||||
|
||||
@@ -448,7 +451,6 @@ export async function loadCliConfig(
|
||||
memoryFileFiltering,
|
||||
);
|
||||
|
||||
let mcpServers = mergeMcpServers(settings, extensionManager.getExtensions());
|
||||
const question = argv.promptInteractive || argv.prompt || '';
|
||||
|
||||
// Determine approval mode with backward compatibility
|
||||
@@ -565,37 +567,8 @@ export async function loadCliConfig(
|
||||
|
||||
const excludeTools = mergeExcludeTools(
|
||||
settings,
|
||||
extensionManager.getExtensions(),
|
||||
extraExcludes.length > 0 ? extraExcludes : undefined,
|
||||
);
|
||||
const blockedMcpServers: Array<{ name: string; extensionName: string }> = [];
|
||||
|
||||
if (!argv.allowedMcpServerNames) {
|
||||
if (settings.mcp?.allowed) {
|
||||
mcpServers = allowedMcpServers(
|
||||
mcpServers,
|
||||
settings.mcp.allowed,
|
||||
blockedMcpServers,
|
||||
);
|
||||
}
|
||||
|
||||
if (settings.mcp?.excluded) {
|
||||
const excludedNames = new Set(settings.mcp.excluded.filter(Boolean));
|
||||
if (excludedNames.size > 0) {
|
||||
mcpServers = Object.fromEntries(
|
||||
Object.entries(mcpServers).filter(([key]) => !excludedNames.has(key)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (argv.allowedMcpServerNames) {
|
||||
mcpServers = allowedMcpServers(
|
||||
mcpServers,
|
||||
argv.allowedMcpServerNames,
|
||||
blockedMcpServers,
|
||||
);
|
||||
}
|
||||
|
||||
const useModelRouter = settings.experimental?.useModelRouter ?? true;
|
||||
const defaultModel = useModelRouter
|
||||
@@ -633,7 +606,11 @@ export async function loadCliConfig(
|
||||
toolDiscoveryCommand: settings.tools?.discoveryCommand,
|
||||
toolCallCommand: settings.tools?.callCommand,
|
||||
mcpServerCommand: settings.mcp?.serverCommand,
|
||||
mcpServers,
|
||||
mcpServers: settings.mcpServers,
|
||||
allowedMcpServers: argv.allowedMcpServerNames ?? settings.mcp?.allowed,
|
||||
blockedMcpServers: argv.allowedMcpServerNames
|
||||
? [] // explicitly allowed servers overrides everything
|
||||
: settings.mcp?.excluded,
|
||||
userMemory: memoryContent,
|
||||
geminiMdFileCount: fileCount,
|
||||
geminiMdFilePaths: filePaths,
|
||||
@@ -663,7 +640,6 @@ export async function loadCliConfig(
|
||||
enabledExtensions: argv.extensions,
|
||||
extensionLoader: extensionManager,
|
||||
enableExtensionReloading: settings.experimental?.extensionReloading,
|
||||
blockedMcpServers,
|
||||
noBrowser: !!process.env['NO_BROWSER'],
|
||||
summarizeToolOutput: settings.model?.summarizeToolOutput,
|
||||
ideMode,
|
||||
@@ -699,75 +675,13 @@ export async function loadCliConfig(
|
||||
});
|
||||
}
|
||||
|
||||
function allowedMcpServers(
|
||||
mcpServers: { [x: string]: MCPServerConfig },
|
||||
allowMCPServers: string[],
|
||||
blockedMcpServers: Array<{ name: string; extensionName: string }>,
|
||||
) {
|
||||
const allowedNames = new Set(allowMCPServers.filter(Boolean));
|
||||
if (allowedNames.size > 0) {
|
||||
mcpServers = Object.fromEntries(
|
||||
Object.entries(mcpServers).filter(([key, server]) => {
|
||||
const isAllowed = allowedNames.has(key);
|
||||
if (!isAllowed) {
|
||||
blockedMcpServers.push({
|
||||
name: key,
|
||||
extensionName: server.extension?.name || '',
|
||||
});
|
||||
}
|
||||
return isAllowed;
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
blockedMcpServers.push(
|
||||
...Object.entries(mcpServers).map(([key, server]) => ({
|
||||
name: key,
|
||||
extensionName: server.extension?.name || '',
|
||||
})),
|
||||
);
|
||||
mcpServers = {};
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
function mergeMcpServers(settings: Settings, extensions: GeminiCLIExtension[]) {
|
||||
const mcpServers = { ...(settings.mcpServers || {}) };
|
||||
for (const extension of extensions) {
|
||||
if (!extension.isActive) {
|
||||
continue;
|
||||
}
|
||||
Object.entries(extension.mcpServers || {}).forEach(([key, server]) => {
|
||||
if (mcpServers[key]) {
|
||||
debugLogger.warn(
|
||||
`Skipping extension MCP config for server with key "${key}" as it already exists.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
mcpServers[key] = {
|
||||
...server,
|
||||
extension,
|
||||
};
|
||||
});
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
function mergeExcludeTools(
|
||||
settings: Settings,
|
||||
extensions: GeminiCLIExtension[],
|
||||
extraExcludes?: string[] | undefined,
|
||||
): string[] {
|
||||
const allExcludeTools = new Set([
|
||||
...(settings.tools?.exclude || []),
|
||||
...(extraExcludes || []),
|
||||
]);
|
||||
for (const extension of extensions) {
|
||||
if (!extension.isActive) {
|
||||
continue;
|
||||
}
|
||||
for (const tool of extension.excludeTools || []) {
|
||||
allExcludeTools.add(tool);
|
||||
}
|
||||
}
|
||||
return [...allExcludeTools];
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
ExtensionDisableEvent,
|
||||
ExtensionEnableEvent,
|
||||
ExtensionInstallEvent,
|
||||
ExtensionLoader,
|
||||
ExtensionUninstallEvent,
|
||||
ExtensionUpdateEvent,
|
||||
getErrorMessage,
|
||||
@@ -36,6 +37,7 @@ import {
|
||||
logExtensionInstallEvent,
|
||||
logExtensionUninstall,
|
||||
logExtensionUpdateEvent,
|
||||
type ExtensionEvents,
|
||||
type MCPServerConfig,
|
||||
type ExtensionInstallMetadata,
|
||||
type GeminiCLIExtension,
|
||||
@@ -54,11 +56,7 @@ import {
|
||||
maybePromptForSettings,
|
||||
type ExtensionSetting,
|
||||
} from './extensions/extensionSettings.js';
|
||||
import type {
|
||||
ExtensionEvents,
|
||||
ExtensionLoader,
|
||||
} from '@google/gemini-cli-core/src/utils/extensionLoader.js';
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { EventEmitter } from 'node:stream';
|
||||
|
||||
interface ExtensionManagerParams {
|
||||
enabledExtensionOverrides?: string[];
|
||||
@@ -66,6 +64,7 @@ interface ExtensionManagerParams {
|
||||
requestConsent: (consent: string) => Promise<boolean>;
|
||||
requestSetting: ((setting: ExtensionSetting) => Promise<string>) | null;
|
||||
workspaceDir: string;
|
||||
eventEmitter?: EventEmitter<ExtensionEvents>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -73,7 +72,7 @@ interface ExtensionManagerParams {
|
||||
*
|
||||
* You must call `loadExtensions` prior to calling other methods on this class.
|
||||
*/
|
||||
export class ExtensionManager implements ExtensionLoader {
|
||||
export class ExtensionManager extends ExtensionLoader {
|
||||
private extensionEnablementManager: ExtensionEnablementManager;
|
||||
private settings: Settings;
|
||||
private requestConsent: (consent: string) => Promise<boolean>;
|
||||
@@ -83,9 +82,9 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
private telemetryConfig: Config;
|
||||
private workspaceDir: string;
|
||||
private loadedExtensions: GeminiCLIExtension[] | undefined;
|
||||
private eventEmitter: EventEmitter<ExtensionEvents>;
|
||||
|
||||
constructor(options: ExtensionManagerParams) {
|
||||
super(options.eventEmitter);
|
||||
this.workspaceDir = options.workspaceDir;
|
||||
this.extensionEnablementManager = new ExtensionEnablementManager(
|
||||
options.enabledExtensionOverrides,
|
||||
@@ -102,7 +101,6 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
});
|
||||
this.requestConsent = options.requestConsent;
|
||||
this.requestSetting = options.requestSetting ?? undefined;
|
||||
this.eventEmitter = new EventEmitter();
|
||||
}
|
||||
|
||||
setRequestConsent(
|
||||
@@ -126,10 +124,6 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
return this.loadedExtensions!;
|
||||
}
|
||||
|
||||
extensionEvents(): EventEmitter<ExtensionEvents> {
|
||||
return this.eventEmitter;
|
||||
}
|
||||
|
||||
async installOrUpdateExtension(
|
||||
installMetadata: ExtensionInstallMetadata,
|
||||
previousExtensionConfig?: ExtensionConfig,
|
||||
@@ -303,7 +297,7 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
await fs.promises.writeFile(metadataPath, metadataString);
|
||||
|
||||
// TODO: Gracefully handle this call failing, we should back up the old
|
||||
// extension prior to overwriting it and then restore it.
|
||||
// extension prior to overwriting it and then restore and restart it.
|
||||
extension = await this.loadExtension(destinationPath)!;
|
||||
if (!extension) {
|
||||
throw new Error(`Extension not found`);
|
||||
@@ -320,7 +314,6 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
'success',
|
||||
),
|
||||
);
|
||||
this.eventEmitter.emit('extensionUpdated', { extension });
|
||||
} else {
|
||||
logExtensionInstallEvent(
|
||||
this.telemetryConfig,
|
||||
@@ -332,7 +325,6 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
'success',
|
||||
),
|
||||
);
|
||||
this.eventEmitter.emit('extensionInstalled', { extension });
|
||||
this.enableExtension(newExtensionConfig.name, SettingScope.User);
|
||||
}
|
||||
} finally {
|
||||
@@ -397,7 +389,7 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
if (!extension) {
|
||||
throw new Error(`Extension not found.`);
|
||||
}
|
||||
this.unloadExtension(extension);
|
||||
await this.unloadExtension(extension);
|
||||
const storage = new ExtensionStorage(extension.name);
|
||||
|
||||
await fs.promises.rm(storage.getExtensionDir(), {
|
||||
@@ -419,9 +411,11 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
'success',
|
||||
),
|
||||
);
|
||||
this.eventEmitter.emit('extensionUninstalled', { extension });
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads all installed extensions, should only be called once.
|
||||
*/
|
||||
async loadExtensions(): Promise<GeminiCLIExtension[]> {
|
||||
if (this.loadedExtensions) {
|
||||
throw new Error('Extensions already loaded, only load extensions once.');
|
||||
@@ -433,12 +427,14 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
}
|
||||
for (const subdir of fs.readdirSync(extensionsDir)) {
|
||||
const extensionDir = path.join(extensionsDir, subdir);
|
||||
|
||||
await this.loadExtension(extensionDir);
|
||||
}
|
||||
return this.loadedExtensions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds `extension` to the list of extensions and starts it if appropriate.
|
||||
*/
|
||||
private async loadExtension(
|
||||
extensionDir: string,
|
||||
): Promise<GeminiCLIExtension | null> {
|
||||
@@ -499,8 +495,9 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
),
|
||||
id: getExtensionId(config, installMetadata),
|
||||
};
|
||||
this.eventEmitter.emit('extensionLoaded', { extension });
|
||||
this.getExtensions().push(extension);
|
||||
this.loadedExtensions = [...this.loadedExtensions, extension];
|
||||
|
||||
await this.maybeStartExtension(extension);
|
||||
return extension;
|
||||
} catch (e) {
|
||||
debugLogger.error(
|
||||
@@ -512,11 +509,17 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
}
|
||||
}
|
||||
|
||||
private unloadExtension(extension: GeminiCLIExtension) {
|
||||
/**
|
||||
* Removes `extension` from the list of extensions and stops it if
|
||||
* appropriate.
|
||||
*/
|
||||
private unloadExtension(
|
||||
extension: GeminiCLIExtension,
|
||||
): Promise<void> | undefined {
|
||||
this.loadedExtensions = this.getExtensions().filter(
|
||||
(entry) => extension !== entry,
|
||||
);
|
||||
this.eventEmitter.emit('extensionUnloaded', { extension });
|
||||
return this.maybeStopExtension(extension);
|
||||
}
|
||||
|
||||
loadExtensionConfig(extensionDir: string): ExtensionConfig {
|
||||
@@ -616,14 +619,18 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
const scopePath =
|
||||
scope === SettingScope.Workspace ? this.workspaceDir : os.homedir();
|
||||
this.extensionEnablementManager.disable(name, true, scopePath);
|
||||
extension.isActive = false;
|
||||
await this.maybeStopExtension(extension);
|
||||
logExtensionDisable(
|
||||
this.telemetryConfig,
|
||||
new ExtensionDisableEvent(hashValue(name), extension.id, scope),
|
||||
);
|
||||
extension.isActive = false;
|
||||
this.eventEmitter.emit('extensionDisabled', { extension });
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables an existing extension for a given scope, and starts it if
|
||||
* appropriate.
|
||||
*/
|
||||
async enableExtension(name: string, scope: SettingScope) {
|
||||
if (
|
||||
scope === SettingScope.System ||
|
||||
@@ -645,7 +652,7 @@ export class ExtensionManager implements ExtensionLoader {
|
||||
new ExtensionEnableEvent(hashValue(name), extension.id, scope),
|
||||
);
|
||||
extension.isActive = true;
|
||||
this.eventEmitter.emit('extensionEnabled', { extension });
|
||||
await this.maybeStartExtension(extension);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1666,7 +1666,10 @@ This extension will run the following MCP servers:
|
||||
});
|
||||
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('my-extension', SettingScope.User);
|
||||
await extensionManager.disableExtension(
|
||||
'my-extension',
|
||||
SettingScope.User,
|
||||
);
|
||||
expect(
|
||||
isEnabled({
|
||||
name: 'my-extension',
|
||||
@@ -1683,7 +1686,10 @@ This extension will run the following MCP servers:
|
||||
});
|
||||
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('my-extension', SettingScope.Workspace);
|
||||
await extensionManager.disableExtension(
|
||||
'my-extension',
|
||||
SettingScope.Workspace,
|
||||
);
|
||||
expect(
|
||||
isEnabled({
|
||||
name: 'my-extension',
|
||||
@@ -1706,8 +1712,14 @@ This extension will run the following MCP servers:
|
||||
});
|
||||
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('my-extension', SettingScope.User);
|
||||
extensionManager.disableExtension('my-extension', SettingScope.User);
|
||||
await extensionManager.disableExtension(
|
||||
'my-extension',
|
||||
SettingScope.User,
|
||||
);
|
||||
await extensionManager.disableExtension(
|
||||
'my-extension',
|
||||
SettingScope.User,
|
||||
);
|
||||
expect(
|
||||
isEnabled({
|
||||
name: 'my-extension',
|
||||
@@ -1738,7 +1750,7 @@ This extension will run the following MCP servers:
|
||||
});
|
||||
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
|
||||
expect(mockLogExtensionDisable).toHaveBeenCalled();
|
||||
expect(ExtensionDisableEvent).toHaveBeenCalledWith(
|
||||
@@ -1766,7 +1778,7 @@ This extension will run the following MCP servers:
|
||||
version: '1.0.0',
|
||||
});
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('ext1', SettingScope.User);
|
||||
await extensionManager.disableExtension('ext1', SettingScope.User);
|
||||
let activeExtensions = getActiveExtensions();
|
||||
expect(activeExtensions).toHaveLength(0);
|
||||
|
||||
@@ -1783,7 +1795,7 @@ This extension will run the following MCP servers:
|
||||
version: '1.0.0',
|
||||
});
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
let activeExtensions = getActiveExtensions();
|
||||
expect(activeExtensions).toHaveLength(0);
|
||||
|
||||
@@ -1804,8 +1816,8 @@ This extension will run the following MCP servers:
|
||||
},
|
||||
});
|
||||
await extensionManager.loadExtensions();
|
||||
extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
extensionManager.enableExtension('ext1', SettingScope.Workspace);
|
||||
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
|
||||
await extensionManager.enableExtension('ext1', SettingScope.Workspace);
|
||||
|
||||
expect(mockLogExtensionEnable).toHaveBeenCalled();
|
||||
expect(ExtensionEnableEvent).toHaveBeenCalledWith(
|
||||
|
||||
@@ -190,6 +190,7 @@ describe('gemini.tsx main function', () => {
|
||||
getDebugMode: () => false,
|
||||
getListExtensions: () => false,
|
||||
getMcpServers: () => ({}),
|
||||
getMcpClientManager: vi.fn(),
|
||||
initialize: vi.fn(),
|
||||
getIdeMode: () => false,
|
||||
getExperimentalZedIntegration: () => false,
|
||||
@@ -339,6 +340,7 @@ describe('gemini.tsx main function kitty protocol', () => {
|
||||
getDebugMode: () => false,
|
||||
getListExtensions: () => false,
|
||||
getMcpServers: () => ({}),
|
||||
getMcpClientManager: vi.fn(),
|
||||
initialize: vi.fn(),
|
||||
getIdeMode: () => false,
|
||||
getExperimentalZedIntegration: () => false,
|
||||
|
||||
@@ -159,8 +159,10 @@ describe('McpPromptLoader', () => {
|
||||
|
||||
describe('loadCommands', () => {
|
||||
const mockConfigWithPrompts = {
|
||||
getMcpServers: () => ({
|
||||
'test-server': { httpUrl: 'https://test-server.com' },
|
||||
getMcpClientManager: () => ({
|
||||
getMcpServers: () => ({
|
||||
'test-server': { httpUrl: 'https://test-server.com' },
|
||||
}),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ export class McpPromptLoader implements ICommandLoader {
|
||||
if (!this.config) {
|
||||
return Promise.resolve([]);
|
||||
}
|
||||
const mcpServers = this.config.getMcpServers() || {};
|
||||
const mcpServers = this.config.getMcpClientManager()?.getMcpServers() || {};
|
||||
for (const serverName in mcpServers) {
|
||||
const prompts = getMCPServerPrompts(this.config, serverName) || [];
|
||||
for (const prompt of prompts) {
|
||||
@@ -101,7 +101,8 @@ export class McpPromptLoader implements ICommandLoader {
|
||||
}
|
||||
|
||||
try {
|
||||
const mcpServers = this.config.getMcpServers() || {};
|
||||
const mcpServers =
|
||||
this.config.getMcpClientManager()?.getMcpServers() || {};
|
||||
const mcpServerConfig = mcpServers[serverName];
|
||||
if (!mcpServerConfig) {
|
||||
return {
|
||||
|
||||
@@ -295,6 +295,7 @@ describe('AppContainer State Management', () => {
|
||||
getExtensions: vi.fn().mockReturnValue([]),
|
||||
setRequestConsent: vi.fn(),
|
||||
setRequestSetting: vi.fn(),
|
||||
start: vi.fn(),
|
||||
} as unknown as ExtensionManager);
|
||||
vi.spyOn(mockConfig, 'getExtensionLoader').mockReturnValue(
|
||||
mockExtensionManager,
|
||||
|
||||
@@ -62,6 +62,7 @@ describe('mcpCommand', () => {
|
||||
getBlockedMcpServers: ReturnType<typeof vi.fn>;
|
||||
getPromptRegistry: ReturnType<typeof vi.fn>;
|
||||
getGeminiClient: ReturnType<typeof vi.fn>;
|
||||
getMcpClientManager: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -88,6 +89,10 @@ describe('mcpCommand', () => {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
}),
|
||||
getGeminiClient: vi.fn(),
|
||||
getMcpClientManager: vi.fn().mockImplementation(() => ({
|
||||
getBlockedMcpServers: vi.fn(),
|
||||
getMcpServers: vi.fn(),
|
||||
})),
|
||||
};
|
||||
|
||||
mockContext = createMockCommandContext({
|
||||
|
||||
@@ -43,7 +43,7 @@ const authCommand: SlashCommand = {
|
||||
};
|
||||
}
|
||||
|
||||
const mcpServers = config.getMcpServers() || {};
|
||||
const mcpServers = config.getMcpClientManager()?.getMcpServers() ?? {};
|
||||
|
||||
if (!serverName) {
|
||||
// List servers that support OAuth
|
||||
@@ -119,20 +119,20 @@ const authCommand: SlashCommand = {
|
||||
);
|
||||
|
||||
// Trigger tool re-discovery to pick up authenticated server
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
if (toolRegistry) {
|
||||
const mcpClientManager = config.getMcpClientManager();
|
||||
if (mcpClientManager) {
|
||||
context.ui.addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text: `Re-discovering tools from '${serverName}'...`,
|
||||
text: `Restarting MCP server '${serverName}'...`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
await toolRegistry.discoverToolsForServer(serverName);
|
||||
await mcpClientManager.restartServer(serverName);
|
||||
}
|
||||
// Update the client with the new tools
|
||||
const geminiClient = config.getGeminiClient();
|
||||
if (geminiClient) {
|
||||
if (geminiClient?.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ const authCommand: SlashCommand = {
|
||||
const { config } = context.services;
|
||||
if (!config) return [];
|
||||
|
||||
const mcpServers = config.getMcpServers() || {};
|
||||
const mcpServers = config.getMcpClientManager()?.getMcpServers() || {};
|
||||
return Object.keys(mcpServers).filter((name) =>
|
||||
name.startsWith(partialArg),
|
||||
);
|
||||
@@ -188,9 +188,10 @@ const listAction = async (
|
||||
};
|
||||
}
|
||||
|
||||
const mcpServers = config.getMcpServers() || {};
|
||||
const mcpServers = config.getMcpClientManager()?.getMcpServers() || {};
|
||||
const serverNames = Object.keys(mcpServers);
|
||||
const blockedMcpServers = config.getBlockedMcpServers() || [];
|
||||
const blockedMcpServers =
|
||||
config.getMcpClientManager()?.getBlockedMcpServers() || [];
|
||||
|
||||
const connectingServers = serverNames.filter(
|
||||
(name) => getMCPServerStatus(name) === MCPServerStatus.CONNECTING,
|
||||
@@ -299,12 +300,12 @@ const refreshCommand: SlashCommand = {
|
||||
};
|
||||
}
|
||||
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
if (!toolRegistry) {
|
||||
const mcpClientManager = config.getMcpClientManager();
|
||||
if (!mcpClientManager) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Could not retrieve tool registry.',
|
||||
content: 'Could not retrieve mcp client manager.',
|
||||
};
|
||||
}
|
||||
|
||||
@@ -316,11 +317,11 @@ const refreshCommand: SlashCommand = {
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
await toolRegistry.restartMcpServers();
|
||||
await mcpClientManager.restart();
|
||||
|
||||
// Update the client with the new tools
|
||||
const geminiClient = config.getGeminiClient();
|
||||
if (geminiClient) {
|
||||
if (geminiClient?.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
|
||||
|
||||
@@ -144,7 +144,10 @@ const createMockConfig = (overrides = {}) => ({
|
||||
getDebugMode: vi.fn(() => false),
|
||||
getAccessibility: vi.fn(() => ({})),
|
||||
getMcpServers: vi.fn(() => ({})),
|
||||
getBlockedMcpServers: vi.fn(() => []),
|
||||
getMcpClientManager: vi.fn().mockImplementation(() => ({
|
||||
getBlockedMcpServers: vi.fn(),
|
||||
getMcpServers: vi.fn(),
|
||||
})),
|
||||
...overrides,
|
||||
});
|
||||
|
||||
|
||||
@@ -101,8 +101,10 @@ export const Composer = () => {
|
||||
ideContext={uiState.ideContextState}
|
||||
geminiMdFileCount={uiState.geminiMdFileCount}
|
||||
contextFileNames={contextFileNames}
|
||||
mcpServers={config.getMcpServers()}
|
||||
blockedMcpServers={config.getBlockedMcpServers()}
|
||||
mcpServers={config.getMcpClientManager()?.getMcpServers() ?? {}}
|
||||
blockedMcpServers={
|
||||
config.getMcpClientManager()?.getBlockedMcpServers() ?? []
|
||||
}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
|
||||
@@ -5,15 +5,13 @@
|
||||
*/
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { appEvents } from './../../utils/events.js';
|
||||
import { AppEvent, appEvents } from './../../utils/events.js';
|
||||
import { Box, Text } from 'ink';
|
||||
import { useConfig } from '../contexts/ConfigContext.js';
|
||||
import { type McpClient, MCPServerStatus } from '@google/gemini-cli-core';
|
||||
import { GeminiSpinner } from './GeminiRespondingSpinner.js';
|
||||
import { theme } from '../semantic-colors.js';
|
||||
|
||||
export const ConfigInitDisplay = () => {
|
||||
const config = useConfig();
|
||||
const [message, setMessage] = useState('Initializing...');
|
||||
|
||||
useEffect(() => {
|
||||
@@ -31,11 +29,11 @@ export const ConfigInitDisplay = () => {
|
||||
setMessage(`Connecting to MCP servers... (${connected}/${clients.size})`);
|
||||
};
|
||||
|
||||
appEvents.on('mcp-client-update', onChange);
|
||||
appEvents.on(AppEvent.McpClientUpdate, onChange);
|
||||
return () => {
|
||||
appEvents.off('mcp-client-update', onChange);
|
||||
appEvents.off(AppEvent.McpClientUpdate, onChange);
|
||||
};
|
||||
}, [config]);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Box marginTop={1}>
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ExtensionEvents, McpClient } from '@google/gemini-cli-core';
|
||||
import { EventEmitter } from 'node:events';
|
||||
|
||||
export enum AppEvent {
|
||||
@@ -11,6 +12,15 @@ export enum AppEvent {
|
||||
LogError = 'log-error',
|
||||
OauthDisplayMessage = 'oauth-display-message',
|
||||
Flicker = 'flicker',
|
||||
McpClientUpdate = 'mcp-client-update',
|
||||
}
|
||||
|
||||
export const appEvents = new EventEmitter();
|
||||
export interface AppEvents extends ExtensionEvents {
|
||||
[AppEvent.OpenDebugConsole]: never[];
|
||||
[AppEvent.LogError]: string[];
|
||||
[AppEvent.OauthDisplayMessage]: string[];
|
||||
[AppEvent.Flicker]: never[];
|
||||
[AppEvent.McpClientUpdate]: Array<Map<string, McpClient> | never>;
|
||||
}
|
||||
|
||||
export const appEvents = new EventEmitter<AppEvents>();
|
||||
|
||||
@@ -154,6 +154,7 @@ import {
|
||||
type ExtensionLoader,
|
||||
SimpleExtensionLoader,
|
||||
} from '../utils/extensionLoader.js';
|
||||
import { McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
|
||||
export type { FileFilteringOptions };
|
||||
export {
|
||||
@@ -251,7 +252,8 @@ export interface ConfigParameters {
|
||||
extensionLoader?: ExtensionLoader;
|
||||
enabledExtensions?: string[];
|
||||
enableExtensionReloading?: boolean;
|
||||
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
|
||||
allowedMcpServers?: string[];
|
||||
blockedMcpServers?: string[];
|
||||
noBrowser?: boolean;
|
||||
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
|
||||
folderTrust?: boolean;
|
||||
@@ -293,6 +295,9 @@ export interface ConfigParameters {
|
||||
|
||||
export class Config {
|
||||
private toolRegistry!: ToolRegistry;
|
||||
private mcpClientManager?: McpClientManager;
|
||||
private allowedMcpServers: string[];
|
||||
private blockedMcpServers: string[];
|
||||
private promptRegistry!: PromptRegistry;
|
||||
private agentRegistry!: AgentRegistry;
|
||||
private readonly sessionId: string;
|
||||
@@ -347,10 +352,6 @@ export class Config {
|
||||
private readonly _extensionLoader: ExtensionLoader;
|
||||
private readonly _enabledExtensions: string[];
|
||||
private readonly enableExtensionReloading: boolean;
|
||||
private readonly _blockedMcpServers: Array<{
|
||||
name: string;
|
||||
extensionName: string;
|
||||
}>;
|
||||
fallbackModelHandler?: FallbackModelHandler;
|
||||
private quotaErrorOccurred: boolean = false;
|
||||
private readonly summarizeToolOutput:
|
||||
@@ -417,6 +418,8 @@ export class Config {
|
||||
this.toolCallCommand = params.toolCallCommand;
|
||||
this.mcpServerCommand = params.mcpServerCommand;
|
||||
this.mcpServers = params.mcpServers;
|
||||
this.allowedMcpServers = params.allowedMcpServers ?? [];
|
||||
this.blockedMcpServers = params.blockedMcpServers ?? [];
|
||||
this.userMemory = params.userMemory ?? '';
|
||||
this.geminiMdFileCount = params.geminiMdFileCount ?? 0;
|
||||
this.geminiMdFilePaths = params.geminiMdFilePaths ?? [];
|
||||
@@ -458,7 +461,6 @@ export class Config {
|
||||
this._extensionLoader =
|
||||
params.extensionLoader ?? new SimpleExtensionLoader([]);
|
||||
this._enabledExtensions = params.enabledExtensions ?? [];
|
||||
this._blockedMcpServers = params.blockedMcpServers ?? [];
|
||||
this.noBrowser = params.noBrowser ?? false;
|
||||
this.summarizeToolOutput = params.summarizeToolOutput;
|
||||
this.folderTrust = params.folderTrust ?? false;
|
||||
@@ -572,6 +574,15 @@ export class Config {
|
||||
await this.agentRegistry.initialize();
|
||||
|
||||
this.toolRegistry = await this.createToolRegistry();
|
||||
this.mcpClientManager = new McpClientManager(
|
||||
this.toolRegistry,
|
||||
this,
|
||||
this.eventEmitter,
|
||||
);
|
||||
await Promise.all([
|
||||
await this.mcpClientManager.startConfiguredMcpServers(),
|
||||
await this.getExtensionLoader().start(this),
|
||||
]);
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
}
|
||||
@@ -752,8 +763,23 @@ export class Config {
|
||||
return this.allowedTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* All the excluded tools from static configuration, loaded extensions, or
|
||||
* other sources.
|
||||
*
|
||||
* May change over time.
|
||||
*/
|
||||
getExcludeTools(): string[] | undefined {
|
||||
return this.excludeTools;
|
||||
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
|
||||
for (const extension of this.getExtensionLoader().getExtensions()) {
|
||||
if (!extension.isActive) {
|
||||
continue;
|
||||
}
|
||||
for (const tool of extension.excludeTools || []) {
|
||||
excludeToolsSet.add(tool);
|
||||
}
|
||||
}
|
||||
return [...excludeToolsSet];
|
||||
}
|
||||
|
||||
getToolDiscoveryCommand(): string | undefined {
|
||||
@@ -768,10 +794,27 @@ export class Config {
|
||||
return this.mcpServerCommand;
|
||||
}
|
||||
|
||||
/**
|
||||
* The user configured MCP servers (via gemini settings files).
|
||||
*
|
||||
* Does NOT include mcp servers configured by extensions.
|
||||
*/
|
||||
getMcpServers(): Record<string, MCPServerConfig> | undefined {
|
||||
return this.mcpServers;
|
||||
}
|
||||
|
||||
getMcpClientManager(): McpClientManager | undefined {
|
||||
return this.mcpClientManager;
|
||||
}
|
||||
|
||||
getAllowedMcpServers(): string[] | undefined {
|
||||
return this.allowedMcpServers;
|
||||
}
|
||||
|
||||
getBlockedMcpServers(): string[] | undefined {
|
||||
return this.blockedMcpServers;
|
||||
}
|
||||
|
||||
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
|
||||
this.mcpServers = mcpServers;
|
||||
}
|
||||
@@ -955,10 +998,6 @@ export class Config {
|
||||
return this.enableExtensionReloading;
|
||||
}
|
||||
|
||||
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
|
||||
return this._blockedMcpServers;
|
||||
}
|
||||
|
||||
getNoBrowser(): boolean {
|
||||
return this.noBrowser;
|
||||
}
|
||||
@@ -1155,7 +1194,7 @@ export class Config {
|
||||
}
|
||||
|
||||
async createToolRegistry(): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(this, this.eventEmitter);
|
||||
const registry = new ToolRegistry(this);
|
||||
|
||||
// Set message bus on tool registry before discovery so MCP tools can access it
|
||||
if (this.getEnableMessageBusIntegration()) {
|
||||
@@ -1250,6 +1289,7 @@ export class Config {
|
||||
if (definition) {
|
||||
// We must respect the main allowed/exclude lists for agents too.
|
||||
const excludeTools = this.getExcludeTools() || [];
|
||||
|
||||
const allowedTools = this.getAllowedTools();
|
||||
|
||||
const isExcluded = excludeTools.includes(definition.name);
|
||||
|
||||
@@ -192,11 +192,9 @@ describe('loggers', () => {
|
||||
getFileFilteringRespectGitIgnore: () => true,
|
||||
getFileFilteringAllowBuildArtifacts: () => false,
|
||||
getDebugMode: () => true,
|
||||
getMcpServers: () => ({
|
||||
'test-server': {
|
||||
command: 'test-command',
|
||||
},
|
||||
}),
|
||||
getMcpServers: () => {
|
||||
throw new Error('Should not call');
|
||||
},
|
||||
getQuestion: () => 'test-question',
|
||||
getTargetDir: () => 'target-dir',
|
||||
getProxy: () => 'http://test.proxy.com:8080',
|
||||
@@ -206,6 +204,13 @@ describe('loggers', () => {
|
||||
{ name: 'ext-one', id: 'id-one' },
|
||||
{ name: 'ext-two', id: 'id-two' },
|
||||
] as GeminiCLIExtension[],
|
||||
getMcpClientManager: () => ({
|
||||
getMcpServers: () => ({
|
||||
'test-server': {
|
||||
command: 'test-command',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
const startSessionEvent = new StartSessionEvent(mockConfig);
|
||||
|
||||
@@ -74,7 +74,8 @@ export class StartSessionEvent implements BaseTelemetryEvent {
|
||||
|
||||
constructor(config: Config, toolRegistry?: ToolRegistry) {
|
||||
const generatorConfig = config.getContentGeneratorConfig();
|
||||
const mcpServers = config.getMcpServers();
|
||||
const mcpServers =
|
||||
config.getMcpClientManager()?.getMcpServers() ?? config.getMcpServers();
|
||||
|
||||
let useGemini = false;
|
||||
let useVertex = false;
|
||||
|
||||
@@ -4,86 +4,193 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
type MockedObject,
|
||||
} from 'vitest';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
McpClient: vi.fn(),
|
||||
populateMcpServerCommand: vi.fn(() => ({
|
||||
'test-server': {},
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('McpClientManager', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
let mockedMcpClient: MockedObject<McpClient>;
|
||||
let mockConfig: MockedObject<Config>;
|
||||
|
||||
it('should discover tools from all servers', async () => {
|
||||
const mockedMcpClient = {
|
||||
beforeEach(() => {
|
||||
mockedMcpClient = vi.mockObject({
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => true,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
getServerConfig: vi.fn(),
|
||||
} as unknown as McpClient);
|
||||
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
|
||||
mockConfig = vi.mockObject({
|
||||
isTrustedFolder: vi.fn().mockReturnValue(true),
|
||||
getMcpServers: vi.fn().mockReturnValue({}),
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getAllowedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(''),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
isInitialized: vi.fn(),
|
||||
}),
|
||||
} as unknown as Config);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should discover tools from all configured', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not discover tools if folder is not trusted', async () => {
|
||||
const mockedMcpClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => false,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.isTrustedFolder.mockReturnValue(false);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not start blocked servers', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should only start allowed servers if allow list is not empty', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
'another-server': {},
|
||||
});
|
||||
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should start servers from extensions', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
isActive: true,
|
||||
version: '1.0.0',
|
||||
path: '/some-path',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not start servers from disabled extensions', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
isActive: false,
|
||||
version: '1.0.0',
|
||||
path: '/some-path',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should add blocked servers to the blockedMcpServers list', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(manager.getBlockedMcpServers()).toEqual([
|
||||
{ name: 'test-server', extensionName: '' },
|
||||
]);
|
||||
});
|
||||
|
||||
describe('restart', () => {
|
||||
it('should restart all running servers', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue({});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
await manager.restart();
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('restartServer', () => {
|
||||
it('should restart the specified server', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue({});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
|
||||
await manager.restartServer('test-server');
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw an error if the server does not exist', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await expect(manager.restartServer('non-existent')).rejects.toThrow(
|
||||
'No MCP server registered with the name "non-existent"',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,10 @@ export class McpClientManager {
|
||||
private discoveryPromise: Promise<void> | undefined;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
private readonly blockedMcpServers: Array<{
|
||||
name: string;
|
||||
extensionName: string;
|
||||
}> = [];
|
||||
|
||||
constructor(
|
||||
toolRegistry: ToolRegistry,
|
||||
@@ -42,19 +46,10 @@ export class McpClientManager {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.cliConfig = cliConfig;
|
||||
this.eventEmitter = eventEmitter;
|
||||
if (this.cliConfig.getEnableExtensionReloading()) {
|
||||
this.cliConfig
|
||||
.getExtensionLoader()
|
||||
.extensionEvents()
|
||||
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionDisabled', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
)
|
||||
.on('extensionUnloaded', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
getBlockedMcpServers() {
|
||||
return this.blockedMcpServers;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -64,21 +59,13 @@ export class McpClientManager {
|
||||
* - Disconnects all MCP clients from their servers.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async unloadExtension(extension: GeminiCLIExtension) {
|
||||
async stopExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Unloading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.keys(extension.mcpServers ?? {}).map((name) => {
|
||||
const newMcpServers = {
|
||||
...this.cliConfig.getMcpServers(),
|
||||
};
|
||||
delete newMcpServers[name];
|
||||
this.cliConfig.setMcpServers(newMcpServers);
|
||||
return this.disconnectClient(name);
|
||||
}),
|
||||
Object.keys(extension.mcpServers ?? {}).map(
|
||||
this.disconnectClient.bind(this),
|
||||
),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -88,20 +75,36 @@ export class McpClientManager {
|
||||
* - Connects MCP clients to each server and discovers their tools.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async loadExtension(extension: GeminiCLIExtension) {
|
||||
async startExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Loading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
|
||||
this.cliConfig.setMcpServers({
|
||||
...this.cliConfig.getMcpServers(),
|
||||
[name]: config,
|
||||
});
|
||||
return this.discoverMcpTools(name, config);
|
||||
}),
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) =>
|
||||
this.maybeDiscoverMcpServer(name, {
|
||||
...config,
|
||||
extension,
|
||||
}),
|
||||
),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
private isAllowedMcpServer(name: string) {
|
||||
const allowedNames = this.cliConfig.getAllowedMcpServers();
|
||||
if (
|
||||
allowedNames &&
|
||||
allowedNames.length > 0 &&
|
||||
allowedNames.indexOf(name) === -1
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const blockedNames = this.cliConfig.getBlockedMcpServers();
|
||||
if (
|
||||
blockedNames &&
|
||||
blockedNames.length > 0 &&
|
||||
blockedNames.indexOf(name) !== -1
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string) {
|
||||
@@ -115,36 +118,68 @@ export class McpClientManager {
|
||||
debugLogger.warn(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
} finally {
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
const geminiClient = this.cliConfig.getGeminiClient();
|
||||
if (geminiClient.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
discoverMcpTools(
|
||||
maybeDiscoverMcpServer(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.isAllowedMcpServer(name)) {
|
||||
if (!this.blockedMcpServers.find((s) => s.name === name)) {
|
||||
this.blockedMcpServers?.push({
|
||||
name,
|
||||
extensionName: config.extension?.name ?? '',
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
if (config.extension && !config.extension.isActive) {
|
||||
return;
|
||||
}
|
||||
const existing = this.clients.get(name);
|
||||
if (existing && existing.getServerConfig().extension !== config.extension) {
|
||||
const extensionText = config.extension
|
||||
? ` from extension "${config.extension.name}"`
|
||||
: '';
|
||||
debugLogger.warn(
|
||||
`Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
await this.disconnectClient(name);
|
||||
if (existing) {
|
||||
await existing.disconnect();
|
||||
}
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const client =
|
||||
existing ??
|
||||
new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
if (!existing) {
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
}
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(this.cliConfig);
|
||||
@@ -161,6 +196,12 @@ export class McpClientManager {
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
const geminiClient = this.cliConfig.getGeminiClient();
|
||||
if (geminiClient.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
resolve();
|
||||
}
|
||||
})();
|
||||
@@ -174,6 +215,7 @@ export class McpClientManager {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
this.discoveryPromise = currentDiscoveryPromise;
|
||||
}
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const currentPromise = this.discoveryPromise;
|
||||
currentPromise.then((_) => {
|
||||
// If we are the last recorded discoveryPromise, then we are done, reset
|
||||
@@ -187,15 +229,21 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates the tool discovery process for all configured MCP servers.
|
||||
* Initiates the tool discovery process for all configured MCP servers (via
|
||||
* gemini settings or command line arguments).
|
||||
*
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*
|
||||
* For any server which is already connected, it will first be disconnected.
|
||||
*
|
||||
* This does NOT load extension MCP servers - this happens when the
|
||||
* ExtensionLoader explicitly calls `loadExtension`.
|
||||
*/
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
async startConfiguredMcpServers(): Promise<void> {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
await this.stop();
|
||||
|
||||
const servers = populateMcpServerCommand(
|
||||
this.cliConfig.getMcpServers() || {},
|
||||
@@ -204,12 +252,40 @@ export class McpClientManager {
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(async ([name, config]) =>
|
||||
this.discoverMcpTools(name, config),
|
||||
Object.entries(servers).map(([name, config]) =>
|
||||
this.maybeDiscoverMcpServer(name, config),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all active MCP Clients.
|
||||
*/
|
||||
async restart(): Promise<void> {
|
||||
await Promise.all(
|
||||
Array.from(this.clients.entries()).map(async ([name, client]) => {
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Restart a single MCP server by name.
|
||||
*/
|
||||
async restartServer(name: string) {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops all running local MCP servers and closes all client connections.
|
||||
* This is the cleanup method to be called on application exit.
|
||||
@@ -236,4 +312,15 @@ export class McpClientManager {
|
||||
getDiscoveryState(): MCPDiscoveryState {
|
||||
return this.discoveryState;
|
||||
}
|
||||
|
||||
/**
|
||||
* All of the MCP server configurations currently loaded.
|
||||
*/
|
||||
getMcpServers(): Record<string, MCPServerConfig> {
|
||||
const mcpServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [name, client] of this.clients.entries()) {
|
||||
mcpServers[name] = client.getServerConfig();
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +303,71 @@ describe('mcp-client', () => {
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should remove tools and prompts on disconnect', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
close: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: {}, prompts: {} }),
|
||||
request: vi.fn().mockResolvedValue({
|
||||
prompts: [{ id: 'prompt1', text: 'a prompt' }],
|
||||
}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
unregisterTool: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const mockedPromptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
unregisterPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
mockedPromptRegistry,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
|
||||
|
||||
await client.disconnect();
|
||||
|
||||
expect(mockedClient.close).toHaveBeenCalledOnce();
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce();
|
||||
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
|
||||
@@ -161,6 +161,7 @@ export class McpClient {
|
||||
return;
|
||||
}
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||
const client = this.client;
|
||||
this.client = undefined;
|
||||
@@ -208,6 +209,10 @@ export class McpClient {
|
||||
this.assertConnected();
|
||||
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
|
||||
}
|
||||
|
||||
getServerConfig(): MCPServerConfig {
|
||||
return this.serverConfig;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -19,20 +19,10 @@ import { spawn } from 'node:child_process';
|
||||
|
||||
import fs from 'node:fs';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock node:child_process
|
||||
vi.mock('node:child_process', async () => {
|
||||
const actual = await vi.importActual('node:child_process');
|
||||
@@ -401,27 +391,6 @@ describe('ToolRegistry', () => {
|
||||
expect(result.llmContent).toContain('Stderr: Something went wrong');
|
||||
expect(result.llmContent).toContain('Exit Code: 1');
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
const discoverSpy = vi.spyOn(
|
||||
McpClientManager.prototype,
|
||||
'discoverAllMcpTools',
|
||||
);
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(discoverSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('DiscoveredToolInvocation', () => {
|
||||
|
||||
@@ -14,13 +14,10 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { connectAndDiscover } from './mcp-client.js';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
@@ -176,12 +173,10 @@ export class ToolRegistry {
|
||||
// The tools keyed by tool name as seen by the LLM.
|
||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private mcpClientManager: McpClientManager;
|
||||
private messageBus?: MessageBus;
|
||||
|
||||
constructor(config: Config, eventEmitter?: EventEmitter) {
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
|
||||
}
|
||||
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
@@ -238,64 +233,7 @@ export class ToolRegistry {
|
||||
async discoverAllTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
this.removeDiscoveredTools();
|
||||
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from project (if available and configured).
|
||||
* Can be called multiple times to update discovered tools.
|
||||
* This will NOT discover tools from the command line, only from MCP servers.
|
||||
*/
|
||||
async discoverMcpTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
this.removeDiscoveredTools();
|
||||
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all MCP servers and re-discovers tools.
|
||||
*/
|
||||
async restartMcpServers(): Promise<void> {
|
||||
await this.discoverMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or re-discover tools for a single MCP server.
|
||||
* @param serverName - The name of the server to discover tools from.
|
||||
*/
|
||||
async discoverToolsForServer(serverName: string): Promise<void> {
|
||||
// Remove any previously discovered tools from this server
|
||||
for (const [name, tool] of this.tools.entries()) {
|
||||
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||
this.tools.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
this.config.getPromptRegistry().removePromptsByServer(serverName);
|
||||
|
||||
const mcpServers = this.config.getMcpServers() ?? {};
|
||||
const serverConfig = mcpServers[serverName];
|
||||
if (serverConfig) {
|
||||
await connectAndDiscover(
|
||||
serverName,
|
||||
serverConfig,
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
this.config,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||
|
||||
108
packages/core/src/utils/extensionLoader.test.ts
Normal file
108
packages/core/src/utils/extensionLoader.test.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
|
||||
describe('SimpleExtensionLoader', () => {
|
||||
let mockConfig: Config;
|
||||
let extensionReloadingEnabled: boolean;
|
||||
let mockMcpClientManager: McpClientManager;
|
||||
const activeExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: true,
|
||||
version: '1.0.0',
|
||||
path: '/path/to/extension',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
};
|
||||
const inactiveExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: false,
|
||||
version: '1.0.0',
|
||||
path: '/path/to/extension',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockMcpClientManager = {
|
||||
startExtension: vi.fn(),
|
||||
stopExtension: vi.fn(),
|
||||
} as unknown as McpClientManager;
|
||||
extensionReloadingEnabled = false;
|
||||
mockConfig = {
|
||||
getMcpClientManager: () => mockMcpClientManager,
|
||||
getEnableExtensionReloading: () => extensionReloadingEnabled,
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should start active extensions', async () => {
|
||||
const loader = new SimpleExtensionLoader([activeExtension]);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockMcpClientManager.startExtension).toHaveBeenCalledExactlyOnceWith(
|
||||
activeExtension,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not start inactive extensions', async () => {
|
||||
const loader = new SimpleExtensionLoader([inactiveExtension]);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('interactive extension loading and unloading', () => {
|
||||
it('should not call `start` or `stop` if the loader is not already started', async () => {
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.loadExtension(activeExtension);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
await loader.unloadExtension(activeExtension);
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should start extensions that were explicitly loaded prior to initializing the loader', async () => {
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.loadExtension(activeExtension);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
await loader.start(mockConfig);
|
||||
expect(
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
});
|
||||
|
||||
it.each([true, false])(
|
||||
'should only call `start` and `stop` if extension reloading is enabled ($i)',
|
||||
async (reloadingEnabled) => {
|
||||
extensionReloadingEnabled = reloadingEnabled;
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
await loader.loadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
} else {
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
}
|
||||
await loader.unloadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.stopExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
} else {
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -4,45 +4,194 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { GeminiCLIExtension } from '../config/config.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||
|
||||
export interface ExtensionLoader {
|
||||
getExtensions(): GeminiCLIExtension[];
|
||||
export abstract class ExtensionLoader {
|
||||
// Assigned in `start`.
|
||||
protected config: Config | undefined;
|
||||
|
||||
extensionEvents(): EventEmitter<ExtensionEvents>;
|
||||
// Used to track the count of currently starting and stopping extensions and
|
||||
// fire appropriate events.
|
||||
protected startingCount: number = 0;
|
||||
protected startCompletedCount: number = 0;
|
||||
protected stoppingCount: number = 0;
|
||||
protected stopCompletedCount: number = 0;
|
||||
|
||||
constructor(private readonly eventEmitter?: EventEmitter<ExtensionEvents>) {}
|
||||
|
||||
/**
|
||||
* All currently known extensions, both active and inactive.
|
||||
*/
|
||||
abstract getExtensions(): GeminiCLIExtension[];
|
||||
|
||||
/**
|
||||
* Fully initializes all active extensions.
|
||||
*
|
||||
* Called within `Config.initialize`, which must already have an
|
||||
* McpClientManager, PromptRegistry, and GeminiChat set up.
|
||||
*/
|
||||
async start(config: Config): Promise<void> {
|
||||
if (!this.config) {
|
||||
this.config = config;
|
||||
} else {
|
||||
throw new Error('Already started, you may only call `start` once.');
|
||||
}
|
||||
await Promise.all(
|
||||
this.getExtensions()
|
||||
.filter((e) => e.isActive)
|
||||
.map(this.startExtension.bind(this)),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unconditionally starts an `extension` and loads all its MCP servers,
|
||||
* context, custom commands, etc. Assumes that `start` has already been called
|
||||
* and we have a Config object.
|
||||
*
|
||||
* This should typically only be called from `start`, most other calls should
|
||||
* go through `maybeStartExtension` which will only start the extension if
|
||||
* extension reloading is enabled and the `config` object is initialized.
|
||||
*/
|
||||
protected async startExtension(extension: GeminiCLIExtension) {
|
||||
if (!this.config) {
|
||||
throw new Error('Cannot call `startExtension` prior to calling `start`.');
|
||||
}
|
||||
this.startingCount++;
|
||||
this.eventEmitter?.emit('extensionsStarting', {
|
||||
total: this.startingCount,
|
||||
completed: this.startCompletedCount,
|
||||
});
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.startExtension(extension);
|
||||
// TODO: Move all extension features here, including at least:
|
||||
// - context file loading
|
||||
// - custom command loading
|
||||
// - excluded tool configuration
|
||||
} finally {
|
||||
this.startCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStarting', {
|
||||
total: this.startingCount,
|
||||
completed: this.startCompletedCount,
|
||||
});
|
||||
if (this.startingCount === this.startCompletedCount) {
|
||||
this.startingCount = 0;
|
||||
this.startCompletedCount = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* If extension reloading is enabled and `start` has already been called,
|
||||
* then calls `startExtension` to include all extension features into the
|
||||
* program.
|
||||
*/
|
||||
protected maybeStartExtension(
|
||||
extension: GeminiCLIExtension,
|
||||
): Promise<void> | undefined {
|
||||
if (this.config && this.config.getEnableExtensionReloading()) {
|
||||
return this.startExtension(extension);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unconditionally stops an `extension` and unloads all its MCP servers,
|
||||
* context, custom commands, etc. Assumes that `start` has already been called
|
||||
* and we have a Config object.
|
||||
*
|
||||
* Most calls should go through `maybeStopExtension` which will only stop the
|
||||
* extension if extension reloading is enabled and the `config` object is
|
||||
* initialized.
|
||||
*/
|
||||
protected async stopExtension(extension: GeminiCLIExtension) {
|
||||
if (!this.config) {
|
||||
throw new Error('Cannot call `stopExtension` prior to calling `start`.');
|
||||
}
|
||||
this.stoppingCount++;
|
||||
this.eventEmitter?.emit('extensionsStopping', {
|
||||
total: this.stoppingCount,
|
||||
completed: this.stopCompletedCount,
|
||||
});
|
||||
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.stopExtension(extension);
|
||||
// TODO: Remove all extension features here, including at least:
|
||||
// - context files
|
||||
// - custom commands
|
||||
// - excluded tools
|
||||
} finally {
|
||||
this.stopCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStopping', {
|
||||
total: this.stoppingCount,
|
||||
completed: this.stopCompletedCount,
|
||||
});
|
||||
if (this.stoppingCount === this.stopCompletedCount) {
|
||||
this.stoppingCount = 0;
|
||||
this.stopCompletedCount = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* If extension reloading is enabled and `start` has already been called,
|
||||
* then this also performs all necessary steps to remove all extension
|
||||
* features from the rest of the system.
|
||||
*/
|
||||
protected maybeStopExtension(
|
||||
extension: GeminiCLIExtension,
|
||||
): Promise<void> | undefined {
|
||||
if (this.config && this.config.getEnableExtensionReloading()) {
|
||||
return this.stopExtension(extension);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
export interface ExtensionEvents {
|
||||
extensionEnabled: ExtensionEnableEvent[];
|
||||
extensionDisabled: ExtensionDisableEvent[];
|
||||
extensionLoaded: ExtensionLoadEvent[];
|
||||
extensionUnloaded: ExtensionUnloadEvent[];
|
||||
extensionInstalled: ExtensionInstallEvent[];
|
||||
extensionUninstalled: ExtensionUninstallEvent[];
|
||||
extensionUpdated: ExtensionUpdateEvent[];
|
||||
extensionsStarting: ExtensionsStartingEvent[];
|
||||
extensionsStopping: ExtensionsStoppingEvent[];
|
||||
}
|
||||
|
||||
interface BaseExtensionEvent {
|
||||
extension: GeminiCLIExtension;
|
||||
export interface ExtensionsStartingEvent {
|
||||
total: number;
|
||||
completed: number;
|
||||
}
|
||||
export type ExtensionDisableEvent = BaseExtensionEvent;
|
||||
export type ExtensionEnableEvent = BaseExtensionEvent;
|
||||
export type ExtensionInstallEvent = BaseExtensionEvent;
|
||||
export type ExtensionLoadEvent = BaseExtensionEvent;
|
||||
export type ExtensionUnloadEvent = BaseExtensionEvent;
|
||||
export type ExtensionUninstallEvent = BaseExtensionEvent;
|
||||
export type ExtensionUpdateEvent = BaseExtensionEvent;
|
||||
|
||||
export class SimpleExtensionLoader implements ExtensionLoader {
|
||||
private _eventEmitter = new EventEmitter<ExtensionEvents>();
|
||||
constructor(private readonly extensions: GeminiCLIExtension[]) {}
|
||||
export interface ExtensionsStoppingEvent {
|
||||
total: number;
|
||||
completed: number;
|
||||
}
|
||||
|
||||
extensionEvents(): EventEmitter<ExtensionEvents> {
|
||||
return this._eventEmitter;
|
||||
export class SimpleExtensionLoader extends ExtensionLoader {
|
||||
constructor(
|
||||
protected readonly extensions: GeminiCLIExtension[],
|
||||
eventEmitter?: EventEmitter<ExtensionEvents>,
|
||||
) {
|
||||
super(eventEmitter);
|
||||
}
|
||||
|
||||
getExtensions(): GeminiCLIExtension[] {
|
||||
return this.extensions;
|
||||
}
|
||||
|
||||
/// Adds `extension` to the list of extensions and calls
|
||||
/// `maybeStartExtension`.
|
||||
///
|
||||
/// This is intended for dynamic loading of extensions after calling `start`.
|
||||
async loadExtension(extension: GeminiCLIExtension) {
|
||||
this.extensions.push(extension);
|
||||
await this.maybeStartExtension(extension);
|
||||
}
|
||||
|
||||
/// Removes `extension` from the list of extensions and calls
|
||||
// `maybeStopExtension` if it was found.
|
||||
///
|
||||
/// This is intended for dynamic unloading of extensions after calling `start`.
|
||||
async unloadExtension(extension: GeminiCLIExtension) {
|
||||
const index = this.extensions.indexOf(extension);
|
||||
if (index === -1) return;
|
||||
this.extensions.splice(index, 1);
|
||||
await this.maybeStopExtension(extension);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user