fix(hooks): deduplicate agent hooks and add cross-platform integration tests (#15701)

This commit is contained in:
Abhi
2025-12-30 14:13:16 -05:00
committed by GitHub
parent 4e6fee7fcd
commit 15c9f88da6
6 changed files with 779 additions and 227 deletions

View File

@@ -0,0 +1,2 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"."}}}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Final Answer"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}

View File

@@ -0,0 +1 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"**Responding**\n\nI will respond to the user's request.\n\n","thought":true}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"totalTokenCount":120,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Response to: "}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":5,"totalTokenCount":125,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Hello World"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":7,"totalTokenCount":127,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}}]}

View File

@@ -0,0 +1,238 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { TestRig } from './test-helper.js';
import { join } from 'node:path';
import { writeFileSync } from 'node:fs';
describe('Hooks Agent Flow', () => {
let rig: TestRig;
beforeEach(() => {
rig = new TestRig();
});
afterEach(async () => {
if (rig) {
await rig.cleanup();
}
});
describe('BeforeAgent Hooks', () => {
it('should inject additional context via BeforeAgent hook', async () => {
await rig.setup('should inject additional context via BeforeAgent hook', {
fakeResponsesPath: join(
import.meta.dirname,
'hooks-agent-flow.responses',
),
});
const hookScript = `
try {
const output = {
decision: "allow",
hookSpecificOutput: {
hookEventName: "BeforeAgent",
additionalContext: "SYSTEM INSTRUCTION: This is injected context."
}
};
process.stdout.write(JSON.stringify(output));
} catch (e) {
console.error('Failed to write stdout:', e);
process.exit(1);
}
console.error('DEBUG: BeforeAgent hook executed');
`;
const scriptPath = join(rig.testDir!, 'before_agent_context.cjs');
writeFileSync(scriptPath, hookScript);
await rig.setup('should inject additional context via BeforeAgent hook', {
settings: {
tools: {
enableHooks: true,
},
hooks: {
BeforeAgent: [
{
hooks: [
{
type: 'command',
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
},
],
},
},
});
await rig.run({ args: 'Hello test' });
// Verify hook execution and telemetry
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
expect(hookTelemetryFound).toBeTruthy();
const hookLogs = rig.readHookLogs();
const beforeAgentLog = hookLogs.find(
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
);
expect(beforeAgentLog).toBeDefined();
expect(beforeAgentLog?.hookCall.stdout).toContain('injected context');
expect(beforeAgentLog?.hookCall.stdout).toContain('"decision":"allow"');
expect(beforeAgentLog?.hookCall.stdout).toContain(
'SYSTEM INSTRUCTION: This is injected context.',
);
});
});
describe('AfterAgent Hooks', () => {
it('should receive prompt and response in AfterAgent hook', async () => {
await rig.setup('should receive prompt and response in AfterAgent hook', {
fakeResponsesPath: join(
import.meta.dirname,
'hooks-agent-flow.responses',
),
});
const hookScript = `
const fs = require('fs');
try {
const input = fs.readFileSync(0, 'utf-8');
console.error('DEBUG: AfterAgent hook input received');
process.stdout.write("Received Input: " + input);
// Ensure separation between the echo and the JSON output if they were to run together (though relying on separate console calls usually separates by newline)
// usage of process.stdout.write does NOT add newline.
// But here we want strictly the output "Received Input..." to be present.
// We also need to output the JSON decision for the hook runner to consider it successful?
// Actually HookRunner parses the *last* valid JSON block or treats text as system message.
// If we output mixed text and JSON, HookRunner might get confused if we don't handle it right.
// Existing test expects "Received Input" in stdout. And "Hello World".
// It DOES NOT parse the decision?
// Wait, HookRunner logic:
// "if (exitCode === EXIT_CODE_SUCCESS && stdout.trim()) ... JSON.parse ..."
// If JSON.parse fails: "Not JSON, convert plain text to structured output"
// So if we output formatted text, it becomes "systemMessage".
// That is fine for this test as we don't check the decision, just the stdout content.
} catch (err) {
console.error('Hook Failed:', err);
process.exit(1);
}
`;
const scriptPath = join(rig.testDir!, 'after_agent_verify.cjs');
writeFileSync(scriptPath, hookScript);
await rig.setup('should receive prompt and response in AfterAgent hook', {
settings: {
tools: {
enableHooks: true,
},
hooks: {
AfterAgent: [
{
hooks: [
{
type: 'command',
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
},
],
},
},
});
await rig.run({ args: 'Hello validation' });
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
expect(hookTelemetryFound).toBeTruthy();
const hookLogs = rig.readHookLogs();
const afterAgentLog = hookLogs.find(
(log) => log.hookCall.hook_event_name === 'AfterAgent',
);
expect(afterAgentLog).toBeDefined();
// Verify the hook stdout contains the input we echoed which proves the hook received the prompt and response
expect(afterAgentLog?.hookCall.stdout).toContain('Received Input');
expect(afterAgentLog?.hookCall.stdout).toContain('Hello validation');
// The fake response contains "Hello World"
expect(afterAgentLog?.hookCall.stdout).toContain('Hello World');
});
});
describe('Multi-step Loops', () => {
it('should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls', async () => {
await rig.setup(
'should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls',
{
fakeResponsesPath: join(
import.meta.dirname,
'hooks-agent-flow-multistep.responses',
),
settings: {
tools: {
enableHooks: true,
},
hooks: {
BeforeAgent: [
{
hooks: [
{
type: 'command',
command: `node -e "console.log('BeforeAgent Fired')"`,
timeout: 5000,
},
],
},
],
AfterAgent: [
{
hooks: [
{
type: 'command',
command: `node -e "console.log('AfterAgent Fired')"`,
timeout: 5000,
},
],
},
],
},
},
},
);
await rig.run({ args: 'Do a multi-step task' });
const hookLogs = rig.readHookLogs();
const beforeAgentLogs = hookLogs.filter(
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
);
const afterAgentLogs = hookLogs.filter(
(log) => log.hookCall.hook_event_name === 'AfterAgent',
);
// Should ensure BeforeAgent fired once
expect(beforeAgentLogs).toHaveLength(1);
// Should ensure AfterAgent fired once
// Note: If the tool call itself triggered BeforeTool/AfterTool, that's fine,
// but BeforeAgent/AfterAgent should only wrap the *entire* turn (User Request -> Final Answer).
expect(afterAgentLogs).toHaveLength(1);
// Verify the output log content to ensure we actually got the final answer
// (This implies the loop completed successfully)
const afterAgentLog = afterAgentLogs[0];
expect(afterAgentLog).toBeDefined();
expect(afterAgentLog?.hookCall.stdout).toContain('AfterAgent Fired');
});
});
});

View File

@@ -44,7 +44,7 @@ describe('Hooks System Integration', () => {
{
type: 'command',
command:
'echo "{\\"decision\\": \\"block\\", \\"reason\\": \\"File writing blocked by security policy\\"}"',
"node -e \"console.log(JSON.stringify({decision: 'block', reason: 'File writing blocked by security policy'}))\"",
timeout: 5000,
},
],
@@ -97,7 +97,7 @@ describe('Hooks System Integration', () => {
{
type: 'command',
command:
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"File writing approved\\"}"',
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'File writing approved'}))\"",
timeout: 5000,
},
],
@@ -129,7 +129,7 @@ describe('Hooks System Integration', () => {
describe('Command Hooks - Additional Context', () => {
it('should add additional context from AfterTool hooks', async () => {
const command =
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"Security scan: File content appears safe\\"}}"';
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'Security scan: File content appears safe'}}))\"";
await rig.setup('should add additional context from AfterTool hooks', {
fakeResponsesPath: join(
import.meta.dirname,
@@ -190,27 +190,24 @@ describe('Hooks System Integration', () => {
'hooks-system.before-model.responses',
),
});
const hookScript = `#!/bin/bash
echo '{
"decision": "allow",
"hookSpecificOutput": {
"hookEventName": "BeforeModel",
"llm_request": {
"messages": [
const hookScript = `const fs = require('fs');
console.log(JSON.stringify({
decision: "allow",
hookSpecificOutput: {
hookEventName: "BeforeModel",
llm_request: {
messages: [
{
"role": "user",
"content": "Please respond with exactly: The security hook modified this request successfully."
role: "user",
content: "Please respond with exactly: The security hook modified this request successfully."
}
]
}
}
}'`;
}));`;
const scriptPath = join(rig.testDir!, 'before_model_hook.sh');
const scriptPath = join(rig.testDir!, 'before_model_hook.cjs');
writeFileSync(scriptPath, hookScript);
// Make executable
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${scriptPath}"`);
await rig.setup('should modify LLM requests with BeforeModel hooks', {
settings: {
@@ -223,7 +220,7 @@ echo '{
hooks: [
{
type: 'command',
command: scriptPath,
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
@@ -250,7 +247,9 @@ echo '{
expect(hookTelemetryFound[0].hookCall.hook_event_name).toBe(
'BeforeModel',
);
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(scriptPath);
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(
`node "${scriptPath}"`,
);
expect(hookTelemetryFound[0].hookCall.hook_input).toBeDefined();
expect(hookTelemetryFound[0].hookCall.hook_output).toBeDefined();
expect(hookTelemetryFound[0].hookCall.exit_code).toBe(0);
@@ -270,30 +269,28 @@ echo '{
),
});
// Create a hook script that modifies the LLM response
const hookScript = `#!/bin/bash
echo '{
"hookSpecificOutput": {
"hookEventName": "AfterModel",
"llm_response": {
"candidates": [
const hookScript = `const fs = require('fs');
console.log(JSON.stringify({
hookSpecificOutput: {
hookEventName: "AfterModel",
llm_response: {
candidates: [
{
"content": {
"role": "model",
"parts": [
content: {
role: "model",
parts: [
"[FILTERED] Response has been filtered for security compliance."
]
},
"finishReason": "STOP"
finishReason: "STOP"
}
]
}
}
}'`;
}));`;
const scriptPath = join(rig.testDir!, 'after_model_hook.sh');
const scriptPath = join(rig.testDir!, 'after_model_hook.cjs');
writeFileSync(scriptPath, hookScript);
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${scriptPath}"`);
await rig.setup('should modify LLM responses with AfterModel hooks', {
settings: {
@@ -306,7 +303,7 @@ echo '{
hooks: [
{
type: 'command',
command: scriptPath,
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
@@ -343,7 +340,7 @@ echo '{
);
// Create inline hook command (works on both Unix and Windows)
const hookCommand =
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeToolSelection\\", \\"toolConfig\\": {\\"mode\\": \\"ANY\\", \\"allowedFunctionNames\\": [\\"read_file\\", \\"run_shell_command\\"]}}}"';
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'BeforeToolSelection', toolConfig: {mode: 'ANY', allowedFunctionNames: ['read_file', 'run_shell_command']}}}))\"";
await rig.setup(
'should modify tool selection with BeforeToolSelection hooks',
@@ -404,19 +401,17 @@ echo '{
),
});
// Create a hook script that adds context to the prompt
const hookScript = `#!/bin/bash
echo '{
"decision": "allow",
"hookSpecificOutput": {
"hookEventName": "BeforeAgent",
"additionalContext": "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
const hookScript = `const fs = require('fs');
console.log(JSON.stringify({
decision: "allow",
hookSpecificOutput: {
hookEventName: "BeforeAgent",
additionalContext: "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
}
}'`;
}));`;
const scriptPath = join(rig.testDir!, 'before_agent_hook.sh');
const scriptPath = join(rig.testDir!, 'before_agent_hook.cjs');
writeFileSync(scriptPath, hookScript);
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${scriptPath}"`);
await rig.setup('should augment prompts with BeforeAgent hooks', {
settings: {
@@ -429,7 +424,7 @@ echo '{
hooks: [
{
type: 'command',
command: scriptPath,
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
@@ -452,9 +447,10 @@ echo '{
describe('Notification Hooks - Permission Handling', () => {
it('should handle notification hooks for tool permissions', async () => {
// Create inline hook command (works on both Unix and Windows)
// Create inline hook command (works on both Unix and Windows)
const hookCommand =
'echo "{\\"suppressOutput\\": false, \\"systemMessage\\": \\"Permission request logged by security hook\\"}"';
'node -e "console.log(JSON.stringify({suppressOutput: false, systemMessage: \'Permission request logged by security hook\'}))"';
await rig.setup('should handle notification hooks for tool permissions', {
fakeResponsesPath: join(
@@ -548,9 +544,9 @@ echo '{
it('should execute hooks sequentially when configured', async () => {
// Create inline hook commands (works on both Unix and Windows)
const hook1Command =
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 1: Initial validation passed.\\"}}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 1: Initial validation passed.'}}))\"";
const hook2Command =
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 2: Security check completed.\\"}}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 2: Security check completed.'}}))\"";
await rig.setup('should execute hooks sequentially when configured', {
fakeResponsesPath: join(
@@ -621,23 +617,22 @@ echo '{
),
});
// Create a hook script that validates the input format
const hookScript = `#!/bin/bash
# Read JSON input from stdin
input=$(cat)
const hookScript = `const fs = require('fs');
const input = fs.readFileSync(0, 'utf-8');
try {
const json = JSON.parse(input);
// Check fields
if (json.session_id && json.cwd && json.hook_event_name && json.timestamp && json.tool_name && json.tool_input) {
console.log(JSON.stringify({decision: "allow", reason: "Input format is correct"}));
} else {
console.log(JSON.stringify({decision: "block", reason: "Input format is invalid"}));
}
} catch (e) {
console.log(JSON.stringify({decision: "block", reason: "Invalid JSON"}));
}`;
# Check for required fields
if echo "$input" | jq -e '.session_id and .cwd and .hook_event_name and .timestamp and .tool_name and .tool_input' > /dev/null 2>&1; then
echo '{"decision": "allow", "reason": "Input format is correct"}'
exit 0
else
echo '{"decision": "block", "reason": "Input format is invalid"}'
exit 0
fi`;
const scriptPath = join(rig.testDir!, 'input_validation_hook.sh');
const scriptPath = join(rig.testDir!, 'input_validation_hook.cjs');
writeFileSync(scriptPath, hookScript);
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${scriptPath}"`);
await rig.setup('should provide correct input format to hooks', {
settings: {
@@ -650,7 +645,7 @@ fi`;
hooks: [
{
type: 'command',
command: scriptPath,
command: `node "${scriptPath}"`,
timeout: 5000,
},
],
@@ -682,11 +677,11 @@ fi`;
it('should handle hooks for all major event types', async () => {
// Create inline hook commands (works on both Unix and Windows)
const beforeToolCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"BeforeTool: File operation logged\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'BeforeTool: File operation logged'}))\"";
const afterToolCommand =
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"AfterTool: Operation completed successfully\\"}}"';
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'AfterTool: Operation completed successfully'}}))\"";
const beforeAgentCommand =
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"BeforeAgent: User request processed\\"}}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'BeforeAgent: User request processed'}}))\"";
await rig.setup('should handle hooks for all major event types', {
fakeResponsesPath: join(
@@ -802,10 +797,10 @@ fi`;
// Create a hook script that fails
// Create inline hook commands (works on both Unix and Windows)
// Failing hook: exits with non-zero code
const failingCommand = 'exit 1';
const failingCommand = 'node -e "process.exit(1)"';
// Working hook: returns success with JSON
const workingCommand =
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Working hook succeeded\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Working hook succeeded'}))\"";
await rig.setup('should handle hook failures gracefully', {
settings: {
@@ -855,7 +850,7 @@ fi`;
it('should generate telemetry events for hook executions', async () => {
// Create inline hook command (works on both Unix and Windows)
const hookCommand =
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Telemetry test hook\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Telemetry test hook'}))\"";
await rig.setup('should generate telemetry events for hook executions', {
fakeResponsesPath: join(
@@ -898,7 +893,7 @@ fi`;
it('should fire SessionStart hook on app startup', async () => {
// Create inline hook command that outputs JSON
const sessionStartCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting on startup\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting on startup'}))\"";
await rig.setup('should fire SessionStart hook on app startup', {
fakeResponsesPath: join(
@@ -958,9 +953,9 @@ fi`;
it('should fire SessionEnd and SessionStart hooks on /clear command', async () => {
// Create inline hook commands for both SessionEnd and SessionStart
const sessionEndCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session ending due to clear\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session ending due to clear'}))\"";
const sessionStartCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting after clear\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting after clear'}))\"";
await rig.setup(
'should fire SessionEnd and SessionStart hooks on /clear command',
@@ -1136,7 +1131,7 @@ fi`;
it('should fire PreCompress hook on automatic compression', async () => {
// Create inline hook command that outputs JSON
const preCompressCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"PreCompress hook executed for automatic compression\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'PreCompress hook executed for automatic compression'}))\"";
await rig.setup('should fire PreCompress hook on automatic compression', {
fakeResponsesPath: join(
@@ -1203,7 +1198,7 @@ fi`;
describe('SessionEnd on Exit', () => {
it('should fire SessionEnd hook on graceful exit in non-interactive mode', async () => {
const sessionEndCommand =
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"SessionEnd hook executed on exit\\"}"';
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'SessionEnd hook executed on exit'}))\"";
await rig.setup('should fire SessionEnd hook on graceful exit', {
fakeResponsesPath: join(
@@ -1297,20 +1292,17 @@ fi`;
});
// Create two hook scripts - one enabled, one disabled
const enabledHookScript = `#!/bin/bash
echo '{"decision": "allow", "systemMessage": "Enabled hook executed"}'`;
const enabledHookScript = `const fs = require('fs');
console.log(JSON.stringify({decision: "allow", systemMessage: "Enabled hook executed"}));`;
const disabledHookScript = `#!/bin/bash
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook should be disabled"}'`;
const disabledHookScript = `const fs = require('fs');
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook should be disabled"}));`;
const enabledPath = join(rig.testDir!, 'enabled_hook.sh');
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
const enabledPath = join(rig.testDir!, 'enabled_hook.cjs');
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
writeFileSync(enabledPath, enabledHookScript);
writeFileSync(disabledPath, disabledHookScript);
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${enabledPath}"`);
execSync(`chmod +x "${disabledPath}"`);
await rig.setup('should not execute hooks disabled in settings file', {
settings: {
@@ -1323,18 +1315,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
hooks: [
{
type: 'command',
command: enabledPath,
command: `node "${enabledPath}"`,
timeout: 5000,
},
{
type: 'command',
command: disabledPath,
command: `node "${disabledPath}"`,
timeout: 5000,
},
],
},
],
disabled: [disabledPath], // Disable the second hook
disabled: [`node "${disabledPath}"`], // Disable the second hook
},
},
});
@@ -1358,10 +1350,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
// Check hook telemetry - only enabled hook should have executed
const hookLogs = rig.readHookLogs();
const enabledHookLog = hookLogs.find(
(log) => log.hookCall.hook_name === enabledPath,
(log) => log.hookCall.hook_name === `node "${enabledPath}"`,
);
const disabledHookLog = hookLogs.find(
(log) => log.hookCall.hook_name === disabledPath,
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
);
expect(enabledHookLog).toBeDefined();
@@ -1380,20 +1372,17 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
);
// Create two hook scripts - one that will be disabled, one that won't
const activeHookScript = `#!/bin/bash
echo '{"decision": "allow", "systemMessage": "Active hook executed"}'`;
const activeHookScript = `const fs = require('fs');
console.log(JSON.stringify({decision: "allow", systemMessage: "Active hook executed"}));`;
const disabledHookScript = `#!/bin/bash
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook is disabled"}'`;
const disabledHookScript = `const fs = require('fs');
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook is disabled"}));`;
const activePath = join(rig.testDir!, 'active_hook.sh');
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
const activePath = join(rig.testDir!, 'active_hook.cjs');
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
writeFileSync(activePath, activeHookScript);
writeFileSync(disabledPath, disabledHookScript);
const { execSync } = await import('node:child_process');
execSync(`chmod +x "${activePath}"`);
execSync(`chmod +x "${disabledPath}"`);
await rig.setup(
'should respect disabled hooks across multiple operations',
@@ -1408,18 +1397,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
hooks: [
{
type: 'command',
command: activePath,
command: `node "${activePath}"`,
timeout: 5000,
},
{
type: 'command',
command: disabledPath,
command: `node "${disabledPath}"`,
timeout: 5000,
},
],
},
],
disabled: [disabledPath], // Disable the second hook
disabled: [`node "${disabledPath}"`], // Disable the second hook
},
},
},
@@ -1441,10 +1430,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
// Check hook telemetry
const hookLogs1 = rig.readHookLogs();
const activeHookLog1 = hookLogs1.find(
(log) => log.hookCall.hook_name === activePath,
(log) => log.hookCall.hook_name === `node "${activePath}"`,
);
const disabledHookLog1 = hookLogs1.find(
(log) => log.hookCall.hook_name === disabledPath,
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
);
expect(activeHookLog1).toBeDefined();
@@ -1465,7 +1454,7 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
// Verify disabled hook still hasn't executed
const hookLogs2 = rig.readHookLogs();
const disabledHookCalls = hookLogs2.filter(
(log) => log.hookCall.hook_name === disabledPath,
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
);
expect(disabledHookCalls.length).toBe(0);
});

View File

@@ -81,6 +81,10 @@ vi.mock('node:fs', () => {
});
// --- Mocks ---
interface MockTurnContext {
getResponseText: Mock<() => string>;
}
const mockTurnRunFn = vi.fn();
vi.mock('./turn', async (importOriginal) => {
@@ -94,6 +98,8 @@ vi.mock('./turn', async (importOriginal) => {
constructor() {
// The constructor can be empty or do some mock setup
}
getResponseText = vi.fn().mockReturnValue('Mock Response');
}
// Export the mock class as 'Turn'
return {
@@ -129,6 +135,15 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
},
}));
vi.mock('../hooks/hookSystem.js');
vi.mock('./clientHookTriggers.js', () => ({
fireBeforeAgentHook: vi.fn(),
fireAfterAgentHook: vi.fn().mockResolvedValue({
decision: 'allow',
continue: false,
suppressOutput: false,
systemMessage: undefined,
}),
}));
/**
* Array.fromAsync ponyfill, which will be available in es 2024.
@@ -543,16 +558,22 @@ describe('Gemini Client (client.ts)', () => {
await client.tryCompressChat('prompt-1', false); // force = false
// 3. Assert Step 1: Check that the flag became true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((client as any).hasFailedCompressionAttempt).toBe(true);
// 3. Assert Step 1: Check that the flag became true
expect(
(client as unknown as { hasFailedCompressionAttempt: boolean })
.hasFailedCompressionAttempt,
).toBe(true);
// 4. Test Step 2: Trigger a forced failure
await client.tryCompressChat('prompt-2', true); // force = true
// 5. Assert Step 2: Check that the flag REMAINS true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((client as any).hasFailedCompressionAttempt).toBe(true);
// 5. Assert Step 2: Check that the flag REMAINS true
expect(
(client as unknown as { hasFailedCompressionAttempt: boolean })
.hasFailedCompressionAttempt,
).toBe(true);
});
it('should not trigger summarization if token count is below threshold', async () => {
@@ -2615,5 +2636,152 @@ ${JSON.stringify(
'test-session-id',
);
});
describe('Hook System', () => {
let mockMessageBus: { publish: Mock; subscribe: Mock };
beforeEach(() => {
vi.clearAllMocks();
mockMessageBus = { publish: vi.fn(), subscribe: vi.fn() };
// Force override config methods on the client instance
client['config'].getEnableHooks = vi.fn().mockReturnValue(true);
client['config'].getMessageBus = vi
.fn()
.mockReturnValue(mockMessageBus);
});
it('should fire BeforeAgent and AfterAgent exactly once for a simple turn', async () => {
const promptId = 'test-prompt-hook-1';
const request = { text: 'Hello Hooks' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Hook Response');
yield { type: GeminiEventType.Content, value: 'Hook Response' };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
'Hook Response',
);
// Map should be empty
expect(client['hookStateMap'].size).toBe(0);
});
it('should fire BeforeAgent once and AfterAgent once even with recursion', async () => {
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
vi.mocked(checkNextSpeaker)
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
.mockResolvedValueOnce(null);
const promptId = 'test-prompt-hook-recursive';
const request = { text: 'Recursion Test' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
let callCount = 0;
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
callCount++;
const response = `Response ${callCount}`;
this.getResponseText.mockReturnValue(response);
yield { type: GeminiEventType.Content, value: response };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
// BeforeAgent should fire ONLY once despite multiple internal turns
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
// AfterAgent should fire ONLY when the stack unwinds
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
// Check cumulative response (separated by newline)
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
'Response 1\nResponse 2',
);
expect(client['hookStateMap'].size).toBe(0);
});
it('should use original request in AfterAgent hook even when continuation happened', async () => {
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
vi.mocked(checkNextSpeaker)
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
.mockResolvedValueOnce(null);
const promptId = 'test-prompt-hook-original-req';
const request = { text: 'Do something' };
const signal = new AbortController().signal;
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Ok');
yield { type: GeminiEventType.Content, value: 'Ok' };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request, // Should be 'Do something'
expect.stringContaining('Ok'),
);
});
it('should cleanup state when prompt_id changes', async () => {
const signal = new AbortController().signal;
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Ok');
yield { type: GeminiEventType.Content, value: 'Ok' };
});
client['hookStateMap'].set('old-id', {
hasFiredBeforeAgent: true,
cumulativeResponse: 'Old',
activeCalls: 0,
originalRequest: { text: 'Old' },
});
client['lastPromptId'] = 'old-id';
const stream = client.sendMessageStream(
{ text: 'New' },
signal,
'new-id',
);
await stream.next();
expect(client['hookStateMap'].has('old-id')).toBe(false);
expect(client['hookStateMap'].has('new-id')).toBe(true);
});
});
});
});

View File

@@ -11,6 +11,7 @@ import type {
Tool,
GenerateContentResponse,
} from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
getDirectoryContextString,
getInitialChatHistory,
@@ -42,6 +43,7 @@ import {
fireBeforeAgentHook,
fireAfterAgentHook,
} from './clientHookTriggers.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import {
ContentRetryFailureEvent,
NextSpeakerCheckEvent,
@@ -61,6 +63,14 @@ import type { RetryAvailabilityContext } from '../utils/retry.js';
const MAX_TURNS = 100;
type BeforeAgentHookReturn =
| {
type: GeminiEventType.Error;
value: { error: Error };
}
| { additionalContext: string | undefined }
| undefined;
export class GeminiClient {
private chat?: GeminiChat;
private sessionTurnCount = 0;
@@ -84,6 +94,95 @@ export class GeminiClient {
this.lastPromptId = this.config.getSessionId();
}
// Hook state to deduplicate BeforeAgent calls and track response for
// AfterAgent
private hookStateMap = new Map<
string,
{
hasFiredBeforeAgent: boolean;
cumulativeResponse: string;
activeCalls: number;
originalRequest: PartListUnion;
}
>();
private async fireBeforeAgentHookSafe(
messageBus: MessageBus,
request: PartListUnion,
prompt_id: string,
): Promise<BeforeAgentHookReturn> {
let hookState = this.hookStateMap.get(prompt_id);
if (!hookState) {
hookState = {
hasFiredBeforeAgent: false,
cumulativeResponse: '',
activeCalls: 0,
originalRequest: request,
};
this.hookStateMap.set(prompt_id, hookState);
}
// Increment active calls for this prompt_id
// This is called at the start of sendMessageStream, so it acts as an entry
// counter. We increment here, assuming this helper is ALWAYS called at
// entry.
hookState.activeCalls++;
if (hookState.hasFiredBeforeAgent) {
return undefined;
}
const hookOutput = await fireBeforeAgentHook(messageBus, request);
hookState.hasFiredBeforeAgent = true;
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
return {
type: GeminiEventType.Error,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
},
};
}
const additionalContext = hookOutput?.getAdditionalContext();
if (additionalContext) {
return { additionalContext };
}
return undefined;
}
private async fireAfterAgentHookSafe(
messageBus: MessageBus,
currentRequest: PartListUnion,
prompt_id: string,
turn?: Turn,
): Promise<DefaultHookOutput | undefined> {
const hookState = this.hookStateMap.get(prompt_id);
// Only fire on the outermost call (when activeCalls is 1)
if (!hookState || hookState.activeCalls !== 1) {
return undefined;
}
if (turn && turn.pendingToolCalls.length > 0) {
return undefined;
}
const finalResponseText =
hookState.cumulativeResponse ||
turn?.getResponseText() ||
'[no response text]';
const finalRequest = hookState.originalRequest || currentRequest;
const hookOutput = await fireAfterAgentHook(
messageBus,
finalRequest,
finalResponseText,
);
return hookOutput;
}
private updateTelemetryTokenCount() {
if (this.chat) {
uiTelemetryService.setLastPromptTokenCount(
@@ -400,63 +499,27 @@ export class GeminiClient {
return this.config.getActiveModel();
}
async *sendMessageStream(
private async *processTurn(
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
boundedTurns: number,
isInvalidStreamRetry: boolean,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!isInvalidStreamRetry) {
this.config.resetTurn();
}
// Re-initialize turn (it was empty before if in loop, or new instance)
let turn = new Turn(this.getChat(), prompt_id);
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (hooksEnabled && messageBus) {
const hookOutput = await fireBeforeAgentHook(messageBus, request);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
yield {
type: GeminiEventType.Error,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
},
};
return new Turn(this.getChat(), prompt_id);
}
// Add additional context from hooks to the request
const additionalContext = hookOutput?.getAdditionalContext();
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [...requestArray, { text: additionalContext }];
}
}
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.lastPromptId = prompt_id;
this.currentSequenceModel = null;
}
this.sessionTurnCount++;
if (
this.config.getMaxSessionTurns() > 0 &&
this.sessionTurnCount > this.config.getMaxSessionTurns()
) {
yield { type: GeminiEventType.MaxSessionTurns };
return new Turn(this.getChat(), prompt_id);
return turn;
}
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, MAX_TURNS);
if (!boundedTurns) {
return new Turn(this.getChat(), prompt_id);
return turn;
}
// Check for context window overflow
@@ -478,7 +541,7 @@ export class GeminiClient {
type: GeminiEventType.ContextWindowWillOverflow,
value: { estimatedRequestTokenCount, remainingTokenCount },
};
return new Turn(this.getChat(), prompt_id);
return turn;
}
const compressed = await this.tryCompressChat(prompt_id, false);
@@ -514,7 +577,8 @@ export class GeminiClient {
this.forceFullIdeContext = false;
}
const turn = new Turn(this.getChat(), prompt_id);
// Re-initialize turn with fresh history
turn = new Turn(this.getChat(), prompt_id);
const controller = new AbortController();
const linkedSignal = AbortSignal.any([signal, controller.signal]);
@@ -555,6 +619,9 @@ export class GeminiClient {
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
let isError = false;
let isInvalidStream = false;
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
@@ -566,94 +633,181 @@ export class GeminiClient {
this.updateTelemetryTokenCount();
if (event.type === GeminiEventType.InvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (isInvalidStreamRetry) {
// We already retried once, so stop here.
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(
4, // 2 initial + 2 after injections
'FAILED_AFTER_PROMPT_INJECTION',
modelToUse,
),
);
return turn;
}
const nextRequest = [{ text: 'System: Please continue.' }];
yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
true, // Set isInvalidStreamRetry to true
isInvalidStream = true;
}
if (event.type === GeminiEventType.Error) {
isError = true;
}
}
if (isError) {
return turn;
}
// Update cumulative response in hook state
// We do this immediately after the stream finishes for THIS turn.
const hooksEnabled = this.config.getEnableHooks();
if (hooksEnabled) {
const responseText = turn.getResponseText() || '';
const hookState = this.hookStateMap.get(prompt_id);
if (hookState && responseText) {
// Append with newline if not empty
hookState.cumulativeResponse = hookState.cumulativeResponse
? `${hookState.cumulativeResponse}\n${responseText}`
: responseText;
}
}
if (isInvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (isInvalidStreamRetry) {
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(
4,
'FAILED_AFTER_PROMPT_INJECTION',
modelToUse,
),
);
return turn;
}
}
if (event.type === GeminiEventType.Error) {
return turn;
}
}
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
// Check if next speaker check is needed
if (this.config.getQuotaErrorOccurred()) {
return turn;
}
if (this.config.getSkipNextSpeakerCheck()) {
return turn;
}
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this.config.getBaseLlmClient(),
signal,
prompt_id,
);
logNextSpeakerCheck(
this.config,
new NextSpeakerCheckEvent(
prompt_id,
turn.finishReason?.toString() || '',
nextSpeakerCheck?.next_speaker || '',
),
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, and the final
// turn object from the recursive call will be returned.
return yield* this.sendMessageStream(
const nextRequest = [{ text: 'System: Please continue.' }];
// Recursive call - update turn with result
turn = yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
// isInvalidStreamRetry is false here, as this is a next speaker check
true,
);
return turn;
}
}
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const responseText = turn.getResponseText() || '[no response text]';
const hookOutput = await fireAfterAgentHook(
messageBus,
request,
responseText,
);
// For AfterAgent hooks, blocking/stop execution should force continuation
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
!this.config.getQuotaErrorOccurred() &&
!this.config.getSkipNextSpeakerCheck()
) {
const continueReason = hookOutput.getEffectiveReason();
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this.config.getBaseLlmClient(),
signal,
prompt_id,
boundedTurns - 1,
);
logNextSpeakerCheck(
this.config,
new NextSpeakerCheckEvent(
prompt_id,
turn.finishReason?.toString() || '',
nextSpeakerCheck?.next_speaker || '',
),
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
turn = yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
// isInvalidStreamRetry is false
);
return turn;
}
}
}
return turn;
}
async *sendMessageStream(
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!isInvalidStreamRetry) {
this.config.resetTurn();
}
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.hookStateMap.delete(this.lastPromptId);
this.lastPromptId = prompt_id;
this.currentSequenceModel = null;
}
if (hooksEnabled && messageBus) {
const hookResult = await this.fireBeforeAgentHookSafe(
messageBus,
request,
prompt_id,
);
if (hookResult) {
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
yield hookResult;
return new Turn(this.getChat(), prompt_id);
} else if ('additionalContext' in hookResult) {
const additionalContext = hookResult.additionalContext;
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [...requestArray, { text: additionalContext }];
}
}
}
}
const boundedTurns = Math.min(turns, MAX_TURNS);
let turn = new Turn(this.getChat(), prompt_id);
try {
turn = yield* this.processTurn(
request,
signal,
prompt_id,
boundedTurns,
isInvalidStreamRetry,
);
// Fire AfterAgent hook if we have a turn and no pending tools
if (hooksEnabled && messageBus) {
const hookOutput = await this.fireAfterAgentHookSafe(
messageBus,
request,
prompt_id,
turn,
);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
const continueReason = hookOutput.getEffectiveReason();
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
signal,
prompt_id,
boundedTurns - 1,
);
}
}
} finally {
const hookState = this.hookStateMap.get(prompt_id);
if (hookState) {
hookState.activeCalls--;
const isPendingTools =
turn?.pendingToolCalls && turn.pendingToolCalls.length > 0;
const isAborted = signal?.aborted;
if (hookState.activeCalls <= 0) {
if (!isPendingTools || isAborted) {
this.hookStateMap.delete(prompt_id);
}
}
}
}