mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 22:55:13 +00:00
feat: Persistent "Always Allow" policies with granular shell & MCP support (#14737)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -114,6 +114,11 @@ export const ToolConfirmationMessage: React.FC<
|
||||
value: ToolConfirmationOutcome.ProceedAlways,
|
||||
key: 'Yes, allow always',
|
||||
});
|
||||
options.push({
|
||||
label: 'Yes, allow always and save to policy',
|
||||
value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
|
||||
key: 'Yes, allow always and save to policy',
|
||||
});
|
||||
}
|
||||
if (!config.getIdeMode() || !isDiffingEnabled) {
|
||||
options.push({
|
||||
@@ -145,6 +150,11 @@ export const ToolConfirmationMessage: React.FC<
|
||||
value: ToolConfirmationOutcome.ProceedAlways,
|
||||
key: `Yes, allow always ...`,
|
||||
});
|
||||
options.push({
|
||||
label: `Yes, allow always and save to policy`,
|
||||
value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
|
||||
key: `Yes, allow always and save to policy`,
|
||||
});
|
||||
}
|
||||
options.push({
|
||||
label: 'No, suggest changes (esc)',
|
||||
@@ -164,6 +174,11 @@ export const ToolConfirmationMessage: React.FC<
|
||||
value: ToolConfirmationOutcome.ProceedAlways,
|
||||
key: 'Yes, allow always',
|
||||
});
|
||||
options.push({
|
||||
label: 'Yes, allow always and save to policy',
|
||||
value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
|
||||
key: 'Yes, allow always and save to policy',
|
||||
});
|
||||
}
|
||||
options.push({
|
||||
label: 'No, suggest changes (esc)',
|
||||
@@ -190,6 +205,11 @@ export const ToolConfirmationMessage: React.FC<
|
||||
value: ToolConfirmationOutcome.ProceedAlwaysServer,
|
||||
key: `Yes, always allow all tools from server "${mcpProps.serverName}"`,
|
||||
});
|
||||
options.push({
|
||||
label: `Yes, allow always tool "${mcpProps.toolName}" and save to policy`,
|
||||
value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
|
||||
key: `Yes, allow always tool "${mcpProps.toolName}" and save to policy`,
|
||||
});
|
||||
}
|
||||
options.push({
|
||||
label: 'No, suggest changes (esc)',
|
||||
|
||||
@@ -10,7 +10,8 @@ Do you want to proceed?
|
||||
|
||||
● 1. Yes, allow once
|
||||
2. Yes, allow always
|
||||
3. No, suggest changes (esc)
|
||||
3. Yes, allow always and save to policy
|
||||
4. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -21,7 +22,8 @@ Do you want to proceed?
|
||||
|
||||
● 1. Yes, allow once
|
||||
2. Yes, allow always
|
||||
3. No, suggest changes (esc)
|
||||
3. Yes, allow always and save to policy
|
||||
4. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -51,8 +53,9 @@ Apply this change?
|
||||
|
||||
● 1. Yes, allow once
|
||||
2. Yes, allow always
|
||||
3. Modify with external editor
|
||||
4. No, suggest changes (esc)
|
||||
3. Yes, allow always and save to policy
|
||||
4. Modify with external editor
|
||||
5. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -73,7 +76,8 @@ Allow execution of: 'echo'?
|
||||
|
||||
● 1. Yes, allow once
|
||||
2. Yes, allow always ...
|
||||
3. No, suggest changes (esc)
|
||||
3. Yes, allow always and save to policy
|
||||
4. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -94,7 +98,8 @@ Do you want to proceed?
|
||||
|
||||
● 1. Yes, allow once
|
||||
2. Yes, allow always
|
||||
3. No, suggest changes (esc)
|
||||
3. Yes, allow always and save to policy
|
||||
4. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -118,6 +123,7 @@ Allow execution of MCP tool "test-tool" from server "test-server"?
|
||||
● 1. Yes, allow once
|
||||
2. Yes, always allow tool "test-tool" from server "test-server"
|
||||
3. Yes, always allow all tools from server "test-server"
|
||||
4. No, suggest changes (esc)
|
||||
4. Yes, allow always tool "test-tool" and save to policy
|
||||
5. No, suggest changes (esc)
|
||||
"
|
||||
`;
|
||||
|
||||
@@ -39,7 +39,8 @@ exports[`<ToolGroupMessage /> > Confirmation Handling > shows confirmation dialo
|
||||
│ │
|
||||
│ ● 1. Yes, allow once │
|
||||
│ 2. Yes, allow always │
|
||||
│ 3. No, suggest changes (esc) │
|
||||
│ 3. Yes, allow always and save to policy │
|
||||
│ 4. No, suggest changes (esc) │
|
||||
│ │
|
||||
│ │
|
||||
│ ? second-confirm A tool for testing │
|
||||
@@ -122,7 +123,8 @@ exports[`<ToolGroupMessage /> > Golden Snapshots > renders tool call awaiting co
|
||||
│ │
|
||||
│ ● 1. Yes, allow once │
|
||||
│ 2. Yes, allow always │
|
||||
│ 3. No, suggest changes (esc) │
|
||||
│ 3. Yes, allow always and save to policy │
|
||||
│ 4. No, suggest changes (esc) │
|
||||
│ │
|
||||
╰──────────────────────────────────────────────────────────────────────────────╯"
|
||||
`;
|
||||
|
||||
@@ -449,6 +449,7 @@ export class Session {
|
||||
);
|
||||
case ToolConfirmationOutcome.ProceedOnce:
|
||||
case ToolConfirmationOutcome.ProceedAlways:
|
||||
case ToolConfirmationOutcome.ProceedAlwaysAndSave:
|
||||
case ToolConfirmationOutcome.ProceedAlwaysServer:
|
||||
case ToolConfirmationOutcome.ProceedAlwaysTool:
|
||||
case ToolConfirmationOutcome.ModifyWithEditor:
|
||||
|
||||
@@ -39,6 +39,10 @@ export interface ToolConfirmationResponse {
|
||||
export interface UpdatePolicy {
|
||||
type: MessageBusType.UPDATE_POLICY;
|
||||
toolName: string;
|
||||
persist?: boolean;
|
||||
argsPattern?: string;
|
||||
commandPrefix?: string;
|
||||
mcpName?: string;
|
||||
}
|
||||
|
||||
export interface ToolPolicyRejection {
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';
|
||||
import { detectIde, IDE_DEFINITIONS } from './detect-ide.js';
|
||||
|
||||
beforeEach(() => {
|
||||
// Ensure Antigravity detection doesn't interfere with other tests
|
||||
vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', '');
|
||||
});
|
||||
|
||||
describe('detectIde', () => {
|
||||
const ideProcessInfo = { pid: 123, command: 'some/path/to/code' };
|
||||
const ideProcessInfoNoCode = { pid: 123, command: 'some/path/to/fork' };
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import { Storage } from '../config/storage.js';
|
||||
@@ -15,7 +16,12 @@ import {
|
||||
type PolicySettings,
|
||||
} from './types.js';
|
||||
import type { PolicyEngine } from './policy-engine.js';
|
||||
import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js';
|
||||
import {
|
||||
loadPoliciesFromToml,
|
||||
type PolicyFileError,
|
||||
escapeRegex,
|
||||
} from './toml-loader.js';
|
||||
import toml from '@iarna/toml';
|
||||
import {
|
||||
MessageBusType,
|
||||
type UpdatePolicy,
|
||||
@@ -233,14 +239,35 @@ export async function createPolicyEngineConfig(
|
||||
};
|
||||
}
|
||||
|
||||
interface TomlRule {
|
||||
toolName?: string;
|
||||
mcpName?: string;
|
||||
decision?: string;
|
||||
priority?: number;
|
||||
commandPrefix?: string;
|
||||
argsPattern?: string;
|
||||
// Index signature to satisfy Record type if needed for toml.stringify
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export function createPolicyUpdater(
|
||||
policyEngine: PolicyEngine,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
messageBus.subscribe(
|
||||
MessageBusType.UPDATE_POLICY,
|
||||
(message: UpdatePolicy) => {
|
||||
async (message: UpdatePolicy) => {
|
||||
const toolName = message.toolName;
|
||||
let argsPattern = message.argsPattern
|
||||
? new RegExp(message.argsPattern)
|
||||
: undefined;
|
||||
|
||||
if (message.commandPrefix) {
|
||||
// Convert commandPrefix to argsPattern for in-memory rule
|
||||
// This mimics what toml-loader does
|
||||
const escapedPrefix = escapeRegex(message.commandPrefix);
|
||||
argsPattern = new RegExp(`"command":"${escapedPrefix}`);
|
||||
}
|
||||
|
||||
policyEngine.addRule({
|
||||
toolName,
|
||||
@@ -249,7 +276,77 @@ export function createPolicyUpdater(
|
||||
// This ensures user "always allow" selections are high priority
|
||||
// but still lose to admin policies (3.xxx) and settings excludes (200)
|
||||
priority: 2.95,
|
||||
argsPattern,
|
||||
});
|
||||
|
||||
if (message.persist) {
|
||||
try {
|
||||
const userPoliciesDir = Storage.getUserPoliciesDir();
|
||||
await fs.mkdir(userPoliciesDir, { recursive: true });
|
||||
const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');
|
||||
|
||||
// Read existing file
|
||||
let existingData: { rule?: TomlRule[] } = {};
|
||||
try {
|
||||
const fileContent = await fs.readFile(policyFile, 'utf-8');
|
||||
existingData = toml.parse(fileContent) as { rule?: TomlRule[] };
|
||||
} catch (error) {
|
||||
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
console.warn(
|
||||
`Failed to parse ${policyFile}, overwriting with new policy.`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize rule array if needed
|
||||
if (!existingData.rule) {
|
||||
existingData.rule = [];
|
||||
}
|
||||
|
||||
// Create new rule object
|
||||
const newRule: TomlRule = {};
|
||||
|
||||
if (message.mcpName) {
|
||||
newRule.mcpName = message.mcpName;
|
||||
// Extract simple tool name
|
||||
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
|
||||
? toolName.slice(message.mcpName.length + 2)
|
||||
: toolName;
|
||||
newRule.toolName = simpleToolName;
|
||||
newRule.decision = 'allow';
|
||||
newRule.priority = 200;
|
||||
} else {
|
||||
newRule.toolName = toolName;
|
||||
newRule.decision = 'allow';
|
||||
newRule.priority = 100;
|
||||
}
|
||||
|
||||
if (message.commandPrefix) {
|
||||
newRule.commandPrefix = message.commandPrefix;
|
||||
} else if (message.argsPattern) {
|
||||
newRule.argsPattern = message.argsPattern;
|
||||
}
|
||||
|
||||
// Add to rules
|
||||
existingData.rule.push(newRule);
|
||||
|
||||
// Serialize back to TOML
|
||||
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
|
||||
const newContent = toml.stringify(existingData as toml.JsonMap);
|
||||
|
||||
// Atomic write: write to tmp then rename
|
||||
const tmpFile = `${policyFile}.tmp`;
|
||||
await fs.writeFile(tmpFile, newContent, 'utf-8');
|
||||
await fs.rename(tmpFile, policyFile);
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Failed to persist policy for ${toolName}`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
209
packages/core/src/policy/persistence.test.ts
Normal file
209
packages/core/src/policy/persistence.test.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import { createPolicyUpdater } from './config.js';
|
||||
import { PolicyEngine } from './policy-engine.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
import { Storage } from '../config/storage.js';
|
||||
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('../config/storage.js');
|
||||
|
||||
describe('createPolicyUpdater', () => {
|
||||
let policyEngine: PolicyEngine;
|
||||
let messageBus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
policyEngine = new PolicyEngine({ rules: [], checkers: [] });
|
||||
messageBus = new MessageBus(policyEngine);
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should persist policy when persist flag is true', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
const userPoliciesDir = '/mock/user/policies';
|
||||
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
); // Simulate new file
|
||||
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const toolName = 'test_tool';
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
persist: true,
|
||||
});
|
||||
|
||||
// Wait for async operations (microtasks)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
expect(Storage.getUserPoliciesDir).toHaveBeenCalled();
|
||||
expect(fs.mkdir).toHaveBeenCalledWith(userPoliciesDir, {
|
||||
recursive: true,
|
||||
});
|
||||
|
||||
// Check written content
|
||||
const expectedContent = expect.stringContaining(`toolName = "test_tool"`);
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\.tmp$/),
|
||||
expectedContent,
|
||||
'utf-8',
|
||||
);
|
||||
expect(fs.rename).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\.tmp$/),
|
||||
path.join(userPoliciesDir, 'auto-saved.toml'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not persist policy when persist flag is false or undefined', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: 'test_tool',
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
expect(fs.rename).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should persist policy with commandPrefix when provided', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
const userPoliciesDir = '/mock/user/policies';
|
||||
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const toolName = 'run_shell_command';
|
||||
const commandPrefix = 'git status';
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
persist: true,
|
||||
commandPrefix,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
// In-memory rule check (unchanged)
|
||||
const rules = policyEngine.getRules();
|
||||
const addedRule = rules.find((r) => r.toolName === toolName);
|
||||
expect(addedRule).toBeDefined();
|
||||
expect(addedRule?.priority).toBe(2.95);
|
||||
expect(addedRule?.argsPattern).toEqual(new RegExp(`"command":"git status`));
|
||||
|
||||
// Verify file written
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\.tmp$/),
|
||||
expect.stringContaining(`commandPrefix = "git status"`),
|
||||
'utf-8',
|
||||
);
|
||||
});
|
||||
|
||||
it('should persist policy with mcpName and toolName when provided', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
const userPoliciesDir = '/mock/user/policies';
|
||||
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const mcpName = 'my-jira-server';
|
||||
const simpleToolName = 'search';
|
||||
const toolName = `${mcpName}__${simpleToolName}`;
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
persist: true,
|
||||
mcpName,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
// Verify file written
|
||||
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
|
||||
const writtenContent = writeCall[1] as string;
|
||||
expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
|
||||
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
|
||||
expect(writtenContent).toContain('priority = 200');
|
||||
});
|
||||
|
||||
it('should escape special characters in toolName and mcpName', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
const userPoliciesDir = '/mock/user/policies';
|
||||
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const mcpName = 'my"jira"server';
|
||||
const toolName = `my"jira"server__search"tool"`;
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
persist: true,
|
||||
mcpName,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
|
||||
const writtenContent = writeCall[1] as string;
|
||||
|
||||
// Verify escaping - should be valid TOML
|
||||
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'
|
||||
// instead of "foo\"bar\"" if there are no single quotes in the string.
|
||||
try {
|
||||
expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`);
|
||||
} catch {
|
||||
expect(writtenContent).toContain(`mcpName = 'my"jira"server'`);
|
||||
}
|
||||
|
||||
try {
|
||||
expect(writtenContent).toContain(`toolName = "search\\"tool\\""`);
|
||||
} catch {
|
||||
expect(writtenContent).toContain(`toolName = 'search"tool"'`);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -126,7 +126,7 @@ export interface PolicyLoadResult {
|
||||
* @param str The string to escape
|
||||
* @returns The escaped string safe for use in a regex
|
||||
*/
|
||||
function escapeRegex(str: string): string {
|
||||
export function escapeRegex(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
|
||||
@@ -100,6 +100,11 @@ vi.mock('../../utils/installationManager.js');
|
||||
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
|
||||
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
|
||||
|
||||
beforeEach(() => {
|
||||
// Ensure Antigravity detection doesn't interfere with other tests
|
||||
vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', '');
|
||||
});
|
||||
|
||||
// TODO(richieforeman): Consider moving this to test setup globally.
|
||||
beforeAll(() => {
|
||||
server.listen({});
|
||||
|
||||
@@ -313,6 +313,7 @@ class EditToolInvocation
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
|
||||
if (ideConfirmation) {
|
||||
const result = await ideConfirmation;
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
type PolicyUpdateOptions,
|
||||
} from './tools.js';
|
||||
import type { CallableTool, FunctionCall, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
@@ -87,6 +88,12 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
protected override getPolicyUpdateOptions(
|
||||
_outcome: ToolConfirmationOutcome,
|
||||
): PolicyUpdateOptions | undefined {
|
||||
return { mcpName: this.serverName };
|
||||
}
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
@@ -115,6 +122,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
|
||||
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
|
||||
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
|
||||
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
|
||||
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -226,6 +226,7 @@ class MemoryToolInvocation extends BaseToolInvocation<
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
MemoryToolInvocation.allowlist.add(allowlistKey);
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
ToolConfirmationOutcome,
|
||||
Kind,
|
||||
type PolicyUpdateOptions,
|
||||
} from './tools.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
@@ -83,6 +84,15 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
return description;
|
||||
}
|
||||
|
||||
protected override getPolicyUpdateOptions(
|
||||
outcome: ToolConfirmationOutcome,
|
||||
): PolicyUpdateOptions | undefined {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
|
||||
return { commandPrefix: this.params.command };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
@@ -124,6 +134,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
commandsToConfirm.forEach((command) => this.allowlist.add(command));
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
|
||||
@@ -683,6 +683,7 @@ class EditToolInvocation
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
|
||||
if (ideConfirmation) {
|
||||
const result = await ideConfirmation;
|
||||
|
||||
@@ -65,6 +65,14 @@ export interface ToolInvocation<
|
||||
): Promise<TResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for policy updates that can be customized by tool invocations.
|
||||
*/
|
||||
export interface PolicyUpdateOptions {
|
||||
commandPrefix?: string;
|
||||
mcpName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience base class for ToolInvocation.
|
||||
*/
|
||||
@@ -112,6 +120,40 @@ export abstract class BaseToolInvocation<
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns tool-specific options for policy updates.
|
||||
* Subclasses can override this to provide additional options like
|
||||
* commandPrefix (for shell) or mcpName (for MCP tools).
|
||||
*/
|
||||
protected getPolicyUpdateOptions(
|
||||
_outcome: ToolConfirmationOutcome,
|
||||
): PolicyUpdateOptions | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to publish a policy update when user selects
|
||||
* ProceedAlways or ProceedAlwaysAndSave.
|
||||
*/
|
||||
protected async publishPolicyUpdate(
|
||||
outcome: ToolConfirmationOutcome,
|
||||
): Promise<void> {
|
||||
if (
|
||||
outcome === ToolConfirmationOutcome.ProceedAlways ||
|
||||
outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave
|
||||
) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
const options = this.getPolicyUpdateOptions(outcome);
|
||||
await this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: this._toolName,
|
||||
persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave,
|
||||
...options,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Subclasses should override this method to provide custom confirmation UI
|
||||
* when the policy engine's decision is 'ASK_USER'.
|
||||
@@ -129,15 +171,7 @@ export abstract class BaseToolInvocation<
|
||||
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
|
||||
prompt: this.getDescription(),
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: this._toolName,
|
||||
});
|
||||
}
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
@@ -686,6 +720,7 @@ export type ToolCallConfirmationDetails =
|
||||
export enum ToolConfirmationOutcome {
|
||||
ProceedOnce = 'proceed_once',
|
||||
ProceedAlways = 'proceed_always',
|
||||
ProceedAlwaysAndSave = 'proceed_always_and_save',
|
||||
ProceedAlwaysServer = 'proceed_always_server',
|
||||
ProceedAlwaysTool = 'proceed_always_tool',
|
||||
ModifyWithEditor = 'modify_with_editor',
|
||||
|
||||
@@ -244,6 +244,7 @@ ${textContent}
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
|
||||
@@ -224,6 +224,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
await this.publishPolicyUpdate(outcome);
|
||||
|
||||
if (ideConfirmation) {
|
||||
const result = await ideConfirmation;
|
||||
|
||||
Reference in New Issue
Block a user