mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 22:48:03 +00:00
[feat] Extension Reloading - respect updates to exclude tools (#12728)
This commit is contained in:
@@ -752,11 +752,11 @@ describe('mergeMcpServers', () => {
|
||||
});
|
||||
|
||||
describe('mergeExcludeTools', () => {
|
||||
const defaultExcludes = [
|
||||
const defaultExcludes = new Set([
|
||||
SHELL_TOOL_NAME,
|
||||
EDIT_TOOL_NAME,
|
||||
WRITE_FILE_TOOL_NAME,
|
||||
];
|
||||
]);
|
||||
const originalIsTTY = process.stdin.isTTY;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -799,7 +799,7 @@ describe('mergeExcludeTools', () => {
|
||||
argv,
|
||||
);
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
|
||||
new Set(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(5);
|
||||
});
|
||||
@@ -821,7 +821,7 @@ describe('mergeExcludeTools', () => {
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3']),
|
||||
new Set(['tool1', 'tool2', 'tool3']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(3);
|
||||
});
|
||||
@@ -852,7 +852,7 @@ describe('mergeExcludeTools', () => {
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4']),
|
||||
new Set(['tool1', 'tool2', 'tool3', 'tool4']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(4);
|
||||
});
|
||||
@@ -863,7 +863,7 @@ describe('mergeExcludeTools', () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getExcludeTools()).toEqual([]);
|
||||
expect(config.getExcludeTools()).toEqual(new Set([]));
|
||||
});
|
||||
|
||||
it('should return default excludes when no excludeTools are specified and it is not interactive', async () => {
|
||||
@@ -881,9 +881,7 @@ describe('mergeExcludeTools', () => {
|
||||
const settings: Settings = { tools: { exclude: ['tool1', 'tool2'] } };
|
||||
vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]);
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
|
||||
expect(config.getExcludeTools()).toHaveLength(2);
|
||||
});
|
||||
|
||||
@@ -903,9 +901,7 @@ describe('mergeExcludeTools', () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
|
||||
expect(config.getExcludeTools()).toHaveLength(2);
|
||||
});
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ describe('handleAtCommand', () => {
|
||||
getToolRegistry,
|
||||
getTargetDir: () => testRootDir,
|
||||
isSandboxed: () => false,
|
||||
getExcludeTools: vi.fn(),
|
||||
getFileService: () => new FileDiscoveryService(testRootDir),
|
||||
getFileFilteringRespectGitIgnore: () => true,
|
||||
getFileFilteringRespectGeminiIgnore: () => true,
|
||||
|
||||
@@ -872,27 +872,6 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(wasShellToolRegistered).toBe(true);
|
||||
});
|
||||
|
||||
it('should not register a tool if excludeTools contains the non-minified class name', async () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
coreTools: undefined, // all tools enabled by default
|
||||
excludeTools: ['ShellTool'],
|
||||
};
|
||||
const config = new Config(params);
|
||||
await config.initialize();
|
||||
|
||||
const registerToolMock = (
|
||||
(await vi.importMock('../tools/tool-registry')) as {
|
||||
ToolRegistry: { prototype: { registerTool: Mock } };
|
||||
}
|
||||
).ToolRegistry.prototype.registerTool;
|
||||
|
||||
const wasShellToolRegistered = (
|
||||
registerToolMock as Mock
|
||||
).mock.calls.some((call) => call[0] instanceof vi.mocked(ShellTool));
|
||||
expect(wasShellToolRegistered).toBe(false);
|
||||
});
|
||||
|
||||
it('should register a tool if coreTools contains an argument-specific pattern with the non-minified class name', async () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
|
||||
@@ -826,7 +826,7 @@ export class Config {
|
||||
*
|
||||
* May change over time.
|
||||
*/
|
||||
getExcludeTools(): string[] | undefined {
|
||||
getExcludeTools(): Set<string> | undefined {
|
||||
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
|
||||
for (const extension of this.getExtensionLoader().getExtensions()) {
|
||||
if (!extension.isActive) {
|
||||
@@ -836,7 +836,7 @@ export class Config {
|
||||
excludeToolsSet.add(tool);
|
||||
}
|
||||
}
|
||||
return [...excludeToolsSet];
|
||||
return excludeToolsSet;
|
||||
}
|
||||
|
||||
getToolDiscoveryCommand(): string | undefined {
|
||||
@@ -1282,7 +1282,6 @@ export class Config {
|
||||
const className = ToolClass.name;
|
||||
const toolName = ToolClass.Name || className;
|
||||
const coreTools = this.getCoreTools();
|
||||
const excludeTools = this.getExcludeTools() || [];
|
||||
// On some platforms, the className can be minified to _ClassName.
|
||||
const normalizedClassName = className.replace(/^_+/, '');
|
||||
|
||||
@@ -1297,14 +1296,6 @@ export class Config {
|
||||
);
|
||||
}
|
||||
|
||||
const isExcluded = excludeTools.some(
|
||||
(tool) => tool === toolName || tool === normalizedClassName,
|
||||
);
|
||||
|
||||
if (isExcluded) {
|
||||
isEnabled = false;
|
||||
}
|
||||
|
||||
if (isEnabled) {
|
||||
// Pass message bus to tools when feature flag is enabled
|
||||
// This first implementation is only focused on the general case of
|
||||
@@ -1363,15 +1354,12 @@ export class Config {
|
||||
);
|
||||
if (definition) {
|
||||
// We must respect the main allowed/exclude lists for agents too.
|
||||
const excludeTools = this.getExcludeTools() || [];
|
||||
|
||||
const allowedTools = this.getAllowedTools();
|
||||
|
||||
const isExcluded = excludeTools.includes(definition.name);
|
||||
const isAllowed =
|
||||
!allowedTools || allowedTools.includes(definition.name);
|
||||
|
||||
if (isAllowed && !isExcluded) {
|
||||
if (isAllowed) {
|
||||
const messageBusEnabled = this.getEnableMessageBusIntegration();
|
||||
const wrapper = new SubagentToolWrapper(
|
||||
definition,
|
||||
|
||||
@@ -243,6 +243,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
);
|
||||
}
|
||||
|
||||
getFullyQualifiedPrefix(): string {
|
||||
return `${this.serverName}__`;
|
||||
}
|
||||
|
||||
asFullyQualifiedTool(): DiscoveredMCPTool {
|
||||
return new DiscoveredMCPTool(
|
||||
this.mcpTool,
|
||||
@@ -251,7 +255,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.description,
|
||||
this.parameterSchema,
|
||||
this.trust,
|
||||
`${this.serverName}__${this.serverToolName}`,
|
||||
`${this.getFullyQualifiedPrefix()}${this.serverToolName}`,
|
||||
this.cliConfig,
|
||||
this.extensionName,
|
||||
this.extensionId,
|
||||
|
||||
@@ -82,7 +82,7 @@ describe('ShellTool', () => {
|
||||
getAllowedTools: vi.fn().mockReturnValue([]),
|
||||
getApprovalMode: vi.fn().mockReturnValue('strict'),
|
||||
getCoreTools: vi.fn().mockReturnValue([]),
|
||||
getExcludeTools: vi.fn().mockReturnValue([]),
|
||||
getExcludeTools: vi.fn().mockReturnValue(new Set([])),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
getTargetDir: vi.fn().mockReturnValue(tempRootDir),
|
||||
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import type { Mocked } from 'vitest';
|
||||
import type { Mocked, MockInstance } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { ConfigParameters } from '../config/config.js';
|
||||
import { Config } from '../config/config.js';
|
||||
@@ -109,6 +109,9 @@ describe('ToolRegistry', () => {
|
||||
let config: Config;
|
||||
let toolRegistry: ToolRegistry;
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetExcludedTools: MockInstance<
|
||||
typeof Config.prototype.getExcludeTools
|
||||
>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.mocked(fs.existsSync).mockReturnValue(true);
|
||||
@@ -132,6 +135,7 @@ describe('ToolRegistry', () => {
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
mockConfigGetExcludedTools = vi.spyOn(config, 'getExcludeTools');
|
||||
vi.spyOn(config, 'getMcpServers');
|
||||
vi.spyOn(config, 'getMcpServerCommand');
|
||||
vi.spyOn(config, 'getPromptRegistry').mockReturnValue({
|
||||
@@ -152,6 +156,75 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('excluded tools', () => {
|
||||
const simpleTool = new MockTool({
|
||||
name: 'tool-a',
|
||||
displayName: 'Tool a',
|
||||
});
|
||||
const excludedTool = new ExcludedMockTool({
|
||||
name: 'excluded-tool-class',
|
||||
displayName: 'Excluded Tool Class',
|
||||
});
|
||||
const mockCallable = {} as CallableTool;
|
||||
const mcpTool = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
'mcp-server',
|
||||
'excluded-mcp-tool',
|
||||
'description',
|
||||
{},
|
||||
);
|
||||
const allowedTool = new MockTool({
|
||||
name: 'allowed-tool',
|
||||
displayName: 'Allowed Tool',
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
name: 'should match simple names',
|
||||
tools: [simpleTool],
|
||||
excludedTools: ['tool-a'],
|
||||
},
|
||||
{
|
||||
name: 'should match simple MCP tool names, when qualified or unqualified',
|
||||
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
|
||||
excludedTools: [mcpTool.name],
|
||||
},
|
||||
{
|
||||
name: 'should match qualified MCP tool names when qualified or unqualified',
|
||||
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
|
||||
excludedTools: [`${mcpTool.getFullyQualifiedPrefix()}${mcpTool.name}`],
|
||||
},
|
||||
{
|
||||
name: 'should match class names',
|
||||
tools: [excludedTool],
|
||||
excludedTools: ['ExcludedMockTool'],
|
||||
},
|
||||
])('$name', ({ tools, excludedTools }) => {
|
||||
toolRegistry.registerTool(allowedTool);
|
||||
for (const tool of tools) {
|
||||
toolRegistry.registerTool(tool);
|
||||
}
|
||||
mockConfigGetExcludedTools.mockReturnValue(new Set(excludedTools));
|
||||
|
||||
expect(toolRegistry.getAllTools()).toEqual([allowedTool]);
|
||||
expect(toolRegistry.getAllToolNames()).toEqual([allowedTool.name]);
|
||||
expect(toolRegistry.getFunctionDeclarations()).toEqual(
|
||||
toolRegistry.getFunctionDeclarationsFiltered([allowedTool.name]),
|
||||
);
|
||||
for (const tool of tools) {
|
||||
expect(toolRegistry.getTool(tool.name)).toBeUndefined();
|
||||
expect(
|
||||
toolRegistry.getFunctionDeclarationsFiltered([tool.name]),
|
||||
).toHaveLength(0);
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
expect(toolRegistry.getToolsByServer(tool.serverName)).toHaveLength(
|
||||
0,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllTools', () => {
|
||||
it('should return all registered tools sorted alphabetically by displayName', () => {
|
||||
// Register tools with displayNames in non-alphabetical order
|
||||
@@ -521,3 +594,12 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Used for tests that exclude by class name.
|
||||
*/
|
||||
class ExcludedMockTool extends MockTool {
|
||||
constructor(options: ConstructorParameters<typeof MockTool>[0]) {
|
||||
super(options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +189,9 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
|
||||
export class ToolRegistry {
|
||||
// The tools keyed by tool name as seen by the LLM.
|
||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
// This includes tools which are currently not active, use `getActiveTools`
|
||||
// and `isActive` to get only the active tools.
|
||||
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private messageBus?: MessageBus;
|
||||
|
||||
@@ -207,10 +209,14 @@ export class ToolRegistry {
|
||||
|
||||
/**
|
||||
* Registers a tool definition.
|
||||
*
|
||||
* Note that excluded tools are still registered to allow for enabling them
|
||||
* later in the session.
|
||||
*
|
||||
* @param tool - The tool object containing schema and execution logic.
|
||||
*/
|
||||
registerTool(tool: AnyDeclarativeTool): void {
|
||||
if (this.tools.has(tool.name)) {
|
||||
if (this.allKnownTools.has(tool.name)) {
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
tool = tool.asFullyQualifiedTool();
|
||||
} else {
|
||||
@@ -220,7 +226,7 @@ export class ToolRegistry {
|
||||
);
|
||||
}
|
||||
}
|
||||
this.tools.set(tool.name, tool);
|
||||
this.allKnownTools.set(tool.name, tool);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -229,7 +235,7 @@ export class ToolRegistry {
|
||||
* 2. Discovered tools.
|
||||
* 3. MCP tools ordered by server name.
|
||||
*
|
||||
* This is a stable sort in that ties preseve existing order.
|
||||
* This is a stable sort in that tries preserve existing order.
|
||||
*/
|
||||
sortTools(): void {
|
||||
const getPriority = (tool: AnyDeclarativeTool): number => {
|
||||
@@ -238,8 +244,8 @@ export class ToolRegistry {
|
||||
return 0; // Built-in
|
||||
};
|
||||
|
||||
this.tools = new Map(
|
||||
Array.from(this.tools.entries()).sort((a, b) => {
|
||||
this.allKnownTools = new Map(
|
||||
Array.from(this.allKnownTools.entries()).sort((a, b) => {
|
||||
const toolA = a[1];
|
||||
const toolB = b[1];
|
||||
const priorityA = getPriority(toolA);
|
||||
@@ -261,9 +267,9 @@ export class ToolRegistry {
|
||||
}
|
||||
|
||||
private removeDiscoveredTools(): void {
|
||||
for (const tool of this.tools.values()) {
|
||||
for (const tool of this.allKnownTools.values()) {
|
||||
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
|
||||
this.tools.delete(tool.name);
|
||||
this.allKnownTools.delete(tool.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -273,9 +279,9 @@ export class ToolRegistry {
|
||||
* @param serverName The name of the server to remove tools from.
|
||||
*/
|
||||
removeMcpToolsByServer(serverName: string): void {
|
||||
for (const [name, tool] of this.tools.entries()) {
|
||||
for (const [name, tool] of this.allKnownTools.entries()) {
|
||||
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||
this.tools.delete(name);
|
||||
this.allKnownTools.delete(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,6 +422,45 @@ export class ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns All the tools that are not excluded.
|
||||
*/
|
||||
private getActiveTools(): AnyDeclarativeTool[] {
|
||||
const excludedTools = this.config.getExcludeTools() ?? new Set([]);
|
||||
const activeTools: AnyDeclarativeTool[] = [];
|
||||
for (const tool of this.allKnownTools.values()) {
|
||||
if (this.isActiveTool(tool, excludedTools)) {
|
||||
activeTools.push(tool);
|
||||
}
|
||||
}
|
||||
return activeTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param tool
|
||||
* @param excludeTools (optional, helps performance for repeated calls)
|
||||
* @returns Whether or not the `tool` is not excluded.
|
||||
*/
|
||||
private isActiveTool(
|
||||
tool: AnyDeclarativeTool,
|
||||
excludeTools?: Set<string>,
|
||||
): boolean {
|
||||
excludeTools ??= this.config.getExcludeTools() ?? new Set([]);
|
||||
const normalizedClassName = tool.constructor.name.replace(/^_+/, '');
|
||||
const possibleNames = [tool.name, normalizedClassName];
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
// Check both the unqualified and qualified name for MCP tools.
|
||||
if (tool.name.startsWith(tool.getFullyQualifiedPrefix())) {
|
||||
possibleNames.push(
|
||||
tool.name.substring(tool.getFullyQualifiedPrefix().length),
|
||||
);
|
||||
} else {
|
||||
possibleNames.push(`${tool.getFullyQualifiedPrefix()}${tool.name}`);
|
||||
}
|
||||
}
|
||||
return !possibleNames.some((name) => excludeTools.has(name));
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
||||
* Extracts the declarations from the ToolListUnion structure.
|
||||
@@ -424,7 +469,7 @@ export class ToolRegistry {
|
||||
*/
|
||||
getFunctionDeclarations(): FunctionDeclaration[] {
|
||||
const declarations: FunctionDeclaration[] = [];
|
||||
this.tools.forEach((tool) => {
|
||||
this.getActiveTools().forEach((tool) => {
|
||||
declarations.push(tool.schema);
|
||||
});
|
||||
return declarations;
|
||||
@@ -438,8 +483,8 @@ export class ToolRegistry {
|
||||
getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] {
|
||||
const declarations: FunctionDeclaration[] = [];
|
||||
for (const name of toolNames) {
|
||||
const tool = this.tools.get(name);
|
||||
if (tool) {
|
||||
const tool = this.allKnownTools.get(name);
|
||||
if (tool && this.isActiveTool(tool)) {
|
||||
declarations.push(tool.schema);
|
||||
}
|
||||
}
|
||||
@@ -447,17 +492,18 @@ export class ToolRegistry {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of all registered and discovered tool names.
|
||||
* Returns an array of all registered and discovered tool names which are not
|
||||
* excluded via configuration.
|
||||
*/
|
||||
getAllToolNames(): string[] {
|
||||
return Array.from(this.tools.keys());
|
||||
return this.getActiveTools().map((tool) => tool.name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of all registered and discovered tool instances.
|
||||
*/
|
||||
getAllTools(): AnyDeclarativeTool[] {
|
||||
return Array.from(this.tools.values()).sort((a, b) =>
|
||||
return this.getActiveTools().sort((a, b) =>
|
||||
a.displayName.localeCompare(b.displayName),
|
||||
);
|
||||
}
|
||||
@@ -467,7 +513,7 @@ export class ToolRegistry {
|
||||
*/
|
||||
getToolsByServer(serverName: string): AnyDeclarativeTool[] {
|
||||
const serverTools: AnyDeclarativeTool[] = [];
|
||||
for (const tool of this.tools.values()) {
|
||||
for (const tool of this.getActiveTools()) {
|
||||
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
|
||||
serverTools.push(tool);
|
||||
}
|
||||
@@ -479,6 +525,10 @@ export class ToolRegistry {
|
||||
* Get the definition of a specific tool.
|
||||
*/
|
||||
getTool(name: string): AnyDeclarativeTool | undefined {
|
||||
return this.tools.get(name);
|
||||
const tool = this.allKnownTools.get(name);
|
||||
if (tool && this.isActiveTool(tool)) {
|
||||
return tool;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,19 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type MockInstance,
|
||||
} from 'vitest';
|
||||
import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
|
||||
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
|
||||
|
||||
@@ -23,15 +32,20 @@ describe('SimpleExtensionLoader', () => {
|
||||
let mockConfig: Config;
|
||||
let extensionReloadingEnabled: boolean;
|
||||
let mockMcpClientManager: McpClientManager;
|
||||
const activeExtension = {
|
||||
let mockGeminiClientSetTools: MockInstance<
|
||||
typeof GeminiClient.prototype.setTools
|
||||
>;
|
||||
|
||||
const activeExtension: GeminiCLIExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: true,
|
||||
version: '1.0.0',
|
||||
path: '/path/to/extension',
|
||||
contextFiles: [],
|
||||
excludeTools: ['some-tool'],
|
||||
id: '123',
|
||||
};
|
||||
const inactiveExtension = {
|
||||
const inactiveExtension: GeminiCLIExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: false,
|
||||
version: '1.0.0',
|
||||
@@ -46,9 +60,14 @@ describe('SimpleExtensionLoader', () => {
|
||||
stopExtension: vi.fn(),
|
||||
} as unknown as McpClientManager;
|
||||
extensionReloadingEnabled = false;
|
||||
mockGeminiClientSetTools = vi.fn();
|
||||
mockConfig = {
|
||||
getMcpClientManager: () => mockMcpClientManager,
|
||||
getEnableExtensionReloading: () => extensionReloadingEnabled,
|
||||
getGeminiClient: vi.fn(() => ({
|
||||
isInitialized: () => true,
|
||||
setTools: mockGeminiClientSetTools,
|
||||
})),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
@@ -106,11 +125,14 @@ describe('SimpleExtensionLoader', () => {
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||
}
|
||||
mockRefreshServerHierarchicalMemory.mockClear();
|
||||
mockGeminiClientSetTools.mockClear();
|
||||
|
||||
await loader.unloadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
@@ -118,9 +140,11 @@ describe('SimpleExtensionLoader', () => {
|
||||
mockMcpClientManager.stopExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ export abstract class ExtensionLoader {
|
||||
});
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.startExtension(extension);
|
||||
await this.maybeRefreshGeminiTools(extension);
|
||||
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
@@ -80,9 +82,6 @@ export abstract class ExtensionLoader {
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Move all enablement of extension features here, including at least:
|
||||
// - excluded tool configuration
|
||||
} finally {
|
||||
this.startCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStarting', {
|
||||
@@ -115,6 +114,21 @@ export abstract class ExtensionLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the gemini tools list if it is initialized and the extension has
|
||||
* any excludeTools settings.
|
||||
*/
|
||||
private async maybeRefreshGeminiTools(
|
||||
extension: GeminiCLIExtension,
|
||||
): Promise<void> {
|
||||
if (extension.excludeTools && extension.excludeTools.length > 0) {
|
||||
const geminiClient = this.config?.getGeminiClient();
|
||||
if (geminiClient?.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* If extension reloading is enabled and `start` has already been called,
|
||||
* then calls `startExtension` to include all extension features into the
|
||||
@@ -150,6 +164,8 @@ export abstract class ExtensionLoader {
|
||||
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.stopExtension(extension);
|
||||
await this.maybeRefreshGeminiTools(extension);
|
||||
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
@@ -157,9 +173,6 @@ export abstract class ExtensionLoader {
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Remove all extension features here, including at least:
|
||||
// - excluded tools
|
||||
} finally {
|
||||
this.stopCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStopping', {
|
||||
|
||||
@@ -58,7 +58,7 @@ beforeEach(() => {
|
||||
);
|
||||
config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => [],
|
||||
getExcludeTools: () => new Set([]),
|
||||
getAllowedTools: () => [],
|
||||
} as unknown as Config;
|
||||
});
|
||||
@@ -89,7 +89,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list', () => {
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -99,7 +99,7 @@ describe('isCommandAllowed', () => {
|
||||
|
||||
it('should prioritize the blocklist over the allowlist', () => {
|
||||
config.getCoreTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -114,7 +114,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block any command when a wildcard is in excludeTools', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command']);
|
||||
const result = isCommandAllowed('any random command', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -124,7 +124,7 @@ describe('isCommandAllowed', () => {
|
||||
|
||||
it('should block a command on the blocklist even with a wildcard allow', () => {
|
||||
config.getCoreTools = () => ['ShellTool'];
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -145,7 +145,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block a chained command if any part is blocked', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||
const result = isCommandAllowed(
|
||||
'echo "hello" && badCommand --danger',
|
||||
config,
|
||||
@@ -159,7 +159,7 @@ describe('isCommandAllowed', () => {
|
||||
it('should block a command that redefines an allowed function to run an unlisted command', () => {
|
||||
config.getCoreTools = () => ['run_shell_command(echo)'];
|
||||
const result = isCommandAllowed(
|
||||
'echo () (curl google.com) ; echo Hello Wolrd',
|
||||
'echo () (curl google.com) ; echo Hello World',
|
||||
config,
|
||||
);
|
||||
expect(result.allowed).toBe(false);
|
||||
@@ -355,7 +355,7 @@ describe('checkCommandPermissions', () => {
|
||||
});
|
||||
|
||||
it('should return a detailed failure object for a blocked command', () => {
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand)']);
|
||||
const result = checkCommandPermissions('badCommand --danger', config);
|
||||
expect(result).toEqual({
|
||||
allAllowed: false,
|
||||
@@ -424,7 +424,7 @@ describe('checkCommandPermissions', () => {
|
||||
});
|
||||
|
||||
it('should block a command on the sessionAllowlist if it is also globally blocked', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||
const result = checkCommandPermissions(
|
||||
'badCommand --danger',
|
||||
config,
|
||||
|
||||
@@ -605,9 +605,9 @@ export function checkCommandPermissions(
|
||||
} as AnyToolInvocation & { params: { command: string } };
|
||||
|
||||
// 1. Blocklist Check (Highest Priority)
|
||||
const excludeTools = config.getExcludeTools() || [];
|
||||
const excludeTools = config.getExcludeTools() || new Set([]);
|
||||
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
|
||||
excludeTools.includes(name),
|
||||
excludeTools.has(name),
|
||||
);
|
||||
|
||||
if (isWildcardBlocked) {
|
||||
@@ -622,7 +622,9 @@ export function checkCommandPermissions(
|
||||
for (const cmd of commandsToValidate) {
|
||||
invocation.params['command'] = cmd;
|
||||
if (
|
||||
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
|
||||
doesToolInvocationMatch('run_shell_command', invocation, [
|
||||
...excludeTools,
|
||||
])
|
||||
) {
|
||||
return {
|
||||
allAllowed: false,
|
||||
|
||||
Reference in New Issue
Block a user