Compare commits

...

4 Commits

Author SHA1 Message Date
Joe Gershenson
4da08c2b31 fix(ci): restore stdio MCP pool checks 2026-03-20 15:31:04 -07:00
Joe Gershenson
05c88b6639 feat(core): pool stdio MCP backends in thread manager 2026-03-20 11:50:47 -07:00
Joe Gershenson
7591ac9184 fix(core): align rebased MCP manager API 2026-03-19 23:52:09 -07:00
Joe Gershenson
945b674884 refactor(core): split MCP backend from session state 2026-03-19 23:37:06 -07:00
14 changed files with 1771 additions and 162 deletions

6
MODULE.bazel.lock generated

File diff suppressed because one or more lines are too long

12
codex-rs/Cargo.lock generated
View File

@@ -800,9 +800,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-lc-rs"
version = "1.15.4"
version = "1.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256"
checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc"
dependencies = [
"aws-lc-sys",
"untrusted 0.7.1",
@@ -811,9 +811,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.37.0"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a"
checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a"
dependencies = [
"cc",
"cmake",
@@ -8239,9 +8239,9 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.103.9"
version = "0.103.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef"
dependencies = [
"aws-lc-rs",
"ring",

View File

@@ -3,5 +3,8 @@ load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "app-server",
crate_name = "codex_app_server",
extra_binaries = [
"//codex-rs/rmcp-client:test_stdio_server",
],
test_tags = ["no-sandbox"],
)

View File

@@ -33,6 +33,7 @@ mod thread_archive;
mod thread_fork;
mod thread_list;
mod thread_loaded_list;
mod thread_mcp_pool;
mod thread_metadata_update;
mod thread_name_websocket;
mod thread_read;

View File

@@ -0,0 +1,450 @@
use std::io::ErrorKind;
use std::path::Path;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use app_test_support::McpProcess;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ServerNotification;
use codex_app_server_protocol::ThreadArchiveParams;
use codex_app_server_protocol::ThreadArchiveResponse;
use codex_app_server_protocol::ThreadArchivedNotification;
use codex_app_server_protocol::ThreadClosedNotification;
use codex_app_server_protocol::ThreadResumeParams;
use codex_app_server_protocol::ThreadResumeResponse;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStatus;
use codex_app_server_protocol::ThreadUnarchiveParams;
use codex_app_server_protocol::ThreadUnarchiveResponse;
use codex_app_server_protocol::ThreadUnarchivedNotification;
use codex_app_server_protocol::ThreadUnsubscribeParams;
use codex_app_server_protocol::ThreadUnsubscribeResponse;
use codex_app_server_protocol::ThreadUnsubscribeStatus;
use codex_app_server_protocol::TurnCompletedNotification;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::TurnStatus;
use codex_app_server_protocol::UserInput;
use core_test_support::responses;
use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::Instant;
use tokio::time::sleep;
use tokio::time::timeout;
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(10);
const STARTUP_COUNT_POLL_INTERVAL: Duration = Duration::from_millis(25);
const STARTUP_COUNT_WAIT_TIMEOUT: Duration = Duration::from_secs(5);
const STARTUP_COUNT_STABILITY_WINDOW: Duration = Duration::from_millis(250);
const STARTUP_COUNT_FILE_ENV_VAR: &str = "MCP_STARTUP_COUNT_FILE";
#[tokio::test]
async fn mcp_pool_survives_unsubscribe_of_one_loaded_thread() -> Result<()> {
skip_if_no_network!(Ok(()));
let responses_server =
create_mock_responses_server_sequence_unchecked(rmcp_echo_turn_bodies("after-unsubscribe"))
.await;
let codex_home = TempDir::new()?;
let startup_count_file = codex_home.path().join("rmcp-startups.log");
create_config_toml(
codex_home.path(),
&responses_server.uri(),
&stdio_server_bin()?,
)?;
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
let mut mcp = McpProcess::new_with_env(
codex_home.path(),
&[(
STARTUP_COUNT_FILE_ENV_VAR,
Some(startup_count_file_value.as_str()),
)],
)
.await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let thread_a = start_thread(&mut mcp).await?;
wait_for_startup_count(&startup_count_file, 1).await?;
let thread_b = start_thread(&mut mcp).await?;
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
unsubscribe_thread(&mut mcp, &thread_a).await?;
run_rmcp_echo_turn(&mut mcp, &thread_b, "after-unsubscribe").await?;
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
Ok(())
}
#[tokio::test]
async fn mcp_pool_survives_archive_of_one_loaded_thread() -> Result<()> {
skip_if_no_network!(Ok(()));
let responses_server = create_mock_responses_server_sequence_unchecked(
[
rmcp_echo_turn_bodies("materialize-archived-thread"),
rmcp_echo_turn_bodies("after-archive"),
]
.into_iter()
.flatten()
.collect(),
)
.await;
let codex_home = TempDir::new()?;
let startup_count_file = codex_home.path().join("rmcp-startups.log");
create_config_toml(
codex_home.path(),
&responses_server.uri(),
&stdio_server_bin()?,
)?;
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
let mut mcp = McpProcess::new_with_env(
codex_home.path(),
&[(
STARTUP_COUNT_FILE_ENV_VAR,
Some(startup_count_file_value.as_str()),
)],
)
.await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let thread_a = start_thread(&mut mcp).await?;
wait_for_startup_count(&startup_count_file, 1).await?;
let thread_b = start_thread(&mut mcp).await?;
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
run_rmcp_echo_turn(&mut mcp, &thread_a, "materialize-archived-thread").await?;
archive_thread(&mut mcp, &thread_a).await?;
run_rmcp_echo_turn(&mut mcp, &thread_b, "after-archive").await?;
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
Ok(())
}
#[tokio::test]
async fn mcp_pool_recreates_backend_after_last_archive_and_resume() -> Result<()> {
skip_if_no_network!(Ok(()));
let responses_server = create_mock_responses_server_sequence_unchecked(
[
rmcp_echo_turn_bodies("materialize-a"),
rmcp_echo_turn_bodies("materialize-b"),
rmcp_echo_turn_bodies("after-resume"),
]
.into_iter()
.flatten()
.collect(),
)
.await;
let codex_home = TempDir::new()?;
let startup_count_file = codex_home.path().join("rmcp-startups.log");
create_config_toml(
codex_home.path(),
&responses_server.uri(),
&stdio_server_bin()?,
)?;
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
let mut mcp = McpProcess::new_with_env(
codex_home.path(),
&[(
STARTUP_COUNT_FILE_ENV_VAR,
Some(startup_count_file_value.as_str()),
)],
)
.await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let thread_a = start_thread(&mut mcp).await?;
wait_for_startup_count(&startup_count_file, 1).await?;
let thread_b = start_thread(&mut mcp).await?;
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
run_rmcp_echo_turn(&mut mcp, &thread_a, "materialize-a").await?;
run_rmcp_echo_turn(&mut mcp, &thread_b, "materialize-b").await?;
archive_thread(&mut mcp, &thread_a).await?;
archive_thread(&mut mcp, &thread_b).await?;
let unarchive = unarchive_thread(&mut mcp, &thread_a).await?;
assert_eq!(unarchive.thread.status, ThreadStatus::NotLoaded);
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
let resume = resume_thread(&mut mcp, &thread_a).await?;
assert_eq!(resume.thread.status, ThreadStatus::Idle);
wait_for_startup_count(&startup_count_file, 2).await?;
run_rmcp_echo_turn(&mut mcp, &thread_a, "after-resume").await?;
assert_startup_count_stays(&startup_count_file, 2, STARTUP_COUNT_STABILITY_WINDOW).await?;
Ok(())
}
fn create_config_toml(codex_home: &Path, server_uri: &str, rmcp_server_bin: &str) -> Result<()> {
let config_toml = codex_home.join("config.toml");
let server_uri = serde_json::to_string(&format!("{server_uri}/v1"))?;
let rmcp_server_bin = serde_json::to_string(rmcp_server_bin)?;
std::fs::write(
config_toml,
format!(
r#"model = "mock-model"
approval_policy = "never"
sandbox_mode = "read-only"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = {server_uri}
wire_api = "responses"
request_max_retries = 0
stream_max_retries = 0
[mcp_servers.rmcp]
command = {rmcp_server_bin}
env_vars = ["{STARTUP_COUNT_FILE_ENV_VAR}"]
startup_timeout_sec = 10.0
"#
),
)
.context("write config.toml")
}
fn rmcp_echo_turn_bodies(label: &str) -> Vec<String> {
let response_id = format!("resp-{label}");
let completion_id = format!("resp-{label}-done");
let call_id = format!("call-rmcp-echo-{label}");
let assistant_message_id = format!("msg-{label}");
let message = format!("ping-{label}");
let final_text = format!("rmcp echo completed for {label}");
let arguments = json!({ "message": message }).to_string();
vec![
responses::sse(vec![
responses::ev_response_created(&response_id),
responses::ev_function_call(&call_id, "mcp__rmcp__echo", &arguments),
responses::ev_completed(&response_id),
]),
responses::sse(vec![
responses::ev_response_created(&completion_id),
responses::ev_assistant_message(&assistant_message_id, &final_text),
responses::ev_completed(&completion_id),
]),
]
}
async fn start_thread(mcp: &mut McpProcess) -> Result<String> {
let request_id = mcp
.send_thread_start_request(ThreadStartParams {
model: Some("mock-model".to_string()),
..Default::default()
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let ThreadStartResponse { thread, .. } = to_response(response)?;
Ok(thread.id)
}
async fn run_rmcp_echo_turn(mcp: &mut McpProcess, thread_id: &str, label: &str) -> Result<()> {
let request_id = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread_id.to_string(),
input: vec![UserInput::Text {
text: format!("call the rmcp echo tool for {label}"),
text_elements: Vec::new(),
}],
..Default::default()
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let TurnStartResponse { turn } = to_response(response)?;
let deadline = Instant::now() + DEFAULT_READ_TIMEOUT;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
let message = timeout(remaining, mcp.read_next_message()).await??;
let codex_app_server_protocol::JSONRPCMessage::Notification(notification) = message else {
continue;
};
if notification.method != "turn/completed" {
continue;
}
let completed: TurnCompletedNotification = serde_json::from_value(
notification
.params
.clone()
.context("turn/completed params")?,
)?;
if completed.thread_id != thread_id || completed.turn.id != turn.id {
continue;
}
assert_eq!(completed.turn.status, TurnStatus::Completed);
return Ok(());
}
}
async fn unsubscribe_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<()> {
let request_id = mcp
.send_thread_unsubscribe_request(ThreadUnsubscribeParams {
thread_id: thread_id.to_string(),
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let unsubscribe: ThreadUnsubscribeResponse = to_response(response)?;
assert_eq!(unsubscribe.status, ThreadUnsubscribeStatus::Unsubscribed);
let notification: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/closed"),
)
.await??;
let ServerNotification::ThreadClosed(ThreadClosedNotification {
thread_id: closed_thread_id,
}) = notification.try_into()?
else {
bail!("expected thread/closed notification");
};
assert_eq!(closed_thread_id, thread_id);
Ok(())
}
async fn archive_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<()> {
let request_id = mcp
.send_thread_archive_request(ThreadArchiveParams {
thread_id: thread_id.to_string(),
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let _: ThreadArchiveResponse = to_response(response)?;
let notification: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/archived"),
)
.await??;
let archived: ThreadArchivedNotification =
serde_json::from_value(notification.params.context("thread/archived params")?)?;
assert_eq!(archived.thread_id, thread_id);
Ok(())
}
async fn unarchive_thread(
mcp: &mut McpProcess,
thread_id: &str,
) -> Result<ThreadUnarchiveResponse> {
let request_id = mcp
.send_thread_unarchive_request(ThreadUnarchiveParams {
thread_id: thread_id.to_string(),
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let unarchive: ThreadUnarchiveResponse = to_response(response)?;
let notification: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/unarchived"),
)
.await??;
let unarchived: ThreadUnarchivedNotification =
serde_json::from_value(notification.params.context("thread/unarchived params")?)?;
assert_eq!(unarchived.thread_id, thread_id);
Ok(unarchive)
}
async fn resume_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<ThreadResumeResponse> {
let request_id = mcp
.send_thread_resume_request(ThreadResumeParams {
thread_id: thread_id.to_string(),
..Default::default()
})
.await?;
let response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
to_response(response)
}
async fn wait_for_startup_count(path: &Path, expected: usize) -> Result<()> {
let deadline = Instant::now() + STARTUP_COUNT_WAIT_TIMEOUT;
loop {
let actual = read_startup_count(path)?;
if actual == expected {
return Ok(());
}
if actual > expected {
bail!("startup count exceeded expectation: expected {expected}, got {actual}");
}
if Instant::now() >= deadline {
bail!("timed out waiting for startup count {expected}; last observed {actual}");
}
sleep(STARTUP_COUNT_POLL_INTERVAL).await;
}
}
async fn assert_startup_count_stays(
path: &Path,
expected: usize,
duration: Duration,
) -> Result<()> {
let deadline = Instant::now() + duration;
loop {
let actual = read_startup_count(path)?;
assert_eq!(actual, expected);
if Instant::now() >= deadline {
return Ok(());
}
sleep(STARTUP_COUNT_POLL_INTERVAL).await;
}
}
fn read_startup_count(path: &Path) -> Result<usize> {
let contents = match std::fs::read_to_string(path) {
Ok(contents) => contents,
Err(error) if error.kind() == ErrorKind::NotFound => return Ok(0),
Err(error) => return Err(error).context("read startup count file"),
};
Ok(contents.lines().count())
}

View File

@@ -8,6 +8,7 @@ use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::find_archived_thread_path_by_id_str;
use crate::find_thread_path_by_id_str;
use crate::mcp_connection_manager::SharedMcpBackendPool;
use crate::rollout::RolloutRecorder;
use crate::session_prefix::format_subagent_context_line;
use crate::session_prefix::format_subagent_notification_message;
@@ -90,6 +91,12 @@ impl AgentControl {
}
}
pub(crate) fn shared_mcp_backend_pool(&self) -> Option<Arc<SharedMcpBackendPool>> {
self.manager
.upgrade()
.map(|state| state.shared_mcp_backend_pool())
}
/// Spawn a new agent thread and submit the initial prompt.
pub(crate) async fn spawn_agent(
&self,

View File

@@ -220,6 +220,7 @@ use crate::mcp::auth::compute_auth_statuses;
use crate::mcp::maybe_prompt_and_install_mcp_dependencies;
use crate::mcp::with_codex_apps_mcp;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_connection_manager::SharedMcpBackendAcquireMode;
use crate::mcp_connection_manager::codex_apps_tools_cache_key;
use crate::mcp_connection_manager::filter_non_codex_apps_mcp_tools_only;
use crate::memories;
@@ -1249,6 +1250,23 @@ impl Session {
per_turn_config
}
fn sandbox_state_for_session_configuration(
session_configuration: &SessionConfiguration,
) -> SandboxState {
SandboxState {
sandbox_policy: session_configuration.sandbox_policy.get().clone(),
codex_linux_sandbox_exe: session_configuration
.original_config_do_not_use
.codex_linux_sandbox_exe
.clone(),
sandbox_cwd: session_configuration.cwd.clone(),
use_legacy_landlock: session_configuration
.original_config_do_not_use
.features
.use_legacy_landlock(),
}
}
pub(crate) async fn codex_home(&self) -> PathBuf {
let state = self.state.lock().await;
state.session_configuration.codex_home().clone()
@@ -1800,6 +1818,7 @@ impl Session {
&config.permissions.approval_policy,
))),
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
shared_mcp_backend_pool: agent_control.shared_mcp_backend_pool(),
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
@@ -1924,17 +1943,40 @@ impl Session {
cancel_guard.cancel();
*cancel_guard = CancellationToken::new();
}
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_statuses.clone(),
&session_configuration.approval_policy,
tx_event.clone(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth),
tool_plugin_provenance,
)
let (mcp_connection_manager, cancel_token) = async {
match sess.services.shared_mcp_backend_pool.as_ref() {
Some(shared_mcp_backend_pool) => {
McpConnectionManager::new_with_pool(
shared_mcp_backend_pool.as_ref(),
SharedMcpBackendAcquireMode::ReuseExisting,
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_statuses.clone(),
&session_configuration.approval_policy,
tx_event.clone(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth),
tool_plugin_provenance,
)
.await
}
None => {
McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_statuses.clone(),
&session_configuration.approval_policy,
tx_event.clone(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth),
tool_plugin_provenance,
)
.await
}
}
}
.instrument(info_span!(
"session_init.mcp_manager_init",
otel.name = "session_init.mcp_manager_init",
@@ -2308,23 +2350,29 @@ impl Session {
) -> ConstraintResult<Arc<TurnContext>> {
let (
session_configuration,
sandbox_policy_changed,
sandbox_state,
sandbox_state_changed,
previous_cwd,
codex_home,
session_source,
) = {
let mut state = self.state.lock().await;
match state.session_configuration.clone().apply(&updates) {
let previous_session_configuration = state.session_configuration.clone();
match previous_session_configuration.apply(&updates) {
Ok(next) => {
let previous_cwd = state.session_configuration.cwd.clone();
let sandbox_policy_changed =
state.session_configuration.sandbox_policy != next.sandbox_policy;
let previous_cwd = previous_session_configuration.cwd.clone();
let previous_sandbox_state = Self::sandbox_state_for_session_configuration(
&previous_session_configuration,
);
let sandbox_state = Self::sandbox_state_for_session_configuration(&next);
let sandbox_state_changed = previous_sandbox_state != sandbox_state;
let codex_home = next.codex_home.clone();
let session_source = next.session_source.clone();
state.session_configuration = next.clone();
(
next,
sandbox_policy_changed,
sandbox_state,
sandbox_state_changed,
previous_cwd,
codex_home,
session_source,
@@ -2357,7 +2405,8 @@ impl Session {
sub_id,
session_configuration,
updates.final_output_json_schema,
sandbox_policy_changed,
sandbox_state,
sandbox_state_changed,
)
.await)
}
@@ -2367,23 +2416,55 @@ impl Session {
sub_id: String,
session_configuration: SessionConfiguration,
final_output_json_schema: Option<Option<Value>>,
sandbox_policy_changed: bool,
sandbox_state: SandboxState,
sandbox_state_changed: bool,
) -> Arc<TurnContext> {
let per_turn_config = Self::build_per_turn_config(&session_configuration);
self.services
.mcp_connection_manager
.read()
.await
.set_approval_policy(&session_configuration.approval_policy);
if sandbox_policy_changed {
let sandbox_state = SandboxState {
sandbox_policy: per_turn_config.permissions.sandbox_policy.get().clone(),
codex_linux_sandbox_exe: per_turn_config.codex_linux_sandbox_exe.clone(),
sandbox_cwd: per_turn_config.cwd.clone(),
use_legacy_landlock: per_turn_config.features.use_legacy_landlock(),
};
if let Err(e) = self
if sandbox_state_changed {
if let Some(shared_mcp_backend_pool) = self.services.shared_mcp_backend_pool.as_ref() {
let auth = self.services.auth_manager.auth().await;
let config = session_configuration.original_config_do_not_use.clone();
let mcp_servers = self
.services
.mcp_manager
.effective_servers(config.as_ref(), auth.as_ref());
let auth_statuses = compute_auth_statuses(
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
)
.await;
let tool_plugin_provenance = self
.services
.mcp_manager
.tool_plugin_provenance(config.as_ref());
{
let manager = self.services.mcp_connection_manager.read().await;
if let Err(e) = manager
.notify_local_sandbox_state_change(&sandbox_state)
.await
{
warn!("failed to notify local MCP sandbox state change: {e:#}");
}
let rebuilt_manager = manager
.rebuild_pooled_backend(
shared_mcp_backend_pool.as_ref(),
SharedMcpBackendAcquireMode::ReuseExisting,
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_statuses,
self.get_tx_event(),
sandbox_state.clone(),
config.codex_home.clone(),
codex_apps_tools_cache_key(auth.as_ref()),
tool_plugin_provenance,
)
.await;
drop(manager);
let mut manager = self.services.mcp_connection_manager.write().await;
*manager = rebuilt_manager;
}
} else if let Err(e) = self
.services
.mcp_connection_manager
.read()
@@ -2394,6 +2475,11 @@ impl Session {
warn!("Failed to notify sandbox state change to MCP servers: {e:#}");
}
}
self.services
.mcp_connection_manager
.read()
.await
.set_approval_policy(&session_configuration.approval_policy);
let model_info = self
.services
@@ -2536,9 +2622,10 @@ impl Session {
};
self.new_turn_from_configuration(
sub_id,
session_configuration,
session_configuration.clone(),
/*final_output_json_schema*/ None,
/*sandbox_policy_changed*/ false,
Self::sandbox_state_for_session_configuration(&session_configuration),
/*sandbox_state_changed*/ false,
)
.await
}
@@ -4059,18 +4146,39 @@ impl Session {
guard.cancel();
*guard = CancellationToken::new();
}
let (refreshed_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
store_mode,
auth_statuses,
&turn_context.config.permissions.approval_policy,
self.get_tx_event(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth.as_ref()),
tool_plugin_provenance,
)
.await;
let (refreshed_manager, cancel_token) = match self.services.shared_mcp_backend_pool.as_ref()
{
Some(shared_mcp_backend_pool) => {
McpConnectionManager::new_with_pool(
shared_mcp_backend_pool.as_ref(),
SharedMcpBackendAcquireMode::ForceCreate,
&mcp_servers,
store_mode,
auth_statuses,
&turn_context.config.permissions.approval_policy,
self.get_tx_event(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth.as_ref()),
tool_plugin_provenance,
)
.await
}
None => {
McpConnectionManager::new(
&mcp_servers,
store_mode,
auth_statuses,
&turn_context.config.permissions.approval_policy,
self.get_tx_event(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(auth.as_ref()),
tool_plugin_provenance,
)
.await
}
};
{
let mut guard = self.services.mcp_startup_cancellation_token.lock().await;
if guard.is_cancelled() {

View File

@@ -10,6 +10,9 @@ use crate::config_loader::Sourced;
use crate::exec::ExecCapturePolicy;
use crate::exec::ExecToolCallOutput;
use crate::function_tool::FunctionCallError;
use crate::mcp::ToolPluginProvenance;
use crate::mcp_connection_manager::SharedMcpBackendAcquireMode;
use crate::mcp_connection_manager::SharedMcpBackendPool;
use crate::mcp_connection_manager::ToolInfo;
use crate::models_manager::model_info;
use crate::shell::default_user_shell;
@@ -2466,6 +2469,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
),
)),
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
shared_mcp_backend_pool: None,
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
@@ -3172,6 +3176,17 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
Arc<Session>,
Arc<TurnContext>,
async_channel::Receiver<Event>,
) {
make_session_and_context_with_dynamic_tools_and_rx_and_pool(dynamic_tools, false).await
}
async fn make_session_and_context_with_dynamic_tools_and_rx_and_pool(
dynamic_tools: Vec<DynamicToolSpec>,
pooled: bool,
) -> (
Arc<Session>,
Arc<TurnContext>,
async_channel::Receiver<Event>,
) {
let (tx_event, rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
@@ -3256,15 +3271,44 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
.await
.expect("create environment"),
);
let shared_mcp_backend_pool = pooled.then(|| Arc::new(SharedMcpBackendPool::new()));
let sandbox_state = SandboxState {
sandbox_policy: session_configuration.sandbox_policy.get().clone(),
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
sandbox_cwd: session_configuration.cwd.clone(),
use_legacy_landlock: config.features.use_legacy_landlock(),
};
let (mcp_connection_manager, mcp_startup_cancellation_token) =
match shared_mcp_backend_pool.as_ref() {
Some(shared_mcp_backend_pool) => {
McpConnectionManager::new_with_pool(
shared_mcp_backend_pool.as_ref(),
SharedMcpBackendAcquireMode::ReuseExisting,
&HashMap::new(),
config.mcp_oauth_credentials_store_mode,
HashMap::new(),
&config.permissions.approval_policy,
tx_event.clone(),
sandbox_state,
config.codex_home.clone(),
codex_apps_tools_cache_key(Some(&auth_manager.auth().await.expect("auth"))),
ToolPluginProvenance::default(),
)
.await
}
None => (
McpConnectionManager::new_mcp_connection_manager_for_tests(
&config.permissions.approval_policy,
),
CancellationToken::new(),
),
};
let file_watcher = Arc::new(FileWatcher::noop());
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(
McpConnectionManager::new_mcp_connection_manager_for_tests(
&config.permissions.approval_policy,
),
)),
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
mcp_connection_manager: Arc::new(RwLock::new(mcp_connection_manager)),
mcp_startup_cancellation_token: Mutex::new(mcp_startup_cancellation_token),
shared_mcp_backend_pool,
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
@@ -3365,6 +3409,14 @@ pub(crate) async fn make_session_and_context_with_rx() -> (
make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await
}
async fn make_pooled_session_and_context_with_rx() -> (
Arc<Session>,
Arc<TurnContext>,
async_channel::Receiver<Event>,
) {
make_session_and_context_with_dynamic_tools_and_rx_and_pool(Vec::new(), true).await
}
#[tokio::test]
async fn refresh_mcp_servers_is_deferred_until_next_turn() {
let (session, turn_context) = make_session_and_context().await;
@@ -3407,6 +3459,29 @@ async fn refresh_mcp_servers_is_deferred_until_next_turn() {
assert!(!new_token.is_cancelled());
}
#[tokio::test]
async fn pooled_new_turn_keeps_local_mcp_startup_token_when_sandbox_state_changes() {
let (session, _turn_context, _rx) = make_pooled_session_and_context_with_rx().await;
let old_token = session.mcp_startup_cancellation_token().await;
assert!(!old_token.is_cancelled());
let new_cwd = session.get_config().await.cwd.join("other-worktree");
session
.new_turn_with_sub_id(
"sandbox-change".to_string(),
SessionSettingsUpdate {
cwd: Some(new_cwd),
..SessionSettingsUpdate::default()
},
)
.await
.expect("new turn");
assert!(!old_token.is_cancelled());
let new_token = session.mcp_startup_cancellation_token().await;
assert!(!new_token.is_cancelled());
}
#[tokio::test]
async fn record_model_warning_appends_user_message() {
let (mut session, turn_context) = make_session_and_context().await;

View File

@@ -248,6 +248,14 @@ enum CachedCodexAppsToolsLoad {
type ResponderMap = HashMap<(String, RequestId), oneshot::Sender<ElicitationResponse>>;
fn decline_elicitation_response() -> ElicitationResponse {
ElicitationResponse {
action: ElicitationAction::Decline,
content: None,
meta: None,
}
}
fn elicitation_is_rejected_by_policy(approval_policy: AskForApproval) -> bool {
match approval_policy {
AskForApproval::Never => true,
@@ -300,11 +308,7 @@ impl ElicitationRequestManager {
.lock()
.is_ok_and(|policy| elicitation_is_rejected_by_policy(*policy))
{
return Ok(ElicitationResponse {
action: ElicitationAction::Decline,
content: None,
meta: None,
});
return Ok(decline_elicitation_response());
}
let request = match elicitation {
@@ -584,7 +588,7 @@ pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
/// When used, the `params` field of the notification is [`SandboxState`].
pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SandboxState {
pub sandbox_policy: SandboxPolicy,
@@ -595,50 +599,27 @@ pub struct SandboxState {
}
/// A thin wrapper around a set of running [`RmcpClient`] instances.
pub(crate) struct McpConnectionManager {
struct SharedMcpBackend {
clients: HashMap<String, AsyncManagedClient>,
server_origins: HashMap<String, String>,
elicitation_requests: ElicitationRequestManager,
}
impl McpConnectionManager {
pub(crate) fn new_uninitialized(approval_policy: &Constrained<AskForApproval>) -> Self {
impl SharedMcpBackend {
fn new_uninitialized() -> Self {
Self {
clients: HashMap::new(),
server_origins: HashMap::new(),
elicitation_requests: ElicitationRequestManager::new(approval_policy.value()),
}
}
#[cfg(test)]
pub(crate) fn new_mcp_connection_manager_for_tests(
approval_policy: &Constrained<AskForApproval>,
) -> Self {
Self::new_uninitialized(approval_policy)
}
pub(crate) fn has_servers(&self) -> bool {
!self.clients.is_empty()
}
pub(crate) fn server_origin(&self, server_name: &str) -> Option<&str> {
self.server_origins.get(server_name).map(String::as_str)
}
pub fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
if let Ok(mut policy) = self.elicitation_requests.approval_policy.lock() {
*policy = approval_policy.value();
}
}
#[allow(clippy::new_ret_no_self, clippy::too_many_arguments)]
pub async fn new(
async fn new(
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
approval_policy: &Constrained<AskForApproval>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
session_handle: SessionMcpHandle,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
@@ -647,7 +628,7 @@ impl McpConnectionManager {
let mut clients = HashMap::new();
let mut server_origins = HashMap::new();
let mut join_set = JoinSet::new();
let elicitation_requests = ElicitationRequestManager::new(approval_policy.value());
let elicitation_requests = session_handle.elicitation_requests();
let tool_plugin_provenance = Arc::new(tool_plugin_provenance);
let mcp_servers = mcp_servers.clone();
for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) {
@@ -725,10 +706,9 @@ impl McpConnectionManager {
(server_name, outcome)
});
}
let manager = Self {
let backend = Self {
clients,
server_origins,
elicitation_requests: elicitation_requests.clone(),
};
tokio::spawn(async move {
let outcomes = join_set.join_all().await;
@@ -752,7 +732,19 @@ impl McpConnectionManager {
})
.await;
});
(manager, cancel_token)
(backend, cancel_token)
}
fn has_servers(&self) -> bool {
!self.clients.is_empty()
}
fn server_origin(&self, server_name: &str) -> Option<&str> {
self.server_origins.get(server_name).map(String::as_str)
}
fn contains_server(&self, server_name: &str) -> bool {
self.clients.contains_key(server_name)
}
async fn client_by_name(&self, name: &str) -> Result<ManagedClient> {
@@ -764,18 +756,7 @@ impl McpConnectionManager {
.context("failed to get client")
}
pub async fn resolve_elicitation(
&self,
server_name: String,
id: RequestId,
response: ElicitationResponse,
) -> Result<()> {
self.elicitation_requests
.resolve(server_name, id, response)
.await
}
pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
let Some(async_managed_client) = self.clients.get(server_name) else {
return false;
};
@@ -786,7 +767,7 @@ impl McpConnectionManager {
}
}
pub(crate) async fn required_startup_failures(
async fn required_startup_failures(
&self,
required_servers: &[String],
) -> Vec<McpStartupFailure> {
@@ -814,7 +795,7 @@ impl McpConnectionManager {
/// Returns a single map that contains all tools. Each key is the
/// fully-qualified name for the tool.
#[instrument(level = "trace", skip_all)]
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
let mut tools = HashMap::new();
for managed_client in self.clients.values() {
let Some(server_tools) = managed_client.listed_tools().await else {
@@ -830,7 +811,7 @@ impl McpConnectionManager {
/// On success, the refreshed tools replace the cache contents and the
/// latest filtered tool map is returned directly to the caller. On
/// failure, the existing cache remains unchanged.
pub async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
let managed_client = self
.clients
.get(CODEX_APPS_MCP_SERVER_NAME)
@@ -874,7 +855,7 @@ impl McpConnectionManager {
/// Returns a single map that contains all resources. Each key is the
/// server name and the value is a vector of resources.
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
let mut join_set = JoinSet::new();
let clients_snapshot = &self.clients;
@@ -886,7 +867,6 @@ impl McpConnectionManager {
};
let timeout = managed_client.tool_timeout;
let client = managed_client.client.clone();
join_set.spawn(async move {
let mut collected: Vec<Resource> = Vec::new();
let mut cursor: Option<String> = None;
@@ -940,7 +920,7 @@ impl McpConnectionManager {
/// Returns a single map that contains all resource templates. Each key is the
/// server name and the value is a vector of resource templates.
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
let mut join_set = JoinSet::new();
let clients_snapshot = &self.clients;
@@ -952,7 +932,6 @@ impl McpConnectionManager {
};
let client = managed_client.client.clone();
let timeout = managed_client.tool_timeout;
join_set.spawn(async move {
let mut collected: Vec<ResourceTemplate> = Vec::new();
let mut cursor: Option<String> = None;
@@ -1009,7 +988,7 @@ impl McpConnectionManager {
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
async fn call_tool(
&self,
server: &str,
tool: &str,
@@ -1047,7 +1026,7 @@ impl McpConnectionManager {
}
/// List resources from the specified server.
pub async fn list_resources(
async fn list_resources(
&self,
server: &str,
params: Option<PaginatedRequestParams>,
@@ -1063,7 +1042,7 @@ impl McpConnectionManager {
}
/// List resource templates from the specified server.
pub async fn list_resource_templates(
async fn list_resource_templates(
&self,
server: &str,
params: Option<PaginatedRequestParams>,
@@ -1079,7 +1058,7 @@ impl McpConnectionManager {
}
/// Read a resource from the specified server.
pub async fn read_resource(
async fn read_resource(
&self,
server: &str,
params: ReadResourceRequestParams,
@@ -1095,14 +1074,14 @@ impl McpConnectionManager {
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
}
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
self.list_all_tools()
.await
.get(tool_name)
.map(|tool| (tool.server_name.clone(), tool.tool.name.to_string()))
}
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
let mut join_set = JoinSet::new();
for async_managed_client in self.clients.values() {
@@ -1131,6 +1110,625 @@ impl McpConnectionManager {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct SharedMcpBackendCacheKey(String);
fn canonicalize_json_value(value: serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Array(values) => {
serde_json::Value::Array(values.into_iter().map(canonicalize_json_value).collect())
}
serde_json::Value::Object(object) => {
let mut entries = object.into_iter().collect::<Vec<_>>();
entries.sort_by(|(left_key, _), (right_key, _)| left_key.cmp(right_key));
let mut canonical = serde_json::Map::with_capacity(entries.len());
for (key, value) in entries {
canonical.insert(key, canonicalize_json_value(value));
}
serde_json::Value::Object(canonical)
}
other => other,
}
}
impl SharedMcpBackendCacheKey {
pub(crate) fn new(
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
sandbox_state: &SandboxState,
) -> Self {
let mut servers = mcp_servers
.iter()
.filter(|(_, config)| config.enabled)
.map(|(name, config)| {
(
name,
canonicalize_json_value(serde_json::to_value(config).unwrap_or_else(|error| {
panic!("serializing MCP server config for cache key: {error}")
})),
)
})
.collect::<Vec<_>>();
servers.sort_by(|(left_name, _), (right_name, _)| left_name.cmp(right_name));
let payload = canonicalize_json_value(serde_json::json!({
"sandboxState": sandbox_state,
"storeMode": store_mode,
"servers": servers,
}));
Self(serde_json::to_string(&payload).unwrap_or_else(|error| {
panic!("serializing shared MCP backend cache key payload: {error}")
}))
}
}
fn is_poolable_transport(transport: &McpServerTransportConfig) -> bool {
matches!(transport, McpServerTransportConfig::Stdio { .. })
}
fn split_poolable_mcp_servers(
mcp_servers: &HashMap<String, McpServerConfig>,
) -> (
HashMap<String, McpServerConfig>,
HashMap<String, McpServerConfig>,
) {
let mut pooled_servers = HashMap::new();
let mut local_servers = HashMap::new();
for (server_name, config) in mcp_servers {
if is_poolable_transport(&config.transport) {
pooled_servers.insert(server_name.clone(), config.clone());
} else {
local_servers.insert(server_name.clone(), config.clone());
}
}
(pooled_servers, local_servers)
}
fn filter_auth_entries_for_servers(
auth_entries: &HashMap<String, McpAuthStatusEntry>,
mcp_servers: &HashMap<String, McpServerConfig>,
) -> HashMap<String, McpAuthStatusEntry> {
auth_entries
.iter()
.filter(|(server_name, _)| mcp_servers.contains_key(*server_name))
.map(|(server_name, entry)| (server_name.clone(), entry.clone()))
.collect()
}
#[derive(Clone)]
struct SharedMcpBackendLease {
inner: Arc<SharedMcpBackendLeaseInner>,
}
impl SharedMcpBackendLease {
fn new(backend: SharedMcpBackend, cancel_token: CancellationToken) -> Self {
Self {
inner: Arc::new(SharedMcpBackendLeaseInner {
backend: Arc::new(backend),
cancel_token,
}),
}
}
fn backend(&self) -> Arc<SharedMcpBackend> {
Arc::clone(&self.inner.backend)
}
}
struct SharedMcpBackendLeaseInner {
backend: Arc<SharedMcpBackend>,
cancel_token: CancellationToken,
}
impl Drop for SharedMcpBackendLeaseInner {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
#[derive(Default)]
pub(crate) struct SharedMcpBackendPool {
backends: Mutex<HashMap<SharedMcpBackendCacheKey, std::sync::Weak<SharedMcpBackendLeaseInner>>>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum SharedMcpBackendAcquireMode {
ReuseExisting,
ForceCreate,
}
impl SharedMcpBackendPool {
pub(crate) fn new() -> Self {
Self::default()
}
#[allow(clippy::too_many_arguments)]
async fn acquire_or_create(
&self,
key: SharedMcpBackendCacheKey,
acquire_mode: SharedMcpBackendAcquireMode,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
session_handle: SessionMcpHandle,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
) -> SharedMcpBackendLease {
let mut backends = self.backends.lock().await;
match acquire_mode {
SharedMcpBackendAcquireMode::ReuseExisting => {
if let Some(existing) = backends.get(&key).and_then(std::sync::Weak::upgrade) {
return SharedMcpBackendLease { inner: existing };
}
}
SharedMcpBackendAcquireMode::ForceCreate => {}
}
let (backend, cancel_token) = SharedMcpBackend::new(
mcp_servers,
store_mode,
auth_entries,
tx_event,
initial_sandbox_state,
session_handle,
codex_home,
codex_apps_tools_cache_key,
tool_plugin_provenance,
)
.await;
let lease = SharedMcpBackendLease::new(backend, cancel_token);
backends.insert(key, Arc::downgrade(&lease.inner));
lease
}
}
#[allow(clippy::too_many_arguments)]
async fn acquire_shared_backend_lease(
pool: &SharedMcpBackendPool,
acquire_mode: SharedMcpBackendAcquireMode,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
session_handle: SessionMcpHandle,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
) -> Option<SharedMcpBackendLease> {
let (pooled_servers, _) = split_poolable_mcp_servers(mcp_servers);
if !pooled_servers.values().any(|server| server.enabled) {
return None;
}
let key = SharedMcpBackendCacheKey::new(&pooled_servers, store_mode, &initial_sandbox_state);
let pooled_auth_entries = filter_auth_entries_for_servers(&auth_entries, &pooled_servers);
Some(
pool.acquire_or_create(
key,
acquire_mode,
&pooled_servers,
store_mode,
pooled_auth_entries,
tx_event,
initial_sandbox_state,
session_handle,
codex_home,
codex_apps_tools_cache_key,
tool_plugin_provenance,
)
.await,
)
}
#[derive(Clone)]
struct SessionMcpHandle {
elicitation_requests: ElicitationRequestManager,
}
impl SessionMcpHandle {
fn new(approval_policy: AskForApproval) -> Self {
Self {
elicitation_requests: ElicitationRequestManager::new(approval_policy),
}
}
fn elicitation_requests(&self) -> ElicitationRequestManager {
self.elicitation_requests.clone()
}
fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
if let Ok(mut policy) = self.elicitation_requests.approval_policy.lock() {
*policy = approval_policy.value();
}
}
async fn resolve_elicitation(
&self,
server_name: String,
id: RequestId,
response: ElicitationResponse,
) -> Result<()> {
self.elicitation_requests
.resolve(server_name, id, response)
.await
}
}
/// A thin facade that keeps today's call sites stable while separating the
/// sharable MCP backend from session-local response routing state.
pub(crate) struct McpConnectionManager {
backend: Arc<SharedMcpBackend>,
shared_backend: Option<Arc<SharedMcpBackend>>,
session: SessionMcpHandle,
_shared_backend_lease: Option<SharedMcpBackendLease>,
}
impl McpConnectionManager {
fn backend_for_server(&self, server_name: &str) -> Option<Arc<SharedMcpBackend>> {
if self.backend.contains_server(server_name) {
Some(Arc::clone(&self.backend))
} else {
self.shared_backend
.as_ref()
.filter(|backend| backend.contains_server(server_name))
.map(Arc::clone)
}
}
fn from_parts(
backend: Arc<SharedMcpBackend>,
shared_backend: Option<Arc<SharedMcpBackend>>,
session: SessionMcpHandle,
shared_backend_lease: Option<SharedMcpBackendLease>,
) -> Self {
Self {
backend,
shared_backend,
session,
_shared_backend_lease: shared_backend_lease,
}
}
pub(crate) fn new_uninitialized(approval_policy: &Constrained<AskForApproval>) -> Self {
Self::from_parts(
Arc::new(SharedMcpBackend::new_uninitialized()),
/*shared_backend*/ None,
SessionMcpHandle::new(approval_policy.value()),
/*shared_backend_lease*/ None,
)
}
#[cfg(test)]
pub(crate) fn new_mcp_connection_manager_for_tests(
approval_policy: &Constrained<AskForApproval>,
) -> Self {
Self::new_uninitialized(approval_policy)
}
pub(crate) fn has_servers(&self) -> bool {
self.backend.has_servers()
|| self
.shared_backend
.as_ref()
.is_some_and(|backend| backend.has_servers())
}
pub(crate) fn server_origin(&self, server_name: &str) -> Option<&str> {
self.backend
.server_origin(server_name)
.or_else(|| self.shared_backend.as_ref()?.server_origin(server_name))
}
pub fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
self.session.set_approval_policy(approval_policy);
}
#[allow(clippy::new_ret_no_self, clippy::too_many_arguments)]
pub async fn new(
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
approval_policy: &Constrained<AskForApproval>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
) -> (Self, CancellationToken) {
let session = SessionMcpHandle::new(approval_policy.value());
let (backend, cancel_token) = SharedMcpBackend::new(
mcp_servers,
store_mode,
auth_entries,
tx_event,
initial_sandbox_state,
session.clone(),
codex_home,
codex_apps_tools_cache_key,
tool_plugin_provenance,
)
.await;
(
Self::from_parts(
Arc::new(backend),
/*shared_backend*/ None,
session,
/*shared_backend_lease*/ None,
),
cancel_token,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn new_with_pool(
pool: &SharedMcpBackendPool,
acquire_mode: SharedMcpBackendAcquireMode,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
approval_policy: &Constrained<AskForApproval>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
) -> (Self, CancellationToken) {
let session = SessionMcpHandle::new(approval_policy.value());
let (_, local_servers) = split_poolable_mcp_servers(mcp_servers);
let local_auth_entries = filter_auth_entries_for_servers(&auth_entries, &local_servers);
let shared_backend_lease = acquire_shared_backend_lease(
pool,
acquire_mode,
mcp_servers,
store_mode,
auth_entries,
tx_event.clone(),
initial_sandbox_state.clone(),
session.clone(),
codex_home.clone(),
codex_apps_tools_cache_key.clone(),
tool_plugin_provenance.clone(),
)
.await;
let shared_backend = shared_backend_lease
.as_ref()
.map(SharedMcpBackendLease::backend);
let (backend, cancel_token) = if local_servers.values().any(|server| server.enabled) {
let (backend, cancel_token) = SharedMcpBackend::new(
&local_servers,
store_mode,
local_auth_entries,
tx_event,
initial_sandbox_state,
session.clone(),
codex_home,
codex_apps_tools_cache_key,
tool_plugin_provenance,
)
.await;
(Arc::new(backend), cancel_token)
} else {
(
Arc::new(SharedMcpBackend::new_uninitialized()),
CancellationToken::new(),
)
};
(
Self::from_parts(backend, shared_backend, session, shared_backend_lease),
cancel_token,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn rebuild_pooled_backend(
&self,
pool: &SharedMcpBackendPool,
acquire_mode: SharedMcpBackendAcquireMode,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
tx_event: Sender<Event>,
initial_sandbox_state: SandboxState,
codex_home: PathBuf,
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
tool_plugin_provenance: ToolPluginProvenance,
) -> Self {
let shared_backend_lease = acquire_shared_backend_lease(
pool,
acquire_mode,
mcp_servers,
store_mode,
auth_entries,
tx_event,
initial_sandbox_state,
self.session.clone(),
codex_home,
codex_apps_tools_cache_key,
tool_plugin_provenance,
)
.await;
let shared_backend = shared_backend_lease
.as_ref()
.map(SharedMcpBackendLease::backend);
Self::from_parts(
Arc::clone(&self.backend),
shared_backend,
self.session.clone(),
shared_backend_lease,
)
}
pub async fn resolve_elicitation(
&self,
server_name: String,
id: RequestId,
response: ElicitationResponse,
) -> Result<()> {
self.session
.resolve_elicitation(server_name, id, response)
.await
}
pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
match self.backend_for_server(server_name) {
Some(backend) => backend.wait_for_server_ready(server_name, timeout).await,
None => false,
}
}
pub(crate) async fn required_startup_failures(
&self,
required_servers: &[String],
) -> Vec<McpStartupFailure> {
let mut failures = Vec::new();
for server_name in required_servers {
let Some(backend) = self.backend_for_server(server_name) else {
failures.push(McpStartupFailure {
server: server_name.clone(),
error: format!("required MCP server `{server_name}` was not initialized"),
});
continue;
};
failures.extend(
backend
.required_startup_failures(std::slice::from_ref(server_name))
.await,
);
}
failures
}
/// Returns a single map that contains all tools. Each key is the
/// fully-qualified name for the tool.
#[instrument(level = "trace", skip_all)]
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
let mut tools = self.backend.list_all_tools().await;
if let Some(shared_backend) = self.shared_backend.as_ref() {
tools.extend(shared_backend.list_all_tools().await);
}
tools
}
/// Force-refresh codex apps tools by bypassing the in-process cache.
///
/// On success, the refreshed tools replace the cache contents and the
/// latest filtered tool map is returned directly to the caller. On
/// failure, the existing cache remains unchanged.
pub async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
self.backend.hard_refresh_codex_apps_tools_cache().await
}
/// Returns a single map that contains all resources. Each key is the
/// server name and the value is a vector of resources.
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
let mut resources = self.backend.list_all_resources().await;
if let Some(shared_backend) = self.shared_backend.as_ref() {
resources.extend(shared_backend.list_all_resources().await);
}
resources
}
/// Returns a single map that contains all resource templates. Each key is the
/// server name and the value is a vector of resource templates.
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
let mut templates = self.backend.list_all_resource_templates().await;
if let Some(shared_backend) = self.shared_backend.as_ref() {
templates.extend(shared_backend.list_all_resource_templates().await);
}
templates
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
meta: Option<serde_json::Value>,
) -> Result<CallToolResult> {
self.backend_for_server(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.call_tool(server, tool, arguments, meta)
.await
}
/// List resources from the specified server.
pub async fn list_resources(
&self,
server: &str,
params: Option<PaginatedRequestParams>,
) -> Result<ListResourcesResult> {
self.backend_for_server(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.list_resources(server, params)
.await
}
/// List resource templates from the specified server.
pub async fn list_resource_templates(
&self,
server: &str,
params: Option<PaginatedRequestParams>,
) -> Result<ListResourceTemplatesResult> {
self.backend_for_server(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.list_resource_templates(server, params)
.await
}
/// Read a resource from the specified server.
pub async fn read_resource(
&self,
server: &str,
params: ReadResourceRequestParams,
) -> Result<ReadResourceResult> {
self.backend_for_server(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.read_resource(server, params)
.await
}
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
if let Some(parsed) = self.backend.parse_tool_name(tool_name).await {
Some(parsed)
} else if let Some(shared_backend) = self.shared_backend.as_ref() {
shared_backend.parse_tool_name(tool_name).await
} else {
None
}
}
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
self.backend
.notify_sandbox_state_change(sandbox_state)
.await?;
if let Some(shared_backend) = self.shared_backend.as_ref() {
shared_backend
.notify_sandbox_state_change(sandbox_state)
.await?;
}
Ok(())
}
pub async fn notify_local_sandbox_state_change(
&self,
sandbox_state: &SandboxState,
) -> Result<()> {
self.backend
.notify_sandbox_state_change(sandbox_state)
.await
}
}
async fn emit_update(
tx_event: &Sender<Event>,
update: McpStartupUpdateEvent,

View File

@@ -1,7 +1,9 @@
use super::*;
use codex_protocol::protocol::GranularApprovalConfig;
use codex_protocol::protocol::McpAuthStatus;
use pretty_assertions::assert_eq;
use rmcp::model::JsonObject;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tempfile::tempdir;
@@ -60,6 +62,98 @@ fn create_codex_apps_tools_cache_context(
}
}
fn test_codex_apps_tools_cache_key() -> CodexAppsToolsCacheKey {
CodexAppsToolsCacheKey {
account_id: None,
chatgpt_user_id: None,
is_workspace_account: false,
}
}
fn stdio_server_config(command: &str) -> McpServerConfig {
McpServerConfig {
transport: McpServerTransportConfig::Stdio {
command: command.to_string(),
args: Vec::new(),
env: None,
env_vars: Vec::new(),
cwd: None,
},
enabled: true,
required: false,
disabled_reason: None,
startup_timeout_sec: Some(Duration::from_secs(1)),
tool_timeout_sec: Some(Duration::from_secs(1)),
enabled_tools: None,
disabled_tools: None,
scopes: None,
oauth_resource: None,
}
}
fn http_server_config(url: &str) -> McpServerConfig {
McpServerConfig {
transport: McpServerTransportConfig::StreamableHttp {
url: url.to_string(),
bearer_token_env_var: None,
http_headers: None,
env_http_headers: None,
},
enabled: true,
required: false,
disabled_reason: None,
startup_timeout_sec: Some(Duration::from_secs(1)),
tool_timeout_sec: Some(Duration::from_secs(1)),
enabled_tools: None,
disabled_tools: None,
scopes: None,
oauth_resource: None,
}
}
fn test_sandbox_state(cwd: PathBuf) -> SandboxState {
SandboxState {
sandbox_policy: SandboxPolicy::DangerFullAccess,
codex_linux_sandbox_exe: None,
sandbox_cwd: cwd,
use_legacy_landlock: false,
}
}
fn pooled_cache_key_for_tests(
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
sandbox_state: &SandboxState,
) -> SharedMcpBackendCacheKey {
let (pooled_servers, _) = split_poolable_mcp_servers(mcp_servers);
SharedMcpBackendCacheKey::new(&pooled_servers, store_mode, sandbox_state)
}
async fn new_pooled_manager_for_tests(
pool: &SharedMcpBackendPool,
acquire_mode: SharedMcpBackendAcquireMode,
mcp_servers: &HashMap<String, McpServerConfig>,
sandbox_state: SandboxState,
) -> McpConnectionManager {
let approval_policy = Constrained::allow_any(AskForApproval::OnRequest);
let (tx_event, _rx_event) = async_channel::unbounded();
let (manager, _cancel_token) = McpConnectionManager::new_with_pool(
pool,
acquire_mode,
mcp_servers,
OAuthCredentialsStoreMode::Auto,
HashMap::new(),
&approval_policy,
tx_event,
sandbox_state,
PathBuf::from("/tmp"),
test_codex_apps_tools_cache_key(),
ToolPluginProvenance::default(),
)
.await;
manager
}
#[test]
fn elicitation_granular_policy_defaults_to_prompting() {
assert!(!elicitation_is_rejected_by_policy(
@@ -409,15 +503,18 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
.shared();
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
manager.clients.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: Some(startup_tools),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
Arc::get_mut(&mut manager.backend)
.expect("test manager backend should be uniquely owned")
.clients
.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: Some(startup_tools),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
let tools = manager.list_all_tools().await;
let tool = tools
@@ -434,15 +531,18 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
.shared();
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
manager.clients.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: None,
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
Arc::get_mut(&mut manager.backend)
.expect("test manager backend should be uniquely owned")
.clients
.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: None,
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
let timeout_result =
tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await;
@@ -456,15 +556,18 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
.shared();
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
manager.clients.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: Some(Vec::new()),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
Arc::get_mut(&mut manager.backend)
.expect("test manager backend should be uniquely owned")
.clients
.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: Some(Vec::new()),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
let timeout_result =
tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await;
@@ -488,15 +591,18 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
let startup_complete = Arc::new(std::sync::atomic::AtomicBool::new(true));
manager.clients.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: failed_client,
startup_snapshot: Some(startup_tools),
startup_complete,
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
Arc::get_mut(&mut manager.backend)
.expect("test manager backend should be uniquely owned")
.clients
.insert(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: failed_client,
startup_snapshot: Some(startup_tools),
startup_complete,
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
let tools = manager.list_all_tools().await;
let tool = tools
@@ -506,6 +612,246 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
assert_eq!(tool.tool_name, "calendar_create_event");
}
#[tokio::test]
async fn parse_tool_name_searches_shared_backend() {
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
let shared_tool = create_test_tool("shared_stdio", "tool_a");
let pending_client = futures::future::pending::<Result<ManagedClient, StartupOutcomeError>>()
.boxed()
.shared();
let mut shared_backend = SharedMcpBackend::new_uninitialized();
shared_backend.clients.insert(
"shared_stdio".to_string(),
AsyncManagedClient {
client: pending_client,
startup_snapshot: Some(vec![shared_tool]),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
let manager = McpConnectionManager::from_parts(
Arc::new(SharedMcpBackend::new_uninitialized()),
Some(Arc::new(shared_backend)),
SessionMcpHandle::new(approval_policy.value()),
None,
);
assert_eq!(
manager.parse_tool_name("mcp__shared_stdio__tool_a").await,
Some(("shared_stdio".to_string(), "tool_a".to_string()))
);
}
#[test]
fn shared_mcp_backend_cache_key_is_stable_for_equivalent_stdio_configs() {
let mut left_servers = HashMap::new();
left_servers.insert("beta".to_string(), stdio_server_config("missing-beta"));
left_servers.insert("alpha".to_string(), stdio_server_config("missing-alpha"));
let mut right_servers = HashMap::new();
right_servers.insert("alpha".to_string(), stdio_server_config("missing-alpha"));
right_servers.insert("beta".to_string(), stdio_server_config("missing-beta"));
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
assert_eq!(
pooled_cache_key_for_tests(
&left_servers,
OAuthCredentialsStoreMode::Auto,
&sandbox_state,
),
pooled_cache_key_for_tests(
&right_servers,
OAuthCredentialsStoreMode::Auto,
&sandbox_state,
),
);
}
#[test]
fn shared_mcp_backend_cache_key_ignores_http_servers() {
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
let mut left_servers = HashMap::new();
left_servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
left_servers.insert(
"http".to_string(),
http_server_config("http://127.0.0.1:9/left"),
);
let mut right_servers = HashMap::new();
right_servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
right_servers.insert(
"http".to_string(),
http_server_config("http://127.0.0.1:9/right"),
);
assert_eq!(
pooled_cache_key_for_tests(
&left_servers,
OAuthCredentialsStoreMode::Auto,
&sandbox_state,
),
pooled_cache_key_for_tests(
&right_servers,
OAuthCredentialsStoreMode::Auto,
&sandbox_state,
),
);
}
#[test]
fn split_poolable_mcp_servers_keeps_http_servers_local() {
let mut servers = HashMap::new();
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
servers.insert(
"http".to_string(),
http_server_config("http://127.0.0.1:9/http"),
);
let (pooled_servers, local_servers) = split_poolable_mcp_servers(&servers);
assert_eq!(pooled_servers.len(), 1);
assert!(pooled_servers.contains_key("stdio"));
assert_eq!(local_servers.len(), 1);
assert!(local_servers.contains_key("http"));
}
#[test]
fn shared_mcp_backend_cache_key_separates_sandbox_state() {
let mut servers = HashMap::new();
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
assert_ne!(
pooled_cache_key_for_tests(
&servers,
OAuthCredentialsStoreMode::Auto,
&test_sandbox_state(PathBuf::from("/tmp/left")),
),
pooled_cache_key_for_tests(
&servers,
OAuthCredentialsStoreMode::Auto,
&test_sandbox_state(PathBuf::from("/tmp/right")),
),
);
}
#[tokio::test]
async fn shared_mcp_backend_pool_reuses_backend_for_same_stdio_config() {
let pool = SharedMcpBackendPool::new();
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
let mut servers = HashMap::new();
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
let manager_1 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
sandbox_state.clone(),
)
.await;
let manager_2 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
sandbox_state,
)
.await;
assert!(Arc::ptr_eq(
manager_1
.shared_backend
.as_ref()
.expect("stdio backend should be pooled"),
manager_2
.shared_backend
.as_ref()
.expect("stdio backend should be pooled"),
));
}
#[tokio::test]
async fn shared_mcp_backend_pool_separates_backends_for_different_sandbox_states() {
let pool = SharedMcpBackendPool::new();
let mut servers = HashMap::new();
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
let manager_1 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
test_sandbox_state(PathBuf::from("/tmp/left")),
)
.await;
let manager_2 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
test_sandbox_state(PathBuf::from("/tmp/right")),
)
.await;
assert!(!Arc::ptr_eq(
manager_1
.shared_backend
.as_ref()
.expect("stdio backend should be pooled"),
manager_2
.shared_backend
.as_ref()
.expect("stdio backend should be pooled"),
));
}
#[tokio::test]
async fn shared_mcp_backend_pool_force_create_replaces_pool_entry_for_same_key() {
let pool = SharedMcpBackendPool::new();
let mut servers = HashMap::new();
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
let manager_1 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
sandbox_state.clone(),
)
.await;
let manager_2 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ForceCreate,
&servers,
sandbox_state.clone(),
)
.await;
let manager_3 = new_pooled_manager_for_tests(
&pool,
SharedMcpBackendAcquireMode::ReuseExisting,
&servers,
sandbox_state,
)
.await;
let shared_1 = manager_1
.shared_backend
.as_ref()
.expect("stdio backend should be pooled");
let shared_2 = manager_2
.shared_backend
.as_ref()
.expect("stdio backend should be pooled");
let shared_3 = manager_3
.shared_backend
.as_ref()
.expect("stdio backend should be pooled");
assert!(!Arc::ptr_eq(shared_1, shared_2));
assert!(Arc::ptr_eq(shared_2, shared_3));
assert!(!Arc::ptr_eq(shared_1, shared_3));
}
#[test]
fn elicitation_capability_enabled_only_for_codex_apps() {
let codex_apps_capability = elicitation_capability_for_server(CODEX_APPS_MCP_SERVER_NAME);

View File

@@ -11,6 +11,7 @@ use crate::exec_policy::ExecPolicyManager;
use crate::file_watcher::FileWatcher;
use crate::mcp::McpManager;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_connection_manager::SharedMcpBackendPool;
use crate::models_manager::manager::ModelsManager;
use crate::plugins::PluginsManager;
use crate::skills::SkillsManager;
@@ -33,6 +34,7 @@ use tokio_util::sync::CancellationToken;
pub(crate) struct SessionServices {
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
pub(crate) shared_mcp_backend_pool: Option<Arc<SharedMcpBackendPool>>,
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
#[cfg_attr(not(unix), allow(dead_code))]
pub(crate) shell_zsh_path: Option<PathBuf>,

View File

@@ -14,6 +14,7 @@ use crate::error::Result as CodexResult;
use crate::file_watcher::FileWatcher;
use crate::file_watcher::FileWatcherEvent;
use crate::mcp::McpManager;
use crate::mcp_connection_manager::SharedMcpBackendPool;
use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig;
use crate::models_manager::manager::ModelsManager;
use crate::plugins::PluginsManager;
@@ -155,6 +156,7 @@ pub(crate) struct ThreadManagerState {
skills_manager: Arc<SkillsManager>,
plugins_manager: Arc<PluginsManager>,
mcp_manager: Arc<McpManager>,
shared_mcp_backend_pool: Arc<SharedMcpBackendPool>,
file_watcher: Arc<FileWatcher>,
session_source: SessionSource,
// Captures submitted ops for testing purpose when test mode is enabled.
@@ -202,6 +204,7 @@ impl ThreadManager {
skills_manager,
plugins_manager,
mcp_manager,
shared_mcp_backend_pool: Arc::new(SharedMcpBackendPool::new()),
file_watcher,
auth_manager,
session_source,
@@ -266,6 +269,7 @@ impl ThreadManager {
skills_manager,
plugins_manager,
mcp_manager,
shared_mcp_backend_pool: Arc::new(SharedMcpBackendPool::new()),
file_watcher,
auth_manager,
session_source: SessionSource::Exec,
@@ -582,6 +586,10 @@ impl ThreadManager {
}
impl ThreadManagerState {
pub(crate) fn shared_mcp_backend_pool(&self) -> Arc<SharedMcpBackendPool> {
Arc::clone(&self.shared_mcp_backend_pool)
}
pub(crate) async fn list_thread_ids(&self) -> Vec<ThreadId> {
self.threads.read().await.keys().copied().collect()
}

View File

@@ -73,8 +73,6 @@ ignore = [
{ id = "RUSTSEC-2024-0388", reason = "derivative is unmaintained; pulled in via starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
{ id = "RUSTSEC-2025-0057", reason = "fxhash is unmaintained; pulled in via starlark_map/starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
{ id = "RUSTSEC-2024-0436", reason = "paste is unmaintained; pulled in via ratatui/rmcp/starlark used by tui/execpolicy; no fixed release yet" },
# TODO(joshka, nornagon): remove this exception when once we update the ratatui fork to a version that uses lru 0.13+.
{ id = "RUSTSEC-2026-0002", reason = "lru 0.12.5 is pulled in via ratatui fork; cannot upgrade until the fork is updated" },
# TODO(fcoury): remove this exception when syntect drops yaml-rust and bincode, or updates to versions that have fixed the vulnerabilities.
{ id = "RUSTSEC-2024-0320", reason = "yaml-rust is unmaintained; pulled in via syntect v5.3.0 used by codex-tui for syntax highlighting; no fixed release yet" },
{ id = "RUSTSEC-2025-0141", reason = "bincode is unmaintained; pulled in via syntect v5.3.0 used by codex-tui for syntax highlighting; no fixed release yet" },

View File

@@ -1,5 +1,7 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::Write;
use std::sync::Arc;
use rmcp::ErrorData as McpError;
@@ -36,6 +38,7 @@ struct TestToolServer {
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const SMALL_PNG_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==";
const STARTUP_COUNT_FILE_ENV_VAR: &str = "MCP_STARTUP_COUNT_FILE";
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
(tokio::io::stdin(), tokio::io::stdout())
@@ -454,8 +457,18 @@ fn parse_data_url(url: &str) -> Option<(String, String)> {
Some((mime.to_string(), data.to_string()))
}
fn record_startup_if_requested() -> std::io::Result<()> {
let Some(path) = std::env::var_os(STARTUP_COUNT_FILE_ENV_VAR) else {
return Ok(());
};
let mut file = OpenOptions::new().create(true).append(true).open(path)?;
writeln!(file, "started")
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
record_startup_if_requested()?;
eprintln!("starting rmcp test server");
// Run the server with STDIO transport. If the client disconnects we simply
// bubble up the error so the process exits.