From 08c3b7aacee485cfc3b4f2eae15b2ff233514ed0 Mon Sep 17 00:00:00 2001 From: Bryan Morgan Date: Sat, 31 Jan 2026 23:25:12 -0500 Subject: [PATCH] feat(acp): add session resume support --- .../cli/src/ui/hooks/useSessionBrowser.ts | 196 +----------------- packages/cli/src/utils/sessionUtils.ts | 194 +++++++++++++++++ .../cli/src/zed-integration/acpResume.test.ts | 149 +++++++++++++ .../zed-integration/zedIntegration.test.ts | 2 +- .../cli/src/zed-integration/zedIntegration.ts | 119 ++++++++++- 5 files changed, 467 insertions(+), 193 deletions(-) create mode 100644 packages/cli/src/zed-integration/acpResume.test.ts diff --git a/packages/cli/src/ui/hooks/useSessionBrowser.ts b/packages/cli/src/ui/hooks/useSessionBrowser.ts index c214011c8b..de6495c3b9 100644 --- a/packages/cli/src/ui/hooks/useSessionBrowser.ts +++ b/packages/cli/src/ui/hooks/useSessionBrowser.ts @@ -13,11 +13,12 @@ import type { ConversationRecord, ResumedSessionData, } from '@google/gemini-cli-core'; -import type { Part } from '@google/genai'; -import { partListUnionToString, coreEvents } from '@google/gemini-cli-core'; -import { checkExhaustive } from '../../utils/checks.js'; +import { coreEvents } from '@google/gemini-cli-core'; import type { SessionInfo } from '../../utils/sessionUtils.js'; -import { MessageType, ToolCallStatus } from '../types.js'; +import { convertSessionToHistoryFormats } from '../../utils/sessionUtils.js'; +import type { Part } from '@google/genai'; + +export { convertSessionToHistoryFormats }; export const useSessionBrowser = ( config: Config, @@ -112,190 +113,3 @@ export const useSessionBrowser = ( ), }; }; - -/** - * Converts session/conversation data into UI history and Gemini client history formats. - */ -export function convertSessionToHistoryFormats( - messages: ConversationRecord['messages'], -): { - uiHistory: HistoryItemWithoutId[]; - clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>; -} { - const uiHistory: HistoryItemWithoutId[] = []; - - for (const msg of messages) { - // Add the message only if it has content - const displayContentString = msg.displayContent - ? partListUnionToString(msg.displayContent) - : undefined; - const contentString = partListUnionToString(msg.content); - const uiText = displayContentString || contentString; - - if (uiText.trim()) { - let messageType: MessageType; - switch (msg.type) { - case 'user': - messageType = MessageType.USER; - break; - case 'info': - messageType = MessageType.INFO; - break; - case 'error': - messageType = MessageType.ERROR; - break; - case 'warning': - messageType = MessageType.WARNING; - break; - case 'gemini': - messageType = MessageType.GEMINI; - break; - default: - checkExhaustive(msg); - messageType = MessageType.GEMINI; - break; - } - - uiHistory.push({ - type: messageType, - text: uiText, - }); - } - - // Add tool calls if present - if ( - msg.type !== 'user' && - 'toolCalls' in msg && - msg.toolCalls && - msg.toolCalls.length > 0 - ) { - uiHistory.push({ - type: 'tool_group', - tools: msg.toolCalls.map((tool) => ({ - callId: tool.id, - name: tool.displayName || tool.name, - description: tool.description || '', - renderOutputAsMarkdown: tool.renderOutputAsMarkdown ?? true, - status: - tool.status === 'success' - ? ToolCallStatus.Success - : ToolCallStatus.Error, - resultDisplay: tool.resultDisplay, - confirmationDetails: undefined, - })), - }); - } - } - - // Convert to Gemini client history format - const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = []; - - for (const msg of messages) { - // Skip system/error messages and user slash commands - if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') { - continue; - } - - if (msg.type === 'user') { - // Skip user slash commands - const contentString = partListUnionToString(msg.content); - if ( - contentString.trim().startsWith('/') || - contentString.trim().startsWith('?') - ) { - continue; - } - - // Add regular user message - clientHistory.push({ - role: 'user', - parts: Array.isArray(msg.content) - ? (msg.content as Part[]) - : [{ text: contentString }], - }); - } else if (msg.type === 'gemini') { - // Handle Gemini messages with potential tool calls - const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0; - - if (hasToolCalls) { - // Create model message with function calls - const modelParts: Part[] = []; - - // Add text content if present - const contentString = partListUnionToString(msg.content); - if (msg.content && contentString.trim()) { - modelParts.push({ text: contentString }); - } - - // Add function calls - for (const toolCall of msg.toolCalls!) { - modelParts.push({ - functionCall: { - name: toolCall.name, - args: toolCall.args, - ...(toolCall.id && { id: toolCall.id }), - }, - }); - } - - clientHistory.push({ - role: 'model', - parts: modelParts, - }); - - // Create single function response message with all tool call responses - const functionResponseParts: Part[] = []; - for (const toolCall of msg.toolCalls!) { - if (toolCall.result) { - // Convert PartListUnion result to function response format - let responseData: Part; - - if (typeof toolCall.result === 'string') { - responseData = { - functionResponse: { - id: toolCall.id, - name: toolCall.name, - response: { - output: toolCall.result, - }, - }, - }; - } else if (Array.isArray(toolCall.result)) { - // toolCall.result is an array containing properly formatted - // function responses - functionResponseParts.push(...(toolCall.result as Part[])); - continue; - } else { - // Fallback for non-array results - responseData = toolCall.result; - } - - functionResponseParts.push(responseData); - } - } - - // Only add user message if we have function responses - if (functionResponseParts.length > 0) { - clientHistory.push({ - role: 'user', - parts: functionResponseParts, - }); - } - } else { - // Regular Gemini message without tool calls - const contentString = partListUnionToString(msg.content); - if (msg.content && contentString.trim()) { - clientHistory.push({ - role: 'model', - parts: [{ text: contentString }], - }); - } - } - } - } - - return { - uiHistory, - clientHistory, - }; -} diff --git a/packages/cli/src/utils/sessionUtils.ts b/packages/cli/src/utils/sessionUtils.ts index 1d7be693b8..63ccf4d14a 100644 --- a/packages/cli/src/utils/sessionUtils.ts +++ b/packages/cli/src/utils/sessionUtils.ts @@ -16,6 +16,13 @@ import { import * as fs from 'node:fs/promises'; import path from 'node:path'; import { stripUnsafeCharacters } from '../ui/utils/textUtils.js'; +import type { Part } from '@google/genai'; +import { checkExhaustive } from './checks.js'; +import { + MessageType, + ToolCallStatus, + type HistoryItemWithoutId, +} from '../ui/types.js'; /** * Constant for the resume "latest" identifier. @@ -514,3 +521,190 @@ export class SessionSelector { } } } + +/** + * Converts session/conversation data into UI history and Gemini client history formats. + */ +export function convertSessionToHistoryFormats( + messages: ConversationRecord['messages'], +): { + uiHistory: HistoryItemWithoutId[]; + clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>; +} { + const uiHistory: HistoryItemWithoutId[] = []; + + for (const msg of messages) { + // Add the message only if it has content + const displayContentString = msg.displayContent + ? partListUnionToString(msg.displayContent) + : undefined; + const contentString = partListUnionToString(msg.content); + const uiText = displayContentString || contentString; + + if (uiText.trim()) { + let messageType: MessageType; + switch (msg.type) { + case 'user': + messageType = MessageType.USER; + break; + case 'info': + messageType = MessageType.INFO; + break; + case 'error': + messageType = MessageType.ERROR; + break; + case 'warning': + messageType = MessageType.WARNING; + break; + case 'gemini': + messageType = MessageType.GEMINI; + break; + default: + checkExhaustive(msg); + messageType = MessageType.GEMINI; + break; + } + + uiHistory.push({ + type: messageType, + text: uiText, + }); + } + + // Add tool calls if present + if ( + msg.type !== 'user' && + 'toolCalls' in msg && + msg.toolCalls && + msg.toolCalls.length > 0 + ) { + uiHistory.push({ + type: 'tool_group', + tools: msg.toolCalls.map((tool) => ({ + callId: tool.id, + name: tool.displayName || tool.name, + description: tool.description || '', + renderOutputAsMarkdown: tool.renderOutputAsMarkdown ?? true, + status: + tool.status === 'success' + ? ToolCallStatus.Success + : ToolCallStatus.Error, + resultDisplay: tool.resultDisplay, + confirmationDetails: undefined, + })), + }); + } + } + + // Convert to Gemini client history format + const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = []; + + for (const msg of messages) { + // Skip system/error messages and user slash commands + if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') { + continue; + } + + if (msg.type === 'user') { + // Skip user slash commands + const contentString = partListUnionToString(msg.content); + if ( + contentString.trim().startsWith('/') || + contentString.trim().startsWith('?') + ) { + continue; + } + + // Add regular user message + clientHistory.push({ + role: 'user', + parts: Array.isArray(msg.content) + ? (msg.content as Part[]) + : [{ text: contentString }], + }); + } else if (msg.type === 'gemini') { + // Handle Gemini messages with potential tool calls + const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0; + + if (hasToolCalls) { + // Create model message with function calls + const modelParts: Part[] = []; + + // Add text content if present + const contentString = partListUnionToString(msg.content); + if (msg.content && contentString.trim()) { + modelParts.push({ text: contentString }); + } + + // Add function calls + for (const toolCall of msg.toolCalls!) { + modelParts.push({ + functionCall: { + name: toolCall.name, + args: toolCall.args, + ...(toolCall.id && { id: toolCall.id }), + }, + }); + } + + clientHistory.push({ + role: 'model', + parts: modelParts, + }); + + // Create single function response message with all tool call responses + const functionResponseParts: Part[] = []; + for (const toolCall of msg.toolCalls!) { + if (toolCall.result) { + // Convert PartListUnion result to function response format + let responseData: Part; + + if (typeof toolCall.result === 'string') { + responseData = { + functionResponse: { + id: toolCall.id, + name: toolCall.name, + response: { + output: toolCall.result, + }, + }, + }; + } else if (Array.isArray(toolCall.result)) { + // toolCall.result is an array containing properly formatted + // function responses + functionResponseParts.push(...(toolCall.result as Part[])); + continue; + } else { + // Fallback for non-array results + responseData = toolCall.result; + } + + functionResponseParts.push(responseData); + } + } + + // Only add user message if we have function responses + if (functionResponseParts.length > 0) { + clientHistory.push({ + role: 'user', + parts: functionResponseParts, + }); + } + } else { + // Regular Gemini message without tool calls + const contentString = partListUnionToString(msg.content); + if (msg.content && contentString.trim()) { + clientHistory.push({ + role: 'model', + parts: [{ text: contentString }], + }); + } + } + } + } + + return { + uiHistory, + clientHistory, + }; +} diff --git a/packages/cli/src/zed-integration/acpResume.test.ts b/packages/cli/src/zed-integration/acpResume.test.ts new file mode 100644 index 0000000000..c98c7c917f --- /dev/null +++ b/packages/cli/src/zed-integration/acpResume.test.ts @@ -0,0 +1,149 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + type Mocked, + type Mock, +} from 'vitest'; +import { GeminiAgent } from './zedIntegration.js'; +import * as acp from '@agentclientprotocol/sdk'; +import { AuthType, type Config } from '@google/gemini-cli-core'; +import { loadCliConfig, type CliArgs } from '../config/config.js'; +import { + SessionSelector, + convertSessionToHistoryFormats, +} from '../utils/sessionUtils.js'; +import type { LoadedSettings } from '../config/settings.js'; + +vi.mock('../config/config.js', () => ({ + loadCliConfig: vi.fn(), +})); + +vi.mock('../utils/sessionUtils.js', () => ({ + SessionSelector: vi.fn(), + convertSessionToHistoryFormats: vi.fn(), +})); + +describe('GeminiAgent Session Resume', () => { + let mockConfig: Mocked; + let mockSettings: Mocked; + let mockArgv: CliArgs; + let mockConnection: Mocked; + let agent: GeminiAgent; + + beforeEach(() => { + mockConfig = { + refreshAuth: vi.fn().mockResolvedValue(undefined), + initialize: vi.fn().mockResolvedValue(undefined), + getFileSystemService: vi.fn(), + setFileSystemService: vi.fn(), + getGeminiClient: vi.fn().mockReturnValue({ + initialize: vi.fn().mockResolvedValue(undefined), + resumeChat: vi.fn().mockResolvedValue(undefined), + getChat: vi.fn().mockReturnValue({}), + }), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/tmp/project'), + }, + } as unknown as Mocked; + mockSettings = { + merged: { + security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } }, + mcpServers: {}, + }, + setValue: vi.fn(), + } as unknown as Mocked; + mockArgv = {} as unknown as CliArgs; + mockConnection = { + sessionUpdate: vi.fn().mockResolvedValue(undefined), + } as unknown as Mocked; + + (loadCliConfig as Mock).mockResolvedValue(mockConfig); + + agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection); + }); + + it('should advertise loadSession capability', async () => { + const response = await agent.initialize({ + protocolVersion: acp.PROTOCOL_VERSION, + }); + expect(response.agentCapabilities?.loadSession).toBe(true); + }); + + it('should load an existing session and stream history', async () => { + const sessionId = 'existing-session-id'; + const sessionData = { + sessionId, + messages: [ + { type: 'user', content: [{ text: 'Hello' }] }, + { + type: 'gemini', + content: [{ text: 'Hi there' }], + thoughts: [{ subject: 'Thinking', description: 'about greeting' }], + }, + ], + }; + + (SessionSelector as unknown as Mock).mockImplementation(() => ({ + resolveSession: vi.fn().mockResolvedValue({ + sessionData, + sessionPath: '/path/to/session.json', + }), + })); + + (convertSessionToHistoryFormats as unknown as Mock).mockReturnValue({ + clientHistory: [], + uiHistory: [], + }); + + const response = await agent.loadSession({ + sessionId, + cwd: '/tmp', + mcpServers: [], + }); + + expect(response).toEqual({}); + expect(mockConfig.getGeminiClient().resumeChat).toHaveBeenCalled(); + + // Verify history streaming (it's called async, so we might need to wait or use a spy on Session) + // In this case, we can verify mockConnection.sessionUpdate calls. + // Since it's not awaited in loadSession, we might need a small delay or use vi.waitFor + + await vi.waitFor(() => { + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'user_message_chunk', + content: expect.objectContaining({ text: 'Hello' }), + }), + }), + ); + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'agent_thought_chunk', + content: expect.objectContaining({ + text: '**Thinking**\nabout greeting', + }), + }), + }), + ); + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'agent_message_chunk', + content: expect.objectContaining({ text: 'Hi there' }), + }), + }), + ); + }); + }); +}); diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index f0ceec4e22..7b3d0a266d 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -129,7 +129,7 @@ describe('GeminiAgent', () => { expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); expect(response.authMethods).toHaveLength(3); - expect(response.agentCapabilities?.loadSession).toBe(false); + expect(response.agentCapabilities?.loadSession).toBe(true); }); it('should authenticate correctly', async () => { diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 7273c0b961..ce40f047f5 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -10,6 +10,7 @@ import type { ToolResult, ToolCallConfirmationDetails, FilterFilesOptions, + ConversationRecord, } from '@google/gemini-cli-core'; import { AuthType, @@ -32,6 +33,7 @@ import { createWorkingStdio, startupProfiler, Kind, + partListUnionToString, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; import { AcpFileSystemService } from './fileSystemService.js'; @@ -47,6 +49,10 @@ import { randomUUID } from 'node:crypto'; import type { CliArgs } from '../config/config.js'; import { loadCliConfig } from '../config/config.js'; import { runExitCleanup } from '../utils/cleanup.js'; +import { + SessionSelector, + convertSessionToHistoryFormats, +} from '../utils/sessionUtils.js'; export async function runZedIntegration( config: Config, @@ -107,7 +113,7 @@ export class GeminiAgent { protocolVersion: acp.PROTOCOL_VERSION, authMethods, agentCapabilities: { - loadSession: false, + loadSession: true, promptCapabilities: { image: true, audio: true, @@ -184,6 +190,69 @@ export class GeminiAgent { }; } + async loadSession({ + sessionId, + cwd, + mcpServers, + }: acp.LoadSessionRequest): Promise { + const config = await this.newSessionConfig(sessionId, cwd, mcpServers); + + let isAuthenticated = false; + if (this.settings.merged.security.auth.selectedType) { + try { + await config.refreshAuth( + this.settings.merged.security.auth.selectedType, + ); + isAuthenticated = true; + } catch (e) { + debugLogger.error(`Authentication failed: ${e}`); + } + } + + if (!isAuthenticated) { + throw acp.RequestError.authRequired(); + } + + const sessionSelector = new SessionSelector(config); + const { sessionData, sessionPath } = + await sessionSelector.resolveSession(sessionId); + + if (this.clientCapabilities?.fs) { + const acpFileSystemService = new AcpFileSystemService( + this.connection, + sessionId, + this.clientCapabilities.fs, + config.getFileSystemService(), + ); + config.setFileSystemService(acpFileSystemService); + } + + const { clientHistory } = convertSessionToHistoryFormats( + sessionData.messages, + ); + + const geminiClient = config.getGeminiClient(); + await geminiClient.initialize(); + await geminiClient.resumeChat(clientHistory, { + conversation: sessionData, + filePath: sessionPath, + }); + + const session = new Session( + sessionId, + geminiClient.getChat(), + config, + this.connection, + ); + this.sessions.set(sessionId, session); + + // Stream history back to client + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.streamHistory(sessionData.messages); + + return {}; + } + async newSessionConfig( sessionId: string, cwd: string, @@ -269,6 +338,54 @@ export class Session { this.pendingPrompt = null; } + async streamHistory(messages: ConversationRecord['messages']): Promise { + for (const msg of messages) { + const contentString = partListUnionToString(msg.content); + + if (msg.type === 'user') { + if (contentString.trim()) { + await this.sendUpdate({ + sessionUpdate: 'user_message_chunk', + content: { type: 'text', text: contentString }, + }); + } + } else if (msg.type === 'gemini') { + // Thoughts + if (msg.thoughts) { + for (const thought of msg.thoughts) { + const thoughtText = `**${thought.subject}**\n${thought.description}`; + await this.sendUpdate({ + sessionUpdate: 'agent_thought_chunk', + content: { type: 'text', text: thoughtText }, + }); + } + } + + // Message text + if (contentString.trim()) { + await this.sendUpdate({ + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text: contentString }, + }); + } + + // Tool calls + if (msg.toolCalls) { + for (const toolCall of msg.toolCalls) { + await this.sendUpdate({ + sessionUpdate: 'tool_call', + toolCallId: toolCall.id, + status: toolCall.status === 'success' ? 'completed' : 'failed', + title: toolCall.displayName || toolCall.name, + content: [], // We could potentially reconstruct content here if needed + kind: 'other', // We don't have Kind here easily without re-resolving tools + }); + } + } + } + } + } + async prompt(params: acp.PromptRequest): Promise { this.pendingPrompt?.abort(); const pendingSend = new AbortController();