diff --git a/packages/core/src/agents/agent-tool.test.ts b/packages/core/src/agents/agent-tool.test.ts index 424f1c6bd9..09c6cdef11 100644 --- a/packages/core/src/agents/agent-tool.test.ts +++ b/packages/core/src/agents/agent-tool.test.ts @@ -12,6 +12,8 @@ import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { LocalSubagentInvocation } from './local-invocation.js'; import { RemoteAgentInvocation } from './remote-invocation.js'; +import { LocalSessionInvocation } from './local-session-invocation.js'; +import { RemoteSessionInvocation } from './remote-session-invocation.js'; import { BrowserAgentInvocation } from './browser/browserAgentInvocation.js'; import { BROWSER_AGENT_NAME } from './browser/browserAgentDefinition.js'; import { AgentRegistry } from './registry.js'; @@ -19,6 +21,8 @@ import type { LocalAgentDefinition, RemoteAgentDefinition } from './types.js'; vi.mock('./local-invocation.js'); vi.mock('./remote-invocation.js'); +vi.mock('./local-session-invocation.js'); +vi.mock('./remote-session-invocation.js'); vi.mock('./browser/browserAgentInvocation.js'); describe('AgentTool', () => { @@ -141,4 +145,112 @@ describe('AgentTool', () => { 'Invoke Browser Agent', ); }); + + describe('agentSessionSubagentEnabled feature flag', () => { + it('should use LocalSessionInvocation when flag is enabled for local agent', async () => { + vi.spyOn(mockConfig, 'isAgentSessionEnabled').mockReturnValue(true); + tool = new AgentTool(mockConfig, mockMessageBus); + + const params = { + agent_name: 'TestLocalAgent', + prompt: 'Do something', + }; + const invocation = tool['createInvocation'](params, mockMessageBus); + await invocation.shouldConfirmExecute(new AbortController().signal); + + expect(LocalSessionInvocation).toHaveBeenCalledWith( + testLocalDefinition, + mockConfig, + { objective: 'Do something' }, + mockMessageBus, + undefined, + ); + expect(LocalSubagentInvocation).not.toHaveBeenCalled(); + }); + + it('should use RemoteSessionInvocation when flag is enabled for remote agent', async () => { + vi.spyOn(mockConfig, 'isAgentSessionEnabled').mockReturnValue(true); + tool = new AgentTool(mockConfig, mockMessageBus); + + const params = { + agent_name: 'TestRemoteAgent', + prompt: 'Search something', + }; + const invocation = tool['createInvocation'](params, mockMessageBus); + await invocation.shouldConfirmExecute(new AbortController().signal); + + expect(RemoteSessionInvocation).toHaveBeenCalledWith( + testRemoteDefinition, + mockConfig, + { query: 'Search something' }, + mockMessageBus, + undefined, + ); + expect(RemoteAgentInvocation).not.toHaveBeenCalled(); + }); + + it('should use legacy invocations when flag is disabled (default)', async () => { + vi.spyOn(mockConfig, 'isAgentSessionEnabled').mockReturnValue(false); + tool = new AgentTool(mockConfig, mockMessageBus); + + const localParams = { + agent_name: 'TestLocalAgent', + prompt: 'Do something', + }; + const localInv = tool['createInvocation'](localParams, mockMessageBus); + await localInv.shouldConfirmExecute(new AbortController().signal); + + expect(LocalSubagentInvocation).toHaveBeenCalled(); + expect(LocalSessionInvocation).not.toHaveBeenCalled(); + + vi.clearAllMocks(); + + const remoteParams = { + agent_name: 'TestRemoteAgent', + prompt: 'Search', + }; + const remoteInv = tool['createInvocation'](remoteParams, mockMessageBus); + await remoteInv.shouldConfirmExecute(new AbortController().signal); + + expect(RemoteAgentInvocation).toHaveBeenCalled(); + expect(RemoteSessionInvocation).not.toHaveBeenCalled(); + }); + + it('should thread onAgentEvent to session invocations', async () => { + vi.spyOn(mockConfig, 'isAgentSessionEnabled').mockReturnValue(true); + const onEvent = vi.fn(); + tool = new AgentTool(mockConfig, mockMessageBus, onEvent); + + const params = { + agent_name: 'TestLocalAgent', + prompt: 'Do something', + }; + const invocation = tool['createInvocation'](params, mockMessageBus); + await invocation.shouldConfirmExecute(new AbortController().signal); + + expect(LocalSessionInvocation).toHaveBeenCalledWith( + testLocalDefinition, + mockConfig, + { objective: 'Do something' }, + mockMessageBus, + { onAgentEvent: onEvent }, + ); + }); + + it('should always use BrowserAgentInvocation for browser agent regardless of flag', async () => { + vi.spyOn(mockConfig, 'isAgentSessionEnabled').mockReturnValue(true); + tool = new AgentTool(mockConfig, mockMessageBus); + + const params = { + agent_name: BROWSER_AGENT_NAME, + prompt: 'Open page', + }; + const invocation = tool['createInvocation'](params, mockMessageBus); + await invocation.shouldConfirmExecute(new AbortController().signal); + + expect(BrowserAgentInvocation).toHaveBeenCalled(); + expect(LocalSessionInvocation).not.toHaveBeenCalled(); + expect(RemoteSessionInvocation).not.toHaveBeenCalled(); + }); + }); }); diff --git a/packages/core/src/agents/agent-tool.ts b/packages/core/src/agents/agent-tool.ts index 899266f77f..2258bfb43a 100644 --- a/packages/core/src/agents/agent-tool.ts +++ b/packages/core/src/agents/agent-tool.ts @@ -18,8 +18,11 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { AgentDefinition, AgentInputs } from './types.js'; import { LocalSubagentInvocation } from './local-invocation.js'; import { RemoteAgentInvocation } from './remote-invocation.js'; +import { LocalSessionInvocation } from './local-session-invocation.js'; +import { RemoteSessionInvocation } from './remote-session-invocation.js'; import { BROWSER_AGENT_NAME } from './browser/browserAgentDefinition.js'; import { BrowserAgentInvocation } from './browser/browserAgentInvocation.js'; +import type { AgentEvent } from '../agent/types.js'; import { formatUserHintsForModel } from '../utils/fastAckHelper.js'; import { isRecord } from '../utils/markdownUtils.js'; import { runInDevTraceSpan } from '../telemetry/trace.js'; @@ -46,6 +49,7 @@ export class AgentTool extends BaseDeclarativeTool< constructor( private readonly context: AgentLoopContext, messageBus: MessageBus, + private readonly onAgentEvent?: (event: AgentEvent) => void, ) { super( AGENT_TOOL_NAME, @@ -100,6 +104,7 @@ export class AgentTool extends BaseDeclarativeTool< this.context, _toolName, _toolDisplayName, + this.onAgentEvent, ); } @@ -133,6 +138,7 @@ class DelegateInvocation extends BaseToolInvocation< private readonly context: AgentLoopContext, _toolName?: string, _toolDisplayName?: string, + private readonly onAgentEvent?: (event: AgentEvent) => void, ) { super( params, @@ -160,7 +166,21 @@ class DelegateInvocation extends BaseToolInvocation< ); } + const useSession = this.context.config.isAgentSessionEnabled(); + const options = this.onAgentEvent + ? { onAgentEvent: this.onAgentEvent } + : undefined; + if (this.definition.kind === 'remote') { + if (useSession) { + return new RemoteSessionInvocation( + this.definition, + this.context, + agentArgs, + this.messageBus, + options, + ); + } return new RemoteAgentInvocation( this.definition, this.context, @@ -168,6 +188,15 @@ class DelegateInvocation extends BaseToolInvocation< this.messageBus, ); } else { + if (useSession) { + return new LocalSessionInvocation( + this.definition, + this.context, + agentArgs, + this.messageBus, + options, + ); + } return new LocalSubagentInvocation( this.definition, this.context, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index fdd6284d78..aa912533bc 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -2585,6 +2585,10 @@ export class Config implements McpContext, AgentLoopContext { return this.experimentalContextManagementConfig; } + isAgentSessionEnabled(): boolean { + return this.agentSessionSubagentEnabled; + } + getContextManagementConfig(): ContextManagementConfig { return this.contextManagement; } @@ -3780,6 +3784,10 @@ export class Config implements McpContext, AgentLoopContext { ); } + getAgentSessionSubagentEnabled(): boolean { + return this.agentSessionSubagentEnabled; + } + /** * Get override settings for a specific agent. * Reads from agents.overrides..