mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 22:48:03 +00:00
fix(hooks): deduplicate agent hooks and add cross-platform integration tests (#15701)
This commit is contained in:
2
integration-tests/hooks-agent-flow-multistep.responses
Normal file
2
integration-tests/hooks-agent-flow-multistep.responses
Normal file
@@ -0,0 +1,2 @@
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"."}}}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Final Answer"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}
|
||||
1
integration-tests/hooks-agent-flow.responses
Normal file
1
integration-tests/hooks-agent-flow.responses
Normal file
@@ -0,0 +1 @@
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"**Responding**\n\nI will respond to the user's request.\n\n","thought":true}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"totalTokenCount":120,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Response to: "}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":5,"totalTokenCount":125,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Hello World"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":7,"totalTokenCount":127,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}}]}
|
||||
238
integration-tests/hooks-agent-flow.test.ts
Normal file
238
integration-tests/hooks-agent-flow.test.ts
Normal file
@@ -0,0 +1,238 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { TestRig } from './test-helper.js';
|
||||
import { join } from 'node:path';
|
||||
import { writeFileSync } from 'node:fs';
|
||||
|
||||
describe('Hooks Agent Flow', () => {
|
||||
let rig: TestRig;
|
||||
|
||||
beforeEach(() => {
|
||||
rig = new TestRig();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
if (rig) {
|
||||
await rig.cleanup();
|
||||
}
|
||||
});
|
||||
|
||||
describe('BeforeAgent Hooks', () => {
|
||||
it('should inject additional context via BeforeAgent hook', async () => {
|
||||
await rig.setup('should inject additional context via BeforeAgent hook', {
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'hooks-agent-flow.responses',
|
||||
),
|
||||
});
|
||||
|
||||
const hookScript = `
|
||||
try {
|
||||
const output = {
|
||||
decision: "allow",
|
||||
hookSpecificOutput: {
|
||||
hookEventName: "BeforeAgent",
|
||||
additionalContext: "SYSTEM INSTRUCTION: This is injected context."
|
||||
}
|
||||
};
|
||||
process.stdout.write(JSON.stringify(output));
|
||||
} catch (e) {
|
||||
console.error('Failed to write stdout:', e);
|
||||
process.exit(1);
|
||||
}
|
||||
console.error('DEBUG: BeforeAgent hook executed');
|
||||
`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'before_agent_context.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
|
||||
await rig.setup('should inject additional context via BeforeAgent hook', {
|
||||
settings: {
|
||||
tools: {
|
||||
enableHooks: true,
|
||||
},
|
||||
hooks: {
|
||||
BeforeAgent: [
|
||||
{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await rig.run({ args: 'Hello test' });
|
||||
|
||||
// Verify hook execution and telemetry
|
||||
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
|
||||
expect(hookTelemetryFound).toBeTruthy();
|
||||
|
||||
const hookLogs = rig.readHookLogs();
|
||||
const beforeAgentLog = hookLogs.find(
|
||||
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
|
||||
);
|
||||
|
||||
expect(beforeAgentLog).toBeDefined();
|
||||
expect(beforeAgentLog?.hookCall.stdout).toContain('injected context');
|
||||
expect(beforeAgentLog?.hookCall.stdout).toContain('"decision":"allow"');
|
||||
expect(beforeAgentLog?.hookCall.stdout).toContain(
|
||||
'SYSTEM INSTRUCTION: This is injected context.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('AfterAgent Hooks', () => {
|
||||
it('should receive prompt and response in AfterAgent hook', async () => {
|
||||
await rig.setup('should receive prompt and response in AfterAgent hook', {
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'hooks-agent-flow.responses',
|
||||
),
|
||||
});
|
||||
|
||||
const hookScript = `
|
||||
const fs = require('fs');
|
||||
try {
|
||||
const input = fs.readFileSync(0, 'utf-8');
|
||||
console.error('DEBUG: AfterAgent hook input received');
|
||||
process.stdout.write("Received Input: " + input);
|
||||
// Ensure separation between the echo and the JSON output if they were to run together (though relying on separate console calls usually separates by newline)
|
||||
// usage of process.stdout.write does NOT add newline.
|
||||
// But here we want strictly the output "Received Input..." to be present.
|
||||
// We also need to output the JSON decision for the hook runner to consider it successful?
|
||||
// Actually HookRunner parses the *last* valid JSON block or treats text as system message.
|
||||
// If we output mixed text and JSON, HookRunner might get confused if we don't handle it right.
|
||||
// Existing test expects "Received Input" in stdout. And "Hello World".
|
||||
// It DOES NOT parse the decision?
|
||||
// Wait, HookRunner logic:
|
||||
// "if (exitCode === EXIT_CODE_SUCCESS && stdout.trim()) ... JSON.parse ..."
|
||||
// If JSON.parse fails: "Not JSON, convert plain text to structured output"
|
||||
// So if we output formatted text, it becomes "systemMessage".
|
||||
// That is fine for this test as we don't check the decision, just the stdout content.
|
||||
} catch (err) {
|
||||
console.error('Hook Failed:', err);
|
||||
process.exit(1);
|
||||
}
|
||||
`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'after_agent_verify.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
|
||||
await rig.setup('should receive prompt and response in AfterAgent hook', {
|
||||
settings: {
|
||||
tools: {
|
||||
enableHooks: true,
|
||||
},
|
||||
hooks: {
|
||||
AfterAgent: [
|
||||
{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await rig.run({ args: 'Hello validation' });
|
||||
|
||||
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
|
||||
expect(hookTelemetryFound).toBeTruthy();
|
||||
|
||||
const hookLogs = rig.readHookLogs();
|
||||
const afterAgentLog = hookLogs.find(
|
||||
(log) => log.hookCall.hook_event_name === 'AfterAgent',
|
||||
);
|
||||
|
||||
expect(afterAgentLog).toBeDefined();
|
||||
// Verify the hook stdout contains the input we echoed which proves the hook received the prompt and response
|
||||
expect(afterAgentLog?.hookCall.stdout).toContain('Received Input');
|
||||
expect(afterAgentLog?.hookCall.stdout).toContain('Hello validation');
|
||||
// The fake response contains "Hello World"
|
||||
expect(afterAgentLog?.hookCall.stdout).toContain('Hello World');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multi-step Loops', () => {
|
||||
it('should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls', async () => {
|
||||
await rig.setup(
|
||||
'should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls',
|
||||
{
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'hooks-agent-flow-multistep.responses',
|
||||
),
|
||||
settings: {
|
||||
tools: {
|
||||
enableHooks: true,
|
||||
},
|
||||
hooks: {
|
||||
BeforeAgent: [
|
||||
{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node -e "console.log('BeforeAgent Fired')"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
AfterAgent: [
|
||||
{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node -e "console.log('AfterAgent Fired')"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
await rig.run({ args: 'Do a multi-step task' });
|
||||
|
||||
const hookLogs = rig.readHookLogs();
|
||||
const beforeAgentLogs = hookLogs.filter(
|
||||
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
|
||||
);
|
||||
const afterAgentLogs = hookLogs.filter(
|
||||
(log) => log.hookCall.hook_event_name === 'AfterAgent',
|
||||
);
|
||||
|
||||
// Should ensure BeforeAgent fired once
|
||||
expect(beforeAgentLogs).toHaveLength(1);
|
||||
|
||||
// Should ensure AfterAgent fired once
|
||||
// Note: If the tool call itself triggered BeforeTool/AfterTool, that's fine,
|
||||
// but BeforeAgent/AfterAgent should only wrap the *entire* turn (User Request -> Final Answer).
|
||||
expect(afterAgentLogs).toHaveLength(1);
|
||||
|
||||
// Verify the output log content to ensure we actually got the final answer
|
||||
// (This implies the loop completed successfully)
|
||||
const afterAgentLog = afterAgentLogs[0];
|
||||
expect(afterAgentLog).toBeDefined();
|
||||
expect(afterAgentLog?.hookCall.stdout).toContain('AfterAgent Fired');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -44,7 +44,7 @@ describe('Hooks System Integration', () => {
|
||||
{
|
||||
type: 'command',
|
||||
command:
|
||||
'echo "{\\"decision\\": \\"block\\", \\"reason\\": \\"File writing blocked by security policy\\"}"',
|
||||
"node -e \"console.log(JSON.stringify({decision: 'block', reason: 'File writing blocked by security policy'}))\"",
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -97,7 +97,7 @@ describe('Hooks System Integration', () => {
|
||||
{
|
||||
type: 'command',
|
||||
command:
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"File writing approved\\"}"',
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'File writing approved'}))\"",
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -129,7 +129,7 @@ describe('Hooks System Integration', () => {
|
||||
describe('Command Hooks - Additional Context', () => {
|
||||
it('should add additional context from AfterTool hooks', async () => {
|
||||
const command =
|
||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"Security scan: File content appears safe\\"}}"';
|
||||
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'Security scan: File content appears safe'}}))\"";
|
||||
await rig.setup('should add additional context from AfterTool hooks', {
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
@@ -190,27 +190,24 @@ describe('Hooks System Integration', () => {
|
||||
'hooks-system.before-model.responses',
|
||||
),
|
||||
});
|
||||
const hookScript = `#!/bin/bash
|
||||
echo '{
|
||||
"decision": "allow",
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "BeforeModel",
|
||||
"llm_request": {
|
||||
"messages": [
|
||||
const hookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({
|
||||
decision: "allow",
|
||||
hookSpecificOutput: {
|
||||
hookEventName: "BeforeModel",
|
||||
llm_request: {
|
||||
messages: [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please respond with exactly: The security hook modified this request successfully."
|
||||
role: "user",
|
||||
content: "Please respond with exactly: The security hook modified this request successfully."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}'`;
|
||||
}));`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'before_model_hook.sh');
|
||||
const scriptPath = join(rig.testDir!, 'before_model_hook.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
// Make executable
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${scriptPath}"`);
|
||||
|
||||
await rig.setup('should modify LLM requests with BeforeModel hooks', {
|
||||
settings: {
|
||||
@@ -223,7 +220,7 @@ echo '{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: scriptPath,
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -250,7 +247,9 @@ echo '{
|
||||
expect(hookTelemetryFound[0].hookCall.hook_event_name).toBe(
|
||||
'BeforeModel',
|
||||
);
|
||||
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(scriptPath);
|
||||
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(
|
||||
`node "${scriptPath}"`,
|
||||
);
|
||||
expect(hookTelemetryFound[0].hookCall.hook_input).toBeDefined();
|
||||
expect(hookTelemetryFound[0].hookCall.hook_output).toBeDefined();
|
||||
expect(hookTelemetryFound[0].hookCall.exit_code).toBe(0);
|
||||
@@ -270,30 +269,28 @@ echo '{
|
||||
),
|
||||
});
|
||||
// Create a hook script that modifies the LLM response
|
||||
const hookScript = `#!/bin/bash
|
||||
echo '{
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "AfterModel",
|
||||
"llm_response": {
|
||||
"candidates": [
|
||||
const hookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({
|
||||
hookSpecificOutput: {
|
||||
hookEventName: "AfterModel",
|
||||
llm_response: {
|
||||
candidates: [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
content: {
|
||||
role: "model",
|
||||
parts: [
|
||||
"[FILTERED] Response has been filtered for security compliance."
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
finishReason: "STOP"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}'`;
|
||||
}));`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'after_model_hook.sh');
|
||||
const scriptPath = join(rig.testDir!, 'after_model_hook.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${scriptPath}"`);
|
||||
|
||||
await rig.setup('should modify LLM responses with AfterModel hooks', {
|
||||
settings: {
|
||||
@@ -306,7 +303,7 @@ echo '{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: scriptPath,
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -343,7 +340,7 @@ echo '{
|
||||
);
|
||||
// Create inline hook command (works on both Unix and Windows)
|
||||
const hookCommand =
|
||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeToolSelection\\", \\"toolConfig\\": {\\"mode\\": \\"ANY\\", \\"allowedFunctionNames\\": [\\"read_file\\", \\"run_shell_command\\"]}}}"';
|
||||
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'BeforeToolSelection', toolConfig: {mode: 'ANY', allowedFunctionNames: ['read_file', 'run_shell_command']}}}))\"";
|
||||
|
||||
await rig.setup(
|
||||
'should modify tool selection with BeforeToolSelection hooks',
|
||||
@@ -404,19 +401,17 @@ echo '{
|
||||
),
|
||||
});
|
||||
// Create a hook script that adds context to the prompt
|
||||
const hookScript = `#!/bin/bash
|
||||
echo '{
|
||||
"decision": "allow",
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "BeforeAgent",
|
||||
"additionalContext": "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
|
||||
const hookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({
|
||||
decision: "allow",
|
||||
hookSpecificOutput: {
|
||||
hookEventName: "BeforeAgent",
|
||||
additionalContext: "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
|
||||
}
|
||||
}'`;
|
||||
}));`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'before_agent_hook.sh');
|
||||
const scriptPath = join(rig.testDir!, 'before_agent_hook.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${scriptPath}"`);
|
||||
|
||||
await rig.setup('should augment prompts with BeforeAgent hooks', {
|
||||
settings: {
|
||||
@@ -429,7 +424,7 @@ echo '{
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: scriptPath,
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -452,9 +447,10 @@ echo '{
|
||||
|
||||
describe('Notification Hooks - Permission Handling', () => {
|
||||
it('should handle notification hooks for tool permissions', async () => {
|
||||
// Create inline hook command (works on both Unix and Windows)
|
||||
// Create inline hook command (works on both Unix and Windows)
|
||||
const hookCommand =
|
||||
'echo "{\\"suppressOutput\\": false, \\"systemMessage\\": \\"Permission request logged by security hook\\"}"';
|
||||
'node -e "console.log(JSON.stringify({suppressOutput: false, systemMessage: \'Permission request logged by security hook\'}))"';
|
||||
|
||||
await rig.setup('should handle notification hooks for tool permissions', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -548,9 +544,9 @@ echo '{
|
||||
it('should execute hooks sequentially when configured', async () => {
|
||||
// Create inline hook commands (works on both Unix and Windows)
|
||||
const hook1Command =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 1: Initial validation passed.\\"}}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 1: Initial validation passed.'}}))\"";
|
||||
const hook2Command =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 2: Security check completed.\\"}}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 2: Security check completed.'}}))\"";
|
||||
|
||||
await rig.setup('should execute hooks sequentially when configured', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -621,23 +617,22 @@ echo '{
|
||||
),
|
||||
});
|
||||
// Create a hook script that validates the input format
|
||||
const hookScript = `#!/bin/bash
|
||||
# Read JSON input from stdin
|
||||
input=$(cat)
|
||||
const hookScript = `const fs = require('fs');
|
||||
const input = fs.readFileSync(0, 'utf-8');
|
||||
try {
|
||||
const json = JSON.parse(input);
|
||||
// Check fields
|
||||
if (json.session_id && json.cwd && json.hook_event_name && json.timestamp && json.tool_name && json.tool_input) {
|
||||
console.log(JSON.stringify({decision: "allow", reason: "Input format is correct"}));
|
||||
} else {
|
||||
console.log(JSON.stringify({decision: "block", reason: "Input format is invalid"}));
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(JSON.stringify({decision: "block", reason: "Invalid JSON"}));
|
||||
}`;
|
||||
|
||||
# Check for required fields
|
||||
if echo "$input" | jq -e '.session_id and .cwd and .hook_event_name and .timestamp and .tool_name and .tool_input' > /dev/null 2>&1; then
|
||||
echo '{"decision": "allow", "reason": "Input format is correct"}'
|
||||
exit 0
|
||||
else
|
||||
echo '{"decision": "block", "reason": "Input format is invalid"}'
|
||||
exit 0
|
||||
fi`;
|
||||
|
||||
const scriptPath = join(rig.testDir!, 'input_validation_hook.sh');
|
||||
const scriptPath = join(rig.testDir!, 'input_validation_hook.cjs');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${scriptPath}"`);
|
||||
|
||||
await rig.setup('should provide correct input format to hooks', {
|
||||
settings: {
|
||||
@@ -650,7 +645,7 @@ fi`;
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: scriptPath,
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
@@ -682,11 +677,11 @@ fi`;
|
||||
it('should handle hooks for all major event types', async () => {
|
||||
// Create inline hook commands (works on both Unix and Windows)
|
||||
const beforeToolCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"BeforeTool: File operation logged\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'BeforeTool: File operation logged'}))\"";
|
||||
const afterToolCommand =
|
||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"AfterTool: Operation completed successfully\\"}}"';
|
||||
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'AfterTool: Operation completed successfully'}}))\"";
|
||||
const beforeAgentCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"BeforeAgent: User request processed\\"}}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'BeforeAgent: User request processed'}}))\"";
|
||||
|
||||
await rig.setup('should handle hooks for all major event types', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -802,10 +797,10 @@ fi`;
|
||||
// Create a hook script that fails
|
||||
// Create inline hook commands (works on both Unix and Windows)
|
||||
// Failing hook: exits with non-zero code
|
||||
const failingCommand = 'exit 1';
|
||||
const failingCommand = 'node -e "process.exit(1)"';
|
||||
// Working hook: returns success with JSON
|
||||
const workingCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Working hook succeeded\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Working hook succeeded'}))\"";
|
||||
|
||||
await rig.setup('should handle hook failures gracefully', {
|
||||
settings: {
|
||||
@@ -855,7 +850,7 @@ fi`;
|
||||
it('should generate telemetry events for hook executions', async () => {
|
||||
// Create inline hook command (works on both Unix and Windows)
|
||||
const hookCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Telemetry test hook\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Telemetry test hook'}))\"";
|
||||
|
||||
await rig.setup('should generate telemetry events for hook executions', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -898,7 +893,7 @@ fi`;
|
||||
it('should fire SessionStart hook on app startup', async () => {
|
||||
// Create inline hook command that outputs JSON
|
||||
const sessionStartCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting on startup\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting on startup'}))\"";
|
||||
|
||||
await rig.setup('should fire SessionStart hook on app startup', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -958,9 +953,9 @@ fi`;
|
||||
it('should fire SessionEnd and SessionStart hooks on /clear command', async () => {
|
||||
// Create inline hook commands for both SessionEnd and SessionStart
|
||||
const sessionEndCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session ending due to clear\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session ending due to clear'}))\"";
|
||||
const sessionStartCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting after clear\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting after clear'}))\"";
|
||||
|
||||
await rig.setup(
|
||||
'should fire SessionEnd and SessionStart hooks on /clear command',
|
||||
@@ -1136,7 +1131,7 @@ fi`;
|
||||
it('should fire PreCompress hook on automatic compression', async () => {
|
||||
// Create inline hook command that outputs JSON
|
||||
const preCompressCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"PreCompress hook executed for automatic compression\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'PreCompress hook executed for automatic compression'}))\"";
|
||||
|
||||
await rig.setup('should fire PreCompress hook on automatic compression', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -1203,7 +1198,7 @@ fi`;
|
||||
describe('SessionEnd on Exit', () => {
|
||||
it('should fire SessionEnd hook on graceful exit in non-interactive mode', async () => {
|
||||
const sessionEndCommand =
|
||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"SessionEnd hook executed on exit\\"}"';
|
||||
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'SessionEnd hook executed on exit'}))\"";
|
||||
|
||||
await rig.setup('should fire SessionEnd hook on graceful exit', {
|
||||
fakeResponsesPath: join(
|
||||
@@ -1297,20 +1292,17 @@ fi`;
|
||||
});
|
||||
|
||||
// Create two hook scripts - one enabled, one disabled
|
||||
const enabledHookScript = `#!/bin/bash
|
||||
echo '{"decision": "allow", "systemMessage": "Enabled hook executed"}'`;
|
||||
const enabledHookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({decision: "allow", systemMessage: "Enabled hook executed"}));`;
|
||||
|
||||
const disabledHookScript = `#!/bin/bash
|
||||
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook should be disabled"}'`;
|
||||
const disabledHookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook should be disabled"}));`;
|
||||
|
||||
const enabledPath = join(rig.testDir!, 'enabled_hook.sh');
|
||||
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
|
||||
const enabledPath = join(rig.testDir!, 'enabled_hook.cjs');
|
||||
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
|
||||
|
||||
writeFileSync(enabledPath, enabledHookScript);
|
||||
writeFileSync(disabledPath, disabledHookScript);
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${enabledPath}"`);
|
||||
execSync(`chmod +x "${disabledPath}"`);
|
||||
|
||||
await rig.setup('should not execute hooks disabled in settings file', {
|
||||
settings: {
|
||||
@@ -1323,18 +1315,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: enabledPath,
|
||||
command: `node "${enabledPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
{
|
||||
type: 'command',
|
||||
command: disabledPath,
|
||||
command: `node "${disabledPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
disabled: [disabledPath], // Disable the second hook
|
||||
disabled: [`node "${disabledPath}"`], // Disable the second hook
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -1358,10 +1350,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
// Check hook telemetry - only enabled hook should have executed
|
||||
const hookLogs = rig.readHookLogs();
|
||||
const enabledHookLog = hookLogs.find(
|
||||
(log) => log.hookCall.hook_name === enabledPath,
|
||||
(log) => log.hookCall.hook_name === `node "${enabledPath}"`,
|
||||
);
|
||||
const disabledHookLog = hookLogs.find(
|
||||
(log) => log.hookCall.hook_name === disabledPath,
|
||||
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||
);
|
||||
|
||||
expect(enabledHookLog).toBeDefined();
|
||||
@@ -1380,20 +1372,17 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
);
|
||||
|
||||
// Create two hook scripts - one that will be disabled, one that won't
|
||||
const activeHookScript = `#!/bin/bash
|
||||
echo '{"decision": "allow", "systemMessage": "Active hook executed"}'`;
|
||||
const activeHookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({decision: "allow", systemMessage: "Active hook executed"}));`;
|
||||
|
||||
const disabledHookScript = `#!/bin/bash
|
||||
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook is disabled"}'`;
|
||||
const disabledHookScript = `const fs = require('fs');
|
||||
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook is disabled"}));`;
|
||||
|
||||
const activePath = join(rig.testDir!, 'active_hook.sh');
|
||||
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
|
||||
const activePath = join(rig.testDir!, 'active_hook.cjs');
|
||||
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
|
||||
|
||||
writeFileSync(activePath, activeHookScript);
|
||||
writeFileSync(disabledPath, disabledHookScript);
|
||||
const { execSync } = await import('node:child_process');
|
||||
execSync(`chmod +x "${activePath}"`);
|
||||
execSync(`chmod +x "${disabledPath}"`);
|
||||
|
||||
await rig.setup(
|
||||
'should respect disabled hooks across multiple operations',
|
||||
@@ -1408,18 +1397,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: activePath,
|
||||
command: `node "${activePath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
{
|
||||
type: 'command',
|
||||
command: disabledPath,
|
||||
command: `node "${disabledPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
disabled: [disabledPath], // Disable the second hook
|
||||
disabled: [`node "${disabledPath}"`], // Disable the second hook
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1441,10 +1430,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
// Check hook telemetry
|
||||
const hookLogs1 = rig.readHookLogs();
|
||||
const activeHookLog1 = hookLogs1.find(
|
||||
(log) => log.hookCall.hook_name === activePath,
|
||||
(log) => log.hookCall.hook_name === `node "${activePath}"`,
|
||||
);
|
||||
const disabledHookLog1 = hookLogs1.find(
|
||||
(log) => log.hookCall.hook_name === disabledPath,
|
||||
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||
);
|
||||
|
||||
expect(activeHookLog1).toBeDefined();
|
||||
@@ -1465,7 +1454,7 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
||||
// Verify disabled hook still hasn't executed
|
||||
const hookLogs2 = rig.readHookLogs();
|
||||
const disabledHookCalls = hookLogs2.filter(
|
||||
(log) => log.hookCall.hook_name === disabledPath,
|
||||
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||
);
|
||||
expect(disabledHookCalls.length).toBe(0);
|
||||
});
|
||||
|
||||
@@ -81,6 +81,10 @@ vi.mock('node:fs', () => {
|
||||
});
|
||||
|
||||
// --- Mocks ---
|
||||
interface MockTurnContext {
|
||||
getResponseText: Mock<() => string>;
|
||||
}
|
||||
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('./turn', async (importOriginal) => {
|
||||
@@ -94,6 +98,8 @@ vi.mock('./turn', async (importOriginal) => {
|
||||
constructor() {
|
||||
// The constructor can be empty or do some mock setup
|
||||
}
|
||||
|
||||
getResponseText = vi.fn().mockReturnValue('Mock Response');
|
||||
}
|
||||
// Export the mock class as 'Turn'
|
||||
return {
|
||||
@@ -129,6 +135,15 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||
},
|
||||
}));
|
||||
vi.mock('../hooks/hookSystem.js');
|
||||
vi.mock('./clientHookTriggers.js', () => ({
|
||||
fireBeforeAgentHook: vi.fn(),
|
||||
fireAfterAgentHook: vi.fn().mockResolvedValue({
|
||||
decision: 'allow',
|
||||
continue: false,
|
||||
suppressOutput: false,
|
||||
systemMessage: undefined,
|
||||
}),
|
||||
}));
|
||||
|
||||
/**
|
||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||
@@ -543,16 +558,22 @@ describe('Gemini Client (client.ts)', () => {
|
||||
await client.tryCompressChat('prompt-1', false); // force = false
|
||||
|
||||
// 3. Assert Step 1: Check that the flag became true
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
||||
// 3. Assert Step 1: Check that the flag became true
|
||||
expect(
|
||||
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||
.hasFailedCompressionAttempt,
|
||||
).toBe(true);
|
||||
|
||||
// 4. Test Step 2: Trigger a forced failure
|
||||
|
||||
await client.tryCompressChat('prompt-2', true); // force = true
|
||||
|
||||
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
||||
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||
expect(
|
||||
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||
.hasFailedCompressionAttempt,
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should not trigger summarization if token count is below threshold', async () => {
|
||||
@@ -2615,5 +2636,152 @@ ${JSON.stringify(
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
describe('Hook System', () => {
|
||||
let mockMessageBus: { publish: Mock; subscribe: Mock };
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockMessageBus = { publish: vi.fn(), subscribe: vi.fn() };
|
||||
|
||||
// Force override config methods on the client instance
|
||||
client['config'].getEnableHooks = vi.fn().mockReturnValue(true);
|
||||
client['config'].getMessageBus = vi
|
||||
.fn()
|
||||
.mockReturnValue(mockMessageBus);
|
||||
});
|
||||
|
||||
it('should fire BeforeAgent and AfterAgent exactly once for a simple turn', async () => {
|
||||
const promptId = 'test-prompt-hook-1';
|
||||
const request = { text: 'Hello Hooks' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Hook Response');
|
||||
yield { type: GeminiEventType.Content, value: 'Hook Response' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
'Hook Response',
|
||||
);
|
||||
|
||||
// Map should be empty
|
||||
expect(client['hookStateMap'].size).toBe(0);
|
||||
});
|
||||
|
||||
it('should fire BeforeAgent once and AfterAgent once even with recursion', async () => {
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
vi.mocked(checkNextSpeaker)
|
||||
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||
.mockResolvedValueOnce(null);
|
||||
|
||||
const promptId = 'test-prompt-hook-recursive';
|
||||
const request = { text: 'Recursion Test' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
let callCount = 0;
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
callCount++;
|
||||
const response = `Response ${callCount}`;
|
||||
this.getResponseText.mockReturnValue(response);
|
||||
yield { type: GeminiEventType.Content, value: response };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
// BeforeAgent should fire ONLY once despite multiple internal turns
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
|
||||
// AfterAgent should fire ONLY when the stack unwinds
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Check cumulative response (separated by newline)
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
'Response 1\nResponse 2',
|
||||
);
|
||||
|
||||
expect(client['hookStateMap'].size).toBe(0);
|
||||
});
|
||||
|
||||
it('should use original request in AfterAgent hook even when continuation happened', async () => {
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
vi.mocked(checkNextSpeaker)
|
||||
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||
.mockResolvedValueOnce(null);
|
||||
|
||||
const promptId = 'test-prompt-hook-original-req';
|
||||
const request = { text: 'Do something' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Ok');
|
||||
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request, // Should be 'Do something'
|
||||
expect.stringContaining('Ok'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should cleanup state when prompt_id changes', async () => {
|
||||
const signal = new AbortController().signal;
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Ok');
|
||||
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||
});
|
||||
|
||||
client['hookStateMap'].set('old-id', {
|
||||
hasFiredBeforeAgent: true,
|
||||
cumulativeResponse: 'Old',
|
||||
activeCalls: 0,
|
||||
originalRequest: { text: 'Old' },
|
||||
});
|
||||
client['lastPromptId'] = 'old-id';
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
{ text: 'New' },
|
||||
signal,
|
||||
'new-id',
|
||||
);
|
||||
await stream.next();
|
||||
|
||||
expect(client['hookStateMap'].has('old-id')).toBe(false);
|
||||
expect(client['hookStateMap'].has('new-id')).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
Tool,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
getInitialChatHistory,
|
||||
@@ -42,6 +43,7 @@ import {
|
||||
fireBeforeAgentHook,
|
||||
fireAfterAgentHook,
|
||||
} from './clientHookTriggers.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
@@ -61,6 +63,14 @@ import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
type BeforeAgentHookReturn =
|
||||
| {
|
||||
type: GeminiEventType.Error;
|
||||
value: { error: Error };
|
||||
}
|
||||
| { additionalContext: string | undefined }
|
||||
| undefined;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private sessionTurnCount = 0;
|
||||
@@ -84,6 +94,95 @@ export class GeminiClient {
|
||||
this.lastPromptId = this.config.getSessionId();
|
||||
}
|
||||
|
||||
// Hook state to deduplicate BeforeAgent calls and track response for
|
||||
// AfterAgent
|
||||
private hookStateMap = new Map<
|
||||
string,
|
||||
{
|
||||
hasFiredBeforeAgent: boolean;
|
||||
cumulativeResponse: string;
|
||||
activeCalls: number;
|
||||
originalRequest: PartListUnion;
|
||||
}
|
||||
>();
|
||||
|
||||
private async fireBeforeAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
prompt_id: string,
|
||||
): Promise<BeforeAgentHookReturn> {
|
||||
let hookState = this.hookStateMap.get(prompt_id);
|
||||
if (!hookState) {
|
||||
hookState = {
|
||||
hasFiredBeforeAgent: false,
|
||||
cumulativeResponse: '',
|
||||
activeCalls: 0,
|
||||
originalRequest: request,
|
||||
};
|
||||
this.hookStateMap.set(prompt_id, hookState);
|
||||
}
|
||||
|
||||
// Increment active calls for this prompt_id
|
||||
// This is called at the start of sendMessageStream, so it acts as an entry
|
||||
// counter. We increment here, assuming this helper is ALWAYS called at
|
||||
// entry.
|
||||
hookState.activeCalls++;
|
||||
|
||||
if (hookState.hasFiredBeforeAgent) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
hookState.hasFiredBeforeAgent = true;
|
||||
|
||||
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
|
||||
return {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
return { additionalContext };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private async fireAfterAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
currentRequest: PartListUnion,
|
||||
prompt_id: string,
|
||||
turn?: Turn,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
// Only fire on the outermost call (when activeCalls is 1)
|
||||
if (!hookState || hookState.activeCalls !== 1) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (turn && turn.pendingToolCalls.length > 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const finalResponseText =
|
||||
hookState.cumulativeResponse ||
|
||||
turn?.getResponseText() ||
|
||||
'[no response text]';
|
||||
const finalRequest = hookState.originalRequest || currentRequest;
|
||||
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
finalRequest,
|
||||
finalResponseText,
|
||||
);
|
||||
return hookOutput;
|
||||
}
|
||||
|
||||
private updateTelemetryTokenCount() {
|
||||
if (this.chat) {
|
||||
uiTelemetryService.setLastPromptTokenCount(
|
||||
@@ -400,63 +499,27 @@ export class GeminiClient {
|
||||
return this.config.getActiveModel();
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
private async *processTurn(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
boundedTurns: number,
|
||||
isInvalidStreamRetry: boolean,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!isInvalidStreamRetry) {
|
||||
this.config.resetTurn();
|
||||
}
|
||||
// Re-initialize turn (it was empty before if in loop, or new instance)
|
||||
let turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Add additional context from hooks to the request
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
this.currentSequenceModel = null;
|
||||
}
|
||||
this.sessionTurnCount++;
|
||||
if (
|
||||
this.config.getMaxSessionTurns() > 0 &&
|
||||
this.sessionTurnCount > this.config.getMaxSessionTurns()
|
||||
) {
|
||||
yield { type: GeminiEventType.MaxSessionTurns };
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
|
||||
const boundedTurns = Math.min(turns, MAX_TURNS);
|
||||
|
||||
if (!boundedTurns) {
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
|
||||
// Check for context window overflow
|
||||
@@ -478,7 +541,7 @@ export class GeminiClient {
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: { estimatedRequestTokenCount, remainingTokenCount },
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id, false);
|
||||
@@ -514,7 +577,8 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = false;
|
||||
}
|
||||
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
// Re-initialize turn with fresh history
|
||||
turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
const controller = new AbortController();
|
||||
const linkedSignal = AbortSignal.any([signal, controller.signal]);
|
||||
@@ -555,6 +619,9 @@ export class GeminiClient {
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
|
||||
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
|
||||
let isError = false;
|
||||
let isInvalidStream = false;
|
||||
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
@@ -566,94 +633,181 @@ export class GeminiClient {
|
||||
this.updateTelemetryTokenCount();
|
||||
|
||||
if (event.type === GeminiEventType.InvalidStream) {
|
||||
if (this.config.getContinueOnFailedApiCall()) {
|
||||
if (isInvalidStreamRetry) {
|
||||
// We already retried once, so stop here.
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
4, // 2 initial + 2 after injections
|
||||
'FAILED_AFTER_PROMPT_INJECTION',
|
||||
modelToUse,
|
||||
),
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||
yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
true, // Set isInvalidStreamRetry to true
|
||||
isInvalidStream = true;
|
||||
}
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
isError = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
// Update cumulative response in hook state
|
||||
// We do this immediately after the stream finishes for THIS turn.
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
if (hooksEnabled) {
|
||||
const responseText = turn.getResponseText() || '';
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
if (hookState && responseText) {
|
||||
// Append with newline if not empty
|
||||
hookState.cumulativeResponse = hookState.cumulativeResponse
|
||||
? `${hookState.cumulativeResponse}\n${responseText}`
|
||||
: responseText;
|
||||
}
|
||||
}
|
||||
|
||||
if (isInvalidStream) {
|
||||
if (this.config.getContinueOnFailedApiCall()) {
|
||||
if (isInvalidStreamRetry) {
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
4,
|
||||
'FAILED_AFTER_PROMPT_INJECTION',
|
||||
modelToUse,
|
||||
),
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
// Check if next speaker check is needed
|
||||
if (this.config.getQuotaErrorOccurred()) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (this.config.getSkipNextSpeakerCheck()) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
const nextSpeakerCheck = await checkNextSpeaker(
|
||||
this.getChat(),
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
prompt_id,
|
||||
);
|
||||
logNextSpeakerCheck(
|
||||
this.config,
|
||||
new NextSpeakerCheckEvent(
|
||||
prompt_id,
|
||||
turn.finishReason?.toString() || '',
|
||||
nextSpeakerCheck?.next_speaker || '',
|
||||
),
|
||||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, and the final
|
||||
// turn object from the recursive call will be returned.
|
||||
return yield* this.sendMessageStream(
|
||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||
// Recursive call - update turn with result
|
||||
turn = yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
// isInvalidStreamRetry is false here, as this is a next speaker check
|
||||
true,
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
|
||||
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
const responseText = turn.getResponseText() || '[no response text]';
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
request,
|
||||
responseText,
|
||||
);
|
||||
|
||||
// For AfterAgent hooks, blocking/stop execution should force continuation
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
!this.config.getQuotaErrorOccurred() &&
|
||||
!this.config.getSkipNextSpeakerCheck()
|
||||
) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
const nextSpeakerCheck = await checkNextSpeaker(
|
||||
this.getChat(),
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
);
|
||||
logNextSpeakerCheck(
|
||||
this.config,
|
||||
new NextSpeakerCheckEvent(
|
||||
prompt_id,
|
||||
turn.finishReason?.toString() || '',
|
||||
nextSpeakerCheck?.next_speaker || '',
|
||||
),
|
||||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
turn = yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
// isInvalidStreamRetry is false
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!isInvalidStreamRetry) {
|
||||
this.config.resetTurn();
|
||||
}
|
||||
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.hookStateMap.delete(this.lastPromptId);
|
||||
this.lastPromptId = prompt_id;
|
||||
this.currentSequenceModel = null;
|
||||
}
|
||||
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookResult = await this.fireBeforeAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
);
|
||||
if (hookResult) {
|
||||
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
|
||||
yield hookResult;
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
} else if ('additionalContext' in hookResult) {
|
||||
const additionalContext = hookResult.additionalContext;
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const boundedTurns = Math.min(turns, MAX_TURNS);
|
||||
let turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
try {
|
||||
turn = yield* this.processTurn(
|
||||
request,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns,
|
||||
isInvalidStreamRetry,
|
||||
);
|
||||
|
||||
// Fire AfterAgent hook if we have a turn and no pending tools
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await this.fireAfterAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
turn,
|
||||
);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
if (hookState) {
|
||||
hookState.activeCalls--;
|
||||
const isPendingTools =
|
||||
turn?.pendingToolCalls && turn.pendingToolCalls.length > 0;
|
||||
const isAborted = signal?.aborted;
|
||||
|
||||
if (hookState.activeCalls <= 0) {
|
||||
if (!isPendingTools || isAborted) {
|
||||
this.hookStateMap.delete(prompt_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user