mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 22:48:03 +00:00
feat(hooks): implement STOP_EXECUTION and enhance hook decision handling (#15685)
This commit is contained in:
@@ -1,2 +1 @@
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"write_file","args":{"file_path":"test.txt","content":"hello"}}}],"role":"model"},"finishReason":"STOP","index":0}]}]}
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Okay, stopping."}],"role":"model"},"finishReason":"STOP","index":0}]}]}
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"**Initializing File Creation**\n\nI'm starting to think about how to make a new file called `test.txt`. My plan is to use a `write_file` tool. I'll need to specify the location and what the file should contain. For now, it will be empty.\n\n\n","thought":true}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":13216,"totalTokenCount":13269,"promptTokensDetails":[{"modality":"TEXT","tokenCount":13216}],"thoughtsTokenCount":53}},{"candidates":[{"content":{"parts":[{"functionCall":{"name":"write_file","args":{"file_path":"test.txt","content":""}},"thoughtSignature":"CiQBcsjafJ20Qbx0YvING6aZ0wYoGWJh3eqornOG4E4AfBLiVsQKXwFyyNp8UlwYs/pv9IRQQGhDlrmlOJF2hfQijryyUYLI+qjDYTpZ6KKIfZF4+vS0soL2BJ3eTXA6gaadFEfNQem3WQVeQoKLFoW4Hv4mbasXqQc0K3p15DuSAtZZENTbCnsBcsjafGK+BJyF/Npnd7gyU0TL5PXePT0nuDFjhJDxlSRUJHDP315TewD3PUYsXd10oWsfhy4B5AngyUiBPUoajdsxg8WxaxnOZYqcp8EIuwtGZrCTev6IihT5nE5jj7u0P9vtnCmkAc6p+4O7Q7Jku1uVGqeJChgzI4YKSAFyyNp8EXSdbttV4xzX+NLKkc276L8Y63tnKU6/Y7fc9/58tU29DSdrgwfe9qmvwtTsO0piFXSLazqHJt8h2bgR7A7GnKDiIA=="}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":13216,"candidatesTokenCount":21,"totalTokenCount":13290,"promptTokensDetails":[{"modality":"TEXT","tokenCount":13216}],"thoughtsTokenCount":53}},{"candidates":[{"content":{"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":13216,"candidatesTokenCount":21,"totalTokenCount":13290,"promptTokensDetails":[{"modality":"TEXT","tokenCount":13216}],"thoughtsTokenCount":53}}]}
|
||||
|
||||
@@ -1608,7 +1608,7 @@ console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook sho
|
||||
});
|
||||
|
||||
const result = await rig.run({
|
||||
args: 'Run tool',
|
||||
args: 'Use write_file to create test.txt',
|
||||
});
|
||||
|
||||
// The hook should have stopped execution message (returned from tool)
|
||||
|
||||
@@ -1745,4 +1745,56 @@ describe('runNonInteractive', () => {
|
||||
);
|
||||
expect(getWrittenOutput()).toContain('Done');
|
||||
});
|
||||
|
||||
it('should stop agent execution immediately when a tool call returns STOP_EXECUTION error', async () => {
|
||||
const toolCallEvent: ServerGeminiStreamEvent = {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'stop-call',
|
||||
name: 'stopTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-stop',
|
||||
},
|
||||
};
|
||||
|
||||
// Mock tool execution returning STOP_EXECUTION
|
||||
mockCoreExecuteToolCall.mockResolvedValue({
|
||||
status: 'error',
|
||||
request: toolCallEvent.value,
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'stop-call',
|
||||
responseParts: [{ text: 'error occurred' }],
|
||||
errorType: ToolErrorType.STOP_EXECUTION,
|
||||
error: new Error('Stop reason from hook'),
|
||||
resultDisplay: undefined,
|
||||
},
|
||||
});
|
||||
|
||||
const firstCallEvents: ServerGeminiStreamEvent[] = [
|
||||
{ type: GeminiEventType.Content, value: 'Executing tool...' },
|
||||
toolCallEvent,
|
||||
];
|
||||
|
||||
// Setup the mock to return events for the first call.
|
||||
// We expect the loop to terminate after the tool execution.
|
||||
// If it doesn't, it might call sendMessageStream again, which we'll assert against.
|
||||
mockGeminiClient.sendMessageStream
|
||||
.mockReturnValueOnce(createStreamFromEvents(firstCallEvents))
|
||||
.mockReturnValueOnce(createStreamFromEvents([]));
|
||||
|
||||
await runNonInteractive({
|
||||
config: mockConfig,
|
||||
settings: mockSettings,
|
||||
input: 'Run stop tool',
|
||||
prompt_id: 'prompt-id-stop',
|
||||
});
|
||||
|
||||
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
|
||||
|
||||
// The key assertion: sendMessageStream should have been called ONLY ONCE (initial user input).
|
||||
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
CoreEvent,
|
||||
createWorkingStdio,
|
||||
recordToolCallInteractions,
|
||||
ToolErrorType,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
import type { Content, Part } from '@google/genai';
|
||||
@@ -416,6 +417,43 @@ export async function runNonInteractive({
|
||||
);
|
||||
}
|
||||
|
||||
// Check if any tool requested to stop execution immediately
|
||||
const stopExecutionTool = completedToolCalls.find(
|
||||
(tc) => tc.response.errorType === ToolErrorType.STOP_EXECUTION,
|
||||
);
|
||||
|
||||
if (stopExecutionTool && stopExecutionTool.response.error) {
|
||||
const stopMessage = `Agent execution stopped: ${stopExecutionTool.response.error.message}`;
|
||||
|
||||
if (config.getOutputFormat() === OutputFormat.TEXT) {
|
||||
process.stderr.write(`${stopMessage}\n`);
|
||||
}
|
||||
|
||||
// Emit final result event for streaming JSON
|
||||
if (streamFormatter) {
|
||||
const metrics = uiTelemetryService.getMetrics();
|
||||
const durationMs = Date.now() - startTime;
|
||||
streamFormatter.emitEvent({
|
||||
type: JsonStreamEventType.RESULT,
|
||||
timestamp: new Date().toISOString(),
|
||||
status: 'success',
|
||||
stats: streamFormatter.convertToStreamStats(
|
||||
metrics,
|
||||
durationMs,
|
||||
),
|
||||
});
|
||||
} else if (config.getOutputFormat() === OutputFormat.JSON) {
|
||||
const formatter = new JsonFormatter();
|
||||
const stats = uiTelemetryService.getMetrics();
|
||||
textOutput.write(
|
||||
formatter.format(config.getSessionId(), responseText, stats),
|
||||
);
|
||||
} else {
|
||||
textOutput.ensureTrailingNewline(); // Ensure a final newline
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
currentMessages = [{ role: 'user', parts: toolResponseParts }];
|
||||
} else {
|
||||
// Emit final result event for streaming JSON
|
||||
|
||||
@@ -694,6 +694,99 @@ describe('useGeminiStream', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should stop agent execution immediately when a tool call returns STOP_EXECUTION error', async () => {
|
||||
const stopExecutionToolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: {
|
||||
callId: 'stop-call',
|
||||
name: 'stopTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-stop',
|
||||
},
|
||||
status: 'error',
|
||||
response: {
|
||||
callId: 'stop-call',
|
||||
responseParts: [{ text: 'error occurred' }],
|
||||
errorType: ToolErrorType.STOP_EXECUTION,
|
||||
error: new Error('Stop reason from hook'),
|
||||
resultDisplay: undefined,
|
||||
},
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
displayName: 'stop tool',
|
||||
},
|
||||
invocation: {
|
||||
getDescription: () => `Mock description`,
|
||||
} as unknown as AnyToolInvocation,
|
||||
} as unknown as TrackedCompletedToolCall,
|
||||
];
|
||||
const client = new MockedGeminiClientClass(mockConfig);
|
||||
|
||||
// Capture the onComplete callback
|
||||
let capturedOnComplete:
|
||||
| ((completedTools: TrackedToolCall[]) => Promise<void>)
|
||||
| null = null;
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [
|
||||
[],
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(),
|
||||
mockCancelAllToolCalls,
|
||||
];
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
client,
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockLoadedSettings,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
80,
|
||||
24,
|
||||
),
|
||||
);
|
||||
|
||||
// Trigger the onComplete callback with STOP_EXECUTION tool
|
||||
await act(async () => {
|
||||
if (capturedOnComplete) {
|
||||
await (capturedOnComplete as any)(stopExecutionToolCalls);
|
||||
}
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['stop-call']);
|
||||
// Should add an info message to history
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: MessageType.INFO,
|
||||
text: expect.stringContaining(
|
||||
'Agent execution stopped: Stop reason from hook',
|
||||
),
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
// Ensure we do NOT call back to the API
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled();
|
||||
// Streaming state should be Idle
|
||||
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||
});
|
||||
});
|
||||
|
||||
it('should group multiple cancelled tool call responses into a single history entry', async () => {
|
||||
const cancelledToolCall1: TrackedCancelledToolCall = {
|
||||
request: {
|
||||
|
||||
@@ -39,6 +39,7 @@ import {
|
||||
EDIT_TOOL_NAMES,
|
||||
processRestorableToolCalls,
|
||||
recordToolCallInteractions,
|
||||
ToolErrorType,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
|
||||
import type {
|
||||
@@ -1153,6 +1154,28 @@ export const useGeminiStream = (
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if any tool requested to stop execution immediately
|
||||
const stopExecutionTool = geminiTools.find(
|
||||
(tc) => tc.response.errorType === ToolErrorType.STOP_EXECUTION,
|
||||
);
|
||||
|
||||
if (stopExecutionTool && stopExecutionTool.response.error) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Agent execution stopped: ${stopExecutionTool.response.error.message}`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setIsResponding(false);
|
||||
|
||||
const callIdsToMarkAsSubmitted = geminiTools.map(
|
||||
(toolCall) => toolCall.request.callId,
|
||||
);
|
||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
return;
|
||||
}
|
||||
|
||||
// If all the tools were cancelled, don't submit a response to Gemini.
|
||||
const allToolsCancelled = geminiTools.every(
|
||||
(tc) => tc.status === 'cancelled',
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { executeToolWithHooks } from './coreToolHookTriggers.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
type ToolResult,
|
||||
@@ -17,8 +18,8 @@ import {
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
class MockInvocation extends BaseToolInvocation<{ key: string }, ToolResult> {
|
||||
constructor(params: { key: string }) {
|
||||
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
||||
constructor(params: { key?: string }) {
|
||||
super(params);
|
||||
}
|
||||
getDescription() {
|
||||
@@ -26,8 +27,10 @@ class MockInvocation extends BaseToolInvocation<{ key: string }, ToolResult> {
|
||||
}
|
||||
async execute() {
|
||||
return {
|
||||
llmContent: `key: ${this.params.key}`,
|
||||
returnDisplay: `key: ${this.params.key}`,
|
||||
llmContent: this.params.key ? `key: ${this.params.key}` : 'success',
|
||||
returnDisplay: this.params.key
|
||||
? `key: ${this.params.key}`
|
||||
: 'success display',
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -39,12 +42,145 @@ describe('executeToolWithHooks', () => {
|
||||
beforeEach(() => {
|
||||
messageBus = {
|
||||
request: vi.fn(),
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
mockTool = {
|
||||
build: vi.fn().mockImplementation((params) => new MockInvocation(params)),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
});
|
||||
|
||||
it('should prioritize continue: false over decision: block in BeforeTool', async () => {
|
||||
const invocation = new MockInvocation({});
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(messageBus.request).mockResolvedValue({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
continue: false,
|
||||
stopReason: 'Stop immediately',
|
||||
decision: 'block',
|
||||
reason: 'Should be ignored because continue is false',
|
||||
},
|
||||
} as HookExecutionResponse);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true,
|
||||
mockTool,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
||||
expect(result.error?.message).toBe('Stop immediately');
|
||||
});
|
||||
|
||||
it('should block execution in BeforeTool if decision is block', async () => {
|
||||
const invocation = new MockInvocation({});
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(messageBus.request).mockResolvedValue({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
decision: 'block',
|
||||
reason: 'Execution blocked',
|
||||
},
|
||||
} as HookExecutionResponse);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true,
|
||||
mockTool,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||
expect(result.error?.message).toBe('Execution blocked');
|
||||
});
|
||||
|
||||
it('should handle continue: false in AfterTool', async () => {
|
||||
const invocation = new MockInvocation({});
|
||||
const abortSignal = new AbortController().signal;
|
||||
const spy = vi.spyOn(invocation, 'execute');
|
||||
|
||||
// BeforeTool allow
|
||||
vi.mocked(messageBus.request)
|
||||
.mockResolvedValueOnce({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: { decision: 'allow' },
|
||||
} as HookExecutionResponse)
|
||||
// AfterTool stop
|
||||
.mockResolvedValueOnce({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
continue: false,
|
||||
stopReason: 'Stop after execution',
|
||||
},
|
||||
} as HookExecutionResponse);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true,
|
||||
mockTool,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
||||
expect(result.error?.message).toBe('Stop after execution');
|
||||
expect(spy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should block result in AfterTool if decision is deny', async () => {
|
||||
const invocation = new MockInvocation({});
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
// BeforeTool allow
|
||||
vi.mocked(messageBus.request)
|
||||
.mockResolvedValueOnce({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: { decision: 'allow' },
|
||||
} as HookExecutionResponse)
|
||||
// AfterTool deny
|
||||
.mockResolvedValueOnce({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
decision: 'deny',
|
||||
reason: 'Result denied',
|
||||
},
|
||||
} as HookExecutionResponse);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true,
|
||||
mockTool,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||
expect(result.error?.message).toBe('Result denied');
|
||||
});
|
||||
|
||||
it('should apply modified tool input from BeforeTool hook', async () => {
|
||||
const params = { key: 'original' };
|
||||
const invocation = new MockInvocation(params);
|
||||
|
||||
@@ -273,6 +273,19 @@ export async function executeToolWithHooks(
|
||||
toolInput,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
if (beforeOutput?.shouldStopExecution()) {
|
||||
const reason = beforeOutput.getEffectiveReason();
|
||||
return {
|
||||
llmContent: `Agent execution stopped by hook: ${reason}`,
|
||||
returnDisplay: `Agent execution stopped by hook: ${reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.STOP_EXECUTION,
|
||||
message: reason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook blocked the tool execution
|
||||
const blockingError = beforeOutput?.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
@@ -286,19 +299,6 @@ export async function executeToolWithHooks(
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
if (beforeOutput?.shouldStopExecution()) {
|
||||
const reason = beforeOutput.getEffectiveReason();
|
||||
return {
|
||||
llmContent: `Agent execution stopped by hook: ${reason}`,
|
||||
returnDisplay: `Agent execution stopped by hook: ${reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
message: `Agent execution stopped: ${reason}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook requested to update tool input
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
@@ -386,9 +386,22 @@ export async function executeToolWithHooks(
|
||||
return {
|
||||
llmContent: `Agent execution stopped by hook: ${reason}`,
|
||||
returnDisplay: `Agent execution stopped by hook: ${reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.STOP_EXECUTION,
|
||||
message: reason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook blocked the tool result
|
||||
const blockingError = afterOutput?.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
return {
|
||||
llmContent: `Tool result blocked: ${blockingError.reason}`,
|
||||
returnDisplay: `Tool result blocked: ${blockingError.reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
message: `Agent execution stopped: ${reason}`,
|
||||
message: blockingError.reason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ export class DefaultHookOutput implements HookOutput {
|
||||
* Get the effective reason for blocking or stopping
|
||||
*/
|
||||
getEffectiveReason(): string {
|
||||
return this.reason || this.stopReason || 'No reason provided';
|
||||
return this.stopReason || this.reason || 'No reason provided';
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -71,6 +71,9 @@ export enum ToolErrorType {
|
||||
|
||||
// WebSearch-specific Errors
|
||||
WEB_SEARCH_FAILED = 'web_search_failed',
|
||||
|
||||
// Hook-specific Errors
|
||||
STOP_EXECUTION = 'stop_execution',
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user