Files
codex/codex-rs/core/src/mcp_connection_manager.rs
viyatb-oai ae4de43ccc feat(linux-sandbox): add bwrap support (#9938)
## Summary
This PR introduces a gated Bubblewrap (bwrap) Linux sandbox path. The
curent Linux sandbox path relies on in-process restrictions (including
Landlock). Bubblewrap gives us a more uniform filesystem isolation
model, especially explicit writable roots with the option to make some
directories read-only and granular network controls.

This is behind a feature flag so we can validate behavior safely before
making it the default.

- Added temporary rollout flag:
  - `features.use_linux_sandbox_bwrap`
- Preserved existing default path when the flag is off.
- In Bubblewrap mode:
- Added internal retry without /proc when /proc mount is not permitted
by the host/container.
2026-02-04 11:13:17 -08:00

1332 lines
45 KiB
Rust

//! Connection manager for Model Context Protocol (MCP) servers.
//!
//! The [`McpConnectionManager`] owns one [`codex_rmcp_client::RmcpClient`] per
//! configured server (keyed by the *server name*). It offers convenience
//! helpers to query the available tools across *all* servers and returns them
//! in a single aggregated map using the fully-qualified tool name
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
use std::collections::HashMap;
use std::collections::HashSet;
use std::env;
use std::ffi::OsString;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use crate::mcp::CODEX_APPS_MCP_SERVER_NAME;
use crate::mcp::auth::McpAuthStatusEntry;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use async_channel::Sender;
use codex_async_utils::CancelErr;
use codex_async_utils::OrCancelExt;
use codex_protocol::approvals::ElicitationRequestEvent;
use codex_protocol::mcp::CallToolResult;
use codex_protocol::mcp::RequestId as ProtocolRequestId;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::McpStartupCompleteEvent;
use codex_protocol::protocol::McpStartupFailure;
use codex_protocol::protocol::McpStartupStatus;
use codex_protocol::protocol::McpStartupUpdateEvent;
use codex_protocol::protocol::SandboxPolicy;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::OAuthCredentialsStoreMode;
use codex_rmcp_client::RmcpClient;
use codex_rmcp_client::SendElicitation;
use futures::future::BoxFuture;
use futures::future::FutureExt;
use futures::future::Shared;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParam;
use rmcp::model::ListResourceTemplatesResult;
use rmcp::model::ListResourcesResult;
use rmcp::model::PaginatedRequestParam;
use rmcp::model::ProtocolVersion;
use rmcp::model::ReadResourceRequestParam;
use rmcp::model::ReadResourceResult;
use rmcp::model::RequestId;
use rmcp::model::Resource;
use rmcp::model::ResourceTemplate;
use rmcp::model::Tool;
use serde::Deserialize;
use serde::Serialize;
use sha1::Digest;
use sha1::Sha1;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use tracing::warn;
use crate::codex::INITIAL_SUBMIT_ID;
use crate::config::types::McpServerConfig;
use crate::config::types::McpServerTransportConfig;
/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
///
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
/// choose a delimiter from this character set.
const MCP_TOOL_NAME_DELIMITER: &str = "__";
const MAX_TOOL_NAME_LENGTH: usize = 64;
/// Default timeout for initializing MCP server & initially listing tools.
pub const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
/// Default timeout for individual tool calls.
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);
/// The Responses API requires tool names to match `^[a-zA-Z0-9_-]+$`.
/// MCP server/tool names are user-controlled, so sanitize the fully-qualified
/// name we expose to the model by replacing any disallowed character with `_`.
fn sanitize_responses_api_tool_name(name: &str) -> String {
let mut sanitized = String::with_capacity(name.len());
for c in name.chars() {
if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
sanitized.push(c);
} else {
sanitized.push('_');
}
}
if sanitized.is_empty() {
"_".to_string()
} else {
sanitized
}
}
fn sha1_hex(s: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(s.as_bytes());
let sha1 = hasher.finalize();
format!("{sha1:x}")
}
fn qualify_tools<I>(tools: I) -> HashMap<String, ToolInfo>
where
I: IntoIterator<Item = ToolInfo>,
{
let mut used_names = HashSet::new();
let mut seen_raw_names = HashSet::new();
let mut qualified_tools = HashMap::new();
for tool in tools {
let qualified_name_raw = format!(
"mcp{}{}{}{}",
MCP_TOOL_NAME_DELIMITER, tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
);
if !seen_raw_names.insert(qualified_name_raw.clone()) {
warn!("skipping duplicated tool {}", qualified_name_raw);
continue;
}
// Start from a "pretty" name (sanitized), then deterministically disambiguate on
// collisions by appending a hash of the *raw* (unsanitized) qualified name. This
// ensures tools like `foo.bar` and `foo_bar` don't collapse to the same key.
let mut qualified_name = sanitize_responses_api_tool_name(&qualified_name_raw);
// Enforce length constraints early; use the raw name for the hash input so the
// output remains stable even when sanitization changes.
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
let sha1_str = sha1_hex(&qualified_name_raw);
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
}
if used_names.contains(&qualified_name) {
warn!("skipping duplicated tool {}", qualified_name);
continue;
}
used_names.insert(qualified_name.clone());
qualified_tools.insert(qualified_name, tool);
}
qualified_tools
}
#[derive(Clone)]
pub(crate) struct ToolInfo {
pub(crate) server_name: String,
pub(crate) tool_name: String,
pub(crate) tool: Tool,
pub(crate) connector_id: Option<String>,
pub(crate) connector_name: Option<String>,
}
type ResponderMap = HashMap<(String, RequestId), oneshot::Sender<ElicitationResponse>>;
#[derive(Clone, Default)]
struct ElicitationRequestManager {
requests: Arc<Mutex<ResponderMap>>,
}
impl ElicitationRequestManager {
async fn resolve(
&self,
server_name: String,
id: RequestId,
response: ElicitationResponse,
) -> Result<()> {
self.requests
.lock()
.await
.remove(&(server_name, id))
.ok_or_else(|| anyhow!("elicitation request not found"))?
.send(response)
.map_err(|e| anyhow!("failed to send elicitation response: {e:?}"))
}
fn make_sender(&self, server_name: String, tx_event: Sender<Event>) -> SendElicitation {
let elicitation_requests = self.requests.clone();
Box::new(move |id, elicitation| {
let elicitation_requests = elicitation_requests.clone();
let tx_event = tx_event.clone();
let server_name = server_name.clone();
async move {
let (tx, rx) = oneshot::channel();
{
let mut lock = elicitation_requests.lock().await;
lock.insert((server_name.clone(), id.clone()), tx);
}
let _ = tx_event
.send(Event {
id: "mcp_elicitation_request".to_string(),
msg: EventMsg::ElicitationRequest(ElicitationRequestEvent {
server_name,
id: match id.clone() {
rmcp::model::NumberOrString::String(value) => {
ProtocolRequestId::String(value.to_string())
}
rmcp::model::NumberOrString::Number(value) => {
ProtocolRequestId::Integer(value)
}
},
message: elicitation.message,
}),
})
.await;
rx.await
.context("elicitation request channel closed unexpectedly")
}
.boxed()
})
}
}
#[derive(Clone)]
struct ManagedClient {
client: Arc<RmcpClient>,
tools: Vec<ToolInfo>,
tool_filter: ToolFilter,
tool_timeout: Option<Duration>,
server_supports_sandbox_state_capability: bool,
}
impl ManagedClient {
/// Returns once the server has ack'd the sandbox state update.
async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
if !self.server_supports_sandbox_state_capability {
return Ok(());
}
let _response = self
.client
.send_custom_request(
MCP_SANDBOX_STATE_METHOD,
Some(serde_json::to_value(sandbox_state)?),
)
.await?;
Ok(())
}
}
#[derive(Clone)]
struct AsyncManagedClient {
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
}
impl AsyncManagedClient {
fn new(
server_name: String,
config: McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
cancel_token: CancellationToken,
tx_event: Sender<Event>,
elicitation_requests: ElicitationRequestManager,
) -> Self {
let tool_filter = ToolFilter::from_config(&config);
let fut = async move {
if let Err(error) = validate_mcp_server_name(&server_name) {
return Err(error.into());
}
let client =
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
match start_server_task(
server_name,
client,
config.startup_timeout_sec.or(Some(DEFAULT_STARTUP_TIMEOUT)),
config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT),
tool_filter,
tx_event,
elicitation_requests,
)
.or_cancel(&cancel_token)
.await
{
Ok(result) => result,
Err(CancelErr::Cancelled) => Err(StartupOutcomeError::Cancelled),
}
};
Self {
client: fut.boxed().shared(),
}
}
async fn client(&self) -> Result<ManagedClient, StartupOutcomeError> {
self.client.clone().await
}
async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
let managed = self.client().await?;
managed.notify_sandbox_state_change(sandbox_state).await
}
}
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
/// Custom MCP request to push sandbox state updates.
/// When used, the `params` field of the notification is [`SandboxState`].
pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SandboxState {
pub sandbox_policy: SandboxPolicy,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub sandbox_cwd: PathBuf,
#[serde(default)]
pub use_linux_sandbox_bwrap: bool,
}
/// A thin wrapper around a set of running [`RmcpClient`] instances.
#[derive(Default)]
pub(crate) struct McpConnectionManager {
clients: HashMap<String, AsyncManagedClient>,
elicitation_requests: ElicitationRequestManager,
}
impl McpConnectionManager {
pub async fn initialize(
&mut self,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
tx_event: Sender<Event>,
cancel_token: CancellationToken,
initial_sandbox_state: SandboxState,
) {
if cancel_token.is_cancelled() {
return;
}
let mut clients = HashMap::new();
let mut join_set = JoinSet::new();
let elicitation_requests = ElicitationRequestManager::default();
let mcp_servers = mcp_servers.clone();
for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) {
let cancel_token = cancel_token.child_token();
let _ = emit_update(
&tx_event,
McpStartupUpdateEvent {
server: server_name.clone(),
status: McpStartupStatus::Starting,
},
)
.await;
let async_managed_client = AsyncManagedClient::new(
server_name.clone(),
cfg,
store_mode,
cancel_token.clone(),
tx_event.clone(),
elicitation_requests.clone(),
);
clients.insert(server_name.clone(), async_managed_client.clone());
let tx_event = tx_event.clone();
let auth_entry = auth_entries.get(&server_name).cloned();
let sandbox_state = initial_sandbox_state.clone();
join_set.spawn(async move {
let outcome = async_managed_client.client().await;
if cancel_token.is_cancelled() {
return (server_name, Err(StartupOutcomeError::Cancelled));
}
let status = match &outcome {
Ok(_) => {
// Send sandbox state notification immediately after Ready
if let Err(e) = async_managed_client
.notify_sandbox_state_change(&sandbox_state)
.await
{
warn!(
"Failed to notify sandbox state to MCP server {server_name}: {e:#}",
);
}
McpStartupStatus::Ready
}
Err(error) => {
let error_str = mcp_init_error_display(
server_name.as_str(),
auth_entry.as_ref(),
error,
);
McpStartupStatus::Failed { error: error_str }
}
};
let _ = emit_update(
&tx_event,
McpStartupUpdateEvent {
server: server_name.clone(),
status,
},
)
.await;
(server_name, outcome)
});
}
self.clients = clients;
self.elicitation_requests = elicitation_requests.clone();
tokio::spawn(async move {
let outcomes = join_set.join_all().await;
let mut summary = McpStartupCompleteEvent::default();
for (server_name, outcome) in outcomes {
match outcome {
Ok(_) => summary.ready.push(server_name),
Err(StartupOutcomeError::Cancelled) => summary.cancelled.push(server_name),
Err(StartupOutcomeError::Failed { error }) => {
summary.failed.push(McpStartupFailure {
server: server_name,
error,
})
}
}
}
let _ = tx_event
.send(Event {
id: INITIAL_SUBMIT_ID.to_owned(),
msg: EventMsg::McpStartupComplete(summary),
})
.await;
});
}
async fn client_by_name(&self, name: &str) -> Result<ManagedClient> {
self.clients
.get(name)
.ok_or_else(|| anyhow!("unknown MCP server '{name}'"))?
.client()
.await
.context("failed to get client")
}
pub async fn resolve_elicitation(
&self,
server_name: String,
id: RequestId,
response: ElicitationResponse,
) -> Result<()> {
self.elicitation_requests
.resolve(server_name, id, response)
.await
}
pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
let Some(async_managed_client) = self.clients.get(server_name) else {
return false;
};
match tokio::time::timeout(timeout, async_managed_client.client()).await {
Ok(Ok(_)) => true,
Ok(Err(_)) | Err(_) => false,
}
}
/// Returns a single map that contains all tools. Each key is the
/// fully-qualified name for the tool.
#[instrument(level = "trace", skip_all)]
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
let mut tools = HashMap::new();
for managed_client in self.clients.values() {
let client = managed_client.client().await.ok();
if let Some(client) = client {
tools.extend(qualify_tools(filter_tools(
client.tools,
client.tool_filter,
)));
}
}
tools
}
/// Returns a single map that contains all resources. Each key is the
/// server name and the value is a vector of resources.
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
let mut join_set = JoinSet::new();
let clients_snapshot = &self.clients;
for (server_name, async_managed_client) in clients_snapshot {
let server_name = server_name.clone();
let Ok(managed_client) = async_managed_client.client().await else {
continue;
};
let timeout = managed_client.tool_timeout;
let client = managed_client.client.clone();
join_set.spawn(async move {
let mut collected: Vec<Resource> = Vec::new();
let mut cursor: Option<String> = None;
loop {
let params = cursor.as_ref().map(|next| PaginatedRequestParam {
cursor: Some(next.clone()),
});
let response = match client.list_resources(params, timeout).await {
Ok(result) => result,
Err(err) => return (server_name, Err(err)),
};
collected.extend(response.resources);
match response.next_cursor {
Some(next) => {
if cursor.as_ref() == Some(&next) {
return (
server_name,
Err(anyhow!("resources/list returned duplicate cursor")),
);
}
cursor = Some(next);
}
None => return (server_name, Ok(collected)),
}
}
});
}
let mut aggregated: HashMap<String, Vec<Resource>> = HashMap::new();
while let Some(join_res) = join_set.join_next().await {
match join_res {
Ok((server_name, Ok(resources))) => {
aggregated.insert(server_name, resources);
}
Ok((server_name, Err(err))) => {
warn!("Failed to list resources for MCP server '{server_name}': {err:#}");
}
Err(err) => {
warn!("Task panic when listing resources for MCP server: {err:#}");
}
}
}
aggregated
}
/// Returns a single map that contains all resource templates. Each key is the
/// server name and the value is a vector of resource templates.
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
let mut join_set = JoinSet::new();
let clients_snapshot = &self.clients;
for (server_name, async_managed_client) in clients_snapshot {
let server_name_cloned = server_name.clone();
let Ok(managed_client) = async_managed_client.client().await else {
continue;
};
let client = managed_client.client.clone();
let timeout = managed_client.tool_timeout;
join_set.spawn(async move {
let mut collected: Vec<ResourceTemplate> = Vec::new();
let mut cursor: Option<String> = None;
loop {
let params = cursor.as_ref().map(|next| PaginatedRequestParam {
cursor: Some(next.clone()),
});
let response = match client.list_resource_templates(params, timeout).await {
Ok(result) => result,
Err(err) => return (server_name_cloned, Err(err)),
};
collected.extend(response.resource_templates);
match response.next_cursor {
Some(next) => {
if cursor.as_ref() == Some(&next) {
return (
server_name_cloned,
Err(anyhow!(
"resources/templates/list returned duplicate cursor"
)),
);
}
cursor = Some(next);
}
None => return (server_name_cloned, Ok(collected)),
}
}
});
}
let mut aggregated: HashMap<String, Vec<ResourceTemplate>> = HashMap::new();
while let Some(join_res) = join_set.join_next().await {
match join_res {
Ok((server_name, Ok(templates))) => {
aggregated.insert(server_name, templates);
}
Ok((server_name, Err(err))) => {
warn!(
"Failed to list resource templates for MCP server '{server_name}': {err:#}"
);
}
Err(err) => {
warn!("Task panic when listing resource templates for MCP server: {err:#}");
}
}
}
aggregated
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult> {
let client = self.client_by_name(server).await?;
if !client.tool_filter.allows(tool) {
return Err(anyhow!(
"tool '{tool}' is disabled for MCP server '{server}'"
));
}
let result: rmcp::model::CallToolResult = client
.client
.call_tool(tool.to_string(), arguments, client.tool_timeout)
.await
.with_context(|| format!("tool call failed for `{server}/{tool}`"))?;
let content = result
.content
.into_iter()
.map(|content| {
serde_json::to_value(content)
.unwrap_or_else(|_| serde_json::Value::String("<content>".to_string()))
})
.collect();
Ok(CallToolResult {
content,
structured_content: result.structured_content,
is_error: result.is_error,
meta: result.meta.and_then(|meta| serde_json::to_value(meta).ok()),
})
}
/// List resources from the specified server.
pub async fn list_resources(
&self,
server: &str,
params: Option<PaginatedRequestParam>,
) -> Result<ListResourcesResult> {
let managed = self.client_by_name(server).await?;
let timeout = managed.tool_timeout;
managed
.client
.list_resources(params, timeout)
.await
.with_context(|| format!("resources/list failed for `{server}`"))
}
/// List resource templates from the specified server.
pub async fn list_resource_templates(
&self,
server: &str,
params: Option<PaginatedRequestParam>,
) -> Result<ListResourceTemplatesResult> {
let managed = self.client_by_name(server).await?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
client
.list_resource_templates(params, timeout)
.await
.with_context(|| format!("resources/templates/list failed for `{server}`"))
}
/// Read a resource from the specified server.
pub async fn read_resource(
&self,
server: &str,
params: ReadResourceRequestParam,
) -> Result<ReadResourceResult> {
let managed = self.client_by_name(server).await?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
let uri = params.uri.clone();
client
.read_resource(params, timeout)
.await
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
}
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
self.list_all_tools()
.await
.get(tool_name)
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
}
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
let mut join_set = JoinSet::new();
for async_managed_client in self.clients.values() {
let sandbox_state = sandbox_state.clone();
let async_managed_client = async_managed_client.clone();
join_set.spawn(async move {
async_managed_client
.notify_sandbox_state_change(&sandbox_state)
.await
});
}
while let Some(join_res) = join_set.join_next().await {
match join_res {
Ok(Ok(())) => {}
Ok(Err(err)) => {
warn!("Failed to notify sandbox state change to MCP server: {err:#}");
}
Err(err) => {
warn!("Task panic when notifying sandbox state change to MCP server: {err:#}");
}
}
}
Ok(())
}
}
async fn emit_update(
tx_event: &Sender<Event>,
update: McpStartupUpdateEvent,
) -> Result<(), async_channel::SendError<Event>> {
tx_event
.send(Event {
id: INITIAL_SUBMIT_ID.to_owned(),
msg: EventMsg::McpStartupUpdate(update),
})
.await
}
/// A tool is allowed to be used if both are true:
/// 1. enabled is None (no allowlist is set) or the tool is explicitly enabled.
/// 2. The tool is not explicitly disabled.
#[derive(Default, Clone)]
pub(crate) struct ToolFilter {
enabled: Option<HashSet<String>>,
disabled: HashSet<String>,
}
impl ToolFilter {
fn from_config(cfg: &McpServerConfig) -> Self {
let enabled = cfg
.enabled_tools
.as_ref()
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>());
let disabled = cfg
.disabled_tools
.as_ref()
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>())
.unwrap_or_default();
Self { enabled, disabled }
}
fn allows(&self, tool_name: &str) -> bool {
if let Some(enabled) = &self.enabled
&& !enabled.contains(tool_name)
{
return false;
}
!self.disabled.contains(tool_name)
}
}
fn filter_tools(tools: Vec<ToolInfo>, filter: ToolFilter) -> Vec<ToolInfo> {
tools
.into_iter()
.filter(|tool| filter.allows(&tool.tool_name))
.collect()
}
fn normalize_codex_apps_tool_title(
server_name: &str,
connector_name: Option<&str>,
value: &str,
) -> String {
if server_name != CODEX_APPS_MCP_SERVER_NAME {
return value.to_string();
}
let Some(connector_name) = connector_name
.map(str::trim)
.filter(|name| !name.is_empty())
else {
return value.to_string();
};
let prefix = format!("{connector_name}_");
if let Some(stripped) = value.strip_prefix(&prefix)
&& !stripped.is_empty()
{
return stripped.to_string();
}
value.to_string()
}
fn resolve_bearer_token(
server_name: &str,
bearer_token_env_var: Option<&str>,
) -> Result<Option<String>> {
let Some(env_var) = bearer_token_env_var else {
return Ok(None);
};
match env::var(env_var) {
Ok(value) => {
if value.is_empty() {
Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' is empty"
))
} else {
Ok(Some(value))
}
}
Err(env::VarError::NotPresent) => Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' is not set"
)),
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
)),
}
}
#[derive(Debug, Clone, thiserror::Error)]
enum StartupOutcomeError {
#[error("MCP startup cancelled")]
Cancelled,
// We can't store the original error here because anyhow::Error doesn't implement
// `Clone`.
#[error("MCP startup failed: {error}")]
Failed { error: String },
}
impl From<anyhow::Error> for StartupOutcomeError {
fn from(error: anyhow::Error) -> Self {
Self::Failed {
error: error.to_string(),
}
}
}
async fn start_server_task(
server_name: String,
client: Arc<RmcpClient>,
startup_timeout: Option<Duration>, // TODO: cancel_token should handle this.
tool_timeout: Duration,
tool_filter: ToolFilter,
tx_event: Sender<Event>,
elicitation_requests: ElicitationRequestManager,
) -> Result<ManagedClient, StartupOutcomeError> {
let params = InitializeRequestParam {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
// indicates this should be an empty object.
elicitation: Some(ElicitationCapability {
schema_validation: None,
}),
},
client_info: Implementation {
name: "codex-mcp-client".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
title: Some("Codex".into()),
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
};
let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event);
let initialize_result = client
.initialize(params, startup_timeout, send_elicitation)
.await
.map_err(StartupOutcomeError::from)?;
let tools = list_tools_for_client(&server_name, &client, startup_timeout)
.await
.map_err(StartupOutcomeError::from)?;
let server_supports_sandbox_state_capability = initialize_result
.capabilities
.experimental
.as_ref()
.and_then(|exp| exp.get(MCP_SANDBOX_STATE_CAPABILITY))
.is_some();
let managed = ManagedClient {
client: Arc::clone(&client),
tools,
tool_timeout: Some(tool_timeout),
tool_filter,
server_supports_sandbox_state_capability,
};
Ok(managed)
}
async fn make_rmcp_client(
server_name: &str,
transport: McpServerTransportConfig,
store_mode: OAuthCredentialsStoreMode,
) -> Result<RmcpClient, StartupOutcomeError> {
match transport {
McpServerTransportConfig::Stdio {
command,
args,
env,
env_vars,
cwd,
} => {
let command_os: OsString = command.into();
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd)
.await
.map_err(|err| StartupOutcomeError::from(anyhow!(err)))
}
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
bearer_token_env_var,
} => {
let resolved_bearer_token =
match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) {
Ok(token) => token,
Err(error) => return Err(error.into()),
};
RmcpClient::new_streamable_http_client(
server_name,
&url,
resolved_bearer_token,
http_headers,
env_http_headers,
store_mode,
)
.await
.map_err(StartupOutcomeError::from)
}
}
}
async fn list_tools_for_client(
server_name: &str,
client: &Arc<RmcpClient>,
timeout: Option<Duration>,
) -> Result<Vec<ToolInfo>> {
let resp = client.list_tools_with_connector_ids(None, timeout).await?;
Ok(resp
.tools
.into_iter()
.map(|tool| {
let connector_name = tool.connector_name;
let mut tool_def = tool.tool;
if let Some(title) = tool_def.title.as_deref() {
let normalized_title =
normalize_codex_apps_tool_title(server_name, connector_name.as_deref(), title);
if tool_def.title.as_deref() != Some(normalized_title.as_str()) {
tool_def.title = Some(normalized_title);
}
}
ToolInfo {
server_name: server_name.to_owned(),
tool_name: tool_def.name.to_string(),
tool: tool_def,
connector_id: tool.connector_id,
connector_name,
}
})
.collect())
}
fn validate_mcp_server_name(server_name: &str) -> Result<()> {
let re = regex_lite::Regex::new(r"^[a-zA-Z0-9_-]+$")?;
if !re.is_match(server_name) {
return Err(anyhow!(
"Invalid MCP server name '{server_name}': must match pattern {pattern}",
pattern = re.as_str()
));
}
Ok(())
}
fn mcp_init_error_display(
server_name: &str,
entry: Option<&McpAuthStatusEntry>,
err: &StartupOutcomeError,
) -> String {
if let Some(McpServerTransportConfig::StreamableHttp {
url,
bearer_token_env_var,
http_headers,
..
}) = &entry.map(|entry| &entry.config.transport)
&& url == "https://api.githubcopilot.com/mcp/"
&& bearer_token_env_var.is_none()
&& http_headers.as_ref().map(HashMap::is_empty).unwrap_or(true)
{
format!(
"GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN"
)
} else if is_mcp_client_auth_required_error(err) {
format!(
"The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`."
)
} else if is_mcp_client_startup_timeout_error(err) {
let startup_timeout_secs = match entry {
Some(entry) => match entry.config.startup_timeout_sec {
Some(timeout) => timeout,
None => DEFAULT_STARTUP_TIMEOUT,
},
None => DEFAULT_STARTUP_TIMEOUT,
}
.as_secs();
format!(
"MCP client for `{server_name}` timed out after {startup_timeout_secs} seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.{server_name}]\nstartup_timeout_sec = XX"
)
} else {
format!("MCP client for `{server_name}` failed to start: {err:#}")
}
}
fn is_mcp_client_auth_required_error(error: &StartupOutcomeError) -> bool {
match error {
StartupOutcomeError::Failed { error } => error.contains("Auth required"),
_ => false,
}
}
fn is_mcp_client_startup_timeout_error(error: &StartupOutcomeError) -> bool {
match error {
StartupOutcomeError::Failed { error } => {
error.contains("request timed out")
|| error.contains("timed out handshaking with MCP server")
}
_ => false,
}
}
#[cfg(test)]
mod mcp_init_error_display_tests {}
#[cfg(test)]
mod tests {
use super::*;
use codex_protocol::protocol::McpAuthStatus;
use rmcp::model::JsonObject;
use std::collections::HashSet;
use std::sync::Arc;
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
ToolInfo {
server_name: server_name.to_string(),
tool_name: tool_name.to_string(),
tool: Tool {
name: tool_name.to_string().into(),
title: None,
description: Some(format!("Test tool: {tool_name}").into()),
input_schema: Arc::new(JsonObject::default()),
output_schema: None,
annotations: None,
icons: None,
meta: None,
},
connector_id: None,
connector_name: None,
}
}
#[test]
fn test_qualify_tools_short_non_duplicated_names() {
let tools = vec![
create_test_tool("server1", "tool1"),
create_test_tool("server1", "tool2"),
];
let qualified_tools = qualify_tools(tools);
assert_eq!(qualified_tools.len(), 2);
assert!(qualified_tools.contains_key("mcp__server1__tool1"));
assert!(qualified_tools.contains_key("mcp__server1__tool2"));
}
#[test]
fn test_qualify_tools_duplicated_names_skipped() {
let tools = vec![
create_test_tool("server1", "duplicate_tool"),
create_test_tool("server1", "duplicate_tool"),
];
let qualified_tools = qualify_tools(tools);
// Only the first tool should remain, the second is skipped
assert_eq!(qualified_tools.len(), 1);
assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool"));
}
#[test]
fn test_qualify_tools_long_names_same_server() {
let server_name = "my_server";
let tools = vec![
create_test_tool(
server_name,
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
),
create_test_tool(
server_name,
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
),
];
let qualified_tools = qualify_tools(tools);
assert_eq!(qualified_tools.len(), 2);
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
keys.sort();
assert_eq!(keys[0].len(), 64);
assert_eq!(
keys[0],
"mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9"
);
assert_eq!(keys[1].len(), 64);
assert_eq!(
keys[1],
"mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341"
);
}
#[test]
fn test_qualify_tools_sanitizes_invalid_characters() {
let tools = vec![create_test_tool("server.one", "tool.two")];
let qualified_tools = qualify_tools(tools);
assert_eq!(qualified_tools.len(), 1);
let (qualified_name, tool) = qualified_tools.into_iter().next().expect("one tool");
assert_eq!(qualified_name, "mcp__server_one__tool_two");
// The key is sanitized for OpenAI, but we keep original parts for the actual MCP call.
assert_eq!(tool.server_name, "server.one");
assert_eq!(tool.tool_name, "tool.two");
assert!(
qualified_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'),
"qualified name must be Responses API compatible: {qualified_name:?}"
);
}
#[test]
fn tool_filter_allows_by_default() {
let filter = ToolFilter::default();
assert!(filter.allows("any"));
}
#[test]
fn tool_filter_applies_enabled_list() {
let filter = ToolFilter {
enabled: Some(HashSet::from(["allowed".to_string()])),
disabled: HashSet::new(),
};
assert!(filter.allows("allowed"));
assert!(!filter.allows("denied"));
}
#[test]
fn tool_filter_applies_disabled_list() {
let filter = ToolFilter {
enabled: None,
disabled: HashSet::from(["blocked".to_string()]),
};
assert!(!filter.allows("blocked"));
assert!(filter.allows("open"));
}
#[test]
fn tool_filter_applies_enabled_then_disabled() {
let filter = ToolFilter {
enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])),
disabled: HashSet::from(["remove".to_string()]),
};
assert!(filter.allows("keep"));
assert!(!filter.allows("remove"));
assert!(!filter.allows("unknown"));
}
#[test]
fn filter_tools_applies_per_server_filters() {
let server1_tools = vec![
create_test_tool("server1", "tool_a"),
create_test_tool("server1", "tool_b"),
];
let server2_tools = vec![create_test_tool("server2", "tool_a")];
let server1_filter = ToolFilter {
enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])),
disabled: HashSet::from(["tool_b".to_string()]),
};
let server2_filter = ToolFilter {
enabled: None,
disabled: HashSet::from(["tool_a".to_string()]),
};
let filtered: Vec<_> = filter_tools(server1_tools, server1_filter)
.into_iter()
.chain(filter_tools(server2_tools, server2_filter))
.collect();
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].server_name, "server1");
assert_eq!(filtered[0].tool_name, "tool_a");
}
#[test]
fn mcp_init_error_display_prompts_for_github_pat() {
let server_name = "github";
let entry = McpAuthStatusEntry {
config: McpServerConfig {
transport: McpServerTransportConfig::StreamableHttp {
url: "https://api.githubcopilot.com/mcp/".to_string(),
bearer_token_env_var: None,
http_headers: None,
env_http_headers: None,
},
enabled: true,
disabled_reason: None,
startup_timeout_sec: None,
tool_timeout_sec: None,
enabled_tools: None,
disabled_tools: None,
scopes: None,
},
auth_status: McpAuthStatus::Unsupported,
};
let err: StartupOutcomeError = anyhow::anyhow!("OAuth is unsupported").into();
let display = mcp_init_error_display(server_name, Some(&entry), &err);
let expected = format!(
"GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN"
);
assert_eq!(expected, display);
}
#[test]
fn mcp_init_error_display_prompts_for_login_when_auth_required() {
let server_name = "example";
let err: StartupOutcomeError = anyhow::anyhow!("Auth required for server").into();
let display = mcp_init_error_display(server_name, None, &err);
let expected = format!(
"The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`."
);
assert_eq!(expected, display);
}
#[test]
fn mcp_init_error_display_reports_generic_errors() {
let server_name = "custom";
let entry = McpAuthStatusEntry {
config: McpServerConfig {
transport: McpServerTransportConfig::StreamableHttp {
url: "https://example.com".to_string(),
bearer_token_env_var: Some("TOKEN".to_string()),
http_headers: None,
env_http_headers: None,
},
enabled: true,
disabled_reason: None,
startup_timeout_sec: None,
tool_timeout_sec: None,
enabled_tools: None,
disabled_tools: None,
scopes: None,
},
auth_status: McpAuthStatus::Unsupported,
};
let err: StartupOutcomeError = anyhow::anyhow!("boom").into();
let display = mcp_init_error_display(server_name, Some(&entry), &err);
let expected = format!("MCP client for `{server_name}` failed to start: {err:#}");
assert_eq!(expected, display);
}
#[test]
fn mcp_init_error_display_includes_startup_timeout_hint() {
let server_name = "slow";
let err: StartupOutcomeError = anyhow::anyhow!("request timed out").into();
let display = mcp_init_error_display(server_name, None, &err);
assert_eq!(
"MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX",
display
);
}
}