diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 313f377b02..f1f802df2b 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -1180,37 +1180,40 @@ Logging in with Google... Restarting Gemini CLI to continue. [config, getPreferredEditor], ); - const activeStream = streamAgent - ? // eslint-disable-next-line react-hooks/rules-of-hooks - useAgentStream({ - agent: streamAgent, - addItem: historyManager.addItem, - onCancelSubmit, - isShellFocused: embeddedShellFocused, - logger, - }) - : // eslint-disable-next-line react-hooks/rules-of-hooks - useGeminiStream( - config.getGeminiClient(), - historyManager.history, - historyManager.addItem, - config, - settings, - setDebugMessage, - handleSlashCommand, - shellModeActive, - getPreferredEditor, - onAuthError, - performMemoryRefresh, - modelSwitchedFromQuotaError, - setModelSwitchedFromQuotaError, - onCancelSubmit, - setEmbeddedShellFocused, - terminalWidth, - terminalHeight, - embeddedShellFocused, - consumePendingHints, - ); + const agentStream = useAgentStream({ + agent: streamAgent, + addItem: historyManager.addItem, + handleSlashCommand, + onCancelSubmit, + isShellFocused: embeddedShellFocused, + logger, + isActive: !!streamAgent, + }); + + const geminiStream = useGeminiStream( + config.getGeminiClient(), + historyManager.history, + historyManager.addItem, + config, + settings, + setDebugMessage, + handleSlashCommand, + shellModeActive, + getPreferredEditor, + onAuthError, + performMemoryRefresh, + modelSwitchedFromQuotaError, + setModelSwitchedFromQuotaError, + onCancelSubmit, + setEmbeddedShellFocused, + terminalWidth, + terminalHeight, + embeddedShellFocused, + consumePendingHints, + !streamAgent, + ); + + const activeStream = streamAgent ? agentStream : geminiStream; const { streamingState, diff --git a/packages/cli/src/ui/hooks/useAgentStream.test.tsx b/packages/cli/src/ui/hooks/useAgentStream.test.tsx index 1136a3592e..6c0be9c3d2 100644 --- a/packages/cli/src/ui/hooks/useAgentStream.test.tsx +++ b/packages/cli/src/ui/hooks/useAgentStream.test.tsx @@ -6,12 +6,12 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { act } from 'react'; -import type { LegacyAgentProtocol } from '@google/gemini-cli-core'; +import type { AgentProtocol } from '@google/gemini-cli-core'; import { renderHookWithProviders } from '../../test-utils/render.js'; // --- MOCKS --- -const mockLegacyAgentProtocol = vi.hoisted(() => ({ +const mockAgentProtocol = vi.hoisted(() => ({ send: vi.fn().mockResolvedValue({ streamId: 'test-stream-id' }), subscribe: vi.fn().mockReturnValue(() => {}), abort: vi.fn().mockResolvedValue(undefined), @@ -43,21 +43,24 @@ describe('useAgentStream', () => { it('should initialize on mount', async () => { await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: vi.fn().mockResolvedValue(false), onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), ); - expect(mockLegacyAgentProtocol.subscribe).toHaveBeenCalled(); + expect(mockAgentProtocol.subscribe).toHaveBeenCalled(); }); it('should call agent.send when submitQuery is called', async () => { + const mockHandleSlashCommand = vi.fn().mockResolvedValue(false); const { result } = await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: mockHandleSlashCommand, onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), @@ -67,7 +70,7 @@ describe('useAgentStream', () => { await result.current.submitQuery('hello'); }); - expect(mockLegacyAgentProtocol.send).toHaveBeenCalledWith({ + expect(mockAgentProtocol.send).toHaveBeenCalledWith({ message: { content: [{ type: 'text', text: 'hello' }] }, }); expect(mockAddItem).toHaveBeenCalledWith( @@ -76,17 +79,73 @@ describe('useAgentStream', () => { ); }); - it('should update streamingState based on agent_start and agent_end events', async () => { + it('should intercept slash commands and not call agent.send if handled', async () => { + const mockHandleSlashCommand = vi + .fn() + .mockResolvedValue({ type: 'handled' }); const { result } = await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: mockHandleSlashCommand, onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), ); - const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock + await act(async () => { + await result.current.submitQuery('/about'); + }); + + expect(mockHandleSlashCommand).toHaveBeenCalledWith('/about'); + expect(mockAgentProtocol.send).not.toHaveBeenCalled(); + }); + + it('should intercept slash commands and call agent.send with new content if submit_prompt', async () => { + const mockHandleSlashCommand = vi.fn().mockResolvedValue({ + type: 'submit_prompt', + content: 'modified prompt', + }); + const { result } = await renderHookWithProviders(() => + useAgentStream({ + agent: mockAgentProtocol as unknown as AgentProtocol, + addItem: mockAddItem, + handleSlashCommand: mockHandleSlashCommand, + onCancelSubmit: mockOnCancelSubmit, + isShellFocused: false, + }), + ); + + await act(async () => { + await result.current.submitQuery('/mcp-prompt'); + }); + + expect(mockHandleSlashCommand).toHaveBeenCalledWith('/mcp-prompt'); + expect(mockAgentProtocol.send).toHaveBeenCalledWith({ + message: { content: [{ type: 'text', text: 'modified prompt' }] }, + }); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageType.USER, + text: 'modified prompt', + }), + expect.any(Number), + ); + }); + + it('should update streamingState based on agent_start and agent_end events', async () => { + const mockHandleSlashCommand = vi.fn().mockResolvedValue(false); + const { result } = await renderHookWithProviders(() => + useAgentStream({ + agent: mockAgentProtocol as unknown as AgentProtocol, + addItem: mockAddItem, + handleSlashCommand: mockHandleSlashCommand, + onCancelSubmit: mockOnCancelSubmit, + isShellFocused: false, + }), + ); + + const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock .calls[0][0]; expect(result.current.streamingState).toBe(StreamingState.Idle); @@ -116,14 +175,15 @@ describe('useAgentStream', () => { it('should accumulate text content and update pendingHistoryItems', async () => { const { result } = await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: vi.fn().mockResolvedValue(false), onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), ); - const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock + const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock .calls[0][0]; act(() => { @@ -160,14 +220,15 @@ describe('useAgentStream', () => { it('should process thought events and update thought state', async () => { const { result } = await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: vi.fn().mockResolvedValue(false), onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), ); - const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock + const eventHandler = vi.mocked(mockAgentProtocol.subscribe).mock .calls[0][0]; act(() => { @@ -190,8 +251,9 @@ describe('useAgentStream', () => { it('should call agent.abort when cancelOngoingRequest is called', async () => { const { result } = await renderHookWithProviders(() => useAgentStream({ - agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol, + agent: mockAgentProtocol as unknown as AgentProtocol, addItem: mockAddItem, + handleSlashCommand: vi.fn().mockResolvedValue(false), onCancelSubmit: mockOnCancelSubmit, isShellFocused: false, }), @@ -201,7 +263,7 @@ describe('useAgentStream', () => { await result.current.cancelOngoingRequest(); }); - expect(mockLegacyAgentProtocol.abort).toHaveBeenCalled(); + expect(mockAgentProtocol.abort).toHaveBeenCalled(); expect(mockOnCancelSubmit).toHaveBeenCalledWith(false, true); }); }); diff --git a/packages/cli/src/ui/hooks/useAgentStream.ts b/packages/cli/src/ui/hooks/useAgentStream.ts index 982391a437..c8aaaf2bcf 100644 --- a/packages/cli/src/ui/hooks/useAgentStream.ts +++ b/packages/cli/src/ui/hooks/useAgentStream.ts @@ -5,12 +5,14 @@ */ import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; +import { type PartListUnion } from '@google/genai'; import { getErrorMessage, MessageSenderType, debugLogger, geminiPartsToContentParts, displayContentToString, + partToString, parseThought, CoreToolCallStatus, type ApprovalMode, @@ -20,15 +22,16 @@ import { type AgentEvent, type AgentProtocol, type Logger, - type Part, } from '@google/gemini-cli-core'; import type { HistoryItemWithoutId, LoopDetectionConfirmationRequest, IndividualToolCallDisplay, HistoryItemToolDisplayGroup, + SlashCommandProcessorResult, } from '../types.js'; import { StreamingState, MessageType } from '../types.js'; +import { isSlashCommand } from '../utils/commandUtils.js'; import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { getToolGroupBorderAppearance } from '../utils/borderStyles.js'; import { type BackgroundTask } from './useExecutionLifecycle.js'; @@ -41,12 +44,16 @@ import { useKeypress } from './useKeypress.js'; export interface UseAgentStreamOptions { agent?: AgentProtocol; addItem: UseHistoryManagerReturn['addItem']; + handleSlashCommand: ( + cmd: string, + ) => Promise; onCancelSubmit: ( shouldRestorePrompt?: boolean, clearBuffer?: boolean, ) => void; isShellFocused?: boolean; logger?: Logger | null; + isActive?: boolean; } /** @@ -56,9 +63,11 @@ export interface UseAgentStreamOptions { export const useAgentStream = ({ agent, addItem, + handleSlashCommand, onCancelSubmit, isShellFocused, logger, + isActive = true, }: UseAgentStreamOptions) => { const [initError] = useState(null); const [retryStatus] = useState(null); @@ -327,9 +336,10 @@ export const useAgentStream = ({ ); useEffect(() => { + if (!isActive) return; const unsubscribe = agent?.subscribe(handleEvent); return () => unsubscribe?.(); - }, [agent, handleEvent]); + }, [agent, handleEvent, isActive]); useKeypress( (key) => { @@ -341,14 +351,15 @@ export const useAgentStream = ({ }, { isActive: - streamingState === StreamingState.Responding || - streamingState === StreamingState.WaitingForConfirmation, + isActive && + (streamingState === StreamingState.Responding || + streamingState === StreamingState.WaitingForConfirmation), }, ); const submitQuery = useCallback( async ( - query: Part[] | string, + query: PartListUnion, options?: { isContinuation: boolean }, _prompt_id?: string, ) => { @@ -360,16 +371,40 @@ export const useAgentStream = ({ geminiMessageBufferRef.current = ''; + let localQuery: PartListUnion = query; + if (!options?.isContinuation) { - if (typeof query === 'string') { - addItem({ type: MessageType.USER, text: query }, timestamp); - void logger?.logMessage(MessageSenderType.USER, query); + if (typeof localQuery === 'string') { + const trimmedQuery = localQuery.trim(); + + if (isSlashCommand(trimmedQuery)) { + const slashResult = await handleSlashCommand(trimmedQuery); + if (slashResult) { + if (slashResult.type === 'handled') { + return; + } + if (slashResult.type === 'submit_prompt') { + localQuery = slashResult.content; + } + // schedule_tool is not yet supported in useAgentStream (mirrors handleAtCommand lack of support here) + } + } } + + const queryText = + typeof localQuery === 'string' + ? localQuery + : partToString(localQuery); + + addItem({ type: MessageType.USER, text: queryText }, timestamp); + void logger?.logMessage(MessageSenderType.USER, queryText); startNewPrompt(); } const parts = geminiPartsToContentParts( - typeof query === 'string' ? [{ text: query }] : query, + (Array.isArray(localQuery) ? localQuery : [localQuery]).map((p) => + typeof p === 'string' ? { text: p } : p, + ), ); try { @@ -384,10 +419,11 @@ export const useAgentStream = ({ ); } }, - [agent, addItem, logger, startNewPrompt], + [agent, addItem, logger, startNewPrompt, handleSlashCommand], ); useEffect(() => { + if (!isActive) return; if (trackedTools.length > 0) { const isNewBatch = !trackedTools.some((tc) => pushedToolCallIdsRef.current.has(tc.callId), @@ -406,11 +442,12 @@ export const useAgentStream = ({ setPushedToolCallIds, setIsFirstToolInGroup, streamingState, + isActive, ]); // Push completed tools to history useEffect(() => { - if (trackedTools.length === 0) return; + if (!isActive || trackedTools.length === 0) return; // We only push to history once all currently known tools in the turn are terminal. // This allows ToolGroupDisplay to correctly hoist ALL notices (topics) for the turn. @@ -480,6 +517,7 @@ export const useAgentStream = ({ activePtyId, isShellFocused, backgroundTasks, + isActive, ]); const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => { diff --git a/packages/cli/src/ui/hooks/useExecutionLifecycle.ts b/packages/cli/src/ui/hooks/useExecutionLifecycle.ts index 884ab544de..d635ea98b5 100644 --- a/packages/cli/src/ui/hooks/useExecutionLifecycle.ts +++ b/packages/cli/src/ui/hooks/useExecutionLifecycle.ts @@ -86,6 +86,7 @@ export const useExecutionLifecycle = ( terminalHeight?: number, activeBackgroundExecutionId?: number, isWaitingForConfirmation?: boolean, + isActive: boolean = true, ) => { const [state, dispatch] = useReducer(shellReducer, initialState); @@ -111,6 +112,7 @@ export const useExecutionLifecycle = ( state.activeShellPtyId ?? activeBackgroundExecutionId ?? undefined; useEffect(() => { + if (!isActive) return; const isForegroundActive = !!activePtyId || !!isWaitingForConfirmation; if (isForegroundActive) { @@ -144,20 +146,23 @@ export const useExecutionLifecycle = ( state.isBackgroundTaskVisible, m, dispatch, + isActive, ]); useEffect( () => () => { + if (!isActive) return; // Unsubscribe from all background task events on unmount for (const unsubscribe of m.subscriptions.values()) { unsubscribe(); } m.subscriptions.clear(); }, - [m], + [m, isActive], ); const toggleBackgroundTasks = useCallback(() => { + if (!isActive) return; if (state.backgroundTasks.size > 0) { const willBeVisible = !state.isBackgroundTaskVisible; dispatch({ type: 'TOGGLE_VISIBILITY' }); @@ -193,9 +198,11 @@ export const useExecutionLifecycle = ( isWaitingForConfirmation, m, dispatch, + isActive, ]); const backgroundCurrentExecution = useCallback(() => { + if (!isActive) return; const pidToBackground = state.activeShellPtyId ?? activeBackgroundExecutionId; if (pidToBackground) { @@ -218,10 +225,11 @@ export const useExecutionLifecycle = ( m.restoreTimeout = null; } } - }, [state.activeShellPtyId, activeBackgroundExecutionId, m]); + }, [state.activeShellPtyId, activeBackgroundExecutionId, m, isActive]); const dismissBackgroundTask = useCallback( async (pid: number) => { + if (!isActive) return; const shell = state.backgroundTasks.get(pid); if (shell) { if (shell.status === 'running') { @@ -240,7 +248,7 @@ export const useExecutionLifecycle = ( } } }, - [state.backgroundTasks, dispatch, m], + [state.backgroundTasks, dispatch, m, isActive], ); const registerBackgroundTask = useCallback( @@ -250,6 +258,7 @@ export const useExecutionLifecycle = ( initialOutput: string | AnsiOutput, completionBehavior?: CompletionBehavior, ) => { + if (!isActive) return; m.backgroundedPids.add(pid); dispatch({ type: 'REGISTER_TASK', @@ -313,7 +322,7 @@ export const useExecutionLifecycle = ( dataUnsubscribe(); }); }, - [dispatch, m], + [dispatch, m, isActive], ); // Auto-register any execution that gets backgrounded, regardless of type. @@ -321,6 +330,7 @@ export const useExecutionLifecycle = ( // ExecutionLifecycleService.createExecution() or attachExecution() // automatically gets Ctrl+B support — no UI changes needed per tool. useEffect(() => { + if (!isActive) return; const listener = (info: { executionId: number; label: string; @@ -342,7 +352,7 @@ export const useExecutionLifecycle = ( return () => { ExecutionLifecycleService.offBackground(listener); }; - }, [registerBackgroundTask, m]); + }, [registerBackgroundTask, m, isActive]); const handleShellCommand = useCallback( (rawQuery: PartListUnion, abortSignal: AbortSignal): boolean => { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index ac63733fa9..e4a1a0af98 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -238,6 +238,7 @@ export const useGeminiStream = ( terminalHeight: number, isShellFocused?: boolean, consumeUserHint?: () => string | null, + isActive: boolean = true, ) => { const [initError, setInitError] = useState(null); const [retryStatus, setRetryStatus] = useState( @@ -276,6 +277,7 @@ export const useGeminiStream = ( }, [config]); useEffect(() => { + if (!isActive) return; const handleRetryAttempt = (payload: RetryAttemptPayload) => { if (turnCancelledRef.current || !isRespondingRef.current) { return; @@ -286,7 +288,7 @@ export const useGeminiStream = ( return () => { coreEvents.off(CoreEvent.RetryAttempt, handleRetryAttempt); }; - }, [isRespondingRef]); + }, [isRespondingRef, isActive]); const [ toolCalls, @@ -387,6 +389,11 @@ export const useGeminiStream = ( [setIsResponding], ); + const streamingState = useMemo( + () => calculateStreamingState(isResponding, toolCalls), + [isResponding, toolCalls], + ); + const { handleShellCommand, activeShellPtyId, @@ -409,15 +416,13 @@ export const useGeminiStream = ( terminalWidth, terminalHeight, activeBackgroundExecutionId, - ); - - const streamingState = useMemo( - () => calculateStreamingState(isResponding, toolCalls), - [isResponding, toolCalls], + streamingState === StreamingState.WaitingForConfirmation, + isActive, ); // Reset tracking when a new batch of tools starts useEffect(() => { + if (!isActive) return; if (toolCalls.length > 0) { const isNewBatch = !toolCalls.some((tc) => pushedToolCallIdsRef.current.has(tc.request.callId), @@ -437,10 +442,12 @@ export const useGeminiStream = ( setPushedToolCallIds, setIsFirstToolInGroup, streamingState, + isActive, ]); // Push completed tools to history as they finish useEffect(() => { + if (!isActive) return; const toolsToPush: TrackedToolCall[] = []; for (let i = 0; i < toolCalls.length; i++) { const tc = toolCalls[i]; @@ -598,8 +605,11 @@ export const useGeminiStream = ( isShellFocused, backgroundTasks, settings.merged.ui?.compactToolOutput, + isActive, ]); + const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => { + if (!isActive) return []; const remainingTools = toolCalls.filter( (tc) => !pushedToolCallIds.has(tc.request.callId), ); @@ -704,6 +714,7 @@ export const useGeminiStream = ( isShellFocused, backgroundTasks, settings.merged.ui?.compactToolOutput, + isActive, ]); const lastQueryRef = useRef(null); @@ -721,6 +732,7 @@ export const useGeminiStream = ( const prevActiveShellPtyIdRef = useRef(null); useEffect(() => { + if (!isActive) return; if ( turnCancelledRef.current && prevActiveShellPtyIdRef.current !== null && @@ -730,9 +742,10 @@ export const useGeminiStream = ( setIsResponding(false); } prevActiveShellPtyIdRef.current = activeShellPtyId; - }, [activeShellPtyId, addItem, setIsResponding]); + }, [activeShellPtyId, addItem, setIsResponding, isActive]); useEffect(() => { + if (!isActive) return; if ( config.getApprovalMode() === ApprovalMode.YOLO && streamingState === StreamingState.Idle @@ -751,7 +764,7 @@ export const useGeminiStream = ( ); } } - }, [streamingState, config, history]); + }, [streamingState, config, history, isActive]); useEffect(() => { if (!isResponding) { @@ -809,6 +822,7 @@ export const useGeminiStream = ( const cancelOngoingRequest = useCallback( (clearBuffer: boolean = true) => { + if (!isActive) return; // If we are already cancelled, do nothing if (turnCancelledRef.current) { if (clearBuffer) { @@ -920,6 +934,7 @@ export const useGeminiStream = ( toolCalls, activeShellPtyId, setIsResponding, + isActive, ], ); @@ -933,8 +948,9 @@ export const useGeminiStream = ( }, { isActive: - streamingState === StreamingState.Responding || - streamingState === StreamingState.WaitingForConfirmation, + isActive && + (streamingState === StreamingState.Responding || + streamingState === StreamingState.WaitingForConfirmation), }, );