mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 22:55:13 +00:00
feat(acp): add session resume support (#18043)
Co-authored-by: Jack Wotherspoon <jackwoth@google.com>
This commit is contained in:
@@ -24,7 +24,14 @@ import { coreEvents } from '@google/gemini-cli-core';
|
|||||||
// Mock modules
|
// Mock modules
|
||||||
vi.mock('fs/promises');
|
vi.mock('fs/promises');
|
||||||
vi.mock('path');
|
vi.mock('path');
|
||||||
vi.mock('../../utils/sessionUtils.js');
|
vi.mock('../../utils/sessionUtils.js', async (importOriginal) => {
|
||||||
|
const actual =
|
||||||
|
await importOriginal<typeof import('../../utils/sessionUtils.js')>();
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
getSessionFiles: vi.fn(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
const MOCKED_PROJECT_TEMP_DIR = '/test/project/temp';
|
const MOCKED_PROJECT_TEMP_DIR = '/test/project/temp';
|
||||||
const MOCKED_CHATS_DIR = '/test/project/temp/chats';
|
const MOCKED_CHATS_DIR = '/test/project/temp/chats';
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ import type {
|
|||||||
ConversationRecord,
|
ConversationRecord,
|
||||||
ResumedSessionData,
|
ResumedSessionData,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import type { Part } from '@google/genai';
|
import { coreEvents } from '@google/gemini-cli-core';
|
||||||
import { partListUnionToString, coreEvents } from '@google/gemini-cli-core';
|
|
||||||
import { checkExhaustive } from '../../utils/checks.js';
|
|
||||||
import type { SessionInfo } from '../../utils/sessionUtils.js';
|
import type { SessionInfo } from '../../utils/sessionUtils.js';
|
||||||
import { MessageType, ToolCallStatus } from '../types.js';
|
import { convertSessionToHistoryFormats } from '../../utils/sessionUtils.js';
|
||||||
|
import type { Part } from '@google/genai';
|
||||||
|
|
||||||
|
export { convertSessionToHistoryFormats };
|
||||||
|
|
||||||
export const useSessionBrowser = (
|
export const useSessionBrowser = (
|
||||||
config: Config,
|
config: Config,
|
||||||
@@ -112,190 +113,3 @@ export const useSessionBrowser = (
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts session/conversation data into UI history and Gemini client history formats.
|
|
||||||
*/
|
|
||||||
export function convertSessionToHistoryFormats(
|
|
||||||
messages: ConversationRecord['messages'],
|
|
||||||
): {
|
|
||||||
uiHistory: HistoryItemWithoutId[];
|
|
||||||
clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>;
|
|
||||||
} {
|
|
||||||
const uiHistory: HistoryItemWithoutId[] = [];
|
|
||||||
|
|
||||||
for (const msg of messages) {
|
|
||||||
// Add the message only if it has content
|
|
||||||
const displayContentString = msg.displayContent
|
|
||||||
? partListUnionToString(msg.displayContent)
|
|
||||||
: undefined;
|
|
||||||
const contentString = partListUnionToString(msg.content);
|
|
||||||
const uiText = displayContentString || contentString;
|
|
||||||
|
|
||||||
if (uiText.trim()) {
|
|
||||||
let messageType: MessageType;
|
|
||||||
switch (msg.type) {
|
|
||||||
case 'user':
|
|
||||||
messageType = MessageType.USER;
|
|
||||||
break;
|
|
||||||
case 'info':
|
|
||||||
messageType = MessageType.INFO;
|
|
||||||
break;
|
|
||||||
case 'error':
|
|
||||||
messageType = MessageType.ERROR;
|
|
||||||
break;
|
|
||||||
case 'warning':
|
|
||||||
messageType = MessageType.WARNING;
|
|
||||||
break;
|
|
||||||
case 'gemini':
|
|
||||||
messageType = MessageType.GEMINI;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
checkExhaustive(msg);
|
|
||||||
messageType = MessageType.GEMINI;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
uiHistory.push({
|
|
||||||
type: messageType,
|
|
||||||
text: uiText,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add tool calls if present
|
|
||||||
if (
|
|
||||||
msg.type !== 'user' &&
|
|
||||||
'toolCalls' in msg &&
|
|
||||||
msg.toolCalls &&
|
|
||||||
msg.toolCalls.length > 0
|
|
||||||
) {
|
|
||||||
uiHistory.push({
|
|
||||||
type: 'tool_group',
|
|
||||||
tools: msg.toolCalls.map((tool) => ({
|
|
||||||
callId: tool.id,
|
|
||||||
name: tool.displayName || tool.name,
|
|
||||||
description: tool.description || '',
|
|
||||||
renderOutputAsMarkdown: tool.renderOutputAsMarkdown ?? true,
|
|
||||||
status:
|
|
||||||
tool.status === 'success'
|
|
||||||
? ToolCallStatus.Success
|
|
||||||
: ToolCallStatus.Error,
|
|
||||||
resultDisplay: tool.resultDisplay,
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
})),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to Gemini client history format
|
|
||||||
const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = [];
|
|
||||||
|
|
||||||
for (const msg of messages) {
|
|
||||||
// Skip system/error messages and user slash commands
|
|
||||||
if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (msg.type === 'user') {
|
|
||||||
// Skip user slash commands
|
|
||||||
const contentString = partListUnionToString(msg.content);
|
|
||||||
if (
|
|
||||||
contentString.trim().startsWith('/') ||
|
|
||||||
contentString.trim().startsWith('?')
|
|
||||||
) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add regular user message
|
|
||||||
clientHistory.push({
|
|
||||||
role: 'user',
|
|
||||||
parts: Array.isArray(msg.content)
|
|
||||||
? (msg.content as Part[])
|
|
||||||
: [{ text: contentString }],
|
|
||||||
});
|
|
||||||
} else if (msg.type === 'gemini') {
|
|
||||||
// Handle Gemini messages with potential tool calls
|
|
||||||
const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0;
|
|
||||||
|
|
||||||
if (hasToolCalls) {
|
|
||||||
// Create model message with function calls
|
|
||||||
const modelParts: Part[] = [];
|
|
||||||
|
|
||||||
// Add text content if present
|
|
||||||
const contentString = partListUnionToString(msg.content);
|
|
||||||
if (msg.content && contentString.trim()) {
|
|
||||||
modelParts.push({ text: contentString });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add function calls
|
|
||||||
for (const toolCall of msg.toolCalls!) {
|
|
||||||
modelParts.push({
|
|
||||||
functionCall: {
|
|
||||||
name: toolCall.name,
|
|
||||||
args: toolCall.args,
|
|
||||||
...(toolCall.id && { id: toolCall.id }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
clientHistory.push({
|
|
||||||
role: 'model',
|
|
||||||
parts: modelParts,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create single function response message with all tool call responses
|
|
||||||
const functionResponseParts: Part[] = [];
|
|
||||||
for (const toolCall of msg.toolCalls!) {
|
|
||||||
if (toolCall.result) {
|
|
||||||
// Convert PartListUnion result to function response format
|
|
||||||
let responseData: Part;
|
|
||||||
|
|
||||||
if (typeof toolCall.result === 'string') {
|
|
||||||
responseData = {
|
|
||||||
functionResponse: {
|
|
||||||
id: toolCall.id,
|
|
||||||
name: toolCall.name,
|
|
||||||
response: {
|
|
||||||
output: toolCall.result,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
} else if (Array.isArray(toolCall.result)) {
|
|
||||||
// toolCall.result is an array containing properly formatted
|
|
||||||
// function responses
|
|
||||||
functionResponseParts.push(...(toolCall.result as Part[]));
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
// Fallback for non-array results
|
|
||||||
responseData = toolCall.result;
|
|
||||||
}
|
|
||||||
|
|
||||||
functionResponseParts.push(responseData);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only add user message if we have function responses
|
|
||||||
if (functionResponseParts.length > 0) {
|
|
||||||
clientHistory.push({
|
|
||||||
role: 'user',
|
|
||||||
parts: functionResponseParts,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Regular Gemini message without tool calls
|
|
||||||
const contentString = partListUnionToString(msg.content);
|
|
||||||
if (msg.content && contentString.trim()) {
|
|
||||||
clientHistory.push({
|
|
||||||
role: 'model',
|
|
||||||
parts: [{ text: contentString }],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
uiHistory,
|
|
||||||
clientHistory,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,6 +16,13 @@ import {
|
|||||||
import * as fs from 'node:fs/promises';
|
import * as fs from 'node:fs/promises';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
import { stripUnsafeCharacters } from '../ui/utils/textUtils.js';
|
import { stripUnsafeCharacters } from '../ui/utils/textUtils.js';
|
||||||
|
import type { Part } from '@google/genai';
|
||||||
|
import { checkExhaustive } from './checks.js';
|
||||||
|
import {
|
||||||
|
MessageType,
|
||||||
|
ToolCallStatus,
|
||||||
|
type HistoryItemWithoutId,
|
||||||
|
} from '../ui/types.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constant for the resume "latest" identifier.
|
* Constant for the resume "latest" identifier.
|
||||||
@@ -514,3 +521,190 @@ export class SessionSelector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts session/conversation data into UI history and Gemini client history formats.
|
||||||
|
*/
|
||||||
|
export function convertSessionToHistoryFormats(
|
||||||
|
messages: ConversationRecord['messages'],
|
||||||
|
): {
|
||||||
|
uiHistory: HistoryItemWithoutId[];
|
||||||
|
clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>;
|
||||||
|
} {
|
||||||
|
const uiHistory: HistoryItemWithoutId[] = [];
|
||||||
|
|
||||||
|
for (const msg of messages) {
|
||||||
|
// Add the message only if it has content
|
||||||
|
const displayContentString = msg.displayContent
|
||||||
|
? partListUnionToString(msg.displayContent)
|
||||||
|
: undefined;
|
||||||
|
const contentString = partListUnionToString(msg.content);
|
||||||
|
const uiText = displayContentString || contentString;
|
||||||
|
|
||||||
|
if (uiText.trim()) {
|
||||||
|
let messageType: MessageType;
|
||||||
|
switch (msg.type) {
|
||||||
|
case 'user':
|
||||||
|
messageType = MessageType.USER;
|
||||||
|
break;
|
||||||
|
case 'info':
|
||||||
|
messageType = MessageType.INFO;
|
||||||
|
break;
|
||||||
|
case 'error':
|
||||||
|
messageType = MessageType.ERROR;
|
||||||
|
break;
|
||||||
|
case 'warning':
|
||||||
|
messageType = MessageType.WARNING;
|
||||||
|
break;
|
||||||
|
case 'gemini':
|
||||||
|
messageType = MessageType.GEMINI;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
checkExhaustive(msg);
|
||||||
|
messageType = MessageType.GEMINI;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
uiHistory.push({
|
||||||
|
type: messageType,
|
||||||
|
text: uiText,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool calls if present
|
||||||
|
if (
|
||||||
|
msg.type !== 'user' &&
|
||||||
|
'toolCalls' in msg &&
|
||||||
|
msg.toolCalls &&
|
||||||
|
msg.toolCalls.length > 0
|
||||||
|
) {
|
||||||
|
uiHistory.push({
|
||||||
|
type: 'tool_group',
|
||||||
|
tools: msg.toolCalls.map((tool) => ({
|
||||||
|
callId: tool.id,
|
||||||
|
name: tool.displayName || tool.name,
|
||||||
|
description: tool.description || '',
|
||||||
|
renderOutputAsMarkdown: tool.renderOutputAsMarkdown ?? true,
|
||||||
|
status:
|
||||||
|
tool.status === 'success'
|
||||||
|
? ToolCallStatus.Success
|
||||||
|
: ToolCallStatus.Error,
|
||||||
|
resultDisplay: tool.resultDisplay,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to Gemini client history format
|
||||||
|
const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = [];
|
||||||
|
|
||||||
|
for (const msg of messages) {
|
||||||
|
// Skip system/error messages and user slash commands
|
||||||
|
if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (msg.type === 'user') {
|
||||||
|
// Skip user slash commands
|
||||||
|
const contentString = partListUnionToString(msg.content);
|
||||||
|
if (
|
||||||
|
contentString.trim().startsWith('/') ||
|
||||||
|
contentString.trim().startsWith('?')
|
||||||
|
) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add regular user message
|
||||||
|
clientHistory.push({
|
||||||
|
role: 'user',
|
||||||
|
parts: Array.isArray(msg.content)
|
||||||
|
? (msg.content as Part[])
|
||||||
|
: [{ text: contentString }],
|
||||||
|
});
|
||||||
|
} else if (msg.type === 'gemini') {
|
||||||
|
// Handle Gemini messages with potential tool calls
|
||||||
|
const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0;
|
||||||
|
|
||||||
|
if (hasToolCalls) {
|
||||||
|
// Create model message with function calls
|
||||||
|
const modelParts: Part[] = [];
|
||||||
|
|
||||||
|
// Add text content if present
|
||||||
|
const contentString = partListUnionToString(msg.content);
|
||||||
|
if (msg.content && contentString.trim()) {
|
||||||
|
modelParts.push({ text: contentString });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add function calls
|
||||||
|
for (const toolCall of msg.toolCalls!) {
|
||||||
|
modelParts.push({
|
||||||
|
functionCall: {
|
||||||
|
name: toolCall.name,
|
||||||
|
args: toolCall.args,
|
||||||
|
...(toolCall.id && { id: toolCall.id }),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHistory.push({
|
||||||
|
role: 'model',
|
||||||
|
parts: modelParts,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create single function response message with all tool call responses
|
||||||
|
const functionResponseParts: Part[] = [];
|
||||||
|
for (const toolCall of msg.toolCalls!) {
|
||||||
|
if (toolCall.result) {
|
||||||
|
// Convert PartListUnion result to function response format
|
||||||
|
let responseData: Part;
|
||||||
|
|
||||||
|
if (typeof toolCall.result === 'string') {
|
||||||
|
responseData = {
|
||||||
|
functionResponse: {
|
||||||
|
id: toolCall.id,
|
||||||
|
name: toolCall.name,
|
||||||
|
response: {
|
||||||
|
output: toolCall.result,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else if (Array.isArray(toolCall.result)) {
|
||||||
|
// toolCall.result is an array containing properly formatted
|
||||||
|
// function responses
|
||||||
|
functionResponseParts.push(...(toolCall.result as Part[]));
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// Fallback for non-array results
|
||||||
|
responseData = toolCall.result;
|
||||||
|
}
|
||||||
|
|
||||||
|
functionResponseParts.push(responseData);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only add user message if we have function responses
|
||||||
|
if (functionResponseParts.length > 0) {
|
||||||
|
clientHistory.push({
|
||||||
|
role: 'user',
|
||||||
|
parts: functionResponseParts,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular Gemini message without tool calls
|
||||||
|
const contentString = partListUnionToString(msg.content);
|
||||||
|
if (msg.content && contentString.trim()) {
|
||||||
|
clientHistory.push({
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: contentString }],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
uiHistory,
|
||||||
|
clientHistory,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
224
packages/cli/src/zed-integration/acpResume.test.ts
Normal file
224
packages/cli/src/zed-integration/acpResume.test.ts
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
describe,
|
||||||
|
it,
|
||||||
|
expect,
|
||||||
|
vi,
|
||||||
|
beforeEach,
|
||||||
|
type Mocked,
|
||||||
|
type Mock,
|
||||||
|
} from 'vitest';
|
||||||
|
import { GeminiAgent } from './zedIntegration.js';
|
||||||
|
import * as acp from '@agentclientprotocol/sdk';
|
||||||
|
import { AuthType, type Config } from '@google/gemini-cli-core';
|
||||||
|
import { loadCliConfig, type CliArgs } from '../config/config.js';
|
||||||
|
import {
|
||||||
|
SessionSelector,
|
||||||
|
convertSessionToHistoryFormats,
|
||||||
|
} from '../utils/sessionUtils.js';
|
||||||
|
import type { LoadedSettings } from '../config/settings.js';
|
||||||
|
|
||||||
|
vi.mock('../config/config.js', () => ({
|
||||||
|
loadCliConfig: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock('../utils/sessionUtils.js', async (importOriginal) => {
|
||||||
|
const actual =
|
||||||
|
await importOriginal<typeof import('../utils/sessionUtils.js')>();
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
SessionSelector: vi.fn(),
|
||||||
|
convertSessionToHistoryFormats: vi.fn(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('GeminiAgent Session Resume', () => {
|
||||||
|
let mockConfig: Mocked<Config>;
|
||||||
|
let mockSettings: Mocked<LoadedSettings>;
|
||||||
|
let mockArgv: CliArgs;
|
||||||
|
let mockConnection: Mocked<acp.AgentSideConnection>;
|
||||||
|
let agent: GeminiAgent;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockConfig = {
|
||||||
|
refreshAuth: vi.fn().mockResolvedValue(undefined),
|
||||||
|
initialize: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getFileSystemService: vi.fn(),
|
||||||
|
setFileSystemService: vi.fn(),
|
||||||
|
getGeminiClient: vi.fn().mockReturnValue({
|
||||||
|
initialize: vi.fn().mockResolvedValue(undefined),
|
||||||
|
resumeChat: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getChat: vi.fn().mockReturnValue({}),
|
||||||
|
}),
|
||||||
|
storage: {
|
||||||
|
getProjectTempDir: vi.fn().mockReturnValue('/tmp/project'),
|
||||||
|
},
|
||||||
|
} as unknown as Mocked<Config>;
|
||||||
|
mockSettings = {
|
||||||
|
merged: {
|
||||||
|
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
|
||||||
|
mcpServers: {},
|
||||||
|
},
|
||||||
|
setValue: vi.fn(),
|
||||||
|
} as unknown as Mocked<LoadedSettings>;
|
||||||
|
mockArgv = {} as unknown as CliArgs;
|
||||||
|
mockConnection = {
|
||||||
|
sessionUpdate: vi.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as Mocked<acp.AgentSideConnection>;
|
||||||
|
|
||||||
|
(loadCliConfig as Mock).mockResolvedValue(mockConfig);
|
||||||
|
|
||||||
|
agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should advertise loadSession capability', async () => {
|
||||||
|
const response = await agent.initialize({
|
||||||
|
protocolVersion: acp.PROTOCOL_VERSION,
|
||||||
|
});
|
||||||
|
expect(response.agentCapabilities?.loadSession).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should load a session, resume chat, and stream all message types', async () => {
|
||||||
|
const sessionId = 'existing-session-id';
|
||||||
|
const sessionData = {
|
||||||
|
sessionId,
|
||||||
|
messages: [
|
||||||
|
{ type: 'user', content: [{ text: 'Hello' }] },
|
||||||
|
{
|
||||||
|
type: 'gemini',
|
||||||
|
content: [{ text: 'Hi there' }],
|
||||||
|
thoughts: [{ subject: 'Thinking', description: 'about greeting' }],
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
id: 'call-1',
|
||||||
|
name: 'test_tool',
|
||||||
|
displayName: 'Test Tool',
|
||||||
|
status: 'success',
|
||||||
|
resultDisplay: 'Tool output',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'gemini',
|
||||||
|
content: [{ text: 'Trying a write' }],
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
id: 'call-2',
|
||||||
|
name: 'write_file',
|
||||||
|
displayName: 'Write File',
|
||||||
|
status: 'error',
|
||||||
|
resultDisplay: 'Permission denied',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
mockConfig.getToolRegistry = vi.fn().mockReturnValue({
|
||||||
|
getTool: vi.fn().mockReturnValue({ kind: 'read' }),
|
||||||
|
});
|
||||||
|
|
||||||
|
(SessionSelector as unknown as Mock).mockImplementation(() => ({
|
||||||
|
resolveSession: vi.fn().mockResolvedValue({
|
||||||
|
sessionData,
|
||||||
|
sessionPath: '/path/to/session.json',
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockClientHistory = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Hi there' }] },
|
||||||
|
];
|
||||||
|
(convertSessionToHistoryFormats as unknown as Mock).mockReturnValue({
|
||||||
|
clientHistory: mockClientHistory,
|
||||||
|
uiHistory: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await agent.loadSession({
|
||||||
|
sessionId,
|
||||||
|
cwd: '/tmp',
|
||||||
|
mcpServers: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response).toEqual({});
|
||||||
|
|
||||||
|
// Verify resumeChat received the correct arguments
|
||||||
|
expect(mockConfig.getGeminiClient().resumeChat).toHaveBeenCalledWith(
|
||||||
|
mockClientHistory,
|
||||||
|
expect.objectContaining({
|
||||||
|
conversation: sessionData,
|
||||||
|
filePath: '/path/to/session.json',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
// User message
|
||||||
|
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
update: expect.objectContaining({
|
||||||
|
sessionUpdate: 'user_message_chunk',
|
||||||
|
content: expect.objectContaining({ text: 'Hello' }),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Agent thought
|
||||||
|
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
update: expect.objectContaining({
|
||||||
|
sessionUpdate: 'agent_thought_chunk',
|
||||||
|
content: expect.objectContaining({
|
||||||
|
text: '**Thinking**\nabout greeting',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Agent message
|
||||||
|
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
update: expect.objectContaining({
|
||||||
|
sessionUpdate: 'agent_message_chunk',
|
||||||
|
content: expect.objectContaining({ text: 'Hi there' }),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Successful tool call → 'completed'
|
||||||
|
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
update: expect.objectContaining({
|
||||||
|
sessionUpdate: 'tool_call',
|
||||||
|
toolCallId: 'call-1',
|
||||||
|
status: 'completed',
|
||||||
|
title: 'Test Tool',
|
||||||
|
kind: 'read',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'content',
|
||||||
|
content: { type: 'text', text: 'Tool output' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Failed tool call → 'failed'
|
||||||
|
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
update: expect.objectContaining({
|
||||||
|
sessionUpdate: 'tool_call',
|
||||||
|
toolCallId: 'call-2',
|
||||||
|
status: 'failed',
|
||||||
|
title: 'Write File',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -129,7 +129,7 @@ describe('GeminiAgent', () => {
|
|||||||
|
|
||||||
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
|
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
|
||||||
expect(response.authMethods).toHaveLength(3);
|
expect(response.authMethods).toHaveLength(3);
|
||||||
expect(response.agentCapabilities?.loadSession).toBe(false);
|
expect(response.agentCapabilities?.loadSession).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should authenticate correctly', async () => {
|
it('should authenticate correctly', async () => {
|
||||||
@@ -273,6 +273,7 @@ describe('Session', () => {
|
|||||||
mockChat = {
|
mockChat = {
|
||||||
sendMessageStream: vi.fn(),
|
sendMessageStream: vi.fn(),
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
|
recordCompletedToolCalls: vi.fn(),
|
||||||
} as unknown as Mocked<GeminiChat>;
|
} as unknown as Mocked<GeminiChat>;
|
||||||
mockTool = {
|
mockTool = {
|
||||||
kind: 'native',
|
kind: 'native',
|
||||||
@@ -293,6 +294,7 @@ describe('Session', () => {
|
|||||||
} as unknown as Mocked<MessageBus>;
|
} as unknown as Mocked<MessageBus>;
|
||||||
mockConfig = {
|
mockConfig = {
|
||||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
|
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
getPreviewFeatures: vi.fn().mockReturnValue({}),
|
getPreviewFeatures: vi.fn().mockReturnValue({}),
|
||||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||||
getFileService: vi.fn().mockReturnValue({
|
getFileService: vi.fn().mockReturnValue({
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import type {
|
|||||||
ToolResult,
|
ToolResult,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
FilterFilesOptions,
|
FilterFilesOptions,
|
||||||
|
ConversationRecord,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import {
|
import {
|
||||||
AuthType,
|
AuthType,
|
||||||
@@ -32,6 +33,7 @@ import {
|
|||||||
createWorkingStdio,
|
createWorkingStdio,
|
||||||
startupProfiler,
|
startupProfiler,
|
||||||
Kind,
|
Kind,
|
||||||
|
partListUnionToString,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import * as acp from '@agentclientprotocol/sdk';
|
import * as acp from '@agentclientprotocol/sdk';
|
||||||
import { AcpFileSystemService } from './fileSystemService.js';
|
import { AcpFileSystemService } from './fileSystemService.js';
|
||||||
@@ -47,6 +49,10 @@ import { randomUUID } from 'node:crypto';
|
|||||||
import type { CliArgs } from '../config/config.js';
|
import type { CliArgs } from '../config/config.js';
|
||||||
import { loadCliConfig } from '../config/config.js';
|
import { loadCliConfig } from '../config/config.js';
|
||||||
import { runExitCleanup } from '../utils/cleanup.js';
|
import { runExitCleanup } from '../utils/cleanup.js';
|
||||||
|
import {
|
||||||
|
SessionSelector,
|
||||||
|
convertSessionToHistoryFormats,
|
||||||
|
} from '../utils/sessionUtils.js';
|
||||||
|
|
||||||
export async function runZedIntegration(
|
export async function runZedIntegration(
|
||||||
config: Config,
|
config: Config,
|
||||||
@@ -107,7 +113,7 @@ export class GeminiAgent {
|
|||||||
protocolVersion: acp.PROTOCOL_VERSION,
|
protocolVersion: acp.PROTOCOL_VERSION,
|
||||||
authMethods,
|
authMethods,
|
||||||
agentCapabilities: {
|
agentCapabilities: {
|
||||||
loadSession: false,
|
loadSession: true,
|
||||||
promptCapabilities: {
|
promptCapabilities: {
|
||||||
image: true,
|
image: true,
|
||||||
audio: true,
|
audio: true,
|
||||||
@@ -146,23 +152,11 @@ export class GeminiAgent {
|
|||||||
mcpServers,
|
mcpServers,
|
||||||
}: acp.NewSessionRequest): Promise<acp.NewSessionResponse> {
|
}: acp.NewSessionRequest): Promise<acp.NewSessionResponse> {
|
||||||
const sessionId = randomUUID();
|
const sessionId = randomUUID();
|
||||||
const config = await this.newSessionConfig(sessionId, cwd, mcpServers);
|
const config = await this.initializeSessionConfig(
|
||||||
|
sessionId,
|
||||||
let isAuthenticated = false;
|
cwd,
|
||||||
if (this.settings.merged.security.auth.selectedType) {
|
mcpServers,
|
||||||
try {
|
);
|
||||||
await config.refreshAuth(
|
|
||||||
this.settings.merged.security.auth.selectedType,
|
|
||||||
);
|
|
||||||
isAuthenticated = true;
|
|
||||||
} catch (e) {
|
|
||||||
debugLogger.error(`Authentication failed: ${e}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isAuthenticated) {
|
|
||||||
throw acp.RequestError.authRequired();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.clientCapabilities?.fs) {
|
if (this.clientCapabilities?.fs) {
|
||||||
const acpFileSystemService = new AcpFileSystemService(
|
const acpFileSystemService = new AcpFileSystemService(
|
||||||
@@ -184,6 +178,88 @@ export class GeminiAgent {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async loadSession({
|
||||||
|
sessionId,
|
||||||
|
cwd,
|
||||||
|
mcpServers,
|
||||||
|
}: acp.LoadSessionRequest): Promise<acp.LoadSessionResponse> {
|
||||||
|
const config = await this.initializeSessionConfig(
|
||||||
|
sessionId,
|
||||||
|
cwd,
|
||||||
|
mcpServers,
|
||||||
|
);
|
||||||
|
|
||||||
|
const sessionSelector = new SessionSelector(config);
|
||||||
|
const { sessionData, sessionPath } =
|
||||||
|
await sessionSelector.resolveSession(sessionId);
|
||||||
|
|
||||||
|
if (this.clientCapabilities?.fs) {
|
||||||
|
const acpFileSystemService = new AcpFileSystemService(
|
||||||
|
this.connection,
|
||||||
|
sessionId,
|
||||||
|
this.clientCapabilities.fs,
|
||||||
|
config.getFileSystemService(),
|
||||||
|
);
|
||||||
|
config.setFileSystemService(acpFileSystemService);
|
||||||
|
}
|
||||||
|
|
||||||
|
const { clientHistory } = convertSessionToHistoryFormats(
|
||||||
|
sessionData.messages,
|
||||||
|
);
|
||||||
|
|
||||||
|
const geminiClient = config.getGeminiClient();
|
||||||
|
await geminiClient.initialize();
|
||||||
|
await geminiClient.resumeChat(clientHistory, {
|
||||||
|
conversation: sessionData,
|
||||||
|
filePath: sessionPath,
|
||||||
|
});
|
||||||
|
|
||||||
|
const session = new Session(
|
||||||
|
sessionId,
|
||||||
|
geminiClient.getChat(),
|
||||||
|
config,
|
||||||
|
this.connection,
|
||||||
|
);
|
||||||
|
this.sessions.set(sessionId, session);
|
||||||
|
|
||||||
|
// Stream history back to client
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||||
|
session.streamHistory(sessionData.messages);
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
private async initializeSessionConfig(
|
||||||
|
sessionId: string,
|
||||||
|
cwd: string,
|
||||||
|
mcpServers: acp.McpServer[],
|
||||||
|
): Promise<Config> {
|
||||||
|
const selectedAuthType = this.settings.merged.security.auth.selectedType;
|
||||||
|
if (!selectedAuthType) {
|
||||||
|
throw acp.RequestError.authRequired();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Create config WITHOUT initializing it (no MCP servers started yet)
|
||||||
|
const config = await this.newSessionConfig(sessionId, cwd, mcpServers);
|
||||||
|
|
||||||
|
// 2. Authenticate BEFORE initializing configuration or starting MCP servers.
|
||||||
|
// This satisfies the security requirement to verify the user before executing
|
||||||
|
// potentially unsafe server definitions.
|
||||||
|
try {
|
||||||
|
await config.refreshAuth(selectedAuthType);
|
||||||
|
} catch (e) {
|
||||||
|
debugLogger.error(`Authentication failed: ${e}`);
|
||||||
|
throw acp.RequestError.authRequired();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Now that we are authenticated, it is safe to initialize the config
|
||||||
|
// which starts the MCP servers and other heavy resources.
|
||||||
|
await config.initialize();
|
||||||
|
startupProfiler.flush(config);
|
||||||
|
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
async newSessionConfig(
|
async newSessionConfig(
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
cwd: string,
|
cwd: string,
|
||||||
@@ -228,8 +304,6 @@ export class GeminiAgent {
|
|||||||
|
|
||||||
const config = await loadCliConfig(settings, sessionId, this.argv, { cwd });
|
const config = await loadCliConfig(settings, sessionId, this.argv, { cwd });
|
||||||
|
|
||||||
await config.initialize();
|
|
||||||
startupProfiler.flush(config);
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,6 +343,73 @@ export class Session {
|
|||||||
this.pendingPrompt = null;
|
this.pendingPrompt = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async streamHistory(messages: ConversationRecord['messages']): Promise<void> {
|
||||||
|
for (const msg of messages) {
|
||||||
|
const contentString = partListUnionToString(msg.content);
|
||||||
|
|
||||||
|
if (msg.type === 'user') {
|
||||||
|
if (contentString.trim()) {
|
||||||
|
await this.sendUpdate({
|
||||||
|
sessionUpdate: 'user_message_chunk',
|
||||||
|
content: { type: 'text', text: contentString },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (msg.type === 'gemini') {
|
||||||
|
// Thoughts
|
||||||
|
if (msg.thoughts) {
|
||||||
|
for (const thought of msg.thoughts) {
|
||||||
|
const thoughtText = `**${thought.subject}**\n${thought.description}`;
|
||||||
|
await this.sendUpdate({
|
||||||
|
sessionUpdate: 'agent_thought_chunk',
|
||||||
|
content: { type: 'text', text: thoughtText },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message text
|
||||||
|
if (contentString.trim()) {
|
||||||
|
await this.sendUpdate({
|
||||||
|
sessionUpdate: 'agent_message_chunk',
|
||||||
|
content: { type: 'text', text: contentString },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool calls
|
||||||
|
if (msg.toolCalls) {
|
||||||
|
for (const toolCall of msg.toolCalls) {
|
||||||
|
const toolCallContent: acp.ToolCallContent[] = [];
|
||||||
|
if (toolCall.resultDisplay) {
|
||||||
|
if (typeof toolCall.resultDisplay === 'string') {
|
||||||
|
toolCallContent.push({
|
||||||
|
type: 'content',
|
||||||
|
content: { type: 'text', text: toolCall.resultDisplay },
|
||||||
|
});
|
||||||
|
} else if ('fileName' in toolCall.resultDisplay) {
|
||||||
|
toolCallContent.push({
|
||||||
|
type: 'diff',
|
||||||
|
path: toolCall.resultDisplay.fileName,
|
||||||
|
oldText: toolCall.resultDisplay.originalContent,
|
||||||
|
newText: toolCall.resultDisplay.newContent,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const tool = this.config.getToolRegistry().getTool(toolCall.name);
|
||||||
|
|
||||||
|
await this.sendUpdate({
|
||||||
|
sessionUpdate: 'tool_call',
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
status: toolCall.status === 'success' ? 'completed' : 'failed',
|
||||||
|
title: toolCall.displayName || toolCall.name,
|
||||||
|
content: toolCallContent,
|
||||||
|
kind: tool ? toAcpToolKind(tool.kind) : 'other',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> {
|
async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> {
|
||||||
this.pendingPrompt?.abort();
|
this.pendingPrompt?.abort();
|
||||||
const pendingSend = new AbortController();
|
const pendingSend = new AbortController();
|
||||||
@@ -533,6 +674,33 @@ export class Session {
|
|||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.chat.recordCompletedToolCalls(this.config.getActiveModel(), [
|
||||||
|
{
|
||||||
|
status: 'success',
|
||||||
|
request: {
|
||||||
|
callId,
|
||||||
|
name: fc.name,
|
||||||
|
args,
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: promptId,
|
||||||
|
},
|
||||||
|
tool,
|
||||||
|
invocation,
|
||||||
|
response: {
|
||||||
|
callId,
|
||||||
|
responseParts: convertToFunctionResponse(
|
||||||
|
fc.name,
|
||||||
|
callId,
|
||||||
|
toolResult.llmContent,
|
||||||
|
this.config.getActiveModel(),
|
||||||
|
),
|
||||||
|
resultDisplay: toolResult.returnDisplay,
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
return convertToFunctionResponse(
|
return convertToFunctionResponse(
|
||||||
fc.name,
|
fc.name,
|
||||||
callId,
|
callId,
|
||||||
@@ -551,6 +719,35 @@ export class Session {
|
|||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.chat.recordCompletedToolCalls(this.config.getActiveModel(), [
|
||||||
|
{
|
||||||
|
status: 'error',
|
||||||
|
request: {
|
||||||
|
callId,
|
||||||
|
name: fc.name,
|
||||||
|
args,
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: promptId,
|
||||||
|
},
|
||||||
|
tool,
|
||||||
|
response: {
|
||||||
|
callId,
|
||||||
|
responseParts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: callId,
|
||||||
|
name: fc.name ?? '',
|
||||||
|
response: { error: error.message },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
resultDisplay: error.message,
|
||||||
|
error,
|
||||||
|
errorType: undefined,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
return errorResponse(error);
|
return errorResponse(error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user