[feat] Extension Reloading - respect updates to exclude tools (#12728)

This commit is contained in:
Jacob MacDonald
2025-11-07 12:18:35 -08:00
committed by GitHub
parent 2077521f84
commit c883403147
12 changed files with 230 additions and 91 deletions

View File

@@ -752,11 +752,11 @@ describe('mergeMcpServers', () => {
});
describe('mergeExcludeTools', () => {
const defaultExcludes = [
const defaultExcludes = new Set([
SHELL_TOOL_NAME,
EDIT_TOOL_NAME,
WRITE_FILE_TOOL_NAME,
];
]);
const originalIsTTY = process.stdin.isTTY;
beforeEach(() => {
@@ -799,7 +799,7 @@ describe('mergeExcludeTools', () => {
argv,
);
expect(config.getExcludeTools()).toEqual(
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
new Set(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
);
expect(config.getExcludeTools()).toHaveLength(5);
});
@@ -821,7 +821,7 @@ describe('mergeExcludeTools', () => {
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getExcludeTools()).toEqual(
expect.arrayContaining(['tool1', 'tool2', 'tool3']),
new Set(['tool1', 'tool2', 'tool3']),
);
expect(config.getExcludeTools()).toHaveLength(3);
});
@@ -852,7 +852,7 @@ describe('mergeExcludeTools', () => {
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getExcludeTools()).toEqual(
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4']),
new Set(['tool1', 'tool2', 'tool3', 'tool4']),
);
expect(config.getExcludeTools()).toHaveLength(4);
});
@@ -863,7 +863,7 @@ describe('mergeExcludeTools', () => {
process.argv = ['node', 'script.js'];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getExcludeTools()).toEqual([]);
expect(config.getExcludeTools()).toEqual(new Set([]));
});
it('should return default excludes when no excludeTools are specified and it is not interactive', async () => {
@@ -881,9 +881,7 @@ describe('mergeExcludeTools', () => {
const settings: Settings = { tools: { exclude: ['tool1', 'tool2'] } };
vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]);
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getExcludeTools()).toEqual(
expect.arrayContaining(['tool1', 'tool2']),
);
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
expect(config.getExcludeTools()).toHaveLength(2);
});
@@ -903,9 +901,7 @@ describe('mergeExcludeTools', () => {
process.argv = ['node', 'script.js'];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getExcludeTools()).toEqual(
expect.arrayContaining(['tool1', 'tool2']),
);
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
expect(config.getExcludeTools()).toHaveLength(2);
});

View File

