diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4b1501380b..3e68b7ed70 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -493,6 +493,7 @@ dependencies = [ "bytes", "clap", "codex-apply-patch", + "codex-mcp-client", "dirs", "env-flags", "eventsource-stream", @@ -500,6 +501,7 @@ dependencies = [ "futures", "landlock", "libc", + "mcp-types", "mime_guess", "openssl-sys", "patch", diff --git a/codex-rs/README.md b/codex-rs/README.md index 3c42ceff4a..f5a1e24de2 100644 --- a/codex-rs/README.md +++ b/codex-rs/README.md @@ -79,6 +79,38 @@ sandbox_permissions = [ ] ``` +### mcp_servers + +Defines the list of MCP servers that Codex can consult for tool use. Currently, only servers that are launched by executing a program that communicate over stdio are supported. For servers that use the SSE transport, consider an adapter like [mcp-proxy](https://github.com/sparfenyuk/mcp-proxy). + +**Note:** Codex may cache the list of tools and resources from an MCP server so that Codex can include this information in context at startup without spawning all the servers. This is designed to save resources by loading MCP servers lazily. + +This config option is comparable to how Claude and Cursor define `mcpServers` in their respective JSON config files, though because Codex uses TOML for its config language, the format is slightly different. For example, the following config in JSON: + +```json +{ + "mcpServers": { + "server-name": { + "command": "npx", + "args": ["-y", "mcp-server"], + "env": { + "API_KEY": "value" + } + } + } +} +``` + +Should be represented as follows in `~/.codex/config.toml`: + +```toml +# IMPORTANT: the top-level key is `mcp_servers` rather than `mcpServers`. +[mcp_servers.server-name] +command = "npx" +args = ["-y", "mcp-server"] +env = { "API_KEY" = "value" } +``` + ### disable_response_storage Currently, customers whose accounts are set to use Zero Data Retention (ZDR) must set `disable_response_storage` to `true` so that Codex uses an alternative to the Responses API that works with ZDR: diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 0ed550f9a8..abd0e607ec 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -14,11 +14,13 @@ base64 = "0.21" bytes = "1.10.1" clap = { version = "4", features = ["derive", "wrap_help"], optional = true } codex-apply-patch = { path = "../apply-patch" } +codex-mcp-client = { path = "../mcp-client" } dirs = "6" env-flags = "0.1.1" eventsource-stream = "0.2.3" fs-err = "3.1.0" futures = "0.3" +mcp-types = { path = "../mcp-types" } mime_guess = "2.0" patch = "0.7" path-absolutize = "3.1.1" diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 10ec0b9780..e0d2892c94 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::collections::HashMap; use std::io::BufRead; use std::path::Path; use std::pin::Pin; @@ -13,6 +14,7 @@ use futures::prelude::*; use reqwest::StatusCode; use serde::Deserialize; use serde::Serialize; +use serde_json::json; use serde_json::Value; use tokio::sync::mpsc; use tokio::time::timeout; @@ -42,6 +44,11 @@ pub struct Prompt { pub instructions: Option, /// Whether to store response on server side (disable_response_storage = !store). pub store: bool, + + /// Additional tools sourced from external MCP servers. Note eachthe key is + /// the "fully qualified" tool name (i.e., prefixed with the server name), + /// which should be reported to the model in place of Tool::name. + pub extra_tools: HashMap, } #[derive(Debug)] @@ -59,7 +66,7 @@ struct Payload<'a> { // we code defensively to avoid this case, but perhaps we should use a // separate enum for serialization. input: &'a Vec, - tools: &'a [Tool], + tools: &'a [serde_json::Value], tool_choice: &'static str, parallel_tool_calls: bool, reasoning: Option, @@ -78,7 +85,7 @@ struct Reasoning { } #[derive(Debug, Serialize)] -struct Tool { +struct ToolInternal { name: &'static str, #[serde(rename = "type")] kind: &'static str, // "function" @@ -105,7 +112,7 @@ enum JsonSchema { } /// Tool usage specification -static TOOLS: LazyLock> = LazyLock::new(|| { +static TOOLS_INTERNAL: LazyLock> = LazyLock::new(|| { let mut properties = BTreeMap::new(); properties.insert( "command".to_string(), @@ -116,7 +123,7 @@ static TOOLS: LazyLock> = LazyLock::new(|| { properties.insert("workdir".to_string(), JsonSchema::String); properties.insert("timeout".to_string(), JsonSchema::Number); - vec![Tool { + vec![ToolInternal { name: "shell", kind: "function", description: "Runs a shell command, and returns its output.", @@ -149,11 +156,26 @@ impl ModelClient { return stream_from_fixture(path).await; } + // Assemble tool list: built-in tools + any extra tools from the prompt. + let mut tools_json: Vec = TOOLS_INTERNAL + .iter() + .map(|t| serde_json::to_value(t).expect("serialize builtin tool")) + .collect(); + tools_json.extend( + prompt + .extra_tools + .clone() + .into_iter() + .map(|(name, tool)| mcp_tool_to_openai_tool(name, tool)), + ); + + debug!("tools_json: {}", serde_json::to_string_pretty(&tools_json)?); + let payload = Payload { model: &self.model, instructions: prompt.instructions.as_ref(), input: &prompt.input, - tools: &TOOLS, + tools: &tools_json, tool_choice: "auto", parallel_tool_calls: false, reasoning: Some(Reasoning { @@ -235,6 +257,18 @@ impl ModelClient { } } +fn mcp_tool_to_openai_tool( + fully_qualified_name: String, + tool: mcp_types::Tool, +) -> serde_json::Value { + json!({ + "name": fully_qualified_name, + "description": tool.description, + "parameters": tool.input_schema, + "type": "function", + }) +} + #[derive(Debug, Deserialize, Serialize)] struct SseEvent { #[serde(rename = "type")] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index c74d0079ee..b1c4c42f29 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -38,6 +38,10 @@ use crate::exec::ExecParams; use crate::exec::ExecToolCallOutput; use crate::exec::SandboxType; use crate::flags::OPENAI_STREAM_MAX_RETRIES; +use crate::mcp_connection_manager::create_mcp_connection_manager; +use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name; +use crate::mcp_connection_manager::McpConnectionManager; +use crate::mcp_tool_call::handle_mcp_tool_call; use crate::models::ContentItem; use crate::models::FunctionCallOutputPayload; use crate::models::ResponseInputItem; @@ -188,9 +192,9 @@ impl Recorder { /// Context for an initialized model agent /// /// A session has at most 1 running task at a time, and can be interrupted by user input. -struct Session { +pub(crate) struct Session { client: ModelClient, - tx_event: Sender, + pub(crate) tx_event: Sender, ctrl_c: Arc, /// The session's current working directory. All relative paths provided by @@ -202,6 +206,9 @@ struct Session { sandbox_policy: SandboxPolicy, writable_roots: Mutex>, + /// Manager for external MCP servers/tools. + pub(crate) mcp: crate::mcp_connection_manager::McpConnectionManager, + /// External notifier command (will be passed as args to exec()). When /// `None` this feature is disabled. notify: Option>, @@ -433,7 +440,7 @@ impl State { } /// A series of Turns in response to user input. -struct AgentTask { +pub(crate) struct AgentTask { sess: Arc, sub_id: String, handle: AbortHandle, @@ -554,6 +561,30 @@ async fn submission_loop( }; let writable_roots = Mutex::new(get_writable_roots(&cwd)); + + // Load config to initialise the MCP connection manager. + let config = match crate::config::Config::load_with_overrides( + crate::config::ConfigOverrides::default(), + ) { + Ok(cfg) => cfg, + Err(e) => { + error!("Failed to load config for MCP servers: {e:#}"); + // Fall back to empty server map so the session can still proceed. + crate::config::Config::load_default_config_for_test() + } + }; + + let mcp = match create_mcp_connection_manager(config.mcp_servers.clone()).await { + Ok(mgr) => mgr, + Err(e) => { + error!("Failed to create MCP connection manager: {e:#}"); + // Use an empty manager so we can still continue. + McpConnectionManager::new(HashMap::new()) + .await + .expect("empty manager should never fail") + } + }; + sess = Some(Arc::new(Session { client, tx_event: tx_event.clone(), @@ -565,6 +596,7 @@ async fn submission_loop( writable_roots, notify, state: Mutex::new(state), + mcp, })); // ack @@ -753,11 +785,15 @@ async fn run_turn( } else { None }; + + let extra_tools = sess.mcp.list_all_tools(); + let prompt = Prompt { input, prev_id, instructions, store, + extra_tools, }; let mut retries = 0; @@ -1141,13 +1177,20 @@ async fn handle_function_call( } } _ => { - // Unknown function: reply with structured failure so the model can adapt. - ResponseInputItem::FunctionCallOutput { - call_id, - output: crate::models::FunctionCallOutputPayload { - content: format!("unsupported call: {}", name), - success: None, - }, + match try_parse_fully_qualified_tool_name(&name) { + Some((server, tool_name)) => { + handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await + } + None => { + // Unknown function: reply with structured failure so the model can adapt. + ResponseInputItem::FunctionCallOutput { + call_id, + output: crate::models::FunctionCallOutputPayload { + content: format!("unsupported call: {}", name), + success: None, + }, + } + } } } } diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 554173c537..f3140e0e9f 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -1,9 +1,11 @@ use crate::flags::OPENAI_DEFAULT_MODEL; +use crate::mcp_server_config::McpServerConfig; use crate::protocol::AskForApproval; use crate::protocol::SandboxPermission; use crate::protocol::SandboxPolicy; use dirs::home_dir; use serde::Deserialize; +use std::collections::HashMap; use std::path::PathBuf; /// Embedded fallback instructions that mirror the TypeScript CLI’s default @@ -56,6 +58,9 @@ pub struct Config { /// for the session. All relative paths inside the business-logic layer are /// resolved against this path. pub cwd: PathBuf, + + /// Definition for MCP servers that Codex can reach out to for tool calls. + pub mcp_servers: HashMap, } /// Base config deserialized from ~/.codex/config.toml. @@ -84,6 +89,10 @@ pub struct ConfigToml { /// System instructions. pub instructions: Option, + + /// Definition for MCP servers that Codex can reach out to for tool calls. + #[serde(default)] + pub mcp_servers: HashMap, } impl ConfigToml { @@ -212,6 +221,7 @@ impl Config { .unwrap_or(false), notify: cfg.notify, instructions, + mcp_servers: cfg.mcp_servers, } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index a5909ed63d..3878fada0d 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -15,6 +15,9 @@ mod flags; mod is_safe_command; #[cfg(target_os = "linux")] pub mod linux; +mod mcp_connection_manager; +pub mod mcp_server_config; +mod mcp_tool_call; mod models; pub mod protocol; mod safety; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs new file mode 100644 index 0000000000..8f13952e40 --- /dev/null +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -0,0 +1,192 @@ +//! Connection manager for Model Context Protocol (MCP) servers. +//! +//! The [`McpConnectionManager`] owns one [`codex_mcp_client::McpClient`] per +//! configured server (keyed by the *server name*). It offers convenience +//! helpers to query the available tools across *all* servers and returns them +//! in a single aggregated map using the fully-qualified tool name +//! `""` as the key. + +use std::collections::HashMap; + +use anyhow::anyhow; +use anyhow::Result; +use codex_mcp_client::McpClient; +use mcp_types::Tool; +use tokio::task::JoinSet; +use tracing::info; +use tracing::warn; + +use crate::mcp_server_config::McpServerConfig; + +/// Delimiter used to separate the server name from the tool name in a fully +/// qualified tool name. +/// +/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must +/// choose a delimiter from this character set. +const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__"; + +fn fully_qualified_tool_name(server: &str, tool: &str) -> String { + format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}") +} + +pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> { + let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?; + if server.is_empty() || tool.is_empty() { + return None; + } + Some((server.to_string(), tool.to_string())) +} + +/// A thin wrapper around a set of running [`McpClient`] instances. +/// +/// The struct is intentionally lightweight – cloning just clones the internal +/// `HashMap` of clients which in turn clones the `Arc`s of each client. +pub(crate) struct McpConnectionManager { + /// Server-name → client instance. + /// + /// The server name originates from the keys of the `mcp_servers` map in + /// the user configuration. + clients: HashMap>, // Arc to cheaply clone + + tools: HashMap, +} + +impl McpConnectionManager { + /// Spawn a [`McpClient`] for each configured server. + /// + /// * `mcp_servers` – Map loaded from the user configuration where *keys* + /// are human-readable server identifiers and *values* are the spawn + /// instructions. + pub async fn new(mcp_servers: HashMap) -> Result { + // Early exit if no servers are configured. + if mcp_servers.is_empty() { + return Ok(Self { + clients: HashMap::new(), + tools: HashMap::new(), + }); + } + + // Spin up all servers concurrently. + let mut join_set = JoinSet::new(); + + // Spawn tasks to launch each server. + for (server_name, cfg) in mcp_servers { + // Perform slash validation up-front so we can return early without + // spawning any tasks when the name is invalid. + if server_name.contains('/') { + return Err(anyhow!( + "MCP server name '{server_name}' must not contain a forward slash (/)" + )); + } + + join_set.spawn(async move { + let McpServerConfig { command, args, env } = cfg; + let client_res = McpClient::new_stdio_client(command, args, env).await; + + (server_name, client_res) + }); + } + + let mut clients: HashMap> = HashMap::new(); + while let Some(res) = join_set.join_next().await { + let (server_name, client_res) = res?; // propagate JoinError + + let client = client_res + .map_err(|e| anyhow!("failed to spawn MCP server '{server_name}': {e}"))?; + + clients.insert(server_name, std::sync::Arc::new(client)); + } + + let tools = list_all_tools(&clients).await?; + + Ok(Self { clients, tools }) + } + + /// Returns a single map that contains **all** tools. Each key is the + /// fully-qualified name for the tool. + pub fn list_all_tools(&self) -> HashMap { + self.tools.clone() + } + + /// Route a fully-qualified tool call to the matching server. + pub async fn call_tool( + &self, + server: &str, + tool: &str, + arguments: Option, + ) -> Result { + let client = self + .clients + .get(server) + .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))? + .clone(); + + client + .call_tool(tool.to_string(), arguments) + .await + .map_err(|e| anyhow!("tool call failed for '{server}/{tool}': {e}")) + } +} + +/// Query every server for its available tools and return a single map that +/// contains **all** tools. Each key is the fully-qualified name for the tool. +pub async fn list_all_tools( + clients: &HashMap>, +) -> Result> { + let mut join_set = JoinSet::new(); + + // Spawn one task per server so we can query them concurrently. This + // keeps the overall latency roughly at the slowest server instead of + // the cumulative latency. + for (server_name, client) in clients { + let server_name_cloned = server_name.clone(); + let client_clone = client.clone(); + join_set.spawn(async move { + let res = client_clone.list_tools(None).await; + (server_name_cloned, res) + }); + } + + let mut aggregated: HashMap = HashMap::new(); + + while let Some(join_res) = join_set.join_next().await { + let (server_name, list_result) = join_res?; // propagate JoinError + + let list_result = list_result?; + + for tool in list_result.tools { + if tool.name.contains('/') { + warn!( + server = %server_name, + tool_name = %tool.name, + "tool name contains '/' – skipping to avoid ambiguity" + ); + continue; + } + + let fq_name = fully_qualified_tool_name(&server_name, &tool.name); + + if aggregated.insert(fq_name.clone(), tool).is_some() { + warn!("tool name collision for '{fq_name}' – overwriting previous entry"); + } + } + } + + info!( + "aggregated {} tools from {} servers", + aggregated.len(), + clients.len() + ); + + Ok(aggregated) +} + +/// Convenience helper that mirrors the previous `create_mcp_connection_manager` +/// free-standing function but returns `Result` and is **async**. Existing +/// call-sites can continue to call the function while new code can use the +/// `McpConnectionManager::new` associated function directly. +pub(crate) async fn create_mcp_connection_manager( + mcp_servers: HashMap, +) -> Result { + McpConnectionManager::new(mcp_servers).await +} diff --git a/codex-rs/core/src/mcp_server_config.rs b/codex-rs/core/src/mcp_server_config.rs new file mode 100644 index 0000000000..261a75d13e --- /dev/null +++ b/codex-rs/core/src/mcp_server_config.rs @@ -0,0 +1,14 @@ +use std::collections::HashMap; + +use serde::Deserialize; + +#[derive(Deserialize, Debug, Clone)] +pub struct McpServerConfig { + pub command: String, + + #[serde(default)] + pub args: Vec, + + #[serde(default)] + pub env: Option>, +} diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs new file mode 100644 index 0000000000..2a93939228 --- /dev/null +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -0,0 +1,94 @@ +use tracing::error; + +use crate::codex::Session; +use crate::models::FunctionCallOutputPayload; +use crate::models::ResponseInputItem; +use crate::protocol::Event; +use crate::protocol::EventMsg; + +/// Handles the specified tool call dispatches the appropriate +/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`. +pub(crate) async fn handle_mcp_tool_call( + sess: &Session, + sub_id: &str, + call_id: String, + server: String, + tool_name: String, + arguments: String, +) -> ResponseInputItem { + // Attempt to route to external MCP server. + let arguments_value: Option = serde_json::from_str(&arguments).ok(); + + let tool_call_begin_event = EventMsg::McpToolCallBegin { + call_id: call_id.clone(), + server: server.clone(), + tool: tool_name.clone(), + arguments: arguments_value.clone(), + }; + if let Err(e) = sess + .tx_event + .send(Event { + id: sub_id.to_string(), + msg: tool_call_begin_event, + }) + .await + { + error!("failed to send tool call begin event: {e}"); + } + + let (tool_call_end_event, tool_call_err) = match sess + .mcp + .call_tool(&server, &tool_name, arguments_value) + .await + { + Ok(result) => ( + EventMsg::McpToolCallEnd { + call_id, + success: !result.is_error.unwrap_or(false), + result: Some(result), + }, + None, + ), + Err(e) => ( + EventMsg::McpToolCallEnd { + call_id, + success: false, + result: None, + }, + Some(e), + ), + }; + if let Err(e) = sess + .tx_event + .send(Event { + id: sub_id.to_string(), + msg: tool_call_end_event.clone(), + }) + .await + { + error!("failed to send tool call end event: {e}"); + } + + let EventMsg::McpToolCallEnd { + call_id, + success, + result, + } = tool_call_end_event + else { + unimplemented!("unexpected event type"); + }; + + ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: result.map_or_else( + || format!("err: {tool_call_err:?}"), + |result| { + serde_json::to_string(&result) + .unwrap_or_else(|e| format!("JSON serialization error: {e}")) + }, + ), + success: Some(success), + }, + } +} diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 851d80e2b9..4796381dbf 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -7,6 +7,7 @@ use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; +use mcp_types::CallToolResult; use serde::Deserialize; use serde::Serialize; @@ -316,6 +317,32 @@ pub enum EventMsg { model: String, }, + McpToolCallBegin { + /// Identifier so this can be paired with the McpToolCallEnd event. + call_id: String, + + /// Name of the MCP server as defined in the config. + server: String, + + /// Name of the tool as given by the MCP server. + tool: String, + + /// Arguments to the tool call. + arguments: Option, + }, + + McpToolCallEnd { + /// Identifier for the McpToolCallBegin that finished. + call_id: String, + + /// Whether the tool call was successful. If `false`, `result` might + /// not be present. + success: bool, + + /// Result of the tool call. Note this could be an error. + result: Option, + }, + /// Notification that the server is about to execute a command. ExecCommandBegin { /// Identifier so this can be paired with the ExecCommandEnd event. diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 54c4804750..ec295a57d1 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -328,6 +328,21 @@ impl ChatWidget<'_> { .record_completed_exec_command(call_id, stdout, stderr, exit_code); self.request_redraw()?; } + EventMsg::McpToolCallBegin { + call_id, + server, + tool, + arguments, + } => { + todo!() + } + EventMsg::McpToolCallEnd { + call_id, + success, + result, + } => { + todo!() + } event => { self.conversation_history .add_background_event(format!("{event:?}"));