//! Connection manager for Model Context Protocol (MCP) servers. //! //! The [`McpConnectionManager`] owns one [`codex_rmcp_client::RmcpClient`] per //! configured server (keyed by the *server name*). It offers convenience //! helpers to query the available tools across *all* servers and returns them //! in a single aggregated map using the fully-qualified tool name //! `""` as the key. use std::collections::HashMap; use std::collections::HashSet; use std::env; use std::ffi::OsString; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; use crate::mcp::auth::McpAuthStatusEntry; use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use async_channel::Sender; use codex_async_utils::CancelErr; use codex_async_utils::OrCancelExt; use codex_protocol::approvals::ElicitationRequestEvent; use codex_protocol::mcp::CallToolResult; use codex_protocol::mcp::RequestId as ProtocolRequestId; use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::McpStartupCompleteEvent; use codex_protocol::protocol::McpStartupFailure; use codex_protocol::protocol::McpStartupStatus; use codex_protocol::protocol::McpStartupUpdateEvent; use codex_protocol::protocol::SandboxPolicy; use codex_rmcp_client::ElicitationResponse; use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::RmcpClient; use codex_rmcp_client::SendElicitation; use futures::future::BoxFuture; use futures::future::FutureExt; use futures::future::Shared; use rmcp::model::ClientCapabilities; use rmcp::model::ElicitationCapability; use rmcp::model::Implementation; use rmcp::model::InitializeRequestParam; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParam; use rmcp::model::ProtocolVersion; use rmcp::model::ReadResourceRequestParam; use rmcp::model::ReadResourceResult; use rmcp::model::RequestId; use rmcp::model::Resource; use rmcp::model::ResourceTemplate; use rmcp::model::Tool; use serde::Deserialize; use serde::Serialize; use sha1::Digest; use sha1::Sha1; use tokio::sync::Mutex; use tokio::sync::oneshot; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::instrument; use tracing::warn; use crate::codex::INITIAL_SUBMIT_ID; use crate::config::types::McpServerConfig; use crate::config::types::McpServerTransportConfig; /// Delimiter used to separate the server name from the tool name in a fully /// qualified tool name. /// /// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must /// choose a delimiter from this character set. const MCP_TOOL_NAME_DELIMITER: &str = "__"; const MAX_TOOL_NAME_LENGTH: usize = 64; /// Default timeout for initializing MCP server & initially listing tools. pub const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10); /// Default timeout for individual tool calls. const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60); /// The Responses API requires tool names to match `^[a-zA-Z0-9_-]+$`. /// MCP server/tool names are user-controlled, so sanitize the fully-qualified /// name we expose to the model by replacing any disallowed character with `_`. fn sanitize_responses_api_tool_name(name: &str) -> String { let mut sanitized = String::with_capacity(name.len()); for c in name.chars() { if c.is_ascii_alphanumeric() || c == '_' || c == '-' { sanitized.push(c); } else { sanitized.push('_'); } } if sanitized.is_empty() { "_".to_string() } else { sanitized } } fn sha1_hex(s: &str) -> String { let mut hasher = Sha1::new(); hasher.update(s.as_bytes()); let sha1 = hasher.finalize(); format!("{sha1:x}") } fn qualify_tools(tools: I) -> HashMap where I: IntoIterator, { let mut used_names = HashSet::new(); let mut seen_raw_names = HashSet::new(); let mut qualified_tools = HashMap::new(); for tool in tools { let qualified_name_raw = format!( "mcp{}{}{}{}", MCP_TOOL_NAME_DELIMITER, tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name ); if !seen_raw_names.insert(qualified_name_raw.clone()) { warn!("skipping duplicated tool {}", qualified_name_raw); continue; } // Start from a "pretty" name (sanitized), then deterministically disambiguate on // collisions by appending a hash of the *raw* (unsanitized) qualified name. This // ensures tools like `foo.bar` and `foo_bar` don't collapse to the same key. let mut qualified_name = sanitize_responses_api_tool_name(&qualified_name_raw); // Enforce length constraints early; use the raw name for the hash input so the // output remains stable even when sanitization changes. if qualified_name.len() > MAX_TOOL_NAME_LENGTH { let sha1_str = sha1_hex(&qualified_name_raw); let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len(); qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str); } if used_names.contains(&qualified_name) { warn!("skipping duplicated tool {}", qualified_name); continue; } used_names.insert(qualified_name.clone()); qualified_tools.insert(qualified_name, tool); } qualified_tools } #[derive(Clone)] pub(crate) struct ToolInfo { pub(crate) server_name: String, pub(crate) tool_name: String, pub(crate) tool: Tool, pub(crate) connector_id: Option, pub(crate) connector_name: Option, } type ResponderMap = HashMap<(String, RequestId), oneshot::Sender>; #[derive(Clone, Default)] struct ElicitationRequestManager { requests: Arc>, } impl ElicitationRequestManager { async fn resolve( &self, server_name: String, id: RequestId, response: ElicitationResponse, ) -> Result<()> { self.requests .lock() .await .remove(&(server_name, id)) .ok_or_else(|| anyhow!("elicitation request not found"))? .send(response) .map_err(|e| anyhow!("failed to send elicitation response: {e:?}")) } fn make_sender(&self, server_name: String, tx_event: Sender) -> SendElicitation { let elicitation_requests = self.requests.clone(); Box::new(move |id, elicitation| { let elicitation_requests = elicitation_requests.clone(); let tx_event = tx_event.clone(); let server_name = server_name.clone(); async move { let (tx, rx) = oneshot::channel(); { let mut lock = elicitation_requests.lock().await; lock.insert((server_name.clone(), id.clone()), tx); } let _ = tx_event .send(Event { id: "mcp_elicitation_request".to_string(), msg: EventMsg::ElicitationRequest(ElicitationRequestEvent { server_name, id: match id.clone() { rmcp::model::NumberOrString::String(value) => { ProtocolRequestId::String(value.to_string()) } rmcp::model::NumberOrString::Number(value) => { ProtocolRequestId::Integer(value) } }, message: elicitation.message, }), }) .await; rx.await .context("elicitation request channel closed unexpectedly") } .boxed() }) } } #[derive(Clone)] struct ManagedClient { client: Arc, tools: Vec, tool_filter: ToolFilter, tool_timeout: Option, server_supports_sandbox_state_capability: bool, } impl ManagedClient { /// Returns once the server has ack'd the sandbox state update. async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { if !self.server_supports_sandbox_state_capability { return Ok(()); } let _response = self .client .send_custom_request( MCP_SANDBOX_STATE_METHOD, Some(serde_json::to_value(sandbox_state)?), ) .await?; Ok(()) } } #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, } impl AsyncManagedClient { fn new( server_name: String, config: McpServerConfig, store_mode: OAuthCredentialsStoreMode, cancel_token: CancellationToken, tx_event: Sender, elicitation_requests: ElicitationRequestManager, ) -> Self { let tool_filter = ToolFilter::from_config(&config); let fut = async move { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } let client = Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); match start_server_task( server_name, client, config.startup_timeout_sec.or(Some(DEFAULT_STARTUP_TIMEOUT)), config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT), tool_filter, tx_event, elicitation_requests, ) .or_cancel(&cancel_token) .await { Ok(result) => result, Err(CancelErr::Cancelled) => Err(StartupOutcomeError::Cancelled), } }; Self { client: fut.boxed().shared(), } } async fn client(&self) -> Result { self.client.clone().await } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; /// Custom MCP request to push sandbox state updates. /// When used, the `params` field of the notification is [`SandboxState`]. pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update"; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SandboxState { pub sandbox_policy: SandboxPolicy, pub codex_linux_sandbox_exe: Option, pub sandbox_cwd: PathBuf, #[serde(default)] pub use_linux_sandbox_bwrap: bool, } /// A thin wrapper around a set of running [`RmcpClient`] instances. #[derive(Default)] pub(crate) struct McpConnectionManager { clients: HashMap, elicitation_requests: ElicitationRequestManager, } impl McpConnectionManager { pub async fn initialize( &mut self, mcp_servers: &HashMap, store_mode: OAuthCredentialsStoreMode, auth_entries: HashMap, tx_event: Sender, cancel_token: CancellationToken, initial_sandbox_state: SandboxState, ) { if cancel_token.is_cancelled() { return; } let mut clients = HashMap::new(); let mut join_set = JoinSet::new(); let elicitation_requests = ElicitationRequestManager::default(); let mcp_servers = mcp_servers.clone(); for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { let cancel_token = cancel_token.child_token(); let _ = emit_update( &tx_event, McpStartupUpdateEvent { server: server_name.clone(), status: McpStartupStatus::Starting, }, ) .await; let async_managed_client = AsyncManagedClient::new( server_name.clone(), cfg, store_mode, cancel_token.clone(), tx_event.clone(), elicitation_requests.clone(), ); clients.insert(server_name.clone(), async_managed_client.clone()); let tx_event = tx_event.clone(); let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { let outcome = async_managed_client.client().await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } let status = match &outcome { Ok(_) => { // Send sandbox state notification immediately after Ready if let Err(e) = async_managed_client .notify_sandbox_state_change(&sandbox_state) .await { warn!( "Failed to notify sandbox state to MCP server {server_name}: {e:#}", ); } McpStartupStatus::Ready } Err(error) => { let error_str = mcp_init_error_display( server_name.as_str(), auth_entry.as_ref(), error, ); McpStartupStatus::Failed { error: error_str } } }; let _ = emit_update( &tx_event, McpStartupUpdateEvent { server: server_name.clone(), status, }, ) .await; (server_name, outcome) }); } self.clients = clients; self.elicitation_requests = elicitation_requests.clone(); tokio::spawn(async move { let outcomes = join_set.join_all().await; let mut summary = McpStartupCompleteEvent::default(); for (server_name, outcome) in outcomes { match outcome { Ok(_) => summary.ready.push(server_name), Err(StartupOutcomeError::Cancelled) => summary.cancelled.push(server_name), Err(StartupOutcomeError::Failed { error }) => { summary.failed.push(McpStartupFailure { server: server_name, error, }) } } } let _ = tx_event .send(Event { id: INITIAL_SUBMIT_ID.to_owned(), msg: EventMsg::McpStartupComplete(summary), }) .await; }); } async fn client_by_name(&self, name: &str) -> Result { self.clients .get(name) .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? .client() .await .context("failed to get client") } pub async fn resolve_elicitation( &self, server_name: String, id: RequestId, response: ElicitationResponse, ) -> Result<()> { self.elicitation_requests .resolve(server_name, id, response) .await } pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool { let Some(async_managed_client) = self.clients.get(server_name) else { return false; }; match tokio::time::timeout(timeout, async_managed_client.client()).await { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } } /// Returns a single map that contains all tools. Each key is the /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { let mut tools = HashMap::new(); for managed_client in self.clients.values() { let client = managed_client.client().await.ok(); if let Some(client) = client { tools.extend(qualify_tools(filter_tools( client.tools, client.tool_filter, ))); } } tools } /// Returns a single map that contains all resources. Each key is the /// server name and the value is a vector of resources. pub async fn list_all_resources(&self) -> HashMap> { let mut join_set = JoinSet::new(); let clients_snapshot = &self.clients; for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); let Ok(managed_client) = async_managed_client.client().await else { continue; }; let timeout = managed_client.tool_timeout; let client = managed_client.client.clone(); join_set.spawn(async move { let mut collected: Vec = Vec::new(); let mut cursor: Option = None; loop { let params = cursor.as_ref().map(|next| PaginatedRequestParam { cursor: Some(next.clone()), }); let response = match client.list_resources(params, timeout).await { Ok(result) => result, Err(err) => return (server_name, Err(err)), }; collected.extend(response.resources); match response.next_cursor { Some(next) => { if cursor.as_ref() == Some(&next) { return ( server_name, Err(anyhow!("resources/list returned duplicate cursor")), ); } cursor = Some(next); } None => return (server_name, Ok(collected)), } } }); } let mut aggregated: HashMap> = HashMap::new(); while let Some(join_res) = join_set.join_next().await { match join_res { Ok((server_name, Ok(resources))) => { aggregated.insert(server_name, resources); } Ok((server_name, Err(err))) => { warn!("Failed to list resources for MCP server '{server_name}': {err:#}"); } Err(err) => { warn!("Task panic when listing resources for MCP server: {err:#}"); } } } aggregated } /// Returns a single map that contains all resource templates. Each key is the /// server name and the value is a vector of resource templates. pub async fn list_all_resource_templates(&self) -> HashMap> { let mut join_set = JoinSet::new(); let clients_snapshot = &self.clients; for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); let Ok(managed_client) = async_managed_client.client().await else { continue; }; let client = managed_client.client.clone(); let timeout = managed_client.tool_timeout; join_set.spawn(async move { let mut collected: Vec = Vec::new(); let mut cursor: Option = None; loop { let params = cursor.as_ref().map(|next| PaginatedRequestParam { cursor: Some(next.clone()), }); let response = match client.list_resource_templates(params, timeout).await { Ok(result) => result, Err(err) => return (server_name_cloned, Err(err)), }; collected.extend(response.resource_templates); match response.next_cursor { Some(next) => { if cursor.as_ref() == Some(&next) { return ( server_name_cloned, Err(anyhow!( "resources/templates/list returned duplicate cursor" )), ); } cursor = Some(next); } None => return (server_name_cloned, Ok(collected)), } } }); } let mut aggregated: HashMap> = HashMap::new(); while let Some(join_res) = join_set.join_next().await { match join_res { Ok((server_name, Ok(templates))) => { aggregated.insert(server_name, templates); } Ok((server_name, Err(err))) => { warn!( "Failed to list resource templates for MCP server '{server_name}': {err:#}" ); } Err(err) => { warn!("Task panic when listing resource templates for MCP server: {err:#}"); } } } aggregated } /// Invoke the tool indicated by the (server, tool) pair. pub async fn call_tool( &self, server: &str, tool: &str, arguments: Option, ) -> Result { let client = self.client_by_name(server).await?; if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" )); } let result: rmcp::model::CallToolResult = client .client .call_tool(tool.to_string(), arguments, client.tool_timeout) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; let content = result .content .into_iter() .map(|content| { serde_json::to_value(content) .unwrap_or_else(|_| serde_json::Value::String("".to_string())) }) .collect(); Ok(CallToolResult { content, structured_content: result.structured_content, is_error: result.is_error, meta: result.meta.and_then(|meta| serde_json::to_value(meta).ok()), }) } /// List resources from the specified server. pub async fn list_resources( &self, server: &str, params: Option, ) -> Result { let managed = self.client_by_name(server).await?; let timeout = managed.tool_timeout; managed .client .list_resources(params, timeout) .await .with_context(|| format!("resources/list failed for `{server}`")) } /// List resource templates from the specified server. pub async fn list_resource_templates( &self, server: &str, params: Option, ) -> Result { let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; client .list_resource_templates(params, timeout) .await .with_context(|| format!("resources/templates/list failed for `{server}`")) } /// Read a resource from the specified server. pub async fn read_resource( &self, server: &str, params: ReadResourceRequestParam, ) -> Result { let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); client .read_resource(params, timeout) .await .with_context(|| format!("resources/read failed for `{server}` ({uri})")) } pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> { self.list_all_tools() .await .get(tool_name) .map(|tool| (tool.server_name.clone(), tool.tool_name.clone())) } pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { let mut join_set = JoinSet::new(); for async_managed_client in self.clients.values() { let sandbox_state = sandbox_state.clone(); let async_managed_client = async_managed_client.clone(); join_set.spawn(async move { async_managed_client .notify_sandbox_state_change(&sandbox_state) .await }); } while let Some(join_res) = join_set.join_next().await { match join_res { Ok(Ok(())) => {} Ok(Err(err)) => { warn!("Failed to notify sandbox state change to MCP server: {err:#}"); } Err(err) => { warn!("Task panic when notifying sandbox state change to MCP server: {err:#}"); } } } Ok(()) } } async fn emit_update( tx_event: &Sender, update: McpStartupUpdateEvent, ) -> Result<(), async_channel::SendError> { tx_event .send(Event { id: INITIAL_SUBMIT_ID.to_owned(), msg: EventMsg::McpStartupUpdate(update), }) .await } /// A tool is allowed to be used if both are true: /// 1. enabled is None (no allowlist is set) or the tool is explicitly enabled. /// 2. The tool is not explicitly disabled. #[derive(Default, Clone)] pub(crate) struct ToolFilter { enabled: Option>, disabled: HashSet, } impl ToolFilter { fn from_config(cfg: &McpServerConfig) -> Self { let enabled = cfg .enabled_tools .as_ref() .map(|tools| tools.iter().cloned().collect::>()); let disabled = cfg .disabled_tools .as_ref() .map(|tools| tools.iter().cloned().collect::>()) .unwrap_or_default(); Self { enabled, disabled } } fn allows(&self, tool_name: &str) -> bool { if let Some(enabled) = &self.enabled && !enabled.contains(tool_name) { return false; } !self.disabled.contains(tool_name) } } fn filter_tools(tools: Vec, filter: ToolFilter) -> Vec { tools .into_iter() .filter(|tool| filter.allows(&tool.tool_name)) .collect() } fn normalize_codex_apps_tool_title( server_name: &str, connector_name: Option<&str>, value: &str, ) -> String { if server_name != CODEX_APPS_MCP_SERVER_NAME { return value.to_string(); } let Some(connector_name) = connector_name .map(str::trim) .filter(|name| !name.is_empty()) else { return value.to_string(); }; let prefix = format!("{connector_name}_"); if let Some(stripped) = value.strip_prefix(&prefix) && !stripped.is_empty() { return stripped.to_string(); } value.to_string() } fn resolve_bearer_token( server_name: &str, bearer_token_env_var: Option<&str>, ) -> Result> { let Some(env_var) = bearer_token_env_var else { return Ok(None); }; match env::var(env_var) { Ok(value) => { if value.is_empty() { Err(anyhow!( "Environment variable {env_var} for MCP server '{server_name}' is empty" )) } else { Ok(Some(value)) } } Err(env::VarError::NotPresent) => Err(anyhow!( "Environment variable {env_var} for MCP server '{server_name}' is not set" )), Err(env::VarError::NotUnicode(_)) => Err(anyhow!( "Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode" )), } } #[derive(Debug, Clone, thiserror::Error)] enum StartupOutcomeError { #[error("MCP startup cancelled")] Cancelled, // We can't store the original error here because anyhow::Error doesn't implement // `Clone`. #[error("MCP startup failed: {error}")] Failed { error: String }, } impl From for StartupOutcomeError { fn from(error: anyhow::Error) -> Self { Self::Failed { error: error.to_string(), } } } async fn start_server_task( server_name: String, client: Arc, startup_timeout: Option, // TODO: cancel_token should handle this. tool_timeout: Duration, tool_filter: ToolFilter, tx_event: Sender, elicitation_requests: ElicitationRequestManager, ) -> Result { let params = InitializeRequestParam { capabilities: ClientCapabilities { experimental: None, roots: None, sampling: None, // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities // indicates this should be an empty object. elicitation: Some(ElicitationCapability { schema_validation: None, }), }, client_info: Implementation { name: "codex-mcp-client".to_owned(), version: env!("CARGO_PKG_VERSION").to_owned(), title: Some("Codex".into()), icons: None, website_url: None, }, protocol_version: ProtocolVersion::V_2025_06_18, }; let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); let initialize_result = client .initialize(params, startup_timeout, send_elicitation) .await .map_err(StartupOutcomeError::from)?; let tools = list_tools_for_client(&server_name, &client, startup_timeout) .await .map_err(StartupOutcomeError::from)?; let server_supports_sandbox_state_capability = initialize_result .capabilities .experimental .as_ref() .and_then(|exp| exp.get(MCP_SANDBOX_STATE_CAPABILITY)) .is_some(); let managed = ManagedClient { client: Arc::clone(&client), tools, tool_timeout: Some(tool_timeout), tool_filter, server_supports_sandbox_state_capability, }; Ok(managed) } async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, ) -> Result { match transport { McpServerTransportConfig::Stdio { command, args, env, env_vars, cwd, } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd) .await .map_err(|err| StartupOutcomeError::from(anyhow!(err))) } McpServerTransportConfig::StreamableHttp { url, http_headers, env_http_headers, bearer_token_env_var, } => { let resolved_bearer_token = match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) { Ok(token) => token, Err(error) => return Err(error.into()), }; RmcpClient::new_streamable_http_client( server_name, &url, resolved_bearer_token, http_headers, env_http_headers, store_mode, ) .await .map_err(StartupOutcomeError::from) } } } async fn list_tools_for_client( server_name: &str, client: &Arc, timeout: Option, ) -> Result> { let resp = client.list_tools_with_connector_ids(None, timeout).await?; Ok(resp .tools .into_iter() .map(|tool| { let connector_name = tool.connector_name; let mut tool_def = tool.tool; if let Some(title) = tool_def.title.as_deref() { let normalized_title = normalize_codex_apps_tool_title(server_name, connector_name.as_deref(), title); if tool_def.title.as_deref() != Some(normalized_title.as_str()) { tool_def.title = Some(normalized_title); } } ToolInfo { server_name: server_name.to_owned(), tool_name: tool_def.name.to_string(), tool: tool_def, connector_id: tool.connector_id, connector_name, } }) .collect()) } fn validate_mcp_server_name(server_name: &str) -> Result<()> { let re = regex_lite::Regex::new(r"^[a-zA-Z0-9_-]+$")?; if !re.is_match(server_name) { return Err(anyhow!( "Invalid MCP server name '{server_name}': must match pattern {pattern}", pattern = re.as_str() )); } Ok(()) } fn mcp_init_error_display( server_name: &str, entry: Option<&McpAuthStatusEntry>, err: &StartupOutcomeError, ) -> String { if let Some(McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, http_headers, .. }) = &entry.map(|entry| &entry.config.transport) && url == "https://api.githubcopilot.com/mcp/" && bearer_token_env_var.is_none() && http_headers.as_ref().map(HashMap::is_empty).unwrap_or(true) { format!( "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" ) } else if is_mcp_client_auth_required_error(err) { format!( "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." ) } else if is_mcp_client_startup_timeout_error(err) { let startup_timeout_secs = match entry { Some(entry) => match entry.config.startup_timeout_sec { Some(timeout) => timeout, None => DEFAULT_STARTUP_TIMEOUT, }, None => DEFAULT_STARTUP_TIMEOUT, } .as_secs(); format!( "MCP client for `{server_name}` timed out after {startup_timeout_secs} seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.{server_name}]\nstartup_timeout_sec = XX" ) } else { format!("MCP client for `{server_name}` failed to start: {err:#}") } } fn is_mcp_client_auth_required_error(error: &StartupOutcomeError) -> bool { match error { StartupOutcomeError::Failed { error } => error.contains("Auth required"), _ => false, } } fn is_mcp_client_startup_timeout_error(error: &StartupOutcomeError) -> bool { match error { StartupOutcomeError::Failed { error } => { error.contains("request timed out") || error.contains("timed out handshaking with MCP server") } _ => false, } } #[cfg(test)] mod mcp_init_error_display_tests {} #[cfg(test)] mod tests { use super::*; use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { ToolInfo { server_name: server_name.to_string(), tool_name: tool_name.to_string(), tool: Tool { name: tool_name.to_string().into(), title: None, description: Some(format!("Test tool: {tool_name}").into()), input_schema: Arc::new(JsonObject::default()), output_schema: None, annotations: None, icons: None, meta: None, }, connector_id: None, connector_name: None, } } #[test] fn test_qualify_tools_short_non_duplicated_names() { let tools = vec![ create_test_tool("server1", "tool1"), create_test_tool("server1", "tool2"), ]; let qualified_tools = qualify_tools(tools); assert_eq!(qualified_tools.len(), 2); assert!(qualified_tools.contains_key("mcp__server1__tool1")); assert!(qualified_tools.contains_key("mcp__server1__tool2")); } #[test] fn test_qualify_tools_duplicated_names_skipped() { let tools = vec![ create_test_tool("server1", "duplicate_tool"), create_test_tool("server1", "duplicate_tool"), ]; let qualified_tools = qualify_tools(tools); // Only the first tool should remain, the second is skipped assert_eq!(qualified_tools.len(), 1); assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool")); } #[test] fn test_qualify_tools_long_names_same_server() { let server_name = "my_server"; let tools = vec![ create_test_tool( server_name, "extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", ), create_test_tool( server_name, "yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", ), ]; let qualified_tools = qualify_tools(tools); assert_eq!(qualified_tools.len(), 2); let mut keys: Vec<_> = qualified_tools.keys().cloned().collect(); keys.sort(); assert_eq!(keys[0].len(), 64); assert_eq!( keys[0], "mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9" ); assert_eq!(keys[1].len(), 64); assert_eq!( keys[1], "mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341" ); } #[test] fn test_qualify_tools_sanitizes_invalid_characters() { let tools = vec![create_test_tool("server.one", "tool.two")]; let qualified_tools = qualify_tools(tools); assert_eq!(qualified_tools.len(), 1); let (qualified_name, tool) = qualified_tools.into_iter().next().expect("one tool"); assert_eq!(qualified_name, "mcp__server_one__tool_two"); // The key is sanitized for OpenAI, but we keep original parts for the actual MCP call. assert_eq!(tool.server_name, "server.one"); assert_eq!(tool.tool_name, "tool.two"); assert!( qualified_name .chars() .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'), "qualified name must be Responses API compatible: {qualified_name:?}" ); } #[test] fn tool_filter_allows_by_default() { let filter = ToolFilter::default(); assert!(filter.allows("any")); } #[test] fn tool_filter_applies_enabled_list() { let filter = ToolFilter { enabled: Some(HashSet::from(["allowed".to_string()])), disabled: HashSet::new(), }; assert!(filter.allows("allowed")); assert!(!filter.allows("denied")); } #[test] fn tool_filter_applies_disabled_list() { let filter = ToolFilter { enabled: None, disabled: HashSet::from(["blocked".to_string()]), }; assert!(!filter.allows("blocked")); assert!(filter.allows("open")); } #[test] fn tool_filter_applies_enabled_then_disabled() { let filter = ToolFilter { enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])), disabled: HashSet::from(["remove".to_string()]), }; assert!(filter.allows("keep")); assert!(!filter.allows("remove")); assert!(!filter.allows("unknown")); } #[test] fn filter_tools_applies_per_server_filters() { let server1_tools = vec![ create_test_tool("server1", "tool_a"), create_test_tool("server1", "tool_b"), ]; let server2_tools = vec![create_test_tool("server2", "tool_a")]; let server1_filter = ToolFilter { enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])), disabled: HashSet::from(["tool_b".to_string()]), }; let server2_filter = ToolFilter { enabled: None, disabled: HashSet::from(["tool_a".to_string()]), }; let filtered: Vec<_> = filter_tools(server1_tools, server1_filter) .into_iter() .chain(filter_tools(server2_tools, server2_filter)) .collect(); assert_eq!(filtered.len(), 1); assert_eq!(filtered[0].server_name, "server1"); assert_eq!(filtered[0].tool_name, "tool_a"); } #[test] fn mcp_init_error_display_prompts_for_github_pat() { let server_name = "github"; let entry = McpAuthStatusEntry { config: McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://api.githubcopilot.com/mcp/".to_string(), bearer_token_env_var: None, http_headers: None, env_http_headers: None, }, enabled: true, disabled_reason: None, startup_timeout_sec: None, tool_timeout_sec: None, enabled_tools: None, disabled_tools: None, scopes: None, }, auth_status: McpAuthStatus::Unsupported, }; let err: StartupOutcomeError = anyhow::anyhow!("OAuth is unsupported").into(); let display = mcp_init_error_display(server_name, Some(&entry), &err); let expected = format!( "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" ); assert_eq!(expected, display); } #[test] fn mcp_init_error_display_prompts_for_login_when_auth_required() { let server_name = "example"; let err: StartupOutcomeError = anyhow::anyhow!("Auth required for server").into(); let display = mcp_init_error_display(server_name, None, &err); let expected = format!( "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." ); assert_eq!(expected, display); } #[test] fn mcp_init_error_display_reports_generic_errors() { let server_name = "custom"; let entry = McpAuthStatusEntry { config: McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://example.com".to_string(), bearer_token_env_var: Some("TOKEN".to_string()), http_headers: None, env_http_headers: None, }, enabled: true, disabled_reason: None, startup_timeout_sec: None, tool_timeout_sec: None, enabled_tools: None, disabled_tools: None, scopes: None, }, auth_status: McpAuthStatus::Unsupported, }; let err: StartupOutcomeError = anyhow::anyhow!("boom").into(); let display = mcp_init_error_display(server_name, Some(&entry), &err); let expected = format!("MCP client for `{server_name}` failed to start: {err:#}"); assert_eq!(expected, display); } #[test] fn mcp_init_error_display_includes_startup_timeout_hint() { let server_name = "slow"; let err: StartupOutcomeError = anyhow::anyhow!("request timed out").into(); let display = mcp_init_error_display(server_name, None, &err); assert_eq!( "MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX", display ); } }