Checkpoint in support /tool commands.

This commit is contained in:
jacob314
2026-05-14 09:35:37 -07:00
parent 84423e6ea1
commit b2946b1052
5 changed files with 201 additions and 72 deletions

View File

@@ -1180,37 +1180,40 @@ Logging in with Google... Restarting Gemini CLI to continue.
[config, getPreferredEditor],
);
const activeStream = streamAgent
? // eslint-disable-next-line react-hooks/rules-of-hooks
useAgentStream({
agent: streamAgent,
addItem: historyManager.addItem,
onCancelSubmit,
isShellFocused: embeddedShellFocused,
logger,
})
: // eslint-disable-next-line react-hooks/rules-of-hooks
useGeminiStream(
config.getGeminiClient(),
historyManager.history,
historyManager.addItem,
config,
settings,
setDebugMessage,
handleSlashCommand,
shellModeActive,
getPreferredEditor,
onAuthError,
performMemoryRefresh,
modelSwitchedFromQuotaError,
setModelSwitchedFromQuotaError,
onCancelSubmit,
setEmbeddedShellFocused,
terminalWidth,
terminalHeight,
embeddedShellFocused,
consumePendingHints,
);
const agentStream = useAgentStream({
agent: streamAgent,
addItem: historyManager.addItem,
handleSlashCommand,
onCancelSubmit,
isShellFocused: embeddedShellFocused,
logger,
isActive: !!streamAgent,
});
const geminiStream = useGeminiStream(
config.getGeminiClient(),
historyManager.history,
historyManager.addItem,
config,
settings,
setDebugMessage,
handleSlashCommand,
shellModeActive,
getPreferredEditor,
onAuthError,
performMemoryRefresh,
modelSwitchedFromQuotaError,
setModelSwitchedFromQuotaError,
onCancelSubmit,
setEmbeddedShellFocused,
terminalWidth,
terminalHeight,
embeddedShellFocused,
consumePendingHints,
!streamAgent,
);
const activeStream = streamAgent ? agentStream : geminiStream;
const {
streamingState,

View File

@@ -6,12 +6,12 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { act } from 'react';
import type { LegacyAgentProtocol } from '@google/gemini-cli-core';
import type { AgentProtocol } from '@google/gemini-cli-core';
import { renderHookWithProviders } from '../../test-utils/render.js';
// --- MOCKS ---
const mockLegacyAgentProtocol = vi.hoisted(() => ({
const mockAgentProtocol = vi.hoisted(() => ({
send: vi.fn().mockResolvedValue({ streamId: 'test-stream-id' }),
subscribe: vi.fn().mockReturnValue(() => {}),
abort: vi.fn().mockResolvedValue(undefined),
@@ -43,21 +43,24 @@ describe('useAgentStream', () => {
it('should initialize on mount', async () => {
await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: vi.fn().mockResolvedValue(false),
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
expect(mockLegacyAgentProtocol.subscribe).toHaveBeenCalled();
expect(mockAgentProtocol.subscribe).toHaveBeenCalled();
});
it('should call agent.send when submitQuery is called', async () => {
const mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: mockHandleSlashCommand,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
@@ -67,7 +70,7 @@ describe('useAgentStream', () => {
await result.current.submitQuery('hello');
});
expect(mockLegacyAgentProtocol.send).toHaveBeenCalledWith({
expect(mockAgentProtocol.send).toHaveBeenCalledWith({
message: { content: [{ type: 'text', text: 'hello' }] },
});
expect(mockAddItem).toHaveBeenCalledWith(
@@ -76,17 +79,73 @@ describe('useAgentStream', () => {
);
});
it('should update streamingState based on agent_start and agent_end events', async () => {
it('should intercept slash commands and not call agent.send if handled', async () => {
const mockHandleSlashCommand = vi
.fn()
.mockResolvedValue({ type: 'handled' });
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: mockHandleSlashCommand,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
await act(async () => {
await result.current.submitQuery('/about');
});
expect(mockHandleSlashCommand).toHaveBeenCalledWith('/about');
expect(mockAgentProtocol.send).not.toHaveBeenCalled();
});
it('should intercept slash commands and call agent.send with new content if submit_prompt', async () => {
const mockHandleSlashCommand = vi.fn().mockResolvedValue({
type: 'submit_prompt',
content: 'modified prompt',
});
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: mockHandleSlashCommand,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
await act(async () => {
await result.current.submitQuery('/mcp-prompt');
});
expect(mockHandleSlashCommand).toHaveBeenCalledWith('/mcp-prompt');
expect(mockAgentProtocol.send).toHaveBeenCalledWith({
message: { content: [{ type: 'text', text: 'modified prompt' }] },
});
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.USER,
text: 'modified prompt',
}),
expect.any(Number),
);
});
it('should update streamingState based on agent_start and agent_end events', async () => {
const mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: mockHandleSlashCommand,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock
.calls[0][0];
expect(result.current.streamingState).toBe(StreamingState.Idle);
@@ -116,14 +175,15 @@ describe('useAgentStream', () => {
it('should accumulate text content and update pendingHistoryItems', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: vi.fn().mockResolvedValue(false),
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock
.calls[0][0];
act(() => {
@@ -160,14 +220,15 @@ describe('useAgentStream', () => {
it('should process thought events and update thought state', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: vi.fn().mockResolvedValue(false),
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock
.calls[0][0];
act(() => {
@@ -190,8 +251,9 @@ describe('useAgentStream', () => {
it('should call agent.abort when cancelOngoingRequest is called', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
agent: mockAgentProtocol as unknown as AgentProtocol,
addItem: mockAddItem,
handleSlashCommand: vi.fn().mockResolvedValue(false),
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
@@ -201,7 +263,7 @@ describe('useAgentStream', () => {
await result.current.cancelOngoingRequest();
});
expect(mockLegacyAgentProtocol.abort).toHaveBeenCalled();
expect(mockAgentProtocol.abort).toHaveBeenCalled();
expect(mockOnCancelSubmit).toHaveBeenCalledWith(false, true);
});
});

View File

@@ -5,12 +5,14 @@
*/
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { type PartListUnion } from '@google/genai';
import {
getErrorMessage,
MessageSenderType,
debugLogger,
geminiPartsToContentParts,
displayContentToString,
partToString,
parseThought,
CoreToolCallStatus,
type ApprovalMode,
@@ -20,15 +22,16 @@ import {
type AgentEvent,
type AgentProtocol,
type Logger,
type Part,
} from '@google/gemini-cli-core';
import type {
HistoryItemWithoutId,
LoopDetectionConfirmationRequest,
IndividualToolCallDisplay,
HistoryItemToolDisplayGroup,
SlashCommandProcessorResult,
} from '../types.js';
import { StreamingState, MessageType } from '../types.js';
import { isSlashCommand } from '../utils/commandUtils.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { getToolGroupBorderAppearance } from '../utils/borderStyles.js';
import { type BackgroundTask } from './useExecutionLifecycle.js';
@@ -41,12 +44,16 @@ import { useKeypress } from './useKeypress.js';
export interface UseAgentStreamOptions {
agent?: AgentProtocol;
addItem: UseHistoryManagerReturn['addItem'];
handleSlashCommand: (
cmd: string,
) => Promise<SlashCommandProcessorResult | false>;
onCancelSubmit: (
shouldRestorePrompt?: boolean,
clearBuffer?: boolean,
) => void;
isShellFocused?: boolean;
logger?: Logger | null;
isActive?: boolean;
}
/**
@@ -56,9 +63,11 @@ export interface UseAgentStreamOptions {
export const useAgentStream = ({
agent,
addItem,
handleSlashCommand,
onCancelSubmit,
isShellFocused,
logger,
isActive = true,
}: UseAgentStreamOptions) => {
const [initError] = useState<string | null>(null);
const [retryStatus] = useState<RetryAttemptPayload | null>(null);
@@ -327,9 +336,10 @@ export const useAgentStream = ({
);
useEffect(() => {
if (!isActive) return;
const unsubscribe = agent?.subscribe(handleEvent);
return () => unsubscribe?.();
}, [agent, handleEvent]);
}, [agent, handleEvent, isActive]);
useKeypress(
(key) => {
@@ -341,14 +351,15 @@ export const useAgentStream = ({
},
{
isActive:
streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation,
isActive &&
(streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation),
},
);
const submitQuery = useCallback(
async (
query: Part[] | string,
query: PartListUnion,
options?: { isContinuation: boolean },
_prompt_id?: string,
) => {
@@ -360,16 +371,40 @@ export const useAgentStream = ({
geminiMessageBufferRef.current = '';
let localQuery: PartListUnion = query;
if (!options?.isContinuation) {
if (typeof query === 'string') {
addItem({ type: MessageType.USER, text: query }, timestamp);
void logger?.logMessage(MessageSenderType.USER, query);
if (typeof localQuery === 'string') {
const trimmedQuery = localQuery.trim();
if (isSlashCommand(trimmedQuery)) {
const slashResult = await handleSlashCommand(trimmedQuery);
if (slashResult) {
if (slashResult.type === 'handled') {
return;
}
if (slashResult.type === 'submit_prompt') {
localQuery = slashResult.content;
}
// schedule_tool is not yet supported in useAgentStream (mirrors handleAtCommand lack of support here)
}
}
}
const queryText =
typeof localQuery === 'string'
? localQuery
: partToString(localQuery);
addItem({ type: MessageType.USER, text: queryText }, timestamp);
void logger?.logMessage(MessageSenderType.USER, queryText);
startNewPrompt();
}
const parts = geminiPartsToContentParts(
typeof query === 'string' ? [{ text: query }] : query,
(Array.isArray(localQuery) ? localQuery : [localQuery]).map((p) =>
typeof p === 'string' ? { text: p } : p,
),
);
try {
@@ -384,10 +419,11 @@ export const useAgentStream = ({
);
}
},
[agent, addItem, logger, startNewPrompt],
[agent, addItem, logger, startNewPrompt, handleSlashCommand],
);
useEffect(() => {
if (!isActive) return;
if (trackedTools.length > 0) {
const isNewBatch = !trackedTools.some((tc) =>
pushedToolCallIdsRef.current.has(tc.callId),
@@ -406,11 +442,12 @@ export const useAgentStream = ({
setPushedToolCallIds,
setIsFirstToolInGroup,
streamingState,
isActive,
]);
// Push completed tools to history
useEffect(() => {
if (trackedTools.length === 0) return;
if (!isActive || trackedTools.length === 0) return;
// We only push to history once all currently known tools in the turn are terminal.
// This allows ToolGroupDisplay to correctly hoist ALL notices (topics) for the turn.
@@ -480,6 +517,7 @@ export const useAgentStream = ({
activePtyId,
isShellFocused,
backgroundTasks,
isActive,
]);
const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => {

View File

@@ -86,6 +86,7 @@ export const useExecutionLifecycle = (
terminalHeight?: number,
activeBackgroundExecutionId?: number,
isWaitingForConfirmation?: boolean,
isActive: boolean = true,
) => {
const [state, dispatch] = useReducer(shellReducer, initialState);
@@ -111,6 +112,7 @@ export const useExecutionLifecycle = (
state.activeShellPtyId ?? activeBackgroundExecutionId ?? undefined;
useEffect(() => {
if (!isActive) return;
const isForegroundActive = !!activePtyId || !!isWaitingForConfirmation;
if (isForegroundActive) {
@@ -144,20 +146,23 @@ export const useExecutionLifecycle = (
state.isBackgroundTaskVisible,
m,
dispatch,
isActive,
]);
useEffect(
() => () => {
if (!isActive) return;
// Unsubscribe from all background task events on unmount
for (const unsubscribe of m.subscriptions.values()) {
unsubscribe();
}
m.subscriptions.clear();
},
[m],
[m, isActive],
);
const toggleBackgroundTasks = useCallback(() => {
if (!isActive) return;
if (state.backgroundTasks.size > 0) {
const willBeVisible = !state.isBackgroundTaskVisible;
dispatch({ type: 'TOGGLE_VISIBILITY' });
@@ -193,9 +198,11 @@ export const useExecutionLifecycle = (
isWaitingForConfirmation,
m,
dispatch,
isActive,
]);
const backgroundCurrentExecution = useCallback(() => {
if (!isActive) return;
const pidToBackground =
state.activeShellPtyId ?? activeBackgroundExecutionId;
if (pidToBackground) {
@@ -218,10 +225,11 @@ export const useExecutionLifecycle = (
m.restoreTimeout = null;
}
}
}, [state.activeShellPtyId, activeBackgroundExecutionId, m]);
}, [state.activeShellPtyId, activeBackgroundExecutionId, m, isActive]);
const dismissBackgroundTask = useCallback(
async (pid: number) => {
if (!isActive) return;
const shell = state.backgroundTasks.get(pid);
if (shell) {
if (shell.status === 'running') {
@@ -240,7 +248,7 @@ export const useExecutionLifecycle = (
}
}
},
[state.backgroundTasks, dispatch, m],
[state.backgroundTasks, dispatch, m, isActive],
);
const registerBackgroundTask = useCallback(
@@ -250,6 +258,7 @@ export const useExecutionLifecycle = (
initialOutput: string | AnsiOutput,
completionBehavior?: CompletionBehavior,
) => {
if (!isActive) return;
m.backgroundedPids.add(pid);
dispatch({
type: 'REGISTER_TASK',
@@ -313,7 +322,7 @@ export const useExecutionLifecycle = (
dataUnsubscribe();
});
},
[dispatch, m],
[dispatch, m, isActive],
);
// Auto-register any execution that gets backgrounded, regardless of type.
@@ -321,6 +330,7 @@ export const useExecutionLifecycle = (
// ExecutionLifecycleService.createExecution() or attachExecution()
// automatically gets Ctrl+B support — no UI changes needed per tool.
useEffect(() => {
if (!isActive) return;
const listener = (info: {
executionId: number;
label: string;
@@ -342,7 +352,7 @@ export const useExecutionLifecycle = (
return () => {
ExecutionLifecycleService.offBackground(listener);
};
}, [registerBackgroundTask, m]);
}, [registerBackgroundTask, m, isActive]);
const handleShellCommand = useCallback(
(rawQuery: PartListUnion, abortSignal: AbortSignal): boolean => {

View File

@@ -238,6 +238,7 @@ export const useGeminiStream = (
terminalHeight: number,
isShellFocused?: boolean,
consumeUserHint?: () => string | null,
isActive: boolean = true,
) => {
const [initError, setInitError] = useState<string | null>(null);
const [retryStatus, setRetryStatus] = useState<RetryAttemptPayload | null>(
@@ -276,6 +277,7 @@ export const useGeminiStream = (
}, [config]);
useEffect(() => {
if (!isActive) return;
const handleRetryAttempt = (payload: RetryAttemptPayload) => {
if (turnCancelledRef.current || !isRespondingRef.current) {
return;
@@ -286,7 +288,7 @@ export const useGeminiStream = (
return () => {
coreEvents.off(CoreEvent.RetryAttempt, handleRetryAttempt);
};
}, [isRespondingRef]);
}, [isRespondingRef, isActive]);
const [
toolCalls,
@@ -387,6 +389,11 @@ export const useGeminiStream = (
[setIsResponding],
);
const streamingState = useMemo(
() => calculateStreamingState(isResponding, toolCalls),
[isResponding, toolCalls],
);
const {
handleShellCommand,
activeShellPtyId,
@@ -409,15 +416,13 @@ export const useGeminiStream = (
terminalWidth,
terminalHeight,
activeBackgroundExecutionId,
);
const streamingState = useMemo(
() => calculateStreamingState(isResponding, toolCalls),
[isResponding, toolCalls],
streamingState === StreamingState.WaitingForConfirmation,
isActive,
);
// Reset tracking when a new batch of tools starts
useEffect(() => {
if (!isActive) return;
if (toolCalls.length > 0) {
const isNewBatch = !toolCalls.some((tc) =>
pushedToolCallIdsRef.current.has(tc.request.callId),
@@ -437,10 +442,12 @@ export const useGeminiStream = (
setPushedToolCallIds,
setIsFirstToolInGroup,
streamingState,
isActive,
]);
// Push completed tools to history as they finish
useEffect(() => {
if (!isActive) return;
const toolsToPush: TrackedToolCall[] = [];
for (let i = 0; i < toolCalls.length; i++) {
const tc = toolCalls[i];
@@ -598,8 +605,11 @@ export const useGeminiStream = (
isShellFocused,
backgroundTasks,
settings.merged.ui?.compactToolOutput,
isActive,
]);
const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => {
if (!isActive) return [];
const remainingTools = toolCalls.filter(
(tc) => !pushedToolCallIds.has(tc.request.callId),
);
@@ -704,6 +714,7 @@ export const useGeminiStream = (
isShellFocused,
backgroundTasks,
settings.merged.ui?.compactToolOutput,
isActive,
]);
const lastQueryRef = useRef<PartListUnion | null>(null);
@@ -721,6 +732,7 @@ export const useGeminiStream = (
const prevActiveShellPtyIdRef = useRef<number | null>(null);
useEffect(() => {
if (!isActive) return;
if (
turnCancelledRef.current &&
prevActiveShellPtyIdRef.current !== null &&
@@ -730,9 +742,10 @@ export const useGeminiStream = (
setIsResponding(false);
}
prevActiveShellPtyIdRef.current = activeShellPtyId;
}, [activeShellPtyId, addItem, setIsResponding]);
}, [activeShellPtyId, addItem, setIsResponding, isActive]);
useEffect(() => {
if (!isActive) return;
if (
config.getApprovalMode() === ApprovalMode.YOLO &&
streamingState === StreamingState.Idle
@@ -751,7 +764,7 @@ export const useGeminiStream = (
);
}
}
}, [streamingState, config, history]);
}, [streamingState, config, history, isActive]);
useEffect(() => {
if (!isResponding) {
@@ -809,6 +822,7 @@ export const useGeminiStream = (
const cancelOngoingRequest = useCallback(
(clearBuffer: boolean = true) => {
if (!isActive) return;
// If we are already cancelled, do nothing
if (turnCancelledRef.current) {
if (clearBuffer) {
@@ -920,6 +934,7 @@ export const useGeminiStream = (
toolCalls,
activeShellPtyId,
setIsResponding,
isActive,
],
);
@@ -933,8 +948,9 @@ export const useGeminiStream = (
},
{
isActive:
streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation,
isActive &&
(streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation),
},
);