mirror of
https://github.com/openai/codex.git
synced 2026-04-18 19:54:47 +00:00
Compare commits
4 Commits
pr18028
...
codex/jger
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4da08c2b31 | ||
|
|
05c88b6639 | ||
|
|
7591ac9184 | ||
|
|
945b674884 |
6
MODULE.bazel.lock
generated
6
MODULE.bazel.lock
generated
File diff suppressed because one or more lines are too long
12
codex-rs/Cargo.lock
generated
12
codex-rs/Cargo.lock
generated
@@ -800,9 +800,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.15.4"
|
||||
version = "1.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256"
|
||||
checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc"
|
||||
dependencies = [
|
||||
"aws-lc-sys",
|
||||
"untrusted 0.7.1",
|
||||
@@ -811,9 +811,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.37.0"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a"
|
||||
checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cmake",
|
||||
@@ -8239,9 +8239,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.9"
|
||||
version = "0.103.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
|
||||
checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"ring",
|
||||
|
||||
@@ -3,5 +3,8 @@ load("//:defs.bzl", "codex_rust_crate")
|
||||
codex_rust_crate(
|
||||
name = "app-server",
|
||||
crate_name = "codex_app_server",
|
||||
extra_binaries = [
|
||||
"//codex-rs/rmcp-client:test_stdio_server",
|
||||
],
|
||||
test_tags = ["no-sandbox"],
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ mod thread_archive;
|
||||
mod thread_fork;
|
||||
mod thread_list;
|
||||
mod thread_loaded_list;
|
||||
mod thread_mcp_pool;
|
||||
mod thread_metadata_update;
|
||||
mod thread_name_websocket;
|
||||
mod thread_read;
|
||||
|
||||
450
codex-rs/app-server/tests/suite/v2/thread_mcp_pool.rs
Normal file
450
codex-rs/app-server/tests/suite/v2/thread_mcp_pool.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::bail;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ThreadArchiveParams;
|
||||
use codex_app_server_protocol::ThreadArchiveResponse;
|
||||
use codex_app_server_protocol::ThreadArchivedNotification;
|
||||
use codex_app_server_protocol::ThreadClosedNotification;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadResumeResponse;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
use codex_app_server_protocol::ThreadStatus;
|
||||
use codex_app_server_protocol::ThreadUnarchiveParams;
|
||||
use codex_app_server_protocol::ThreadUnarchiveResponse;
|
||||
use codex_app_server_protocol::ThreadUnarchivedNotification;
|
||||
use codex_app_server_protocol::ThreadUnsubscribeParams;
|
||||
use codex_app_server_protocol::ThreadUnsubscribeResponse;
|
||||
use codex_app_server_protocol::ThreadUnsubscribeStatus;
|
||||
use codex_app_server_protocol::TurnCompletedNotification;
|
||||
use codex_app_server_protocol::TurnStartParams;
|
||||
use codex_app_server_protocol::TurnStartResponse;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
use codex_app_server_protocol::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::stdio_server_bin;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const STARTUP_COUNT_POLL_INTERVAL: Duration = Duration::from_millis(25);
|
||||
const STARTUP_COUNT_WAIT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const STARTUP_COUNT_STABILITY_WINDOW: Duration = Duration::from_millis(250);
|
||||
const STARTUP_COUNT_FILE_ENV_VAR: &str = "MCP_STARTUP_COUNT_FILE";
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_pool_survives_unsubscribe_of_one_loaded_thread() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let responses_server =
|
||||
create_mock_responses_server_sequence_unchecked(rmcp_echo_turn_bodies("after-unsubscribe"))
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let startup_count_file = codex_home.path().join("rmcp-startups.log");
|
||||
create_config_toml(
|
||||
codex_home.path(),
|
||||
&responses_server.uri(),
|
||||
&stdio_server_bin()?,
|
||||
)?;
|
||||
|
||||
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
|
||||
let mut mcp = McpProcess::new_with_env(
|
||||
codex_home.path(),
|
||||
&[(
|
||||
STARTUP_COUNT_FILE_ENV_VAR,
|
||||
Some(startup_count_file_value.as_str()),
|
||||
)],
|
||||
)
|
||||
.await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let thread_a = start_thread(&mut mcp).await?;
|
||||
wait_for_startup_count(&startup_count_file, 1).await?;
|
||||
|
||||
let thread_b = start_thread(&mut mcp).await?;
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
unsubscribe_thread(&mut mcp, &thread_a).await?;
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_b, "after-unsubscribe").await?;
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_pool_survives_archive_of_one_loaded_thread() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let responses_server = create_mock_responses_server_sequence_unchecked(
|
||||
[
|
||||
rmcp_echo_turn_bodies("materialize-archived-thread"),
|
||||
rmcp_echo_turn_bodies("after-archive"),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let startup_count_file = codex_home.path().join("rmcp-startups.log");
|
||||
create_config_toml(
|
||||
codex_home.path(),
|
||||
&responses_server.uri(),
|
||||
&stdio_server_bin()?,
|
||||
)?;
|
||||
|
||||
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
|
||||
let mut mcp = McpProcess::new_with_env(
|
||||
codex_home.path(),
|
||||
&[(
|
||||
STARTUP_COUNT_FILE_ENV_VAR,
|
||||
Some(startup_count_file_value.as_str()),
|
||||
)],
|
||||
)
|
||||
.await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let thread_a = start_thread(&mut mcp).await?;
|
||||
wait_for_startup_count(&startup_count_file, 1).await?;
|
||||
|
||||
let thread_b = start_thread(&mut mcp).await?;
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_a, "materialize-archived-thread").await?;
|
||||
archive_thread(&mut mcp, &thread_a).await?;
|
||||
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_b, "after-archive").await?;
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_pool_recreates_backend_after_last_archive_and_resume() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let responses_server = create_mock_responses_server_sequence_unchecked(
|
||||
[
|
||||
rmcp_echo_turn_bodies("materialize-a"),
|
||||
rmcp_echo_turn_bodies("materialize-b"),
|
||||
rmcp_echo_turn_bodies("after-resume"),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let startup_count_file = codex_home.path().join("rmcp-startups.log");
|
||||
create_config_toml(
|
||||
codex_home.path(),
|
||||
&responses_server.uri(),
|
||||
&stdio_server_bin()?,
|
||||
)?;
|
||||
|
||||
let startup_count_file_value = startup_count_file.to_string_lossy().to_string();
|
||||
let mut mcp = McpProcess::new_with_env(
|
||||
codex_home.path(),
|
||||
&[(
|
||||
STARTUP_COUNT_FILE_ENV_VAR,
|
||||
Some(startup_count_file_value.as_str()),
|
||||
)],
|
||||
)
|
||||
.await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let thread_a = start_thread(&mut mcp).await?;
|
||||
wait_for_startup_count(&startup_count_file, 1).await?;
|
||||
|
||||
let thread_b = start_thread(&mut mcp).await?;
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_a, "materialize-a").await?;
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_b, "materialize-b").await?;
|
||||
|
||||
archive_thread(&mut mcp, &thread_a).await?;
|
||||
archive_thread(&mut mcp, &thread_b).await?;
|
||||
|
||||
let unarchive = unarchive_thread(&mut mcp, &thread_a).await?;
|
||||
assert_eq!(unarchive.thread.status, ThreadStatus::NotLoaded);
|
||||
assert_startup_count_stays(&startup_count_file, 1, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
let resume = resume_thread(&mut mcp, &thread_a).await?;
|
||||
assert_eq!(resume.thread.status, ThreadStatus::Idle);
|
||||
wait_for_startup_count(&startup_count_file, 2).await?;
|
||||
|
||||
run_rmcp_echo_turn(&mut mcp, &thread_a, "after-resume").await?;
|
||||
assert_startup_count_stays(&startup_count_file, 2, STARTUP_COUNT_STABILITY_WINDOW).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str, rmcp_server_bin: &str) -> Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
let server_uri = serde_json::to_string(&format!("{server_uri}/v1"))?;
|
||||
let rmcp_server_bin = serde_json::to_string(rmcp_server_bin)?;
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = {server_uri}
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
|
||||
[mcp_servers.rmcp]
|
||||
command = {rmcp_server_bin}
|
||||
env_vars = ["{STARTUP_COUNT_FILE_ENV_VAR}"]
|
||||
startup_timeout_sec = 10.0
|
||||
"#
|
||||
),
|
||||
)
|
||||
.context("write config.toml")
|
||||
}
|
||||
|
||||
fn rmcp_echo_turn_bodies(label: &str) -> Vec<String> {
|
||||
let response_id = format!("resp-{label}");
|
||||
let completion_id = format!("resp-{label}-done");
|
||||
let call_id = format!("call-rmcp-echo-{label}");
|
||||
let assistant_message_id = format!("msg-{label}");
|
||||
let message = format!("ping-{label}");
|
||||
let final_text = format!("rmcp echo completed for {label}");
|
||||
let arguments = json!({ "message": message }).to_string();
|
||||
|
||||
vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created(&response_id),
|
||||
responses::ev_function_call(&call_id, "mcp__rmcp__echo", &arguments),
|
||||
responses::ev_completed(&response_id),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created(&completion_id),
|
||||
responses::ev_assistant_message(&assistant_message_id, &final_text),
|
||||
responses::ev_completed(&completion_id),
|
||||
]),
|
||||
]
|
||||
}
|
||||
|
||||
async fn start_thread(mcp: &mut McpProcess) -> Result<String> {
|
||||
let request_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response(response)?;
|
||||
Ok(thread.id)
|
||||
}
|
||||
|
||||
async fn run_rmcp_echo_turn(mcp: &mut McpProcess, thread_id: &str, label: &str) -> Result<()> {
|
||||
let request_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
input: vec![UserInput::Text {
|
||||
text: format!("call the rmcp echo tool for {label}"),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let TurnStartResponse { turn } = to_response(response)?;
|
||||
|
||||
let deadline = Instant::now() + DEFAULT_READ_TIMEOUT;
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
let message = timeout(remaining, mcp.read_next_message()).await??;
|
||||
let codex_app_server_protocol::JSONRPCMessage::Notification(notification) = message else {
|
||||
continue;
|
||||
};
|
||||
if notification.method != "turn/completed" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let completed: TurnCompletedNotification = serde_json::from_value(
|
||||
notification
|
||||
.params
|
||||
.clone()
|
||||
.context("turn/completed params")?,
|
||||
)?;
|
||||
if completed.thread_id != thread_id || completed.turn.id != turn.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
assert_eq!(completed.turn.status, TurnStatus::Completed);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
async fn unsubscribe_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<()> {
|
||||
let request_id = mcp
|
||||
.send_thread_unsubscribe_request(ThreadUnsubscribeParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let unsubscribe: ThreadUnsubscribeResponse = to_response(response)?;
|
||||
assert_eq!(unsubscribe.status, ThreadUnsubscribeStatus::Unsubscribed);
|
||||
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("thread/closed"),
|
||||
)
|
||||
.await??;
|
||||
let ServerNotification::ThreadClosed(ThreadClosedNotification {
|
||||
thread_id: closed_thread_id,
|
||||
}) = notification.try_into()?
|
||||
else {
|
||||
bail!("expected thread/closed notification");
|
||||
};
|
||||
assert_eq!(closed_thread_id, thread_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn archive_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<()> {
|
||||
let request_id = mcp
|
||||
.send_thread_archive_request(ThreadArchiveParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let _: ThreadArchiveResponse = to_response(response)?;
|
||||
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("thread/archived"),
|
||||
)
|
||||
.await??;
|
||||
let archived: ThreadArchivedNotification =
|
||||
serde_json::from_value(notification.params.context("thread/archived params")?)?;
|
||||
assert_eq!(archived.thread_id, thread_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unarchive_thread(
|
||||
mcp: &mut McpProcess,
|
||||
thread_id: &str,
|
||||
) -> Result<ThreadUnarchiveResponse> {
|
||||
let request_id = mcp
|
||||
.send_thread_unarchive_request(ThreadUnarchiveParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let unarchive: ThreadUnarchiveResponse = to_response(response)?;
|
||||
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("thread/unarchived"),
|
||||
)
|
||||
.await??;
|
||||
let unarchived: ThreadUnarchivedNotification =
|
||||
serde_json::from_value(notification.params.context("thread/unarchived params")?)?;
|
||||
assert_eq!(unarchived.thread_id, thread_id);
|
||||
|
||||
Ok(unarchive)
|
||||
}
|
||||
|
||||
async fn resume_thread(mcp: &mut McpProcess, thread_id: &str) -> Result<ThreadResumeResponse> {
|
||||
let request_id = mcp
|
||||
.send_thread_resume_request(ThreadResumeParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
to_response(response)
|
||||
}
|
||||
|
||||
async fn wait_for_startup_count(path: &Path, expected: usize) -> Result<()> {
|
||||
let deadline = Instant::now() + STARTUP_COUNT_WAIT_TIMEOUT;
|
||||
loop {
|
||||
let actual = read_startup_count(path)?;
|
||||
if actual == expected {
|
||||
return Ok(());
|
||||
}
|
||||
if actual > expected {
|
||||
bail!("startup count exceeded expectation: expected {expected}, got {actual}");
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
bail!("timed out waiting for startup count {expected}; last observed {actual}");
|
||||
}
|
||||
sleep(STARTUP_COUNT_POLL_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn assert_startup_count_stays(
|
||||
path: &Path,
|
||||
expected: usize,
|
||||
duration: Duration,
|
||||
) -> Result<()> {
|
||||
let deadline = Instant::now() + duration;
|
||||
loop {
|
||||
let actual = read_startup_count(path)?;
|
||||
assert_eq!(actual, expected);
|
||||
if Instant::now() >= deadline {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(STARTUP_COUNT_POLL_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn read_startup_count(path: &Path) -> Result<usize> {
|
||||
let contents = match std::fs::read_to_string(path) {
|
||||
Ok(contents) => contents,
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => return Ok(0),
|
||||
Err(error) => return Err(error).context("read startup count file"),
|
||||
};
|
||||
Ok(contents.lines().count())
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::find_archived_thread_path_by_id_str;
|
||||
use crate::find_thread_path_by_id_str;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendPool;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::session_prefix::format_subagent_context_line;
|
||||
use crate::session_prefix::format_subagent_notification_message;
|
||||
@@ -90,6 +91,12 @@ impl AgentControl {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn shared_mcp_backend_pool(&self) -> Option<Arc<SharedMcpBackendPool>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
.map(|state| state.shared_mcp_backend_pool())
|
||||
}
|
||||
|
||||
/// Spawn a new agent thread and submit the initial prompt.
|
||||
pub(crate) async fn spawn_agent(
|
||||
&self,
|
||||
|
||||
@@ -220,6 +220,7 @@ use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp::maybe_prompt_and_install_mcp_dependencies;
|
||||
use crate::mcp::with_codex_apps_mcp;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendAcquireMode;
|
||||
use crate::mcp_connection_manager::codex_apps_tools_cache_key;
|
||||
use crate::mcp_connection_manager::filter_non_codex_apps_mcp_tools_only;
|
||||
use crate::memories;
|
||||
@@ -1249,6 +1250,23 @@ impl Session {
|
||||
per_turn_config
|
||||
}
|
||||
|
||||
fn sandbox_state_for_session_configuration(
|
||||
session_configuration: &SessionConfiguration,
|
||||
) -> SandboxState {
|
||||
SandboxState {
|
||||
sandbox_policy: session_configuration.sandbox_policy.get().clone(),
|
||||
codex_linux_sandbox_exe: session_configuration
|
||||
.original_config_do_not_use
|
||||
.codex_linux_sandbox_exe
|
||||
.clone(),
|
||||
sandbox_cwd: session_configuration.cwd.clone(),
|
||||
use_legacy_landlock: session_configuration
|
||||
.original_config_do_not_use
|
||||
.features
|
||||
.use_legacy_landlock(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn codex_home(&self) -> PathBuf {
|
||||
let state = self.state.lock().await;
|
||||
state.session_configuration.codex_home().clone()
|
||||
@@ -1800,6 +1818,7 @@ impl Session {
|
||||
&config.permissions.approval_policy,
|
||||
))),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
shared_mcp_backend_pool: agent_control.shared_mcp_backend_pool(),
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
@@ -1924,17 +1943,40 @@ impl Session {
|
||||
cancel_guard.cancel();
|
||||
*cancel_guard = CancellationToken::new();
|
||||
}
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_statuses.clone(),
|
||||
&session_configuration.approval_policy,
|
||||
tx_event.clone(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
let (mcp_connection_manager, cancel_token) = async {
|
||||
match sess.services.shared_mcp_backend_pool.as_ref() {
|
||||
Some(shared_mcp_backend_pool) => {
|
||||
McpConnectionManager::new_with_pool(
|
||||
shared_mcp_backend_pool.as_ref(),
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_statuses.clone(),
|
||||
&session_configuration.approval_policy,
|
||||
tx_event.clone(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_statuses.clone(),
|
||||
&session_configuration.approval_policy,
|
||||
tx_event.clone(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(info_span!(
|
||||
"session_init.mcp_manager_init",
|
||||
otel.name = "session_init.mcp_manager_init",
|
||||
@@ -2308,23 +2350,29 @@ impl Session {
|
||||
) -> ConstraintResult<Arc<TurnContext>> {
|
||||
let (
|
||||
session_configuration,
|
||||
sandbox_policy_changed,
|
||||
sandbox_state,
|
||||
sandbox_state_changed,
|
||||
previous_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
) = {
|
||||
let mut state = self.state.lock().await;
|
||||
match state.session_configuration.clone().apply(&updates) {
|
||||
let previous_session_configuration = state.session_configuration.clone();
|
||||
match previous_session_configuration.apply(&updates) {
|
||||
Ok(next) => {
|
||||
let previous_cwd = state.session_configuration.cwd.clone();
|
||||
let sandbox_policy_changed =
|
||||
state.session_configuration.sandbox_policy != next.sandbox_policy;
|
||||
let previous_cwd = previous_session_configuration.cwd.clone();
|
||||
let previous_sandbox_state = Self::sandbox_state_for_session_configuration(
|
||||
&previous_session_configuration,
|
||||
);
|
||||
let sandbox_state = Self::sandbox_state_for_session_configuration(&next);
|
||||
let sandbox_state_changed = previous_sandbox_state != sandbox_state;
|
||||
let codex_home = next.codex_home.clone();
|
||||
let session_source = next.session_source.clone();
|
||||
state.session_configuration = next.clone();
|
||||
(
|
||||
next,
|
||||
sandbox_policy_changed,
|
||||
sandbox_state,
|
||||
sandbox_state_changed,
|
||||
previous_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
@@ -2357,7 +2405,8 @@ impl Session {
|
||||
sub_id,
|
||||
session_configuration,
|
||||
updates.final_output_json_schema,
|
||||
sandbox_policy_changed,
|
||||
sandbox_state,
|
||||
sandbox_state_changed,
|
||||
)
|
||||
.await)
|
||||
}
|
||||
@@ -2367,23 +2416,55 @@ impl Session {
|
||||
sub_id: String,
|
||||
session_configuration: SessionConfiguration,
|
||||
final_output_json_schema: Option<Option<Value>>,
|
||||
sandbox_policy_changed: bool,
|
||||
sandbox_state: SandboxState,
|
||||
sandbox_state_changed: bool,
|
||||
) -> Arc<TurnContext> {
|
||||
let per_turn_config = Self::build_per_turn_config(&session_configuration);
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_approval_policy(&session_configuration.approval_policy);
|
||||
|
||||
if sandbox_policy_changed {
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: per_turn_config.permissions.sandbox_policy.get().clone(),
|
||||
codex_linux_sandbox_exe: per_turn_config.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: per_turn_config.cwd.clone(),
|
||||
use_legacy_landlock: per_turn_config.features.use_legacy_landlock(),
|
||||
};
|
||||
if let Err(e) = self
|
||||
if sandbox_state_changed {
|
||||
if let Some(shared_mcp_backend_pool) = self.services.shared_mcp_backend_pool.as_ref() {
|
||||
let auth = self.services.auth_manager.auth().await;
|
||||
let config = session_configuration.original_config_do_not_use.clone();
|
||||
let mcp_servers = self
|
||||
.services
|
||||
.mcp_manager
|
||||
.effective_servers(config.as_ref(), auth.as_ref());
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
)
|
||||
.await;
|
||||
let tool_plugin_provenance = self
|
||||
.services
|
||||
.mcp_manager
|
||||
.tool_plugin_provenance(config.as_ref());
|
||||
{
|
||||
let manager = self.services.mcp_connection_manager.read().await;
|
||||
if let Err(e) = manager
|
||||
.notify_local_sandbox_state_change(&sandbox_state)
|
||||
.await
|
||||
{
|
||||
warn!("failed to notify local MCP sandbox state change: {e:#}");
|
||||
}
|
||||
let rebuilt_manager = manager
|
||||
.rebuild_pooled_backend(
|
||||
shared_mcp_backend_pool.as_ref(),
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_statuses,
|
||||
self.get_tx_event(),
|
||||
sandbox_state.clone(),
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
drop(manager);
|
||||
let mut manager = self.services.mcp_connection_manager.write().await;
|
||||
*manager = rebuilt_manager;
|
||||
}
|
||||
} else if let Err(e) = self
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -2394,6 +2475,11 @@ impl Session {
|
||||
warn!("Failed to notify sandbox state change to MCP servers: {e:#}");
|
||||
}
|
||||
}
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_approval_policy(&session_configuration.approval_policy);
|
||||
|
||||
let model_info = self
|
||||
.services
|
||||
@@ -2536,9 +2622,10 @@ impl Session {
|
||||
};
|
||||
self.new_turn_from_configuration(
|
||||
sub_id,
|
||||
session_configuration,
|
||||
session_configuration.clone(),
|
||||
/*final_output_json_schema*/ None,
|
||||
/*sandbox_policy_changed*/ false,
|
||||
Self::sandbox_state_for_session_configuration(&session_configuration),
|
||||
/*sandbox_state_changed*/ false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -4059,18 +4146,39 @@ impl Session {
|
||||
guard.cancel();
|
||||
*guard = CancellationToken::new();
|
||||
}
|
||||
let (refreshed_manager, cancel_token) = McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
store_mode,
|
||||
auth_statuses,
|
||||
&turn_context.config.permissions.approval_policy,
|
||||
self.get_tx_event(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
let (refreshed_manager, cancel_token) = match self.services.shared_mcp_backend_pool.as_ref()
|
||||
{
|
||||
Some(shared_mcp_backend_pool) => {
|
||||
McpConnectionManager::new_with_pool(
|
||||
shared_mcp_backend_pool.as_ref(),
|
||||
SharedMcpBackendAcquireMode::ForceCreate,
|
||||
&mcp_servers,
|
||||
store_mode,
|
||||
auth_statuses,
|
||||
&turn_context.config.permissions.approval_policy,
|
||||
self.get_tx_event(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
store_mode,
|
||||
auth_statuses,
|
||||
&turn_context.config.permissions.approval_policy,
|
||||
self.get_tx_event(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await
|
||||
}
|
||||
};
|
||||
{
|
||||
let mut guard = self.services.mcp_startup_cancellation_token.lock().await;
|
||||
if guard.is_cancelled() {
|
||||
|
||||
@@ -10,6 +10,9 @@ use crate::config_loader::Sourced;
|
||||
use crate::exec::ExecCapturePolicy;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp::ToolPluginProvenance;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendAcquireMode;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendPool;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::models_manager::model_info;
|
||||
use crate::shell::default_user_shell;
|
||||
@@ -2466,6 +2469,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
),
|
||||
)),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
shared_mcp_backend_pool: None,
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
@@ -3172,6 +3176,17 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
) {
|
||||
make_session_and_context_with_dynamic_tools_and_rx_and_pool(dynamic_tools, false).await
|
||||
}
|
||||
|
||||
async fn make_session_and_context_with_dynamic_tools_and_rx_and_pool(
|
||||
dynamic_tools: Vec<DynamicToolSpec>,
|
||||
pooled: bool,
|
||||
) -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
) {
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
@@ -3256,15 +3271,44 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
.await
|
||||
.expect("create environment"),
|
||||
);
|
||||
let shared_mcp_backend_pool = pooled.then(|| Arc::new(SharedMcpBackendPool::new()));
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: session_configuration.sandbox_policy.get().clone(),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: session_configuration.cwd.clone(),
|
||||
use_legacy_landlock: config.features.use_legacy_landlock(),
|
||||
};
|
||||
let (mcp_connection_manager, mcp_startup_cancellation_token) =
|
||||
match shared_mcp_backend_pool.as_ref() {
|
||||
Some(shared_mcp_backend_pool) => {
|
||||
McpConnectionManager::new_with_pool(
|
||||
shared_mcp_backend_pool.as_ref(),
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&HashMap::new(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
HashMap::new(),
|
||||
&config.permissions.approval_policy,
|
||||
tx_event.clone(),
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(Some(&auth_manager.auth().await.expect("auth"))),
|
||||
ToolPluginProvenance::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => (
|
||||
McpConnectionManager::new_mcp_connection_manager_for_tests(
|
||||
&config.permissions.approval_policy,
|
||||
),
|
||||
CancellationToken::new(),
|
||||
),
|
||||
};
|
||||
|
||||
let file_watcher = Arc::new(FileWatcher::noop());
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(
|
||||
McpConnectionManager::new_mcp_connection_manager_for_tests(
|
||||
&config.permissions.approval_policy,
|
||||
),
|
||||
)),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
mcp_connection_manager: Arc::new(RwLock::new(mcp_connection_manager)),
|
||||
mcp_startup_cancellation_token: Mutex::new(mcp_startup_cancellation_token),
|
||||
shared_mcp_backend_pool,
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
@@ -3365,6 +3409,14 @@ pub(crate) async fn make_session_and_context_with_rx() -> (
|
||||
make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await
|
||||
}
|
||||
|
||||
async fn make_pooled_session_and_context_with_rx() -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
) {
|
||||
make_session_and_context_with_dynamic_tools_and_rx_and_pool(Vec::new(), true).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_mcp_servers_is_deferred_until_next_turn() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
@@ -3407,6 +3459,29 @@ async fn refresh_mcp_servers_is_deferred_until_next_turn() {
|
||||
assert!(!new_token.is_cancelled());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pooled_new_turn_keeps_local_mcp_startup_token_when_sandbox_state_changes() {
|
||||
let (session, _turn_context, _rx) = make_pooled_session_and_context_with_rx().await;
|
||||
let old_token = session.mcp_startup_cancellation_token().await;
|
||||
assert!(!old_token.is_cancelled());
|
||||
|
||||
let new_cwd = session.get_config().await.cwd.join("other-worktree");
|
||||
session
|
||||
.new_turn_with_sub_id(
|
||||
"sandbox-change".to_string(),
|
||||
SessionSettingsUpdate {
|
||||
cwd: Some(new_cwd),
|
||||
..SessionSettingsUpdate::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("new turn");
|
||||
|
||||
assert!(!old_token.is_cancelled());
|
||||
let new_token = session.mcp_startup_cancellation_token().await;
|
||||
assert!(!new_token.is_cancelled());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_model_warning_appends_user_message() {
|
||||
let (mut session, turn_context) = make_session_and_context().await;
|
||||
|
||||
@@ -248,6 +248,14 @@ enum CachedCodexAppsToolsLoad {
|
||||
|
||||
type ResponderMap = HashMap<(String, RequestId), oneshot::Sender<ElicitationResponse>>;
|
||||
|
||||
fn decline_elicitation_response() -> ElicitationResponse {
|
||||
ElicitationResponse {
|
||||
action: ElicitationAction::Decline,
|
||||
content: None,
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn elicitation_is_rejected_by_policy(approval_policy: AskForApproval) -> bool {
|
||||
match approval_policy {
|
||||
AskForApproval::Never => true,
|
||||
@@ -300,11 +308,7 @@ impl ElicitationRequestManager {
|
||||
.lock()
|
||||
.is_ok_and(|policy| elicitation_is_rejected_by_policy(*policy))
|
||||
{
|
||||
return Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Decline,
|
||||
content: None,
|
||||
meta: None,
|
||||
});
|
||||
return Ok(decline_elicitation_response());
|
||||
}
|
||||
|
||||
let request = match elicitation {
|
||||
@@ -584,7 +588,7 @@ pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
/// When used, the `params` field of the notification is [`SandboxState`].
|
||||
pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SandboxState {
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
@@ -595,50 +599,27 @@ pub struct SandboxState {
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`RmcpClient`] instances.
|
||||
pub(crate) struct McpConnectionManager {
|
||||
struct SharedMcpBackend {
|
||||
clients: HashMap<String, AsyncManagedClient>,
|
||||
server_origins: HashMap<String, String>,
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
pub(crate) fn new_uninitialized(approval_policy: &Constrained<AskForApproval>) -> Self {
|
||||
impl SharedMcpBackend {
|
||||
fn new_uninitialized() -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
server_origins: HashMap::new(),
|
||||
elicitation_requests: ElicitationRequestManager::new(approval_policy.value()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new_mcp_connection_manager_for_tests(
|
||||
approval_policy: &Constrained<AskForApproval>,
|
||||
) -> Self {
|
||||
Self::new_uninitialized(approval_policy)
|
||||
}
|
||||
|
||||
pub(crate) fn has_servers(&self) -> bool {
|
||||
!self.clients.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn server_origin(&self, server_name: &str) -> Option<&str> {
|
||||
self.server_origins.get(server_name).map(String::as_str)
|
||||
}
|
||||
|
||||
pub fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
|
||||
if let Ok(mut policy) = self.elicitation_requests.approval_policy.lock() {
|
||||
*policy = approval_policy.value();
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::new_ret_no_self, clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
async fn new(
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
approval_policy: &Constrained<AskForApproval>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
session_handle: SessionMcpHandle,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
@@ -647,7 +628,7 @@ impl McpConnectionManager {
|
||||
let mut clients = HashMap::new();
|
||||
let mut server_origins = HashMap::new();
|
||||
let mut join_set = JoinSet::new();
|
||||
let elicitation_requests = ElicitationRequestManager::new(approval_policy.value());
|
||||
let elicitation_requests = session_handle.elicitation_requests();
|
||||
let tool_plugin_provenance = Arc::new(tool_plugin_provenance);
|
||||
let mcp_servers = mcp_servers.clone();
|
||||
for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) {
|
||||
@@ -725,10 +706,9 @@ impl McpConnectionManager {
|
||||
(server_name, outcome)
|
||||
});
|
||||
}
|
||||
let manager = Self {
|
||||
let backend = Self {
|
||||
clients,
|
||||
server_origins,
|
||||
elicitation_requests: elicitation_requests.clone(),
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
let outcomes = join_set.join_all().await;
|
||||
@@ -752,7 +732,19 @@ impl McpConnectionManager {
|
||||
})
|
||||
.await;
|
||||
});
|
||||
(manager, cancel_token)
|
||||
(backend, cancel_token)
|
||||
}
|
||||
|
||||
fn has_servers(&self) -> bool {
|
||||
!self.clients.is_empty()
|
||||
}
|
||||
|
||||
fn server_origin(&self, server_name: &str) -> Option<&str> {
|
||||
self.server_origins.get(server_name).map(String::as_str)
|
||||
}
|
||||
|
||||
fn contains_server(&self, server_name: &str) -> bool {
|
||||
self.clients.contains_key(server_name)
|
||||
}
|
||||
|
||||
async fn client_by_name(&self, name: &str) -> Result<ManagedClient> {
|
||||
@@ -764,18 +756,7 @@ impl McpConnectionManager {
|
||||
.context("failed to get client")
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
&self,
|
||||
server_name: String,
|
||||
id: RequestId,
|
||||
response: ElicitationResponse,
|
||||
) -> Result<()> {
|
||||
self.elicitation_requests
|
||||
.resolve(server_name, id, response)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
|
||||
async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
|
||||
let Some(async_managed_client) = self.clients.get(server_name) else {
|
||||
return false;
|
||||
};
|
||||
@@ -786,7 +767,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn required_startup_failures(
|
||||
async fn required_startup_failures(
|
||||
&self,
|
||||
required_servers: &[String],
|
||||
) -> Vec<McpStartupFailure> {
|
||||
@@ -814,7 +795,7 @@ impl McpConnectionManager {
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
|
||||
async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
|
||||
let mut tools = HashMap::new();
|
||||
for managed_client in self.clients.values() {
|
||||
let Some(server_tools) = managed_client.listed_tools().await else {
|
||||
@@ -830,7 +811,7 @@ impl McpConnectionManager {
|
||||
/// On success, the refreshed tools replace the cache contents and the
|
||||
/// latest filtered tool map is returned directly to the caller. On
|
||||
/// failure, the existing cache remains unchanged.
|
||||
pub async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
|
||||
async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
|
||||
let managed_client = self
|
||||
.clients
|
||||
.get(CODEX_APPS_MCP_SERVER_NAME)
|
||||
@@ -874,7 +855,7 @@ impl McpConnectionManager {
|
||||
|
||||
/// Returns a single map that contains all resources. Each key is the
|
||||
/// server name and the value is a vector of resources.
|
||||
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
let clients_snapshot = &self.clients;
|
||||
@@ -886,7 +867,6 @@ impl McpConnectionManager {
|
||||
};
|
||||
let timeout = managed_client.tool_timeout;
|
||||
let client = managed_client.client.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<Resource> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
@@ -940,7 +920,7 @@ impl McpConnectionManager {
|
||||
|
||||
/// Returns a single map that contains all resource templates. Each key is the
|
||||
/// server name and the value is a vector of resource templates.
|
||||
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
let clients_snapshot = &self.clients;
|
||||
@@ -952,7 +932,6 @@ impl McpConnectionManager {
|
||||
};
|
||||
let client = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<ResourceTemplate> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
@@ -1009,7 +988,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
pub async fn call_tool(
|
||||
async fn call_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
@@ -1047,7 +1026,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
@@ -1063,7 +1042,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
|
||||
/// List resource templates from the specified server.
|
||||
pub async fn list_resource_templates(
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
@@ -1079,7 +1058,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
|
||||
/// Read a resource from the specified server.
|
||||
pub async fn read_resource(
|
||||
async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
@@ -1095,14 +1074,14 @@ impl McpConnectionManager {
|
||||
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
|
||||
}
|
||||
|
||||
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.list_all_tools()
|
||||
.await
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool.name.to_string()))
|
||||
}
|
||||
|
||||
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
|
||||
async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for async_managed_client in self.clients.values() {
|
||||
@@ -1131,6 +1110,625 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct SharedMcpBackendCacheKey(String);
|
||||
|
||||
fn canonicalize_json_value(value: serde_json::Value) -> serde_json::Value {
|
||||
match value {
|
||||
serde_json::Value::Array(values) => {
|
||||
serde_json::Value::Array(values.into_iter().map(canonicalize_json_value).collect())
|
||||
}
|
||||
serde_json::Value::Object(object) => {
|
||||
let mut entries = object.into_iter().collect::<Vec<_>>();
|
||||
entries.sort_by(|(left_key, _), (right_key, _)| left_key.cmp(right_key));
|
||||
|
||||
let mut canonical = serde_json::Map::with_capacity(entries.len());
|
||||
for (key, value) in entries {
|
||||
canonical.insert(key, canonicalize_json_value(value));
|
||||
}
|
||||
|
||||
serde_json::Value::Object(canonical)
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedMcpBackendCacheKey {
|
||||
pub(crate) fn new(
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> Self {
|
||||
let mut servers = mcp_servers
|
||||
.iter()
|
||||
.filter(|(_, config)| config.enabled)
|
||||
.map(|(name, config)| {
|
||||
(
|
||||
name,
|
||||
canonicalize_json_value(serde_json::to_value(config).unwrap_or_else(|error| {
|
||||
panic!("serializing MCP server config for cache key: {error}")
|
||||
})),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
servers.sort_by(|(left_name, _), (right_name, _)| left_name.cmp(right_name));
|
||||
let payload = canonicalize_json_value(serde_json::json!({
|
||||
"sandboxState": sandbox_state,
|
||||
"storeMode": store_mode,
|
||||
"servers": servers,
|
||||
}));
|
||||
Self(serde_json::to_string(&payload).unwrap_or_else(|error| {
|
||||
panic!("serializing shared MCP backend cache key payload: {error}")
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_poolable_transport(transport: &McpServerTransportConfig) -> bool {
|
||||
matches!(transport, McpServerTransportConfig::Stdio { .. })
|
||||
}
|
||||
|
||||
fn split_poolable_mcp_servers(
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
) -> (
|
||||
HashMap<String, McpServerConfig>,
|
||||
HashMap<String, McpServerConfig>,
|
||||
) {
|
||||
let mut pooled_servers = HashMap::new();
|
||||
let mut local_servers = HashMap::new();
|
||||
|
||||
for (server_name, config) in mcp_servers {
|
||||
if is_poolable_transport(&config.transport) {
|
||||
pooled_servers.insert(server_name.clone(), config.clone());
|
||||
} else {
|
||||
local_servers.insert(server_name.clone(), config.clone());
|
||||
}
|
||||
}
|
||||
|
||||
(pooled_servers, local_servers)
|
||||
}
|
||||
|
||||
fn filter_auth_entries_for_servers(
|
||||
auth_entries: &HashMap<String, McpAuthStatusEntry>,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
) -> HashMap<String, McpAuthStatusEntry> {
|
||||
auth_entries
|
||||
.iter()
|
||||
.filter(|(server_name, _)| mcp_servers.contains_key(*server_name))
|
||||
.map(|(server_name, entry)| (server_name.clone(), entry.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SharedMcpBackendLease {
|
||||
inner: Arc<SharedMcpBackendLeaseInner>,
|
||||
}
|
||||
|
||||
impl SharedMcpBackendLease {
|
||||
fn new(backend: SharedMcpBackend, cancel_token: CancellationToken) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(SharedMcpBackendLeaseInner {
|
||||
backend: Arc::new(backend),
|
||||
cancel_token,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn backend(&self) -> Arc<SharedMcpBackend> {
|
||||
Arc::clone(&self.inner.backend)
|
||||
}
|
||||
}
|
||||
|
||||
struct SharedMcpBackendLeaseInner {
|
||||
backend: Arc<SharedMcpBackend>,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl Drop for SharedMcpBackendLeaseInner {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SharedMcpBackendPool {
|
||||
backends: Mutex<HashMap<SharedMcpBackendCacheKey, std::sync::Weak<SharedMcpBackendLeaseInner>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum SharedMcpBackendAcquireMode {
|
||||
ReuseExisting,
|
||||
ForceCreate,
|
||||
}
|
||||
|
||||
impl SharedMcpBackendPool {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn acquire_or_create(
|
||||
&self,
|
||||
key: SharedMcpBackendCacheKey,
|
||||
acquire_mode: SharedMcpBackendAcquireMode,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
session_handle: SessionMcpHandle,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
) -> SharedMcpBackendLease {
|
||||
let mut backends = self.backends.lock().await;
|
||||
match acquire_mode {
|
||||
SharedMcpBackendAcquireMode::ReuseExisting => {
|
||||
if let Some(existing) = backends.get(&key).and_then(std::sync::Weak::upgrade) {
|
||||
return SharedMcpBackendLease { inner: existing };
|
||||
}
|
||||
}
|
||||
SharedMcpBackendAcquireMode::ForceCreate => {}
|
||||
}
|
||||
|
||||
let (backend, cancel_token) = SharedMcpBackend::new(
|
||||
mcp_servers,
|
||||
store_mode,
|
||||
auth_entries,
|
||||
tx_event,
|
||||
initial_sandbox_state,
|
||||
session_handle,
|
||||
codex_home,
|
||||
codex_apps_tools_cache_key,
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
let lease = SharedMcpBackendLease::new(backend, cancel_token);
|
||||
backends.insert(key, Arc::downgrade(&lease.inner));
|
||||
lease
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn acquire_shared_backend_lease(
|
||||
pool: &SharedMcpBackendPool,
|
||||
acquire_mode: SharedMcpBackendAcquireMode,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
session_handle: SessionMcpHandle,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
) -> Option<SharedMcpBackendLease> {
|
||||
let (pooled_servers, _) = split_poolable_mcp_servers(mcp_servers);
|
||||
if !pooled_servers.values().any(|server| server.enabled) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key = SharedMcpBackendCacheKey::new(&pooled_servers, store_mode, &initial_sandbox_state);
|
||||
let pooled_auth_entries = filter_auth_entries_for_servers(&auth_entries, &pooled_servers);
|
||||
|
||||
Some(
|
||||
pool.acquire_or_create(
|
||||
key,
|
||||
acquire_mode,
|
||||
&pooled_servers,
|
||||
store_mode,
|
||||
pooled_auth_entries,
|
||||
tx_event,
|
||||
initial_sandbox_state,
|
||||
session_handle,
|
||||
codex_home,
|
||||
codex_apps_tools_cache_key,
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SessionMcpHandle {
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
}
|
||||
|
||||
impl SessionMcpHandle {
|
||||
fn new(approval_policy: AskForApproval) -> Self {
|
||||
Self {
|
||||
elicitation_requests: ElicitationRequestManager::new(approval_policy),
|
||||
}
|
||||
}
|
||||
|
||||
fn elicitation_requests(&self) -> ElicitationRequestManager {
|
||||
self.elicitation_requests.clone()
|
||||
}
|
||||
|
||||
fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
|
||||
if let Ok(mut policy) = self.elicitation_requests.approval_policy.lock() {
|
||||
*policy = approval_policy.value();
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_elicitation(
|
||||
&self,
|
||||
server_name: String,
|
||||
id: RequestId,
|
||||
response: ElicitationResponse,
|
||||
) -> Result<()> {
|
||||
self.elicitation_requests
|
||||
.resolve(server_name, id, response)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// A thin facade that keeps today's call sites stable while separating the
|
||||
/// sharable MCP backend from session-local response routing state.
|
||||
pub(crate) struct McpConnectionManager {
|
||||
backend: Arc<SharedMcpBackend>,
|
||||
shared_backend: Option<Arc<SharedMcpBackend>>,
|
||||
session: SessionMcpHandle,
|
||||
_shared_backend_lease: Option<SharedMcpBackendLease>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
fn backend_for_server(&self, server_name: &str) -> Option<Arc<SharedMcpBackend>> {
|
||||
if self.backend.contains_server(server_name) {
|
||||
Some(Arc::clone(&self.backend))
|
||||
} else {
|
||||
self.shared_backend
|
||||
.as_ref()
|
||||
.filter(|backend| backend.contains_server(server_name))
|
||||
.map(Arc::clone)
|
||||
}
|
||||
}
|
||||
|
||||
fn from_parts(
|
||||
backend: Arc<SharedMcpBackend>,
|
||||
shared_backend: Option<Arc<SharedMcpBackend>>,
|
||||
session: SessionMcpHandle,
|
||||
shared_backend_lease: Option<SharedMcpBackendLease>,
|
||||
) -> Self {
|
||||
Self {
|
||||
backend,
|
||||
shared_backend,
|
||||
session,
|
||||
_shared_backend_lease: shared_backend_lease,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_uninitialized(approval_policy: &Constrained<AskForApproval>) -> Self {
|
||||
Self::from_parts(
|
||||
Arc::new(SharedMcpBackend::new_uninitialized()),
|
||||
/*shared_backend*/ None,
|
||||
SessionMcpHandle::new(approval_policy.value()),
|
||||
/*shared_backend_lease*/ None,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new_mcp_connection_manager_for_tests(
|
||||
approval_policy: &Constrained<AskForApproval>,
|
||||
) -> Self {
|
||||
Self::new_uninitialized(approval_policy)
|
||||
}
|
||||
|
||||
pub(crate) fn has_servers(&self) -> bool {
|
||||
self.backend.has_servers()
|
||||
|| self
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.is_some_and(|backend| backend.has_servers())
|
||||
}
|
||||
|
||||
pub(crate) fn server_origin(&self, server_name: &str) -> Option<&str> {
|
||||
self.backend
|
||||
.server_origin(server_name)
|
||||
.or_else(|| self.shared_backend.as_ref()?.server_origin(server_name))
|
||||
}
|
||||
|
||||
pub fn set_approval_policy(&self, approval_policy: &Constrained<AskForApproval>) {
|
||||
self.session.set_approval_policy(approval_policy);
|
||||
}
|
||||
|
||||
#[allow(clippy::new_ret_no_self, clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
approval_policy: &Constrained<AskForApproval>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
) -> (Self, CancellationToken) {
|
||||
let session = SessionMcpHandle::new(approval_policy.value());
|
||||
let (backend, cancel_token) = SharedMcpBackend::new(
|
||||
mcp_servers,
|
||||
store_mode,
|
||||
auth_entries,
|
||||
tx_event,
|
||||
initial_sandbox_state,
|
||||
session.clone(),
|
||||
codex_home,
|
||||
codex_apps_tools_cache_key,
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
(
|
||||
Self::from_parts(
|
||||
Arc::new(backend),
|
||||
/*shared_backend*/ None,
|
||||
session,
|
||||
/*shared_backend_lease*/ None,
|
||||
),
|
||||
cancel_token,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn new_with_pool(
|
||||
pool: &SharedMcpBackendPool,
|
||||
acquire_mode: SharedMcpBackendAcquireMode,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
approval_policy: &Constrained<AskForApproval>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
) -> (Self, CancellationToken) {
|
||||
let session = SessionMcpHandle::new(approval_policy.value());
|
||||
let (_, local_servers) = split_poolable_mcp_servers(mcp_servers);
|
||||
let local_auth_entries = filter_auth_entries_for_servers(&auth_entries, &local_servers);
|
||||
let shared_backend_lease = acquire_shared_backend_lease(
|
||||
pool,
|
||||
acquire_mode,
|
||||
mcp_servers,
|
||||
store_mode,
|
||||
auth_entries,
|
||||
tx_event.clone(),
|
||||
initial_sandbox_state.clone(),
|
||||
session.clone(),
|
||||
codex_home.clone(),
|
||||
codex_apps_tools_cache_key.clone(),
|
||||
tool_plugin_provenance.clone(),
|
||||
)
|
||||
.await;
|
||||
let shared_backend = shared_backend_lease
|
||||
.as_ref()
|
||||
.map(SharedMcpBackendLease::backend);
|
||||
let (backend, cancel_token) = if local_servers.values().any(|server| server.enabled) {
|
||||
let (backend, cancel_token) = SharedMcpBackend::new(
|
||||
&local_servers,
|
||||
store_mode,
|
||||
local_auth_entries,
|
||||
tx_event,
|
||||
initial_sandbox_state,
|
||||
session.clone(),
|
||||
codex_home,
|
||||
codex_apps_tools_cache_key,
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
(Arc::new(backend), cancel_token)
|
||||
} else {
|
||||
(
|
||||
Arc::new(SharedMcpBackend::new_uninitialized()),
|
||||
CancellationToken::new(),
|
||||
)
|
||||
};
|
||||
(
|
||||
Self::from_parts(backend, shared_backend, session, shared_backend_lease),
|
||||
cancel_token,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn rebuild_pooled_backend(
|
||||
&self,
|
||||
pool: &SharedMcpBackendPool,
|
||||
acquire_mode: SharedMcpBackendAcquireMode,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_state: SandboxState,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
) -> Self {
|
||||
let shared_backend_lease = acquire_shared_backend_lease(
|
||||
pool,
|
||||
acquire_mode,
|
||||
mcp_servers,
|
||||
store_mode,
|
||||
auth_entries,
|
||||
tx_event,
|
||||
initial_sandbox_state,
|
||||
self.session.clone(),
|
||||
codex_home,
|
||||
codex_apps_tools_cache_key,
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
let shared_backend = shared_backend_lease
|
||||
.as_ref()
|
||||
.map(SharedMcpBackendLease::backend);
|
||||
Self::from_parts(
|
||||
Arc::clone(&self.backend),
|
||||
shared_backend,
|
||||
self.session.clone(),
|
||||
shared_backend_lease,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
&self,
|
||||
server_name: String,
|
||||
id: RequestId,
|
||||
response: ElicitationResponse,
|
||||
) -> Result<()> {
|
||||
self.session
|
||||
.resolve_elicitation(server_name, id, response)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_server_ready(&self, server_name: &str, timeout: Duration) -> bool {
|
||||
match self.backend_for_server(server_name) {
|
||||
Some(backend) => backend.wait_for_server_ready(server_name, timeout).await,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn required_startup_failures(
|
||||
&self,
|
||||
required_servers: &[String],
|
||||
) -> Vec<McpStartupFailure> {
|
||||
let mut failures = Vec::new();
|
||||
for server_name in required_servers {
|
||||
let Some(backend) = self.backend_for_server(server_name) else {
|
||||
failures.push(McpStartupFailure {
|
||||
server: server_name.clone(),
|
||||
error: format!("required MCP server `{server_name}` was not initialized"),
|
||||
});
|
||||
continue;
|
||||
};
|
||||
|
||||
failures.extend(
|
||||
backend
|
||||
.required_startup_failures(std::slice::from_ref(server_name))
|
||||
.await,
|
||||
);
|
||||
}
|
||||
failures
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
|
||||
let mut tools = self.backend.list_all_tools().await;
|
||||
if let Some(shared_backend) = self.shared_backend.as_ref() {
|
||||
tools.extend(shared_backend.list_all_tools().await);
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Force-refresh codex apps tools by bypassing the in-process cache.
|
||||
///
|
||||
/// On success, the refreshed tools replace the cache contents and the
|
||||
/// latest filtered tool map is returned directly to the caller. On
|
||||
/// failure, the existing cache remains unchanged.
|
||||
pub async fn hard_refresh_codex_apps_tools_cache(&self) -> Result<HashMap<String, ToolInfo>> {
|
||||
self.backend.hard_refresh_codex_apps_tools_cache().await
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resources. Each key is the
|
||||
/// server name and the value is a vector of resources.
|
||||
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
let mut resources = self.backend.list_all_resources().await;
|
||||
if let Some(shared_backend) = self.shared_backend.as_ref() {
|
||||
resources.extend(shared_backend.list_all_resources().await);
|
||||
}
|
||||
resources
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resource templates. Each key is the
|
||||
/// server name and the value is a vector of resource templates.
|
||||
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
let mut templates = self.backend.list_all_resource_templates().await;
|
||||
if let Some(shared_backend) = self.shared_backend.as_ref() {
|
||||
templates.extend(shared_backend.list_all_resource_templates().await);
|
||||
}
|
||||
templates
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
) -> Result<CallToolResult> {
|
||||
self.backend_for_server(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.call_tool(server, tool, arguments, meta)
|
||||
.await
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
self.backend_for_server(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.list_resources(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
/// List resource templates from the specified server.
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
self.backend_for_server(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.list_resource_templates(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Read a resource from the specified server.
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> Result<ReadResourceResult> {
|
||||
self.backend_for_server(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.read_resource(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
if let Some(parsed) = self.backend.parse_tool_name(tool_name).await {
|
||||
Some(parsed)
|
||||
} else if let Some(shared_backend) = self.shared_backend.as_ref() {
|
||||
shared_backend.parse_tool_name(tool_name).await
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
|
||||
self.backend
|
||||
.notify_sandbox_state_change(sandbox_state)
|
||||
.await?;
|
||||
if let Some(shared_backend) = self.shared_backend.as_ref() {
|
||||
shared_backend
|
||||
.notify_sandbox_state_change(sandbox_state)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn notify_local_sandbox_state_change(
|
||||
&self,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> Result<()> {
|
||||
self.backend
|
||||
.notify_sandbox_state_change(sandbox_state)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_update(
|
||||
tx_event: &Sender<Event>,
|
||||
update: McpStartupUpdateEvent,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::*;
|
||||
use codex_protocol::protocol::GranularApprovalConfig;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::JsonObject;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
@@ -60,6 +62,98 @@ fn create_codex_apps_tools_cache_context(
|
||||
}
|
||||
}
|
||||
|
||||
fn test_codex_apps_tools_cache_key() -> CodexAppsToolsCacheKey {
|
||||
CodexAppsToolsCacheKey {
|
||||
account_id: None,
|
||||
chatgpt_user_id: None,
|
||||
is_workspace_account: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn stdio_server_config(command: &str) -> McpServerConfig {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: command.to_string(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(1)),
|
||||
tool_timeout_sec: Some(Duration::from_secs(1)),
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn http_server_config(url: &str) -> McpServerConfig {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: url.to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(1)),
|
||||
tool_timeout_sec: Some(Duration::from_secs(1)),
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_sandbox_state(cwd: PathBuf) -> SandboxState {
|
||||
SandboxState {
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
codex_linux_sandbox_exe: None,
|
||||
sandbox_cwd: cwd,
|
||||
use_legacy_landlock: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn pooled_cache_key_for_tests(
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> SharedMcpBackendCacheKey {
|
||||
let (pooled_servers, _) = split_poolable_mcp_servers(mcp_servers);
|
||||
SharedMcpBackendCacheKey::new(&pooled_servers, store_mode, sandbox_state)
|
||||
}
|
||||
|
||||
async fn new_pooled_manager_for_tests(
|
||||
pool: &SharedMcpBackendPool,
|
||||
acquire_mode: SharedMcpBackendAcquireMode,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
sandbox_state: SandboxState,
|
||||
) -> McpConnectionManager {
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnRequest);
|
||||
let (tx_event, _rx_event) = async_channel::unbounded();
|
||||
let (manager, _cancel_token) = McpConnectionManager::new_with_pool(
|
||||
pool,
|
||||
acquire_mode,
|
||||
mcp_servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
HashMap::new(),
|
||||
&approval_policy,
|
||||
tx_event,
|
||||
sandbox_state,
|
||||
PathBuf::from("/tmp"),
|
||||
test_codex_apps_tools_cache_key(),
|
||||
ToolPluginProvenance::default(),
|
||||
)
|
||||
.await;
|
||||
manager
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elicitation_granular_policy_defaults_to_prompting() {
|
||||
assert!(!elicitation_is_rejected_by_policy(
|
||||
@@ -409,15 +503,18 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
|
||||
.shared();
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
|
||||
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
|
||||
manager.clients.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
Arc::get_mut(&mut manager.backend)
|
||||
.expect("test manager backend should be uniquely owned")
|
||||
.clients
|
||||
.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
let tools = manager.list_all_tools().await;
|
||||
let tool = tools
|
||||
@@ -434,15 +531,18 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
|
||||
.shared();
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
|
||||
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
|
||||
manager.clients.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
Arc::get_mut(&mut manager.backend)
|
||||
.expect("test manager backend should be uniquely owned")
|
||||
.clients
|
||||
.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
let timeout_result =
|
||||
tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await;
|
||||
@@ -456,15 +556,18 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
|
||||
.shared();
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
|
||||
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
|
||||
manager.clients.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
Arc::get_mut(&mut manager.backend)
|
||||
.expect("test manager backend should be uniquely owned")
|
||||
.clients
|
||||
.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
let timeout_result =
|
||||
tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await;
|
||||
@@ -488,15 +591,18 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
|
||||
let mut manager = McpConnectionManager::new_uninitialized(&approval_policy);
|
||||
let startup_complete = Arc::new(std::sync::atomic::AtomicBool::new(true));
|
||||
manager.clients.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
Arc::get_mut(&mut manager.backend)
|
||||
.expect("test manager backend should be uniquely owned")
|
||||
.clients
|
||||
.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
let tools = manager.list_all_tools().await;
|
||||
let tool = tools
|
||||
@@ -506,6 +612,246 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
||||
assert_eq!(tool.tool_name, "calendar_create_event");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_tool_name_searches_shared_backend() {
|
||||
let approval_policy = Constrained::allow_any(AskForApproval::OnFailure);
|
||||
let shared_tool = create_test_tool("shared_stdio", "tool_a");
|
||||
let pending_client = futures::future::pending::<Result<ManagedClient, StartupOutcomeError>>()
|
||||
.boxed()
|
||||
.shared();
|
||||
|
||||
let mut shared_backend = SharedMcpBackend::new_uninitialized();
|
||||
shared_backend.clients.insert(
|
||||
"shared_stdio".to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(vec![shared_tool]),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
let manager = McpConnectionManager::from_parts(
|
||||
Arc::new(SharedMcpBackend::new_uninitialized()),
|
||||
Some(Arc::new(shared_backend)),
|
||||
SessionMcpHandle::new(approval_policy.value()),
|
||||
None,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
manager.parse_tool_name("mcp__shared_stdio__tool_a").await,
|
||||
Some(("shared_stdio".to_string(), "tool_a".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shared_mcp_backend_cache_key_is_stable_for_equivalent_stdio_configs() {
|
||||
let mut left_servers = HashMap::new();
|
||||
left_servers.insert("beta".to_string(), stdio_server_config("missing-beta"));
|
||||
left_servers.insert("alpha".to_string(), stdio_server_config("missing-alpha"));
|
||||
|
||||
let mut right_servers = HashMap::new();
|
||||
right_servers.insert("alpha".to_string(), stdio_server_config("missing-alpha"));
|
||||
right_servers.insert("beta".to_string(), stdio_server_config("missing-beta"));
|
||||
|
||||
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
|
||||
|
||||
assert_eq!(
|
||||
pooled_cache_key_for_tests(
|
||||
&left_servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&sandbox_state,
|
||||
),
|
||||
pooled_cache_key_for_tests(
|
||||
&right_servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&sandbox_state,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shared_mcp_backend_cache_key_ignores_http_servers() {
|
||||
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
|
||||
|
||||
let mut left_servers = HashMap::new();
|
||||
left_servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
left_servers.insert(
|
||||
"http".to_string(),
|
||||
http_server_config("http://127.0.0.1:9/left"),
|
||||
);
|
||||
|
||||
let mut right_servers = HashMap::new();
|
||||
right_servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
right_servers.insert(
|
||||
"http".to_string(),
|
||||
http_server_config("http://127.0.0.1:9/right"),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
pooled_cache_key_for_tests(
|
||||
&left_servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&sandbox_state,
|
||||
),
|
||||
pooled_cache_key_for_tests(
|
||||
&right_servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&sandbox_state,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_poolable_mcp_servers_keeps_http_servers_local() {
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
servers.insert(
|
||||
"http".to_string(),
|
||||
http_server_config("http://127.0.0.1:9/http"),
|
||||
);
|
||||
|
||||
let (pooled_servers, local_servers) = split_poolable_mcp_servers(&servers);
|
||||
|
||||
assert_eq!(pooled_servers.len(), 1);
|
||||
assert!(pooled_servers.contains_key("stdio"));
|
||||
assert_eq!(local_servers.len(), 1);
|
||||
assert!(local_servers.contains_key("http"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shared_mcp_backend_cache_key_separates_sandbox_state() {
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
|
||||
assert_ne!(
|
||||
pooled_cache_key_for_tests(
|
||||
&servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&test_sandbox_state(PathBuf::from("/tmp/left")),
|
||||
),
|
||||
pooled_cache_key_for_tests(
|
||||
&servers,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&test_sandbox_state(PathBuf::from("/tmp/right")),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shared_mcp_backend_pool_reuses_backend_for_same_stdio_config() {
|
||||
let pool = SharedMcpBackendPool::new();
|
||||
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
|
||||
let manager_1 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
sandbox_state.clone(),
|
||||
)
|
||||
.await;
|
||||
let manager_2 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
sandbox_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(Arc::ptr_eq(
|
||||
manager_1
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled"),
|
||||
manager_2
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled"),
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shared_mcp_backend_pool_separates_backends_for_different_sandbox_states() {
|
||||
let pool = SharedMcpBackendPool::new();
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
|
||||
let manager_1 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
test_sandbox_state(PathBuf::from("/tmp/left")),
|
||||
)
|
||||
.await;
|
||||
let manager_2 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
test_sandbox_state(PathBuf::from("/tmp/right")),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(!Arc::ptr_eq(
|
||||
manager_1
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled"),
|
||||
manager_2
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled"),
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shared_mcp_backend_pool_force_create_replaces_pool_entry_for_same_key() {
|
||||
let pool = SharedMcpBackendPool::new();
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert("stdio".to_string(), stdio_server_config("missing-stdio"));
|
||||
let sandbox_state = test_sandbox_state(PathBuf::from("/tmp/shared"));
|
||||
|
||||
let manager_1 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
sandbox_state.clone(),
|
||||
)
|
||||
.await;
|
||||
let manager_2 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ForceCreate,
|
||||
&servers,
|
||||
sandbox_state.clone(),
|
||||
)
|
||||
.await;
|
||||
let manager_3 = new_pooled_manager_for_tests(
|
||||
&pool,
|
||||
SharedMcpBackendAcquireMode::ReuseExisting,
|
||||
&servers,
|
||||
sandbox_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
let shared_1 = manager_1
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled");
|
||||
let shared_2 = manager_2
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled");
|
||||
let shared_3 = manager_3
|
||||
.shared_backend
|
||||
.as_ref()
|
||||
.expect("stdio backend should be pooled");
|
||||
|
||||
assert!(!Arc::ptr_eq(shared_1, shared_2));
|
||||
assert!(Arc::ptr_eq(shared_2, shared_3));
|
||||
assert!(!Arc::ptr_eq(shared_1, shared_3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elicitation_capability_enabled_only_for_codex_apps() {
|
||||
let codex_apps_capability = elicitation_capability_for_server(CODEX_APPS_MCP_SERVER_NAME);
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::mcp::McpManager;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendPool;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::plugins::PluginsManager;
|
||||
use crate::skills::SkillsManager;
|
||||
@@ -33,6 +34,7 @@ use tokio_util::sync::CancellationToken;
|
||||
pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
|
||||
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
|
||||
pub(crate) shared_mcp_backend_pool: Option<Arc<SharedMcpBackendPool>>,
|
||||
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
|
||||
#[cfg_attr(not(unix), allow(dead_code))]
|
||||
pub(crate) shell_zsh_path: Option<PathBuf>,
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::error::Result as CodexResult;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::file_watcher::FileWatcherEvent;
|
||||
use crate::mcp::McpManager;
|
||||
use crate::mcp_connection_manager::SharedMcpBackendPool;
|
||||
use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::plugins::PluginsManager;
|
||||
@@ -155,6 +156,7 @@ pub(crate) struct ThreadManagerState {
|
||||
skills_manager: Arc<SkillsManager>,
|
||||
plugins_manager: Arc<PluginsManager>,
|
||||
mcp_manager: Arc<McpManager>,
|
||||
shared_mcp_backend_pool: Arc<SharedMcpBackendPool>,
|
||||
file_watcher: Arc<FileWatcher>,
|
||||
session_source: SessionSource,
|
||||
// Captures submitted ops for testing purpose when test mode is enabled.
|
||||
@@ -202,6 +204,7 @@ impl ThreadManager {
|
||||
skills_manager,
|
||||
plugins_manager,
|
||||
mcp_manager,
|
||||
shared_mcp_backend_pool: Arc::new(SharedMcpBackendPool::new()),
|
||||
file_watcher,
|
||||
auth_manager,
|
||||
session_source,
|
||||
@@ -266,6 +269,7 @@ impl ThreadManager {
|
||||
skills_manager,
|
||||
plugins_manager,
|
||||
mcp_manager,
|
||||
shared_mcp_backend_pool: Arc::new(SharedMcpBackendPool::new()),
|
||||
file_watcher,
|
||||
auth_manager,
|
||||
session_source: SessionSource::Exec,
|
||||
@@ -582,6 +586,10 @@ impl ThreadManager {
|
||||
}
|
||||
|
||||
impl ThreadManagerState {
|
||||
pub(crate) fn shared_mcp_backend_pool(&self) -> Arc<SharedMcpBackendPool> {
|
||||
Arc::clone(&self.shared_mcp_backend_pool)
|
||||
}
|
||||
|
||||
pub(crate) async fn list_thread_ids(&self) -> Vec<ThreadId> {
|
||||
self.threads.read().await.keys().copied().collect()
|
||||
}
|
||||
|
||||
@@ -73,8 +73,6 @@ ignore = [
|
||||
{ id = "RUSTSEC-2024-0388", reason = "derivative is unmaintained; pulled in via starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0057", reason = "fxhash is unmaintained; pulled in via starlark_map/starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2024-0436", reason = "paste is unmaintained; pulled in via ratatui/rmcp/starlark used by tui/execpolicy; no fixed release yet" },
|
||||
# TODO(joshka, nornagon): remove this exception when once we update the ratatui fork to a version that uses lru 0.13+.
|
||||
{ id = "RUSTSEC-2026-0002", reason = "lru 0.12.5 is pulled in via ratatui fork; cannot upgrade until the fork is updated" },
|
||||
# TODO(fcoury): remove this exception when syntect drops yaml-rust and bincode, or updates to versions that have fixed the vulnerabilities.
|
||||
{ id = "RUSTSEC-2024-0320", reason = "yaml-rust is unmaintained; pulled in via syntect v5.3.0 used by codex-tui for syntax highlighting; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0141", reason = "bincode is unmaintained; pulled in via syntect v5.3.0 used by codex-tui for syntax highlighting; no fixed release yet" },
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::ErrorData as McpError;
|
||||
@@ -36,6 +38,7 @@ struct TestToolServer {
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
const SMALL_PNG_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==";
|
||||
const STARTUP_COUNT_FILE_ENV_VAR: &str = "MCP_STARTUP_COUNT_FILE";
|
||||
|
||||
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
|
||||
(tokio::io::stdin(), tokio::io::stdout())
|
||||
@@ -454,8 +457,18 @@ fn parse_data_url(url: &str) -> Option<(String, String)> {
|
||||
Some((mime.to_string(), data.to_string()))
|
||||
}
|
||||
|
||||
fn record_startup_if_requested() -> std::io::Result<()> {
|
||||
let Some(path) = std::env::var_os(STARTUP_COUNT_FILE_ENV_VAR) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut file = OpenOptions::new().create(true).append(true).open(path)?;
|
||||
writeln!(file, "started")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
record_startup_if_requested()?;
|
||||
eprintln!("starting rmcp test server");
|
||||
// Run the server with STDIO transport. If the client disconnects we simply
|
||||
// bubble up the error so the process exits.
|
||||
|
||||
Reference in New Issue
Block a user