@@ -57,6 +57,7 @@ describe('handleAtCommand', () => {
getToolRegistry,
getTargetDir: () => testRootDir,
isSandboxed: () => false,
getExcludeTools: vi.fn(),
getFileService: () => new FileDiscoveryService(testRootDir),
getFileFilteringRespectGitIgnore: () => true,
getFileFilteringRespectGeminiIgnore: () => true,

View File

@@ -872,27 +872,6 @@ describe('Server Config (config.ts)', () => {
expect(wasShellToolRegistered).toBe(true);
});
it('should not register a tool if excludeTools contains the non-minified class name', async () => {
const params: ConfigParameters = {
...baseParams,
coreTools: undefined, // all tools enabled by default
excludeTools: ['ShellTool'],
};
const config = new Config(params);
await config.initialize();
const registerToolMock = (
(await vi.importMock('../tools/tool-registry')) as {
ToolRegistry: { prototype: { registerTool: Mock } };
}
).ToolRegistry.prototype.registerTool;
const wasShellToolRegistered = (
registerToolMock as Mock
).mock.calls.some((call) => call[0] instanceof vi.mocked(ShellTool));
expect(wasShellToolRegistered).toBe(false);
});
it('should register a tool if coreTools contains an argument-specific pattern with the non-minified class name', async () => {
const params: ConfigParameters = {
...baseParams,

View File

@@ -826,7 +826,7 @@ export class Config {
*
* May change over time.
*/
getExcludeTools(): string[] | undefined {
getExcludeTools(): Set<string> | undefined {
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
for (const extension of this.getExtensionLoader().getExtensions()) {
if (!extension.isActive) {
@@ -836,7 +836,7 @@ export class Config {
excludeToolsSet.add(tool);
}
}
return [...excludeToolsSet];
return excludeToolsSet;
}
getToolDiscoveryCommand(): string | undefined {
@@ -1282,7 +1282,6 @@ export class Config {
const className = ToolClass.name;
const toolName = ToolClass.Name || className;
const coreTools = this.getCoreTools();
const excludeTools = this.getExcludeTools() || [];
// On some platforms, the className can be minified to _ClassName.
const normalizedClassName = className.replace(/^_+/, '');
@@ -1297,14 +1296,6 @@ export class Config {
);
}
const isExcluded = excludeTools.some(
(tool) => tool === toolName || tool === normalizedClassName,
);
if (isExcluded) {
isEnabled = false;
}
if (isEnabled) {
// Pass message bus to tools when feature flag is enabled
// This first implementation is only focused on the general case of
@@ -1363,15 +1354,12 @@ export class Config {
);
if (definition) {
// We must respect the main allowed/exclude lists for agents too.
const excludeTools = this.getExcludeTools() || [];
const allowedTools = this.getAllowedTools();
const isExcluded = excludeTools.includes(definition.name);
const isAllowed =
!allowedTools || allowedTools.includes(definition.name);
if (isAllowed && !isExcluded) {
if (isAllowed) {
const messageBusEnabled = this.getEnableMessageBusIntegration();
const wrapper = new SubagentToolWrapper(
definition,

View File

@@ -243,6 +243,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
);
}
getFullyQualifiedPrefix(): string {
return `${this.serverName}__`;
}
asFullyQualifiedTool(): DiscoveredMCPTool {
return new DiscoveredMCPTool(
this.mcpTool,
@@ -251,7 +255,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
this.description,
this.parameterSchema,
this.trust,
`${this.serverName}__${this.serverToolName}`,
`${this.getFullyQualifiedPrefix()}${this.serverToolName}`,
this.cliConfig,
this.extensionName,
this.extensionId,

View File

@@ -82,7 +82,7 @@ describe('ShellTool', () => {
getAllowedTools: vi.fn().mockReturnValue([]),
getApprovalMode: vi.fn().mockReturnValue('strict'),
getCoreTools: vi.fn().mockReturnValue([]),
getExcludeTools: vi.fn().mockReturnValue([]),
getExcludeTools: vi.fn().mockReturnValue(new Set([])),
getDebugMode: vi.fn().mockReturnValue(false),
getTargetDir: vi.fn().mockReturnValue(tempRootDir),
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),

View File

@@ -5,7 +5,7 @@
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import type { Mocked } from 'vitest';
import type { Mocked, MockInstance } from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type { ConfigParameters } from '../config/config.js';
import { Config } from '../config/config.js';
@@ -109,6 +109,9 @@ describe('ToolRegistry', () => {
let config: Config;
let toolRegistry: ToolRegistry;
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
let mockConfigGetExcludedTools: MockInstance<
typeof Config.prototype.getExcludeTools
>;
beforeEach(() => {
vi.mocked(fs.existsSync).mockReturnValue(true);
@@ -132,6 +135,7 @@ describe('ToolRegistry', () => {
config,
'getToolDiscoveryCommand',
);
mockConfigGetExcludedTools = vi.spyOn(config, 'getExcludeTools');
vi.spyOn(config, 'getMcpServers');
vi.spyOn(config, 'getMcpServerCommand');
vi.spyOn(config, 'getPromptRegistry').mockReturnValue({
@@ -152,6 +156,75 @@ describe('ToolRegistry', () => {
});
});
describe('excluded tools', () => {
const simpleTool = new MockTool({
name: 'tool-a',
displayName: 'Tool a',
});
const excludedTool = new ExcludedMockTool({
name: 'excluded-tool-class',
displayName: 'Excluded Tool Class',
});
const mockCallable = {} as CallableTool;
const mcpTool = new DiscoveredMCPTool(
mockCallable,
'mcp-server',
'excluded-mcp-tool',
'description',
{},
);
const allowedTool = new MockTool({
name: 'allowed-tool',
displayName: 'Allowed Tool',
});
it.each([
{
name: 'should match simple names',
tools: [simpleTool],
excludedTools: ['tool-a'],
},
{
name: 'should match simple MCP tool names, when qualified or unqualified',
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
excludedTools: [mcpTool.name],
},
{
name: 'should match qualified MCP tool names when qualified or unqualified',
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
excludedTools: [`${mcpTool.getFullyQualifiedPrefix()}${mcpTool.name}`],
},
{
name: 'should match class names',
tools: [excludedTool],
excludedTools: ['ExcludedMockTool'],
},
])('$name', ({ tools, excludedTools }) => {
toolRegistry.registerTool(allowedTool);
for (const tool of tools) {
toolRegistry.registerTool(tool);
}
mockConfigGetExcludedTools.mockReturnValue(new Set(excludedTools));
expect(toolRegistry.getAllTools()).toEqual([allowedTool]);
expect(toolRegistry.getAllToolNames()).toEqual([allowedTool.name]);
expect(toolRegistry.getFunctionDeclarations()).toEqual(
toolRegistry.getFunctionDeclarationsFiltered([allowedTool.name]),
);
for (const tool of tools) {
expect(toolRegistry.getTool(tool.name)).toBeUndefined();
expect(
toolRegistry.getFunctionDeclarationsFiltered([tool.name]),
).toHaveLength(0);
if (tool instanceof DiscoveredMCPTool) {
expect(toolRegistry.getToolsByServer(tool.serverName)).toHaveLength(
0,
);
}
}
});
});
describe('getAllTools', () => {
it('should return all registered tools sorted alphabetically by displayName', () => {
// Register tools with displayNames in non-alphabetical order
@@ -521,3 +594,12 @@ describe('ToolRegistry', () => {
});
});
});
/**
* Used for tests that exclude by class name.
*/
class ExcludedMockTool extends MockTool {
constructor(options: ConstructorParameters<typeof MockTool>[0]) {
super(options);
}
}

View File

@@ -189,7 +189,9 @@ Signal: Signal number or \`(none)\` if no signal was received.
export class ToolRegistry {
// The tools keyed by tool name as seen by the LLM.
private tools: Map<string, AnyDeclarativeTool> = new Map();
// This includes tools which are currently not active, use `getActiveTools`
// and `isActive` to get only the active tools.
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
private messageBus?: MessageBus;
@@ -207,10 +209,14 @@ export class ToolRegistry {
/**
* Registers a tool definition.
*
* Note that excluded tools are still registered to allow for enabling them
* later in the session.
*
* @param tool - The tool object containing schema and execution logic.
*/
registerTool(tool: AnyDeclarativeTool): void {
if (this.tools.has(tool.name)) {
if (this.allKnownTools.has(tool.name)) {
if (tool instanceof DiscoveredMCPTool) {
tool = tool.asFullyQualifiedTool();
} else {
@@ -220,7 +226,7 @@ export class ToolRegistry {
);
}
}
this.tools.set(tool.name, tool);
this.allKnownTools.set(tool.name, tool);
}
/**
@@ -229,7 +235,7 @@ export class ToolRegistry {
* 2. Discovered tools.
* 3. MCP tools ordered by server name.
*
* This is a stable sort in that ties preseve existing order.
* This is a stable sort in that tries preserve existing order.
*/
sortTools(): void {
const getPriority = (tool: AnyDeclarativeTool): number => {
@@ -238,8 +244,8 @@ export class ToolRegistry {
return 0; // Built-in
};
this.tools = new Map(
Array.from(this.tools.entries()).sort((a, b) => {
this.allKnownTools = new Map(
Array.from(this.allKnownTools.entries()).sort((a, b) => {
const toolA = a[1];
const toolB = b[1];
const priorityA = getPriority(toolA);
@@ -261,9 +267,9 @@ export class ToolRegistry {
}
private removeDiscoveredTools(): void {
for (const tool of this.tools.values()) {
for (const tool of this.allKnownTools.values()) {
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
this.allKnownTools.delete(tool.name);
}
}
}
@@ -273,9 +279,9 @@ export class ToolRegistry {
* @param serverName The name of the server to remove tools from.
*/
removeMcpToolsByServer(serverName: string): void {
for (const [name, tool] of this.tools.entries()) {
for (const [name, tool] of this.allKnownTools.entries()) {
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
this.tools.delete(name);
this.allKnownTools.delete(name);
}
}
}
@@ -416,6 +422,45 @@ export class ToolRegistry {
}
}
/**
* @returns All the tools that are not excluded.
*/
private getActiveTools(): AnyDeclarativeTool[] {
const excludedTools = this.config.getExcludeTools() ?? new Set([]);
const activeTools: AnyDeclarativeTool[] = [];
for (const tool of this.allKnownTools.values()) {
if (this.isActiveTool(tool, excludedTools)) {
activeTools.push(tool);
}
}
return activeTools;
}
/**
* @param tool
* @param excludeTools (optional, helps performance for repeated calls)
* @returns Whether or not the `tool` is not excluded.
*/
private isActiveTool(
tool: AnyDeclarativeTool,
excludeTools?: Set<string>,
): boolean {
excludeTools ??= this.config.getExcludeTools() ?? new Set([]);
const normalizedClassName = tool.constructor.name.replace(/^_+/, '');
const possibleNames = [tool.name, normalizedClassName];
if (tool instanceof DiscoveredMCPTool) {
// Check both the unqualified and qualified name for MCP tools.
if (tool.name.startsWith(tool.getFullyQualifiedPrefix())) {
possibleNames.push(
tool.name.substring(tool.getFullyQualifiedPrefix().length),
);
} else {
possibleNames.push(`${tool.getFullyQualifiedPrefix()}${tool.name}`);
}
}
return !possibleNames.some((name) => excludeTools.has(name));
}
/**
* Retrieves the list of tool schemas (FunctionDeclaration array).
* Extracts the declarations from the ToolListUnion structure.
@@ -424,7 +469,7 @@ export class ToolRegistry {
*/
getFunctionDeclarations(): FunctionDeclaration[] {
const declarations: FunctionDeclaration[] = [];
this.tools.forEach((tool) => {
this.getActiveTools().forEach((tool) => {
declarations.push(tool.schema);
});
return declarations;
@@ -438,8 +483,8 @@ export class ToolRegistry {
getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] {
const declarations: FunctionDeclaration[] = [];
for (const name of toolNames) {
const tool = this.tools.get(name);
if (tool) {
const tool = this.allKnownTools.get(name);
if (tool && this.isActiveTool(tool)) {
declarations.push(tool.schema);
}
}
@@ -447,17 +492,18 @@ export class ToolRegistry {
}
/**
* Returns an array of all registered and discovered tool names.
* Returns an array of all registered and discovered tool names which are not
* excluded via configuration.
*/
getAllToolNames(): string[] {
return Array.from(this.tools.keys());
return this.getActiveTools().map((tool) => tool.name);
}
/**
* Returns an array of all registered and discovered tool instances.
*/
getAllTools(): AnyDeclarativeTool[] {
return Array.from(this.tools.values()).sort((a, b) =>
return this.getActiveTools().sort((a, b) =>
a.displayName.localeCompare(b.displayName),
);
}
@@ -467,7 +513,7 @@ export class ToolRegistry {
*/
getToolsByServer(serverName: string): AnyDeclarativeTool[] {
const serverTools: AnyDeclarativeTool[] = [];
for (const tool of this.tools.values()) {
for (const tool of this.getActiveTools()) {
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
serverTools.push(tool);
}
@@ -479,6 +525,10 @@ export class ToolRegistry {
* Get the definition of a specific tool.
*/
getTool(name: string): AnyDeclarativeTool | undefined {
return this.tools.get(name);
const tool = this.allKnownTools.get(name);
if (tool && this.isActiveTool(tool)) {
return tool;
}
return;
}
}

View File

@@ -4,10 +4,19 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
import {
describe,
expect,
it,
vi,
beforeEach,
afterEach,
type MockInstance,
} from 'vitest';
import { SimpleExtensionLoader } from './extensionLoader.js';
import type { Config } from '../config/config.js';
import type { Config, GeminiCLIExtension } from '../config/config.js';
import { type McpClientManager } from '../tools/mcp-client-manager.js';
import type { GeminiClient } from '../core/client.js';
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
@@ -23,15 +32,20 @@ describe('SimpleExtensionLoader', () => {
let mockConfig: Config;
let extensionReloadingEnabled: boolean;
let mockMcpClientManager: McpClientManager;
const activeExtension = {
let mockGeminiClientSetTools: MockInstance<
typeof GeminiClient.prototype.setTools
>;
const activeExtension: GeminiCLIExtension = {
name: 'test-extension',
isActive: true,
version: '1.0.0',
path: '/path/to/extension',
contextFiles: [],
excludeTools: ['some-tool'],
id: '123',
};
const inactiveExtension = {
const inactiveExtension: GeminiCLIExtension = {
name: 'test-extension',
isActive: false,
version: '1.0.0',
@@ -46,9 +60,14 @@ describe('SimpleExtensionLoader', () => {
stopExtension: vi.fn(),
} as unknown as McpClientManager;
extensionReloadingEnabled = false;
mockGeminiClientSetTools = vi.fn();
mockConfig = {
getMcpClientManager: () => mockMcpClientManager,
getEnableExtensionReloading: () => extensionReloadingEnabled,
getGeminiClient: vi.fn(() => ({
isInitialized: () => true,
setTools: mockGeminiClientSetTools,
})),
} as unknown as Config;
});
@@ -106,11 +125,14 @@ describe('SimpleExtensionLoader', () => {
mockMcpClientManager.startExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
} else {
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
}
mockRefreshServerHierarchicalMemory.mockClear();
mockGeminiClientSetTools.mockClear();
await loader.unloadExtension(activeExtension);
if (reloadingEnabled) {
@@ -118,9 +140,11 @@ describe('SimpleExtensionLoader', () => {
mockMcpClientManager.stopExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
} else {
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
}
});

View File

@@ -73,6 +73,8 @@ export abstract class ExtensionLoader {
});
try {
await this.config.getMcpClientManager()!.startExtension(extension);
await this.maybeRefreshGeminiTools(extension);
// Note: Context files are loaded only once all extensions are done
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
// below.
@@ -80,9 +82,6 @@ export abstract class ExtensionLoader {
// TODO: Update custom command updating away from the event based system
// and call directly into a custom command manager here. See the
// useSlashCommandProcessor hook which responds to events fired here today.
// TODO: Move all enablement of extension features here, including at least:
// - excluded tool configuration
} finally {
this.startCompletedCount++;
this.eventEmitter?.emit('extensionsStarting', {
@@ -115,6 +114,21 @@ export abstract class ExtensionLoader {
}
}
/**
* Refreshes the gemini tools list if it is initialized and the extension has
* any excludeTools settings.
*/
private async maybeRefreshGeminiTools(
extension: GeminiCLIExtension,
): Promise<void> {
if (extension.excludeTools && extension.excludeTools.length > 0) {
const geminiClient = this.config?.getGeminiClient();
if (geminiClient?.isInitialized()) {
await geminiClient.setTools();
}
}
}
/**
* If extension reloading is enabled and `start` has already been called,
* then calls `startExtension` to include all extension features into the
@@ -150,6 +164,8 @@ export abstract class ExtensionLoader {
try {
await this.config.getMcpClientManager()!.stopExtension(extension);
await this.maybeRefreshGeminiTools(extension);
// Note: Context files are loaded only once all extensions are done
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
// below.
@@ -157,9 +173,6 @@ export abstract class ExtensionLoader {
// TODO: Update custom command updating away from the event based system
// and call directly into a custom command manager here. See the
// useSlashCommandProcessor hook which responds to events fired here today.
// TODO: Remove all extension features here, including at least:
// - excluded tools
} finally {
this.stopCompletedCount++;
this.eventEmitter?.emit('extensionsStopping', {

View File

@@ -58,7 +58,7 @@ beforeEach(() => {
);
config = {
getCoreTools: () => [],
getExcludeTools: () => [],
getExcludeTools: () => new Set([]),
getAllowedTools: () => [],
} as unknown as Config;
});
@@ -89,7 +89,7 @@ describe('isCommandAllowed', () => {
});
it('should block a command if it is in the blocked list', () => {
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
const result = isCommandAllowed('badCommand --danger', config);
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
@@ -99,7 +99,7 @@ describe('isCommandAllowed', () => {
it('should prioritize the blocklist over the allowlist', () => {
config.getCoreTools = () => ['ShellTool(badCommand --danger)'];
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
const result = isCommandAllowed('badCommand --danger', config);
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
@@ -114,7 +114,7 @@ describe('isCommandAllowed', () => {
});
it('should block any command when a wildcard is in excludeTools', () => {
config.getExcludeTools = () => ['run_shell_command'];
config.getExcludeTools = () => new Set(['run_shell_command']);
const result = isCommandAllowed('any random command', config);
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
@@ -124,7 +124,7 @@ describe('isCommandAllowed', () => {
it('should block a command on the blocklist even with a wildcard allow', () => {
config.getCoreTools = () => ['ShellTool'];
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
const result = isCommandAllowed('badCommand --danger', config);
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
@@ -145,7 +145,7 @@ describe('isCommandAllowed', () => {
});
it('should block a chained command if any part is blocked', () => {
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
const result = isCommandAllowed(
'echo "hello" && badCommand --danger',
config,
@@ -159,7 +159,7 @@ describe('isCommandAllowed', () => {
it('should block a command that redefines an allowed function to run an unlisted command', () => {
config.getCoreTools = () => ['run_shell_command(echo)'];
const result = isCommandAllowed(
'echo () (curl google.com) ; echo Hello Wolrd',
'echo () (curl google.com) ; echo Hello World',
config,
);
expect(result.allowed).toBe(false);
@@ -355,7 +355,7 @@ describe('checkCommandPermissions', () => {
});
it('should return a detailed failure object for a blocked command', () => {
config.getExcludeTools = () => ['ShellTool(badCommand)'];
config.getExcludeTools = () => new Set(['ShellTool(badCommand)']);
const result = checkCommandPermissions('badCommand --danger', config);
expect(result).toEqual({
allAllowed: false,
@@ -424,7 +424,7 @@ describe('checkCommandPermissions', () => {
});
it('should block a command on the sessionAllowlist if it is also globally blocked', () => {
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
const result = checkCommandPermissions(
'badCommand --danger',
config,

View File

@@ -605,9 +605,9 @@ export function checkCommandPermissions(
} as AnyToolInvocation & { params: { command: string } };
// 1. Blocklist Check (Highest Priority)
const excludeTools = config.getExcludeTools() || [];
const excludeTools = config.getExcludeTools() || new Set([]);
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
excludeTools.includes(name),
excludeTools.has(name),
);
if (isWildcardBlocked) {
@@ -622,7 +622,9 @@ export function checkCommandPermissions(
for (const cmd of commandsToValidate) {
invocation.params['command'] = cmd;
if (
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
doesToolInvocationMatch('run_shell_command', invocation, [
...excludeTools,
])
) {
return {
allAllowed: false,