diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 1838032e95..544908bd35 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -810,6 +810,102 @@ describe('mcp-client', () => { }); }); + it('should transform nullable array schemas and preserve properties during discovery', async () => { + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'nullableTool', + description: 'Tool with nullable array', + inputSchema: { + type: 'object', + properties: { + tags: { + type: ['array', 'null'], + items: { type: 'string' }, + }, + }, + $defs: { + SomeType: { type: 'string' }, + }, + }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ + prompts: [], + }), + request: vi.fn().mockResolvedValue({}), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + const mockedToolRegistry = { + registerTool: vi.fn(), + sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + const promptRegistry = { + registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry; + const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), + setResourcesForServer: vi.fn(), + removeResourcesByServer: vi.fn(), + } as unknown as ResourceRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + await client.connect(); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); + expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); + const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock + .calls[0][0]; + expect(registeredTool.schema.parametersJsonSchema).toEqual({ + type: 'object', + properties: { + tags: { + type: 'array', + nullable: true, + items: { type: 'string' }, + }, + wait_for_previous: { + type: 'boolean', + description: + 'Set to true to wait for all previously requested tools in this turn to complete before starting. Set to false (or omit) to run in parallel. Use true when this tool depends on the output of previous tools.', + }, + }, + $defs: { + SomeType: { type: 'string' }, + }, + }); + }); + it('should discover resources when a server only exposes resources', async () => { const mockedClient = { connect: vi.fn(), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 3cadad99be..1bf73c2b27 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -1308,6 +1308,46 @@ export async function discoverTools( continue; } + if (toolDef.inputSchema) { + try { + const transform = (obj: unknown): unknown => { + if (obj === null || typeof obj !== 'object') return obj; + if (Array.isArray(obj)) return obj.map(transform); + + const res = { ...obj } as Record; + + if (Array.isArray(res['type']) && res['type'].length === 2) { + const nIdx = res['type'].indexOf('null'); + if (nIdx !== -1 && typeof res['type'][1 - nIdx] === 'string') { + res['type'] = res['type'][1 - nIdx]; + res['nullable'] = true; + } + } + + for (const k in res) { + if (Object.prototype.hasOwnProperty.call(res, k)) { + res[k] = transform(res[k]); + } + } + return res; + }; + + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + toolDef.inputSchema = transform(toolDef.inputSchema) as { + type: 'object'; + properties?: Record; + required?: string[]; + }; + } catch (error) { + cliConfig.emitMcpDiagnostic( + 'error', + `Failed to parse adjusted inputSchema for tool '${toolDef.name}' from server '${mcpServerName}'. Using original schema. Error: ${error instanceof Error ? error.message : String(error)}`, + error, + mcpServerName, + ); + } + } + const mcpCallableTool = new McpCallableTool( mcpClient, toolDef,