mirror of
https://github.com/openai/codex.git
synced 2026-04-17 11:14:48 +00:00
Compare commits
25 Commits
dev/ningyi
...
dev/remote
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ac989f2c8 | ||
|
|
05c5a981f0 | ||
|
|
4fd43636d7 | ||
|
|
12c2397da2 | ||
|
|
cc17c031c1 | ||
|
|
b3f866f46e | ||
|
|
076eb80792 | ||
|
|
42555edee1 | ||
|
|
d7586c4629 | ||
|
|
5d8c2338ed | ||
|
|
cb80678ac9 | ||
|
|
4478580860 | ||
|
|
330e966cf3 | ||
|
|
1ae664ecda | ||
|
|
2917f7273d | ||
|
|
1d4a2663d2 | ||
|
|
a087158ebe | ||
|
|
fdfe2ca044 | ||
|
|
97fd5d45f3 | ||
|
|
0666bc7110 | ||
|
|
d94930b3fa | ||
|
|
6e1452d0db | ||
|
|
e77367575f | ||
|
|
53ea2a8dda | ||
|
|
086ae0abe5 |
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -2396,6 +2396,7 @@ dependencies = [
|
||||
"async-channel",
|
||||
"codex-async-utils",
|
||||
"codex-config",
|
||||
"codex-exec-server",
|
||||
"codex-login",
|
||||
"codex-otel",
|
||||
"codex-plugin",
|
||||
@@ -2683,6 +2684,7 @@ dependencies = [
|
||||
"axum",
|
||||
"codex-client",
|
||||
"codex-config",
|
||||
"codex-exec-server",
|
||||
"codex-keyring-store",
|
||||
"codex-protocol",
|
||||
"codex-utils-cargo-bin",
|
||||
|
||||
@@ -270,6 +270,7 @@ use codex_login::default_client::set_default_client_residency_requirement;
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_login::request_device_code;
|
||||
use codex_login::run_login_server;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::McpServerStatusSnapshot;
|
||||
use codex_mcp::McpSnapshotDetail;
|
||||
use codex_mcp::collect_mcp_server_status_snapshot_with_detail;
|
||||
@@ -5392,10 +5393,38 @@ impl CodexMessageProcessor {
|
||||
.to_mcp_config(self.thread_manager.plugins_manager().as_ref())
|
||||
.await;
|
||||
let auth = self.auth_manager.auth().await;
|
||||
let runtime_environment = match self.thread_manager.environment_manager().current().await {
|
||||
Ok(Some(environment)) => {
|
||||
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
|
||||
}
|
||||
Ok(None) => McpRuntimeEnvironment::new(
|
||||
Arc::new(codex_exec_server::Environment::default()),
|
||||
config.cwd.to_path_buf(),
|
||||
),
|
||||
Err(err) => {
|
||||
// TODO(aibrahim): Investigate degrading MCP status listing when
|
||||
// executor environment creation fails.
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to create environment: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::list_mcp_server_status_task(outgoing, request, params, config, mcp_config, auth)
|
||||
.await;
|
||||
Self::list_mcp_server_status_task(
|
||||
outgoing,
|
||||
request,
|
||||
params,
|
||||
config,
|
||||
mcp_config,
|
||||
auth,
|
||||
runtime_environment,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5406,6 +5435,7 @@ impl CodexMessageProcessor {
|
||||
config: Config,
|
||||
mcp_config: codex_mcp::McpConfig,
|
||||
auth: Option<CodexAuth>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) {
|
||||
let detail = match params.detail.unwrap_or(McpServerStatusDetail::Full) {
|
||||
McpServerStatusDetail::Full => McpSnapshotDetail::Full,
|
||||
@@ -5416,6 +5446,7 @@ impl CodexMessageProcessor {
|
||||
&mcp_config,
|
||||
auth.as_ref(),
|
||||
request_id.request_id.to_string(),
|
||||
runtime_environment,
|
||||
detail,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -16,6 +16,7 @@ anyhow = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-exec-server = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-plugin = { workspace = true }
|
||||
|
||||
@@ -38,6 +38,7 @@ pub use mcp_connection_manager::CodexAppsToolsCacheKey;
|
||||
pub use mcp_connection_manager::DEFAULT_STARTUP_TIMEOUT;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_META_CAPABILITY;
|
||||
pub use mcp_connection_manager::McpConnectionManager;
|
||||
pub use mcp_connection_manager::McpRuntimeEnvironment;
|
||||
pub use mcp_connection_manager::SandboxState;
|
||||
pub use mcp_connection_manager::ToolInfo;
|
||||
pub use mcp_connection_manager::codex_apps_tools_cache_key;
|
||||
|
||||
@@ -35,6 +35,7 @@ use codex_protocol::protocol::SandboxPolicy;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::McpRuntimeEnvironment;
|
||||
use crate::mcp_connection_manager::codex_apps_tools_cache_key;
|
||||
pub type McpManager = McpConnectionManager;
|
||||
|
||||
@@ -320,14 +321,23 @@ pub async fn collect_mcp_snapshot(
|
||||
config: &McpConfig,
|
||||
auth: Option<&CodexAuth>,
|
||||
submit_id: String,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> McpListToolsResponseEvent {
|
||||
collect_mcp_snapshot_with_detail(config, auth, submit_id, McpSnapshotDetail::Full).await
|
||||
collect_mcp_snapshot_with_detail(
|
||||
config,
|
||||
auth,
|
||||
submit_id,
|
||||
runtime_environment,
|
||||
McpSnapshotDetail::Full,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn collect_mcp_snapshot_with_detail(
|
||||
config: &McpConfig,
|
||||
auth: Option<&CodexAuth>,
|
||||
submit_id: String,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
detail: McpSnapshotDetail,
|
||||
) -> McpListToolsResponseEvent {
|
||||
let mcp_servers = effective_mcp_servers(config, auth);
|
||||
@@ -355,6 +365,7 @@ pub async fn collect_mcp_snapshot_with_detail(
|
||||
submit_id,
|
||||
tx_event,
|
||||
SandboxPolicy::new_read_only_policy(),
|
||||
runtime_environment,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
@@ -385,15 +396,23 @@ pub async fn collect_mcp_server_status_snapshot(
|
||||
config: &McpConfig,
|
||||
auth: Option<&CodexAuth>,
|
||||
submit_id: String,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> McpServerStatusSnapshot {
|
||||
collect_mcp_server_status_snapshot_with_detail(config, auth, submit_id, McpSnapshotDetail::Full)
|
||||
.await
|
||||
collect_mcp_server_status_snapshot_with_detail(
|
||||
config,
|
||||
auth,
|
||||
submit_id,
|
||||
runtime_environment,
|
||||
McpSnapshotDetail::Full,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn collect_mcp_server_status_snapshot_with_detail(
|
||||
config: &McpConfig,
|
||||
auth: Option<&CodexAuth>,
|
||||
submit_id: String,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
detail: McpSnapshotDetail,
|
||||
) -> McpServerStatusSnapshot {
|
||||
let mcp_servers = effective_mcp_servers(config, auth);
|
||||
@@ -421,6 +440,7 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
|
||||
submit_id,
|
||||
tx_event,
|
||||
SandboxPolicy::new_read_only_policy(),
|
||||
runtime_environment,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
|
||||
@@ -36,6 +36,7 @@ use codex_async_utils::CancelErr;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_config::Constrained;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_protocol::ToolName;
|
||||
use codex_protocol::approvals::ElicitationRequest;
|
||||
use codex_protocol::approvals::ElicitationRequestEvent;
|
||||
@@ -50,8 +51,11 @@ use codex_protocol::protocol::McpStartupStatus;
|
||||
use codex_protocol::protocol::McpStartupUpdateEvent;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::ExecutorStdioServerLauncher;
|
||||
use codex_rmcp_client::LocalStdioServerLauncher;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_rmcp_client::SendElicitation;
|
||||
use codex_rmcp_client::StdioServerLauncher;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::future::FutureExt;
|
||||
use futures::future::Shared;
|
||||
@@ -491,6 +495,7 @@ impl AsyncManagedClient {
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
codex_apps_tools_cache_context: Option<CodexAppsToolsCacheContext>,
|
||||
tool_plugin_provenance: Arc<ToolPluginProvenance>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Self {
|
||||
let tool_filter = ToolFilter::from_config(&config);
|
||||
let startup_snapshot = load_startup_cached_codex_apps_tools_snapshot(
|
||||
@@ -507,8 +512,15 @@ impl AsyncManagedClient {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let client =
|
||||
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
|
||||
let client = Arc::new(
|
||||
make_rmcp_client(
|
||||
&server_name,
|
||||
config.clone(),
|
||||
store_mode,
|
||||
runtime_environment,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
match start_server_task(
|
||||
server_name,
|
||||
client,
|
||||
@@ -648,6 +660,32 @@ pub struct McpConnectionManager {
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
}
|
||||
|
||||
/// Runtime placement information used when starting MCP server transports.
|
||||
///
|
||||
/// `McpConfig` describes what servers exist. This value describes where those
|
||||
/// servers should run for the current caller. Keep it explicit at manager
|
||||
/// construction time so status/snapshot paths and real sessions make the same
|
||||
/// local-vs-remote decision.
|
||||
#[derive(Clone)]
|
||||
pub struct McpRuntimeEnvironment {
|
||||
environment: Arc<Environment>,
|
||||
cwd: PathBuf,
|
||||
}
|
||||
|
||||
impl McpRuntimeEnvironment {
|
||||
pub fn new(environment: Arc<Environment>, cwd: PathBuf) -> Self {
|
||||
Self { environment, cwd }
|
||||
}
|
||||
|
||||
fn environment(&self) -> Arc<Environment> {
|
||||
Arc::clone(&self.environment)
|
||||
}
|
||||
|
||||
fn cwd(&self) -> PathBuf {
|
||||
self.cwd.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
pub fn configured_servers(&self, config: &McpConfig) -> HashMap<String, McpServerConfig> {
|
||||
configured_mcp_servers(config)
|
||||
@@ -708,6 +746,7 @@ impl McpConnectionManager {
|
||||
submit_id: String,
|
||||
tx_event: Sender<Event>,
|
||||
initial_sandbox_policy: SandboxPolicy,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
codex_home: PathBuf,
|
||||
codex_apps_tools_cache_key: CodexAppsToolsCacheKey,
|
||||
tool_plugin_provenance: ToolPluginProvenance,
|
||||
@@ -752,6 +791,7 @@ impl McpConnectionManager {
|
||||
elicitation_requests.clone(),
|
||||
codex_apps_tools_cache_context,
|
||||
Arc::clone(&tool_plugin_provenance),
|
||||
runtime_environment.clone(),
|
||||
);
|
||||
clients.insert(server_name.clone(), async_managed_client.clone());
|
||||
let tx_event = tx_event.clone();
|
||||
@@ -1482,9 +1522,25 @@ struct StartServerTaskParams {
|
||||
|
||||
async fn make_rmcp_client(
|
||||
server_name: &str,
|
||||
transport: McpServerTransportConfig,
|
||||
config: McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<RmcpClient, StartupOutcomeError> {
|
||||
let McpServerConfig {
|
||||
transport,
|
||||
experimental_environment,
|
||||
..
|
||||
} = config;
|
||||
let remote_environment = match experimental_environment.as_deref() {
|
||||
None | Some("local") => false,
|
||||
Some("remote") => true,
|
||||
Some(environment) => {
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
"unsupported experimental_environment `{environment}` for MCP server `{server_name}`"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
@@ -1500,7 +1556,25 @@ async fn make_rmcp_client(
|
||||
.map(|(key, value)| (key.into(), value.into()))
|
||||
.collect::<HashMap<_, _>>()
|
||||
});
|
||||
RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd)
|
||||
let launcher = if remote_environment {
|
||||
let exec_environment = runtime_environment.environment();
|
||||
if !exec_environment.is_remote() {
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
"remote MCP server `{server_name}` requires a remote executor environment"
|
||||
)));
|
||||
}
|
||||
Arc::new(ExecutorStdioServerLauncher::new(
|
||||
exec_environment.get_exec_backend(),
|
||||
runtime_environment.cwd(),
|
||||
))
|
||||
} else {
|
||||
Arc::new(LocalStdioServerLauncher) as Arc<dyn StdioServerLauncher>
|
||||
};
|
||||
|
||||
// `RmcpClient` always sees a launched MCP stdio server. The
|
||||
// launcher hides whether that means a local child process or an
|
||||
// executor process whose stdin/stdout bytes cross the process API.
|
||||
RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd, launcher)
|
||||
.await
|
||||
.map_err(|err| StartupOutcomeError::from(anyhow!(err)))
|
||||
}
|
||||
@@ -1510,6 +1584,24 @@ async fn make_rmcp_client(
|
||||
env_http_headers,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
if remote_environment {
|
||||
let exec_environment = runtime_environment.environment();
|
||||
if !exec_environment.is_remote() {
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
"remote MCP server `{server_name}` requires a remote executor environment"
|
||||
)));
|
||||
}
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
// Remote HTTP needs the future low-level executor
|
||||
// `network/request` API so reqwest runs on the executor side.
|
||||
// Do not fall back to local HTTP here; the config explicitly
|
||||
// asked for remote placement.
|
||||
"remote streamable HTTP MCP server `{server_name}` is not implemented yet"
|
||||
)));
|
||||
}
|
||||
|
||||
// Local streamable HTTP remains the existing reqwest path from
|
||||
// the orchestrator process.
|
||||
let resolved_bearer_token =
|
||||
match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) {
|
||||
Ok(token) => token,
|
||||
|
||||
@@ -80,6 +80,7 @@ use codex_login::CodexAuth;
|
||||
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
|
||||
use codex_login::default_client::originator;
|
||||
use codex_mcp::McpConnectionManager;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::ToolInfo;
|
||||
use codex_mcp::codex_apps_tools_cache_key;
|
||||
#[cfg(test)]
|
||||
@@ -2154,7 +2155,7 @@ impl Session {
|
||||
code_mode_service: crate::tools::code_mode::CodeModeService::new(
|
||||
config.js_repl_node_path.clone(),
|
||||
),
|
||||
environment,
|
||||
environment: environment.clone(),
|
||||
};
|
||||
services
|
||||
.model_client
|
||||
@@ -2248,6 +2249,12 @@ impl Session {
|
||||
INITIAL_SUBMIT_ID.to_owned(),
|
||||
tx_event.clone(),
|
||||
session_configuration.sandbox_policy.get().clone(),
|
||||
McpRuntimeEnvironment::new(
|
||||
environment
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(Environment::default())),
|
||||
session_configuration.cwd.to_path_buf(),
|
||||
),
|
||||
config.codex_home.to_path_buf(),
|
||||
codex_apps_tools_cache_key(auth),
|
||||
tool_plugin_provenance,
|
||||
@@ -4584,6 +4591,13 @@ impl Session {
|
||||
turn_context.sub_id.clone(),
|
||||
self.get_tx_event(),
|
||||
turn_context.sandbox_policy.get().clone(),
|
||||
McpRuntimeEnvironment::new(
|
||||
turn_context
|
||||
.environment
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(Environment::default())),
|
||||
turn_context.cwd.to_path_buf(),
|
||||
),
|
||||
config.codex_home.to_path_buf(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
|
||||
@@ -14,6 +14,7 @@ pub use codex_app_server_protocol::AppInfo;
|
||||
pub use codex_app_server_protocol::AppMetadata;
|
||||
use codex_connectors::AllConnectorsCacheKey;
|
||||
use codex_connectors::DirectoryListResponse;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_login::token_data::TokenData;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_tools::DiscoverableTool;
|
||||
@@ -40,6 +41,7 @@ use codex_login::default_client::is_first_party_chat_originator;
|
||||
use codex_login::default_client::originator;
|
||||
use codex_mcp::CODEX_APPS_MCP_SERVER_NAME;
|
||||
use codex_mcp::McpConnectionManager;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::ToolInfo;
|
||||
use codex_mcp::ToolPluginProvenance;
|
||||
use codex_mcp::codex_apps_tools_cache_key;
|
||||
@@ -233,6 +235,7 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_options_and_status(
|
||||
INITIAL_SUBMIT_ID.to_owned(),
|
||||
tx_event,
|
||||
SandboxPolicy::new_read_only_policy(),
|
||||
McpRuntimeEnvironment::new(Arc::new(Environment::default()), config.cwd.to_path_buf()),
|
||||
config.codex_home.to_path_buf(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
ToolPluginProvenance::default(),
|
||||
|
||||
@@ -152,6 +152,7 @@ fn exec_server_params_for_request(
|
||||
env_policy,
|
||||
env,
|
||||
tty,
|
||||
pipe_stdin: false,
|
||||
arg0: request.arg0.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use super::process::UnifiedExecProcess;
|
||||
use crate::unified_exec::UnifiedExecError;
|
||||
use async_trait::async_trait;
|
||||
use codex_exec_server::ExecProcess;
|
||||
use codex_exec_server::ExecProcessEventReceiver;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::ProcessId;
|
||||
use codex_exec_server::ReadResponse;
|
||||
@@ -33,6 +34,10 @@ impl ExecProcess for MockExecProcess {
|
||||
self.wake_tx.subscribe()
|
||||
}
|
||||
|
||||
fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
panic!("MockExecProcess does not support event streaming")
|
||||
}
|
||||
|
||||
async fn read(
|
||||
&self,
|
||||
_after_seq: Option<u64>,
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::ensure;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
use std::path::Path;
|
||||
use std::process::Command as StdCommand;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
@@ -34,6 +39,7 @@ use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_cargo_bin::cargo_bin;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::remote_env_env_var;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::mount_models_once;
|
||||
@@ -86,6 +92,66 @@ enum McpCallEvent {
|
||||
End(String),
|
||||
}
|
||||
|
||||
const REMOTE_MCP_ENVIRONMENT: &str = "remote";
|
||||
|
||||
fn remote_aware_experimental_environment() -> Option<String> {
|
||||
// These tests run locally in normal CI and against the Docker-backed
|
||||
// executor in full-ci. Match that shared test environment instead of
|
||||
// parameterizing each stdio MCP test with its own local/remote cases.
|
||||
std::env::var_os(remote_env_env_var()).map(|_| REMOTE_MCP_ENVIRONMENT.to_string())
|
||||
}
|
||||
|
||||
fn remote_aware_stdio_server_bin() -> anyhow::Result<String> {
|
||||
let bin = stdio_server_bin()?;
|
||||
let Some(container_name) = std::env::var_os(remote_env_env_var()) else {
|
||||
return Ok(bin);
|
||||
};
|
||||
let container_name = container_name
|
||||
.into_string()
|
||||
.map_err(|value| anyhow::anyhow!("remote env container name must be utf-8: {value:?}"))?;
|
||||
|
||||
// Keep the Docker path rewrite scoped to tests that use `build_remote_aware`.
|
||||
// Other MCP tests still start their stdio server from the orchestrator test
|
||||
// process, even when the full-ci remote env is present.
|
||||
//
|
||||
// Remote-aware MCP tests run the executor inside Docker. The stdio test
|
||||
// server is built on the host, so hand the executor a copied in-container
|
||||
// path instead of the host build artifact path.
|
||||
// Several remote-aware MCP tests can run in parallel; give each copied
|
||||
// binary its own path so one test cannot replace another test's executable.
|
||||
let unique_suffix = SystemTime::now().duration_since(UNIX_EPOCH)?.as_nanos();
|
||||
let remote_path = format!(
|
||||
"/tmp/codex-remote-env/test_stdio_server-{}-{unique_suffix}",
|
||||
std::process::id()
|
||||
);
|
||||
let container_target = format!("{container_name}:{remote_path}");
|
||||
let copy_output = StdCommand::new("docker")
|
||||
.arg("cp")
|
||||
.arg(&bin)
|
||||
.arg(&container_target)
|
||||
.output()
|
||||
.with_context(|| format!("copy {bin} to remote MCP test env"))?;
|
||||
ensure!(
|
||||
copy_output.status.success(),
|
||||
"docker cp test_stdio_server failed: stdout={} stderr={}",
|
||||
String::from_utf8_lossy(©_output.stdout).trim(),
|
||||
String::from_utf8_lossy(©_output.stderr).trim()
|
||||
);
|
||||
|
||||
let chmod_output = StdCommand::new("docker")
|
||||
.args(["exec", &container_name, "chmod", "+x", remote_path.as_str()])
|
||||
.output()
|
||||
.context("mark remote test_stdio_server executable")?;
|
||||
ensure!(
|
||||
chmod_output.status.success(),
|
||||
"docker chmod test_stdio_server failed: stdout={} stderr={}",
|
||||
String::from_utf8_lossy(&chmod_output.stdout).trim(),
|
||||
String::from_utf8_lossy(&chmod_output.stderr).trim()
|
||||
);
|
||||
|
||||
Ok(remote_path)
|
||||
}
|
||||
|
||||
async fn wait_for_mcp_tool(fixture: &TestCodex, tool_name: &str) -> anyhow::Result<()> {
|
||||
let tools_ready_deadline = Instant::now() + Duration::from_secs(30);
|
||||
loop {
|
||||
@@ -115,6 +181,7 @@ async fn wait_for_mcp_tool(fixture: &TestCodex, tool_name: &str) -> anyhow::Resu
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestMcpServerOptions {
|
||||
experimental_environment: Option<String>,
|
||||
supports_parallel_tool_calls: bool,
|
||||
tool_timeout_sec: Option<Duration>,
|
||||
}
|
||||
@@ -144,7 +211,7 @@ fn insert_mcp_server(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport,
|
||||
experimental_environment: None,
|
||||
experimental_environment: options.experimental_environment,
|
||||
enabled: true,
|
||||
required: false,
|
||||
supports_parallel_tool_calls: options.supports_parallel_tool_calls,
|
||||
@@ -198,7 +265,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env";
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
@@ -213,10 +280,13 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
)])),
|
||||
Vec::new(),
|
||||
),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
@@ -342,17 +412,20 @@ async fn stdio_mcp_tool_call_includes_sandbox_state_meta() -> anyhow::Result<()>
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
insert_mcp_server(
|
||||
config,
|
||||
server_name,
|
||||
stdio_transport(rmcp_test_server_bin, /*env*/ None, Vec::new()),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
|
||||
let tools_ready_deadline = Instant::now() + Duration::from_secs(30);
|
||||
@@ -414,7 +487,7 @@ async fn stdio_mcp_tool_call_includes_sandbox_state_meta() -> anyhow::Result<()>
|
||||
);
|
||||
assert_eq!(
|
||||
sandbox_meta.get("sandboxCwd").and_then(Value::as_str),
|
||||
fixture.cwd.path().to_str()
|
||||
fixture.config.cwd.as_path().to_str()
|
||||
);
|
||||
assert_eq!(sandbox_meta.get("useLegacyLandlock"), Some(&json!(false)));
|
||||
|
||||
@@ -454,7 +527,7 @@ async fn stdio_mcp_parallel_tool_calls_default_false_runs_serially() -> anyhow::
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
@@ -463,12 +536,13 @@ async fn stdio_mcp_parallel_tool_calls_default_false_runs_serially() -> anyhow::
|
||||
server_name,
|
||||
stdio_transport(rmcp_test_server_bin, /*env*/ None, Vec::new()),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
tool_timeout_sec: Some(Duration::from_secs(2)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
@@ -585,7 +659,7 @@ async fn stdio_mcp_parallel_tool_calls_opt_in_runs_concurrently() -> anyhow::Res
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
@@ -594,12 +668,13 @@ async fn stdio_mcp_parallel_tool_calls_opt_in_runs_concurrently() -> anyhow::Res
|
||||
server_name,
|
||||
stdio_transport(rmcp_test_server_bin, /*env*/ None, Vec::new()),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
supports_parallel_tool_calls: true,
|
||||
tool_timeout_sec: Some(Duration::from_secs(2)),
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
@@ -675,7 +750,7 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> {
|
||||
.await;
|
||||
|
||||
// Build the stdio rmcp server and pass the image as data URL so it can construct ImageContent.
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
@@ -690,10 +765,13 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> {
|
||||
)])),
|
||||
Vec::new(),
|
||||
),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
@@ -828,7 +906,7 @@ async fn stdio_image_responses_preserve_original_detail_metadata() -> anyhow::Re
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_model("gpt-5.3-codex")
|
||||
@@ -837,10 +915,13 @@ async fn stdio_image_responses_preserve_original_detail_metadata() -> anyhow::Re
|
||||
config,
|
||||
server_name,
|
||||
stdio_transport(rmcp_test_server_bin, /*env*/ None, Vec::new()),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
@@ -1050,7 +1131,7 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
|
||||
@@ -1066,10 +1147,13 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re
|
||||
)])),
|
||||
Vec::new(),
|
||||
),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
|
||||
fixture
|
||||
@@ -1165,7 +1249,7 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
|
||||
|
||||
let expected_env_value = "propagated-env-from-whitelist";
|
||||
let _guard = EnvVarGuard::set("MCP_TEST_VALUE", OsStr::new(expected_env_value));
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let rmcp_test_server_bin = remote_aware_stdio_server_bin()?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
@@ -1177,10 +1261,13 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
|
||||
/*env*/ None,
|
||||
vec!["MCP_TEST_VALUE".to_string()],
|
||||
),
|
||||
TestMcpServerOptions::default(),
|
||||
TestMcpServerOptions {
|
||||
experimental_environment: remote_aware_experimental_environment(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ Request params:
|
||||
"PATH": "/usr/bin:/bin"
|
||||
},
|
||||
"tty": true,
|
||||
"pipeStdin": false,
|
||||
"arg0": null
|
||||
}
|
||||
```
|
||||
@@ -95,8 +96,8 @@ Field definitions:
|
||||
- `argv`: command vector. It must be non-empty.
|
||||
- `cwd`: absolute working directory used for the child process.
|
||||
- `env`: environment variables passed to the child process.
|
||||
- `tty`: when `true`, spawn a PTY-backed interactive process; when `false`,
|
||||
spawn a pipe-backed process with closed stdin.
|
||||
- `tty`: when `true`, spawn a PTY-backed interactive process.
|
||||
- `pipeStdin`: when `true`, keep non-PTY stdin writable via `process/write`.
|
||||
- `arg0`: optional argv0 override forwarded to `codex-utils-pty`.
|
||||
|
||||
Response:
|
||||
@@ -111,7 +112,7 @@ Behavior notes:
|
||||
|
||||
- Reusing an existing `processId` is rejected.
|
||||
- PTY-backed processes accept later writes through `process/write`.
|
||||
- Pipe-backed processes are launched with stdin closed and reject writes.
|
||||
- Non-PTY processes reject writes unless `pipeStdin` is `true`.
|
||||
- Output is streamed asynchronously via `process/output`.
|
||||
- Exit is reported asynchronously via `process/exited`.
|
||||
|
||||
@@ -153,7 +154,7 @@ Response:
|
||||
|
||||
### `process/write`
|
||||
|
||||
Writes raw bytes to a running PTY-backed process stdin.
|
||||
Writes raw bytes to a running process stdin.
|
||||
|
||||
Request params:
|
||||
|
||||
@@ -177,7 +178,7 @@ Response:
|
||||
Behavior notes:
|
||||
|
||||
- Writes to an unknown `processId` are rejected.
|
||||
- Writes to a non-PTY process are rejected because stdin is already closed.
|
||||
- Writes to a non-PTY process are rejected unless it started with `pipeStdin`.
|
||||
|
||||
### `process/terminate`
|
||||
|
||||
@@ -325,7 +326,7 @@ Initialize:
|
||||
Start a process:
|
||||
|
||||
```json
|
||||
{"id":2,"method":"process/start","params":{"processId":"proc-1","argv":["bash","-lc","printf 'ready\\n'; while IFS= read -r line; do printf 'echo:%s\\n' \"$line\"; done"],"cwd":"/tmp","env":{"PATH":"/usr/bin:/bin"},"tty":true,"arg0":null}}
|
||||
{"id":2,"method":"process/start","params":{"processId":"proc-1","argv":["bash","-lc","printf 'ready\\n'; while IFS= read -r line; do printf 'echo:%s\\n' \"$line\"; done"],"cwd":"/tmp","env":{"PATH":"/usr/bin:/bin"},"tty":true,"pipeStdin":false,"arg0":null}}
|
||||
{"id":2,"result":{"processId":"proc-1"}}
|
||||
{"method":"process/output","params":{"processId":"proc-1","seq":1,"stream":"stdout","chunk":"cmVhZHkK"}}
|
||||
```
|
||||
|
||||
@@ -16,6 +16,9 @@ use crate::ProcessId;
|
||||
use crate::client_api::ExecServerClientConnectOptions;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
use crate::process::ExecProcessEventReceiver;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
use crate::protocol::EXEC_EXITED_METHOD;
|
||||
use crate::protocol::EXEC_METHOD;
|
||||
@@ -53,6 +56,7 @@ use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ProcessOutputChunk;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::TerminateParams;
|
||||
@@ -65,6 +69,7 @@ use crate::rpc::RpcClientEvent;
|
||||
|
||||
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const PROCESS_EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
impl Default for ExecServerClientConnectOptions {
|
||||
fn default() -> Self {
|
||||
@@ -100,6 +105,7 @@ impl RemoteExecServerConnectArgs {
|
||||
|
||||
pub(crate) struct SessionState {
|
||||
wake_tx: watch::Sender<u64>,
|
||||
events: ExecProcessEventLog,
|
||||
failure: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
@@ -121,6 +127,11 @@ struct Inner {
|
||||
// need serialization so concurrent register/remove operations do not
|
||||
// overwrite each other's copy-on-write updates.
|
||||
sessions_write_lock: Mutex<()>,
|
||||
// Once the transport closes, every executor operation should fail quickly
|
||||
// with the same message. This process/filesystem-level latch prevents
|
||||
// callers from waiting on request-specific timeouts after the environment
|
||||
// is gone.
|
||||
disconnected: std::sync::RwLock<Option<String>>,
|
||||
session_id: std::sync::RwLock<Option<String>>,
|
||||
reader_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
@@ -152,6 +163,8 @@ pub enum ExecServerError {
|
||||
InitializeTimedOut { timeout: Duration },
|
||||
#[error("exec-server transport closed")]
|
||||
Closed,
|
||||
#[error("{0}")]
|
||||
Disconnected(String),
|
||||
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("exec-server protocol error: {0}")]
|
||||
@@ -227,19 +240,11 @@ impl ExecServerClient {
|
||||
}
|
||||
|
||||
pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(EXEC_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(EXEC_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(EXEC_READ_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(EXEC_READ_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn write(
|
||||
@@ -247,107 +252,73 @@ impl ExecServerClient {
|
||||
process_id: &ProcessId,
|
||||
chunk: Vec<u8>,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(
|
||||
EXEC_WRITE_METHOD,
|
||||
&WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(
|
||||
EXEC_WRITE_METHOD,
|
||||
&WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn terminate(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<TerminateResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams {
|
||||
process_id: process_id.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams {
|
||||
process_id: process_id.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn fs_read_file(
|
||||
&self,
|
||||
params: FsReadFileParams,
|
||||
) -> Result<FsReadFileResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_READ_FILE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_READ_FILE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_write_file(
|
||||
&self,
|
||||
params: FsWriteFileParams,
|
||||
) -> Result<FsWriteFileResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_WRITE_FILE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_WRITE_FILE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_create_directory(
|
||||
&self,
|
||||
params: FsCreateDirectoryParams,
|
||||
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_CREATE_DIRECTORY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_CREATE_DIRECTORY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_get_metadata(
|
||||
&self,
|
||||
params: FsGetMetadataParams,
|
||||
) -> Result<FsGetMetadataResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_GET_METADATA_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_GET_METADATA_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_read_directory(
|
||||
&self,
|
||||
params: FsReadDirectoryParams,
|
||||
) -> Result<FsReadDirectoryResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_READ_DIRECTORY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_READ_DIRECTORY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_remove(
|
||||
&self,
|
||||
params: FsRemoveParams,
|
||||
) -> Result<FsRemoveResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_REMOVE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_REMOVE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_COPY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_COPY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub(crate) async fn register_session(
|
||||
@@ -392,7 +363,7 @@ impl ExecServerClient {
|
||||
&& let Err(err) =
|
||||
handle_server_notification(&inner, notification).await
|
||||
{
|
||||
fail_all_sessions(
|
||||
mark_disconnected(
|
||||
&inner,
|
||||
format!("exec-server notification handling failed: {err}"),
|
||||
)
|
||||
@@ -402,7 +373,7 @@ impl ExecServerClient {
|
||||
}
|
||||
RpcClientEvent::Disconnected { reason } => {
|
||||
if let Some(inner) = weak.upgrade() {
|
||||
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
|
||||
mark_disconnected(&inner, disconnected_message(reason.as_deref()))
|
||||
.await;
|
||||
}
|
||||
return;
|
||||
@@ -415,6 +386,7 @@ impl ExecServerClient {
|
||||
client: rpc_client,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
disconnected: std::sync::RwLock::new(None),
|
||||
session_id: std::sync::RwLock::new(None),
|
||||
reader_task,
|
||||
}
|
||||
@@ -432,6 +404,37 @@ impl ExecServerClient {
|
||||
.await
|
||||
.map_err(ExecServerError::Json)
|
||||
}
|
||||
|
||||
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
|
||||
where
|
||||
P: serde::Serialize,
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
// Reject new work before allocating a JSON-RPC request id. MCP tool
|
||||
// calls, process writes, and fs operations all pass through here, so
|
||||
// this is the shared low-level failure path after executor disconnect.
|
||||
if let Some(error) = self.inner.disconnected_error() {
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
match self.inner.client.call(method, params).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(error) => {
|
||||
let error = ExecServerError::from(error);
|
||||
if is_transport_closed_error(&error) {
|
||||
// A call can race with disconnect after the preflight
|
||||
// check. Latch the disconnect once and fail every
|
||||
// registered process session before returning this call
|
||||
// error.
|
||||
let message = disconnected_message(/*reason*/ None);
|
||||
let message = mark_disconnected(&self.inner, message).await;
|
||||
Err(ExecServerError::Disconnected(message))
|
||||
} else {
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpcCallError> for ExecServerError {
|
||||
@@ -452,6 +455,7 @@ impl SessionState {
|
||||
let (wake_tx, _wake_rx) = watch::channel(0);
|
||||
Self {
|
||||
wake_tx,
|
||||
events: ExecProcessEventLog::new(PROCESS_EVENT_CHANNEL_CAPACITY),
|
||||
failure: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
@@ -460,19 +464,31 @@ impl SessionState {
|
||||
self.wake_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
self.events.subscribe()
|
||||
}
|
||||
|
||||
fn note_change(&self, seq: u64) {
|
||||
let next = (*self.wake_tx.borrow()).max(seq);
|
||||
let _ = self.wake_tx.send(next);
|
||||
}
|
||||
|
||||
fn publish_event(&self, event: ExecProcessEvent) {
|
||||
self.events.publish(event);
|
||||
}
|
||||
|
||||
async fn set_failure(&self, message: String) {
|
||||
let mut failure = self.failure.lock().await;
|
||||
if failure.is_none() {
|
||||
*failure = Some(message);
|
||||
let should_publish = failure.is_none();
|
||||
if should_publish {
|
||||
*failure = Some(message.clone());
|
||||
}
|
||||
drop(failure);
|
||||
let next = (*self.wake_tx.borrow()).saturating_add(1);
|
||||
let _ = self.wake_tx.send(next);
|
||||
if should_publish {
|
||||
self.publish_event(ExecProcessEvent::Failed(message));
|
||||
}
|
||||
}
|
||||
|
||||
async fn failed_response(&self) -> Option<ReadResponse> {
|
||||
@@ -505,6 +521,10 @@ impl Session {
|
||||
self.state.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
self.state.subscribe_events()
|
||||
}
|
||||
|
||||
pub(crate) async fn read(
|
||||
&self,
|
||||
after_seq: Option<u64>,
|
||||
@@ -550,6 +570,26 @@ impl Session {
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn disconnected_error(&self) -> Option<ExecServerError> {
|
||||
self.disconnected
|
||||
.read()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
.map(ExecServerError::Disconnected)
|
||||
}
|
||||
|
||||
fn set_disconnected(&self, message: String) -> Option<String> {
|
||||
let mut disconnected = self
|
||||
.disconnected
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if disconnected.is_some() {
|
||||
return None;
|
||||
}
|
||||
*disconnected = Some(message.clone());
|
||||
Some(message)
|
||||
}
|
||||
|
||||
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
|
||||
self.sessions.load().get(process_id).cloned()
|
||||
}
|
||||
@@ -560,6 +600,12 @@ impl Inner {
|
||||
session: Arc<SessionState>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _sessions_write_guard = self.sessions_write_lock.lock().await;
|
||||
// Do not register a process session that can never receive executor
|
||||
// notifications. Without this check, remote MCP startup could create a
|
||||
// dead session and wait for process output that will never arrive.
|
||||
if let Some(error) = self.disconnected_error() {
|
||||
return Err(error);
|
||||
}
|
||||
let sessions = self.sessions.load();
|
||||
if sessions.contains_key(process_id) {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
@@ -600,20 +646,42 @@ fn disconnected_message(reason: Option<&str>) -> String {
|
||||
}
|
||||
|
||||
fn is_transport_closed_error(error: &ExecServerError) -> bool {
|
||||
matches!(error, ExecServerError::Closed)
|
||||
|| matches!(
|
||||
error,
|
||||
ExecServerError::Server {
|
||||
code: -32000,
|
||||
message,
|
||||
} if message == "JSON-RPC transport closed"
|
||||
)
|
||||
matches!(
|
||||
error,
|
||||
ExecServerError::Closed | ExecServerError::Disconnected(_)
|
||||
) || matches!(
|
||||
error,
|
||||
ExecServerError::Server {
|
||||
code: -32000,
|
||||
message,
|
||||
} if message == "JSON-RPC transport closed"
|
||||
)
|
||||
}
|
||||
|
||||
async fn mark_disconnected(inner: &Arc<Inner>, message: String) -> String {
|
||||
// The first observer records the canonical disconnect reason and wakes all
|
||||
// sessions. Later observers reuse that message so concurrent tool calls
|
||||
// report one consistent environment failure.
|
||||
if let Some(message) = inner.set_disconnected(message.clone()) {
|
||||
fail_all_sessions(inner, message.clone()).await;
|
||||
message
|
||||
} else {
|
||||
inner
|
||||
.disconnected
|
||||
.read()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
.unwrap_or(message)
|
||||
}
|
||||
}
|
||||
|
||||
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
|
||||
let sessions = inner.take_all_sessions().await;
|
||||
|
||||
for (_, session) in sessions {
|
||||
// Sessions synthesize a closed read response and emit a pushed Failed
|
||||
// event. That covers both polling consumers and streaming consumers
|
||||
// such as executor-backed MCP stdio.
|
||||
session.set_failure(message.clone()).await;
|
||||
}
|
||||
}
|
||||
@@ -628,6 +696,11 @@ async fn handle_server_notification(
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
if let Some(session) = inner.get_session(¶ms.process_id) {
|
||||
session.note_change(params.seq);
|
||||
session.publish_event(ExecProcessEvent::Output(ProcessOutputChunk {
|
||||
seq: params.seq,
|
||||
stream: params.stream,
|
||||
chunk: params.chunk,
|
||||
}));
|
||||
}
|
||||
}
|
||||
EXEC_EXITED_METHOD => {
|
||||
@@ -635,6 +708,10 @@ async fn handle_server_notification(
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
if let Some(session) = inner.get_session(¶ms.process_id) {
|
||||
session.note_change(params.seq);
|
||||
session.publish_event(ExecProcessEvent::Exited {
|
||||
seq: params.seq,
|
||||
exit_code: params.exit_code,
|
||||
});
|
||||
}
|
||||
}
|
||||
EXEC_CLOSED_METHOD => {
|
||||
@@ -645,6 +722,7 @@ async fn handle_server_notification(
|
||||
let session = inner.remove_session(¶ms.process_id).await;
|
||||
if let Some(session) = session {
|
||||
session.note_change(params.seq);
|
||||
session.publish_event(ExecProcessEvent::Closed { seq: params.seq });
|
||||
}
|
||||
}
|
||||
other => {
|
||||
|
||||
@@ -346,6 +346,7 @@ mod tests {
|
||||
env_policy: None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -39,6 +39,8 @@ pub use local_file_system::LOCAL_FS;
|
||||
pub use local_file_system::LocalFileSystem;
|
||||
pub use process::ExecBackend;
|
||||
pub use process::ExecProcess;
|
||||
pub use process::ExecProcessEvent;
|
||||
pub use process::ExecProcessEventReceiver;
|
||||
pub use process::StartedExecProcess;
|
||||
pub use process_id::ProcessId;
|
||||
pub use protocol::ExecClosedNotification;
|
||||
@@ -65,6 +67,7 @@ pub use protocol::FsWriteFileParams;
|
||||
pub use protocol::FsWriteFileResponse;
|
||||
pub use protocol::InitializeParams;
|
||||
pub use protocol::InitializeResponse;
|
||||
pub use protocol::ProcessOutputChunk;
|
||||
pub use protocol::ReadParams;
|
||||
pub use protocol::ReadResponse;
|
||||
pub use protocol::TerminateParams;
|
||||
|
||||
@@ -17,9 +17,12 @@ use tokio::sync::watch;
|
||||
|
||||
use crate::ExecBackend;
|
||||
use crate::ExecProcess;
|
||||
use crate::ExecProcessEvent;
|
||||
use crate::ExecProcessEventReceiver;
|
||||
use crate::ExecServerError;
|
||||
use crate::ProcessId;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
use crate::protocol::ExecClosedNotification;
|
||||
use crate::protocol::ExecEnvPolicy;
|
||||
@@ -44,6 +47,7 @@ use crate::rpc::invalid_request;
|
||||
|
||||
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
|
||||
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
|
||||
const PROCESS_EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
#[cfg(test)]
|
||||
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
|
||||
#[cfg(not(test))]
|
||||
@@ -59,11 +63,13 @@ struct RetainedOutputChunk {
|
||||
struct RunningProcess {
|
||||
session: ExecCommandSession,
|
||||
tty: bool,
|
||||
pipe_stdin: bool,
|
||||
output: VecDeque<RetainedOutputChunk>,
|
||||
retained_bytes: usize,
|
||||
next_seq: u64,
|
||||
exit_code: Option<i32>,
|
||||
wake_tx: watch::Sender<u64>,
|
||||
events: ExecProcessEventLog,
|
||||
output_notify: Arc<Notify>,
|
||||
open_streams: usize,
|
||||
closed: bool,
|
||||
@@ -88,6 +94,7 @@ struct LocalExecProcess {
|
||||
process_id: ProcessId,
|
||||
backend: LocalProcess,
|
||||
wake_tx: watch::Sender<u64>,
|
||||
events: ExecProcessEventLog,
|
||||
}
|
||||
|
||||
impl Default for LocalProcess {
|
||||
@@ -137,7 +144,7 @@ impl LocalProcess {
|
||||
async fn start_process(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
) -> Result<(ExecResponse, watch::Sender<u64>), JSONRPCErrorError> {
|
||||
) -> Result<(ExecResponse, watch::Sender<u64>, ExecProcessEventLog), JSONRPCErrorError> {
|
||||
let process_id = params.process_id.clone();
|
||||
let (program, args) = params
|
||||
.argv
|
||||
@@ -165,6 +172,15 @@ impl LocalProcess {
|
||||
TerminalSize::default(),
|
||||
)
|
||||
.await
|
||||
} else if params.pipe_stdin {
|
||||
codex_utils_pty::spawn_pipe_process(
|
||||
program,
|
||||
args,
|
||||
params.cwd.as_path(),
|
||||
&env,
|
||||
¶ms.arg0,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
codex_utils_pty::spawn_pipe_process_no_stdin(
|
||||
program,
|
||||
@@ -188,6 +204,7 @@ impl LocalProcess {
|
||||
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let (wake_tx, _wake_rx) = watch::channel(0);
|
||||
let events = ExecProcessEventLog::new(PROCESS_EVENT_CHANNEL_CAPACITY);
|
||||
{
|
||||
let mut process_map = self.inner.processes.lock().await;
|
||||
process_map.insert(
|
||||
@@ -195,11 +212,13 @@ impl LocalProcess {
|
||||
ProcessEntry::Running(Box::new(RunningProcess {
|
||||
session: spawned.session,
|
||||
tty: params.tty,
|
||||
pipe_stdin: params.pipe_stdin,
|
||||
output: VecDeque::new(),
|
||||
retained_bytes: 0,
|
||||
next_seq: 1,
|
||||
exit_code: None,
|
||||
wake_tx: wake_tx.clone(),
|
||||
events: events.clone(),
|
||||
output_notify: Arc::clone(&output_notify),
|
||||
open_streams: 2,
|
||||
closed: false,
|
||||
@@ -236,13 +255,13 @@ impl LocalProcess {
|
||||
output_notify,
|
||||
));
|
||||
|
||||
Ok((ExecResponse { process_id }, wake_tx))
|
||||
Ok((ExecResponse { process_id }, wake_tx, events))
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
self.start_process(params)
|
||||
.await
|
||||
.map(|(response, _)| response)
|
||||
.map(|(response, _, _)| response)
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_read(
|
||||
@@ -339,7 +358,7 @@ impl LocalProcess {
|
||||
status: WriteStatus::Starting,
|
||||
});
|
||||
};
|
||||
if !process.tty {
|
||||
if !process.tty && !process.pipe_stdin {
|
||||
return Ok(WriteResponse {
|
||||
status: WriteStatus::StdinClosed,
|
||||
});
|
||||
@@ -413,7 +432,7 @@ fn shell_environment_policy(env_policy: &ExecEnvPolicy) -> ShellEnvironmentPolic
|
||||
#[async_trait]
|
||||
impl ExecBackend for LocalProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
|
||||
let (response, wake_tx) = self
|
||||
let (response, wake_tx, events) = self
|
||||
.start_process(params)
|
||||
.await
|
||||
.map_err(map_handler_error)?;
|
||||
@@ -422,6 +441,7 @@ impl ExecBackend for LocalProcess {
|
||||
process_id: response.process_id,
|
||||
backend: self.clone(),
|
||||
wake_tx,
|
||||
events,
|
||||
}),
|
||||
})
|
||||
}
|
||||
@@ -437,6 +457,10 @@ impl ExecProcess for LocalExecProcess {
|
||||
self.wake_tx.subscribe()
|
||||
}
|
||||
|
||||
fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
self.events.subscribe()
|
||||
}
|
||||
|
||||
async fn read(
|
||||
&self,
|
||||
after_seq: Option<u64>,
|
||||
@@ -537,11 +561,19 @@ async fn stream_output(
|
||||
process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len());
|
||||
}
|
||||
let _ = process.wake_tx.send(seq);
|
||||
let output = ProcessOutputChunk {
|
||||
seq,
|
||||
stream,
|
||||
chunk: chunk.into(),
|
||||
};
|
||||
process
|
||||
.events
|
||||
.publish(ExecProcessEvent::Output(output.clone()));
|
||||
ExecOutputDeltaNotification {
|
||||
process_id: process_id.clone(),
|
||||
seq,
|
||||
stream,
|
||||
chunk: chunk.into(),
|
||||
chunk: output.chunk,
|
||||
}
|
||||
};
|
||||
output_notify.notify_waiters();
|
||||
@@ -569,6 +601,9 @@ async fn watch_exit(
|
||||
process.next_seq += 1;
|
||||
process.exit_code = Some(exit_code);
|
||||
let _ = process.wake_tx.send(seq);
|
||||
process
|
||||
.events
|
||||
.publish(ExecProcessEvent::Exited { seq, exit_code });
|
||||
Some(ExecExitedNotification {
|
||||
process_id: process_id.clone(),
|
||||
seq,
|
||||
@@ -629,6 +664,7 @@ async fn maybe_emit_closed(process_id: ProcessId, inner: Arc<Inner>) {
|
||||
let seq = process.next_seq;
|
||||
process.next_seq += 1;
|
||||
let _ = process.wake_tx.send(seq);
|
||||
process.events.publish(ExecProcessEvent::Closed { seq });
|
||||
Some(ExecClosedNotification {
|
||||
process_id: process_id.clone(),
|
||||
seq,
|
||||
@@ -667,6 +703,7 @@ mod tests {
|
||||
env_policy: None,
|
||||
env,
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use crate::ExecServerError;
|
||||
use crate::ProcessId;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ProcessOutputChunk;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::WriteResponse;
|
||||
|
||||
@@ -13,12 +17,101 @@ pub struct StartedExecProcess {
|
||||
pub process: Arc<dyn ExecProcess>,
|
||||
}
|
||||
|
||||
/// Pushed process events for consumers that want to follow process output as it
|
||||
/// arrives instead of polling retained output with [`ExecProcess::read`].
|
||||
///
|
||||
/// The stream is scoped to one [`ExecProcess`] handle. `Output` events carry
|
||||
/// stdout, stderr, or pty bytes. `Exited` reports the process exit status, while
|
||||
/// `Closed` means all output streams have ended and no more output events will
|
||||
/// arrive. `Failed` is used when the process session cannot continue, for
|
||||
/// example because the remote executor connection disconnected.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecProcessEvent {
|
||||
Output(ProcessOutputChunk),
|
||||
Exited { seq: u64, exit_code: i32 },
|
||||
Closed { seq: u64 },
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ExecProcessEventLog {
|
||||
inner: Arc<ExecProcessEventLogInner>,
|
||||
}
|
||||
|
||||
struct ExecProcessEventLogInner {
|
||||
history: StdMutex<VecDeque<ExecProcessEvent>>,
|
||||
live_tx: broadcast::Sender<ExecProcessEvent>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ExecProcessEventLog {
|
||||
pub(crate) fn new(capacity: usize) -> Self {
|
||||
let (live_tx, _live_rx) = broadcast::channel(capacity);
|
||||
Self {
|
||||
inner: Arc::new(ExecProcessEventLogInner {
|
||||
history: StdMutex::new(VecDeque::new()),
|
||||
live_tx,
|
||||
capacity,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn publish(&self, event: ExecProcessEvent) {
|
||||
let mut history = self
|
||||
.inner
|
||||
.history
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
history.push_back(event.clone());
|
||||
while history.len() > self.inner.capacity {
|
||||
history.pop_front();
|
||||
}
|
||||
|
||||
let _ = self.inner.live_tx.send(event);
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe(&self) -> ExecProcessEventReceiver {
|
||||
let history = self
|
||||
.inner
|
||||
.history
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let live_rx = self.inner.live_tx.subscribe();
|
||||
let replay = history.iter().cloned().collect();
|
||||
|
||||
ExecProcessEventReceiver { replay, live_rx }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExecProcessEventReceiver {
|
||||
replay: VecDeque<ExecProcessEvent>,
|
||||
live_rx: broadcast::Receiver<ExecProcessEvent>,
|
||||
}
|
||||
|
||||
impl ExecProcessEventReceiver {
|
||||
pub async fn recv(&mut self) -> Result<ExecProcessEvent, broadcast::error::RecvError> {
|
||||
if let Some(event) = self.replay.pop_front() {
|
||||
return Ok(event);
|
||||
}
|
||||
|
||||
self.live_rx.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle for an executor-managed process.
|
||||
///
|
||||
/// Implementations must support both retained-output reads and pushed events:
|
||||
/// `read` is the request/response API for callers that want to page through
|
||||
/// buffered output, while `subscribe_events` is the streaming API for callers
|
||||
/// that want output and lifecycle changes delivered as they happen.
|
||||
#[async_trait]
|
||||
pub trait ExecProcess: Send + Sync {
|
||||
fn process_id(&self) -> &ProcessId;
|
||||
|
||||
fn subscribe_wake(&self) -> watch::Receiver<u64>;
|
||||
|
||||
fn subscribe_events(&self) -> ExecProcessEventReceiver;
|
||||
|
||||
async fn read(
|
||||
&self,
|
||||
after_seq: Option<u64>,
|
||||
|
||||
@@ -69,6 +69,9 @@ pub struct ExecParams {
|
||||
pub env_policy: Option<ExecEnvPolicy>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub tty: bool,
|
||||
/// Keep non-tty stdin writable through `process/write`.
|
||||
#[serde(default)]
|
||||
pub pipe_stdin: bool,
|
||||
pub arg0: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use tracing::trace;
|
||||
|
||||
use crate::ExecBackend;
|
||||
use crate::ExecProcess;
|
||||
use crate::ExecProcessEventReceiver;
|
||||
use crate::ExecServerError;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::client::ExecServerClient;
|
||||
@@ -56,6 +57,10 @@ impl ExecProcess for RemoteExecProcess {
|
||||
self.session.subscribe_wake()
|
||||
}
|
||||
|
||||
fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
self.session.subscribe_events()
|
||||
}
|
||||
|
||||
async fn read(
|
||||
&self,
|
||||
after_seq: Option<u64>,
|
||||
|
||||
@@ -18,12 +18,23 @@ use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
|
||||
type PendingRequest = oneshot::Sender<Result<Value, JSONRPCErrorError>>;
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
/// The underlying JSON-RPC transport closed before this call completed.
|
||||
Closed,
|
||||
/// The response bytes were valid JSON-RPC but not the expected result type.
|
||||
Json(serde_json::Error),
|
||||
/// The executor returned a JSON-RPC error response for this call.
|
||||
Server(JSONRPCErrorError),
|
||||
}
|
||||
|
||||
type PendingRequest = oneshot::Sender<Result<Value, RpcCallError>>;
|
||||
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
|
||||
type RequestRoute<S> =
|
||||
Box<dyn Fn(Arc<S>, JSONRPCRequest) -> BoxFuture<RpcServerOutboundMessage> + Send + Sync>;
|
||||
@@ -172,6 +183,10 @@ where
|
||||
pub(crate) struct RpcClient {
|
||||
write_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
|
||||
// Shared transport state from `JsonRpcConnection`. Calls use this to fail
|
||||
// immediately when the socket closes, even if no JSON-RPC error response
|
||||
// can be delivered for their request id.
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
reader_task: JoinHandle<()>,
|
||||
@@ -179,8 +194,7 @@ pub(crate) struct RpcClient {
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) =
|
||||
connection.into_parts();
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
@@ -218,6 +232,7 @@ impl RpcClient {
|
||||
Self {
|
||||
write_tx,
|
||||
pending,
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
reader_task,
|
||||
@@ -251,6 +266,12 @@ impl RpcClient {
|
||||
P: Serialize,
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
// Avoid creating a pending request after the connection is already
|
||||
// known closed.
|
||||
if *self.disconnected_rx.borrow() {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
|
||||
let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst));
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.pending
|
||||
@@ -280,10 +301,20 @@ impl RpcClient {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
|
||||
let result = response_rx.await.map_err(|_| RpcCallError::Closed)?;
|
||||
let mut disconnected_rx = self.disconnected_rx.clone();
|
||||
let result: Result<Value, RpcCallError> = tokio::select! {
|
||||
result = response_rx => result.map_err(|_| RpcCallError::Closed)?,
|
||||
_ = disconnected_rx.changed() => {
|
||||
// The connection closed while the request was in flight. Remove
|
||||
// the pending sender here so `drain_pending` does not need to
|
||||
// deliver a second terminal result for the same request.
|
||||
self.pending.lock().await.remove(&request_id);
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
};
|
||||
let response = match result {
|
||||
Ok(response) => response,
|
||||
Err(error) => return Err(RpcCallError::Server(error)),
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
serde_json::from_value(response).map_err(RpcCallError::Json)
|
||||
}
|
||||
@@ -304,13 +335,6 @@ impl Drop for RpcClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
Closed,
|
||||
Json(serde_json::Error),
|
||||
Server(JSONRPCErrorError),
|
||||
}
|
||||
|
||||
pub(crate) fn encode_server_message(
|
||||
message: RpcServerOutboundMessage,
|
||||
) -> Result<JSONRPCMessage, serde_json::Error> {
|
||||
@@ -417,7 +441,7 @@ async fn handle_server_message(
|
||||
}
|
||||
JSONRPCMessage::Error(JSONRPCError { id, error }) => {
|
||||
if let Some(pending) = pending.lock().await.remove(&id) {
|
||||
let _ = pending.send(Err(error));
|
||||
let _ = pending.send(Err(RpcCallError::Server(error)));
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
@@ -445,11 +469,7 @@ async fn drain_pending(pending: &Mutex<HashMap<RequestId, PendingRequest>>) {
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
for pending in pending {
|
||||
let _ = pending.send(Err(JSONRPCErrorError {
|
||||
code: -32000,
|
||||
data: None,
|
||||
message: "JSON-RPC transport closed".to_string(),
|
||||
}));
|
||||
let _ = pending.send(Err(RpcCallError::Closed));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ fn exec_params_with_argv(process_id: &str, argv: Vec<String>) -> ExecParams {
|
||||
env_policy: None,
|
||||
env: inherited_path_env(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,6 +393,7 @@ mod tests {
|
||||
env_policy: None,
|
||||
env,
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,14 +4,18 @@ mod common;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_exec_server::ExecBackend;
|
||||
use codex_exec_server::ExecOutputStream;
|
||||
use codex_exec_server::ExecParams;
|
||||
use codex_exec_server::ExecProcess;
|
||||
use codex_exec_server::ExecProcessEvent;
|
||||
use codex_exec_server::ProcessId;
|
||||
use codex_exec_server::ReadResponse;
|
||||
use codex_exec_server::StartedExecProcess;
|
||||
use codex_exec_server::WriteStatus;
|
||||
use pretty_assertions::assert_eq;
|
||||
use test_case::test_case;
|
||||
use tokio::sync::watch;
|
||||
@@ -26,6 +30,22 @@ struct ProcessContext {
|
||||
server: Option<ExecServerHarness>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum ProcessEventSnapshot {
|
||||
Output {
|
||||
seq: u64,
|
||||
stream: ExecOutputStream,
|
||||
text: String,
|
||||
},
|
||||
Exited {
|
||||
seq: u64,
|
||||
exit_code: i32,
|
||||
},
|
||||
Closed {
|
||||
seq: u64,
|
||||
},
|
||||
}
|
||||
|
||||
async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
|
||||
if use_remote {
|
||||
let server = exec_server().await?;
|
||||
@@ -54,6 +74,7 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
@@ -115,6 +136,69 @@ async fn collect_process_output_from_reads(
|
||||
Ok((output, exit_code, true))
|
||||
}
|
||||
|
||||
async fn collect_process_output_from_events(
|
||||
session: Arc<dyn ExecProcess>,
|
||||
) -> Result<(String, String, Option<i32>, bool)> {
|
||||
let mut events = session.subscribe_events();
|
||||
let mut stdout = String::new();
|
||||
let mut stderr = String::new();
|
||||
let mut exit_code = None;
|
||||
loop {
|
||||
match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecProcessEvent::Output(chunk) => match chunk.stream {
|
||||
ExecOutputStream::Stdout | ExecOutputStream::Pty => {
|
||||
stdout.push_str(&String::from_utf8_lossy(&chunk.chunk.into_inner()));
|
||||
}
|
||||
ExecOutputStream::Stderr => {
|
||||
stderr.push_str(&String::from_utf8_lossy(&chunk.chunk.into_inner()));
|
||||
}
|
||||
},
|
||||
ExecProcessEvent::Exited {
|
||||
seq: _,
|
||||
exit_code: code,
|
||||
} => {
|
||||
exit_code = Some(code);
|
||||
}
|
||||
ExecProcessEvent::Closed { seq: _ } => {
|
||||
drop(session);
|
||||
return Ok((stdout, stderr, exit_code, true));
|
||||
}
|
||||
ExecProcessEvent::Failed(message) => {
|
||||
anyhow::bail!("process failed before closed state: {message}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_process_event_snapshots(
|
||||
session: Arc<dyn ExecProcess>,
|
||||
) -> Result<Vec<ProcessEventSnapshot>> {
|
||||
let mut events = session.subscribe_events();
|
||||
let mut snapshots = Vec::new();
|
||||
loop {
|
||||
let snapshot = match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecProcessEvent::Output(chunk) => ProcessEventSnapshot::Output {
|
||||
seq: chunk.seq,
|
||||
stream: chunk.stream,
|
||||
text: String::from_utf8_lossy(&chunk.chunk.into_inner()).into_owned(),
|
||||
},
|
||||
ExecProcessEvent::Exited { seq, exit_code } => {
|
||||
ProcessEventSnapshot::Exited { seq, exit_code }
|
||||
}
|
||||
ExecProcessEvent::Closed { seq } => ProcessEventSnapshot::Closed { seq },
|
||||
ExecProcessEvent::Failed(message) => {
|
||||
anyhow::bail!("process failed before closed state: {message}");
|
||||
}
|
||||
};
|
||||
let closed = matches!(snapshot, ProcessEventSnapshot::Closed { .. });
|
||||
snapshots.push(snapshot);
|
||||
if closed {
|
||||
drop(session);
|
||||
return Ok(snapshots);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stream".to_string();
|
||||
@@ -131,6 +215,7 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
@@ -145,6 +230,96 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_pushes_events(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-events".to_string();
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone().into(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"printf 'event output\\n'; sleep 0.1; printf 'event err\\n' >&2; sleep 0.1; exit 7".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir()?,
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
let StartedExecProcess { process } = session;
|
||||
let actual = collect_process_event_snapshots(process).await?;
|
||||
assert_eq!(
|
||||
actual,
|
||||
vec![
|
||||
ProcessEventSnapshot::Output {
|
||||
seq: 1,
|
||||
stream: ExecOutputStream::Stdout,
|
||||
text: "event output\n".to_string(),
|
||||
},
|
||||
ProcessEventSnapshot::Output {
|
||||
seq: 2,
|
||||
stream: ExecOutputStream::Stderr,
|
||||
text: "event err\n".to_string(),
|
||||
},
|
||||
ProcessEventSnapshot::Exited {
|
||||
seq: 3,
|
||||
exit_code: 7,
|
||||
},
|
||||
ProcessEventSnapshot::Closed { seq: 4 },
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_replays_events_after_close(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-events-late".to_string();
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone().into(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"printf 'late one\\n'; printf 'late two\\n'".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir()?,
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
let StartedExecProcess { process } = session;
|
||||
let wake_rx = process.subscribe_wake();
|
||||
let read_result = collect_process_output_from_reads(Arc::clone(&process), wake_rx).await?;
|
||||
assert_eq!(
|
||||
read_result,
|
||||
("late one\nlate two\n".to_string(), Some(0), true)
|
||||
);
|
||||
|
||||
let event_result = collect_process_output_from_events(process).await?;
|
||||
assert_eq!(
|
||||
event_result,
|
||||
(
|
||||
"late one\nlate two\n".to_string(),
|
||||
String::new(),
|
||||
Some(0),
|
||||
true
|
||||
)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stdin".to_string();
|
||||
@@ -164,6 +339,7 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: true,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
@@ -184,6 +360,73 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_write_then_read_without_tty(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stdin-pipe".to_string();
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone().into(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"IFS= read line; printf 'from-stdin:%s\\n' \"$line\"".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir()?,
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: true,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
let write_response = session.process.write(b"hello\n".to_vec()).await?;
|
||||
assert_eq!(write_response.status, WriteStatus::Accepted);
|
||||
let StartedExecProcess { process } = session;
|
||||
let wake_rx = process.subscribe_wake();
|
||||
let actual = collect_process_output_from_reads(process, wake_rx).await?;
|
||||
|
||||
assert_eq!(actual, ("from-stdin:hello\n".to_string(), Some(0), true));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_rejects_write_without_pipe_stdin(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stdin-closed".to_string();
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone().into(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"sleep 0.3; if IFS= read -r line; then printf 'read:%s\\n' \"$line\"; else printf 'eof\\n'; fi".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir()?,
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
let write_response = session.process.write(b"ignored\n".to_vec()).await?;
|
||||
assert_eq!(write_response.status, WriteStatus::StdinClosed);
|
||||
let StartedExecProcess { process } = session;
|
||||
let wake_rx = process.subscribe_wake();
|
||||
let (output, exit_code, closed) = collect_process_output_from_reads(process, wake_rx).await?;
|
||||
|
||||
assert_eq!(output, "eof\n");
|
||||
assert_eq!(exit_code, Some(0));
|
||||
assert!(closed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
use_remote: bool,
|
||||
) -> Result<()> {
|
||||
@@ -201,6 +444,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
@@ -234,19 +478,49 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
env_policy: /*env_policy*/ None,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let process = Arc::clone(&session.process);
|
||||
let mut events = process.subscribe_events();
|
||||
let process_for_pending_read = Arc::clone(&process);
|
||||
let pending_read = tokio::spawn(async move {
|
||||
process_for_pending_read
|
||||
.read(
|
||||
/*after_seq*/ None,
|
||||
/*max_bytes*/ None,
|
||||
/*wait_ms*/ Some(60_000),
|
||||
)
|
||||
.await
|
||||
});
|
||||
let server = context
|
||||
.server
|
||||
.as_mut()
|
||||
.expect("remote context should include exec-server harness");
|
||||
server.shutdown().await?;
|
||||
|
||||
let mut wake_rx = session.process.subscribe_wake();
|
||||
let response =
|
||||
read_process_until_change(session.process, &mut wake_rx, /*after_seq*/ None).await?;
|
||||
let event = timeout(Duration::from_secs(2), events.recv()).await??;
|
||||
let ExecProcessEvent::Failed(event_message) = event else {
|
||||
anyhow::bail!("expected process failure event, got {event:?}");
|
||||
};
|
||||
assert!(
|
||||
event_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected failure event: {event_message}"
|
||||
);
|
||||
|
||||
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
|
||||
let pending_message = pending_response
|
||||
.failure
|
||||
.expect("pending read should surface disconnect as a failure");
|
||||
assert!(
|
||||
pending_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected pending failure message: {pending_message}"
|
||||
);
|
||||
|
||||
let mut wake_rx = process.subscribe_wake();
|
||||
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
|
||||
let message = response
|
||||
.failure
|
||||
.expect("disconnect should surface as a failure");
|
||||
@@ -259,6 +533,20 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
"disconnect should close the process session"
|
||||
);
|
||||
|
||||
let write_result = timeout(
|
||||
Duration::from_secs(2),
|
||||
session.process.write(b"hello".to_vec()),
|
||||
)
|
||||
.await
|
||||
.context("timed out waiting for write after disconnect")?;
|
||||
let write_error = write_result.expect_err("write after disconnect should fail");
|
||||
assert!(
|
||||
write_error
|
||||
.to_string()
|
||||
.starts_with("exec-server transport disconnected"),
|
||||
"unexpected write error: {write_error}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -280,6 +568,24 @@ async fn exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_streams_output(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn exec_process_pushes_events(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_pushes_events(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn exec_process_replays_events_after_close(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_replays_events_after_close(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -289,6 +595,24 @@ async fn exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_write_then_read(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn exec_process_write_then_read_without_tty(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_write_then_read_without_tty(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn exec_process_rejects_write_without_pipe_stdin(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_rejects_write_without_pipe_stdin(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
@@ -10,6 +10,8 @@ use codex_exec_server::InitializeResponse;
|
||||
use codex_exec_server::ProcessId;
|
||||
use codex_exec_server::ReadResponse;
|
||||
use codex_exec_server::TerminateResponse;
|
||||
use codex_exec_server::WriteResponse;
|
||||
use codex_exec_server::WriteStatus;
|
||||
use common::exec_server::exec_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
@@ -47,6 +49,7 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> {
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"pipeStdin": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
@@ -75,6 +78,99 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
let initialize_id = server
|
||||
.send_request(
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
let _ = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
server
|
||||
.send_notification("initialized", serde_json::json!({}))
|
||||
.await?;
|
||||
|
||||
let process_start_id = server
|
||||
.send_request(
|
||||
"process/start",
|
||||
serde_json::json!({
|
||||
"processId": "proc-default-stdin",
|
||||
"argv": [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"sleep 0.3; if IFS= read -r line; then printf 'read:%s\\n' \"$line\"; else printf 'eof\\n'; fi"
|
||||
],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/start response");
|
||||
};
|
||||
let process_start_response: ExecResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(
|
||||
process_start_response,
|
||||
ExecResponse {
|
||||
process_id: ProcessId::from("proc-default-stdin")
|
||||
}
|
||||
);
|
||||
|
||||
let write_id = server
|
||||
.send_request(
|
||||
"process/write",
|
||||
serde_json::json!({
|
||||
"processId": "proc-default-stdin",
|
||||
"chunk": "aWdub3JlZAo="
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &write_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/write response");
|
||||
};
|
||||
let write_response: WriteResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(
|
||||
write_response,
|
||||
WriteResponse {
|
||||
status: WriteStatus::StdinClosed
|
||||
}
|
||||
);
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_resumes_detached_session_without_killing_processes() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
@@ -113,6 +209,7 @@ async fn exec_server_resumes_detached_session_without_killing_processes() -> any
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"pipeStdin": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ axum = { workspace = true, default-features = false, features = [
|
||||
] }
|
||||
codex-client = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-exec-server = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-utils-pty = { workspace = true }
|
||||
|
||||
364
codex-rs/rmcp-client/src/executor_process_transport.rs
Normal file
364
codex-rs/rmcp-client/src/executor_process_transport.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
//! rmcp transport adapter for an executor-managed MCP stdio process.
|
||||
//!
|
||||
//! This module owns the lower-level byte translation after
|
||||
//! `stdio_server_launcher` has already started a process through
|
||||
//! `ExecBackend::start`. It does not choose where the MCP server runs and it
|
||||
//! does not implement MCP lifecycle behavior. MCP protocol ownership stays in
|
||||
//! `RmcpClient` and rmcp:
|
||||
//!
|
||||
//! 1. rmcp serializes a JSON-RPC message and calls [`Transport::send`].
|
||||
//! 2. This transport appends the stdio newline delimiter and writes those bytes
|
||||
//! to executor `process/write`.
|
||||
//! 3. The executor writes the bytes to the child process stdin.
|
||||
//! 4. The child writes newline-delimited JSON-RPC messages to stdout.
|
||||
//! 5. The executor reports stdout bytes through pushed process events.
|
||||
//! 6. This transport buffers stdout until it has one full line, deserializes
|
||||
//! that line, and returns the rmcp message from [`Transport::receive`].
|
||||
//!
|
||||
//! Stderr is deliberately not part of the MCP byte stream. It is logged for
|
||||
//! diagnostics only, matching the local stdio implementation.
|
||||
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::mem::take;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_exec_server::ExecOutputStream;
|
||||
use codex_exec_server::ExecProcess;
|
||||
use codex_exec_server::ExecProcessEvent;
|
||||
use codex_exec_server::ExecProcessEventReceiver;
|
||||
use codex_exec_server::ProcessId;
|
||||
use codex_exec_server::ProcessOutputChunk;
|
||||
use codex_exec_server::WriteStatus;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RxJsonRpcMessage;
|
||||
use rmcp::service::TxJsonRpcMessage;
|
||||
use rmcp::transport::Transport;
|
||||
use serde_json::from_slice;
|
||||
use serde_json::to_vec;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::debug;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
static PROCESS_COUNTER: AtomicUsize = AtomicUsize::new(1);
|
||||
|
||||
// Remote public implementation.
|
||||
|
||||
/// A client-side rmcp transport backed by an executor-managed process.
|
||||
///
|
||||
/// The orchestrator owns this value and calls rmcp on it. The process it wraps
|
||||
/// may be local or remote depending on the `ExecBackend` used to create it, but
|
||||
/// for remote MCP stdio the process lives on the executor and all interaction
|
||||
/// crosses the executor process RPC boundary.
|
||||
pub(super) struct ExecutorProcessTransport {
|
||||
/// Logical process handle returned by the executor process API.
|
||||
///
|
||||
/// `write` forwards stdin bytes. `terminate` stops the child when rmcp
|
||||
/// closes the transport.
|
||||
process: Arc<dyn ExecProcess>,
|
||||
|
||||
/// Pushed output/lifecycle stream for the process.
|
||||
///
|
||||
/// The executor process API still supports retained-output reads, but MCP
|
||||
/// stdio is naturally streaming. This receiver lets rmcp wait for stdout
|
||||
/// chunks without issuing `process/read` after each output notification.
|
||||
events: ExecProcessEventReceiver,
|
||||
|
||||
/// Human-readable program name used only in diagnostics.
|
||||
program_name: String,
|
||||
|
||||
/// Buffered child stdout bytes that have not yet formed a complete
|
||||
/// newline-delimited JSON-RPC message.
|
||||
stdout: Vec<u8>,
|
||||
|
||||
/// Buffered stderr bytes for diagnostic logging.
|
||||
stderr: Vec<u8>,
|
||||
|
||||
/// Whether the executor has reported process closure or a terminal
|
||||
/// subscription failure. Once closed, any remaining partial stdout line is
|
||||
/// flushed once and then rmcp receives EOF.
|
||||
closed: bool,
|
||||
|
||||
/// Whether this transport already asked the executor to terminate the MCP
|
||||
/// server process.
|
||||
terminated: bool,
|
||||
|
||||
/// Highest executor process event sequence observed by this transport.
|
||||
///
|
||||
/// When the pushed event stream lags, use this as the retained-output read
|
||||
/// cursor to recover missed stdout/stderr chunks from the executor.
|
||||
last_seq: u64,
|
||||
}
|
||||
|
||||
impl ExecutorProcessTransport {
|
||||
pub(super) fn new(process: Arc<dyn ExecProcess>, program_name: String) -> Self {
|
||||
// Subscribe before returning the transport to rmcp. Some test servers
|
||||
// can emit output or exit quickly after `process/start`, and the
|
||||
// process event log will replay anything that landed before this
|
||||
// subscriber was attached.
|
||||
let events = process.subscribe_events();
|
||||
Self {
|
||||
process,
|
||||
events,
|
||||
program_name,
|
||||
stdout: Vec::new(),
|
||||
stderr: Vec::new(),
|
||||
closed: false,
|
||||
terminated: false,
|
||||
last_seq: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn next_process_id() -> ProcessId {
|
||||
// Process IDs are logical handles scoped to the executor connection,
|
||||
// not OS pids. A monotonic client-side id is enough to avoid
|
||||
// collisions between MCP servers started in the same session.
|
||||
let index = PROCESS_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
ProcessId::from(format!("mcp-stdio-{index}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Transport<RoleClient> for ExecutorProcessTransport {
|
||||
type Error = io::Error;
|
||||
|
||||
fn send(
|
||||
&mut self,
|
||||
item: TxJsonRpcMessage<RoleClient>,
|
||||
) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
|
||||
let process = Arc::clone(&self.process);
|
||||
async move {
|
||||
// rmcp hands us a structured JSON-RPC message. Stdio transport on
|
||||
// the wire is JSON plus one newline delimiter.
|
||||
let mut bytes = to_vec(&item).map_err(io::Error::other)?;
|
||||
bytes.push(b'\n');
|
||||
let response = process.write(bytes).await.map_err(io::Error::other)?;
|
||||
match response.status {
|
||||
WriteStatus::Accepted => Ok(()),
|
||||
WriteStatus::UnknownProcess => {
|
||||
Err(io::Error::new(io::ErrorKind::BrokenPipe, "unknown process"))
|
||||
}
|
||||
WriteStatus::StdinClosed => {
|
||||
Err(io::Error::new(io::ErrorKind::BrokenPipe, "stdin closed"))
|
||||
}
|
||||
WriteStatus::Starting => Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"process is starting",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
|
||||
self.receive_message()
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> std::result::Result<(), Self::Error> {
|
||||
self.process.terminate().await.map_err(io::Error::other)?;
|
||||
self.terminated = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutorProcessTransport {
|
||||
async fn receive_message(&mut self) -> Option<RxJsonRpcMessage<RoleClient>> {
|
||||
loop {
|
||||
// rmcp stdio framing is line-oriented JSON. We first drain any
|
||||
// complete line already buffered from an earlier process event.
|
||||
if let Some(message) = self.take_stdout_message(/*allow_partial*/ self.closed) {
|
||||
return Some(message);
|
||||
}
|
||||
if self.closed {
|
||||
self.flush_stderr();
|
||||
return None;
|
||||
}
|
||||
|
||||
match self.events.recv().await {
|
||||
Ok(ExecProcessEvent::Output(chunk)) => {
|
||||
// The executor pushes raw process bytes. This is the only
|
||||
// place where those bytes are split back into the stdout
|
||||
// protocol stream and stderr diagnostics.
|
||||
self.push_process_output_if_new(chunk);
|
||||
}
|
||||
Ok(ExecProcessEvent::Exited { seq, .. }) => {
|
||||
self.note_seq(seq);
|
||||
// Wait for `Closed` before ending the rmcp stream so any
|
||||
// output flushed during process shutdown can still be
|
||||
// decoded into JSON-RPC messages.
|
||||
}
|
||||
Ok(ExecProcessEvent::Closed { seq }) => {
|
||||
self.note_seq(seq);
|
||||
self.closed = true;
|
||||
}
|
||||
Ok(ExecProcessEvent::Failed(message)) => {
|
||||
warn!(
|
||||
"Remote MCP server process failed ({}): {message}",
|
||||
self.program_name
|
||||
);
|
||||
self.closed = true;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(skipped)) => {
|
||||
warn!(
|
||||
"Remote MCP server output stream lagged ({}): skipped {skipped} events",
|
||||
self.program_name
|
||||
);
|
||||
if let Err(error) = self.recover_lagged_events().await {
|
||||
warn!(
|
||||
"Failed to recover remote MCP server output stream ({}): {error}",
|
||||
self.program_name
|
||||
);
|
||||
self.closed = true;
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
self.closed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn note_seq(&mut self, seq: u64) {
|
||||
self.last_seq = self.last_seq.max(seq);
|
||||
}
|
||||
|
||||
fn should_accept_seq(&mut self, seq: u64) -> bool {
|
||||
if seq <= self.last_seq {
|
||||
return false;
|
||||
}
|
||||
self.last_seq = seq;
|
||||
true
|
||||
}
|
||||
|
||||
async fn recover_lagged_events(&mut self) -> io::Result<()> {
|
||||
let response = self
|
||||
.process
|
||||
.read(
|
||||
Some(self.last_seq),
|
||||
/*max_bytes*/ None,
|
||||
/*wait_ms*/ Some(0),
|
||||
)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
for chunk in response.chunks {
|
||||
self.push_process_output_if_new(chunk);
|
||||
}
|
||||
self.last_seq = self.last_seq.max(response.next_seq.saturating_sub(1));
|
||||
if let Some(message) = response.failure {
|
||||
warn!(
|
||||
"Remote MCP server process failed ({}): {message}",
|
||||
self.program_name
|
||||
);
|
||||
self.closed = true;
|
||||
} else if response.closed {
|
||||
self.closed = true;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn push_process_output_if_new(&mut self, chunk: ProcessOutputChunk) {
|
||||
if !self.should_accept_seq(chunk.seq) {
|
||||
return;
|
||||
}
|
||||
self.push_process_output(chunk);
|
||||
}
|
||||
|
||||
fn push_process_output(&mut self, chunk: ProcessOutputChunk) {
|
||||
let bytes = chunk.chunk.into_inner();
|
||||
match chunk.stream {
|
||||
// MCP stdio uses stdout as the protocol stream. PTY output is
|
||||
// accepted defensively because the executor process API has a
|
||||
// unified stream enum, but remote MCP starts with `tty=false`.
|
||||
ExecOutputStream::Stdout | ExecOutputStream::Pty => {
|
||||
self.stdout.extend_from_slice(&bytes);
|
||||
}
|
||||
// Stderr is intentionally out-of-band. It should help debug server
|
||||
// startup failures without entering rmcp framing.
|
||||
ExecOutputStream::Stderr => {
|
||||
self.push_stderr(&bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_stdout_message(&mut self, allow_partial: bool) -> Option<RxJsonRpcMessage<RoleClient>> {
|
||||
// A normal MCP stdio server emits one JSON-RPC message per newline.
|
||||
// If the process has already closed, accept a final unterminated line
|
||||
// so EOF after a complete JSON object behaves like local rmcp's
|
||||
// `decode_eof` handling.
|
||||
loop {
|
||||
let line_end = self.stdout.iter().position(|byte| *byte == b'\n');
|
||||
let line = match (line_end, allow_partial && !self.stdout.is_empty()) {
|
||||
(Some(index), _) => {
|
||||
let mut line = self.stdout.drain(..=index).collect::<Vec<_>>();
|
||||
line.pop();
|
||||
line
|
||||
}
|
||||
(None, true) => self.stdout.drain(..).collect(),
|
||||
(None, false) => return None,
|
||||
};
|
||||
let line = Self::trim_trailing_carriage_return(line);
|
||||
match from_slice::<RxJsonRpcMessage<RoleClient>>(&line) {
|
||||
Ok(message) => return Some(message),
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"Failed to parse remote MCP server message ({}): {error}",
|
||||
self.program_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_stderr(&mut self, bytes: &[u8]) {
|
||||
// Keep stderr line-oriented in logs so a chatty MCP server does not
|
||||
// produce one log record per byte chunk.
|
||||
self.stderr.extend_from_slice(bytes);
|
||||
while let Some(index) = self.stderr.iter().position(|byte| *byte == b'\n') {
|
||||
let mut line = self.stderr.drain(..=index).collect::<Vec<_>>();
|
||||
line.pop();
|
||||
if line.last() == Some(&b'\r') {
|
||||
line.pop();
|
||||
}
|
||||
info!(
|
||||
"MCP server stderr ({}): {}",
|
||||
self.program_name,
|
||||
String::from_utf8_lossy(&line)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_stderr(&mut self) {
|
||||
if self.stderr.is_empty() {
|
||||
return;
|
||||
}
|
||||
let line = take(&mut self.stderr);
|
||||
info!(
|
||||
"MCP server stderr ({}): {}",
|
||||
self.program_name,
|
||||
String::from_utf8_lossy(&line)
|
||||
);
|
||||
}
|
||||
|
||||
fn trim_trailing_carriage_return(mut line: Vec<u8>) -> Vec<u8> {
|
||||
if line.last() == Some(&b'\r') {
|
||||
line.pop();
|
||||
}
|
||||
line
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ExecutorProcessTransport {
|
||||
fn drop(&mut self) {
|
||||
if self.terminated {
|
||||
return;
|
||||
}
|
||||
|
||||
let process = Arc::clone(&self.process);
|
||||
tokio::spawn(async move {
|
||||
if let Err(error) = process.terminate().await {
|
||||
warn!("Failed to terminate remote MCP server process on drop: {error}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
mod auth_status;
|
||||
mod elicitation_client_service;
|
||||
mod executor_process_transport;
|
||||
mod logging_client_handler;
|
||||
mod oauth;
|
||||
mod perform_oauth_login;
|
||||
mod program_resolver;
|
||||
mod rmcp_client;
|
||||
mod stdio_server_launcher;
|
||||
mod utils;
|
||||
|
||||
pub use auth_status::StreamableHttpOAuthDiscovery;
|
||||
@@ -29,3 +31,6 @@ pub use rmcp_client::ListToolsWithConnectorIdResult;
|
||||
pub use rmcp_client::RmcpClient;
|
||||
pub use rmcp_client::SendElicitation;
|
||||
pub use rmcp_client::ToolWithConnectorId;
|
||||
pub use stdio_server_launcher::ExecutorStdioServerLauncher;
|
||||
pub use stdio_server_launcher::LocalStdioServerLauncher;
|
||||
pub use stdio_server_launcher::StdioServerLauncher;
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::ffi::OsString;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -52,7 +51,6 @@ use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::AuthRequiredError;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClient;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
@@ -63,23 +61,20 @@ use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use sse_stream::Sse;
|
||||
use sse_stream::SseStream;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::elicitation_client_service::ElicitationClientService;
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::program_resolver;
|
||||
use crate::stdio_server_launcher::StdioServerCommand;
|
||||
use crate::stdio_server_launcher::StdioServerLauncher;
|
||||
use crate::stdio_server_launcher::StdioServerTransport;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
|
||||
@@ -307,9 +302,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
}
|
||||
|
||||
enum PendingTransport {
|
||||
ChildProcess {
|
||||
transport: TokioChildProcess,
|
||||
process_group_guard: Option<ProcessGroupGuard>,
|
||||
Stdio {
|
||||
transport: StdioServerTransport,
|
||||
},
|
||||
StreamableHttp {
|
||||
transport: StreamableHttpClientTransport<StreamableHttpResponseClient>,
|
||||
@@ -325,79 +319,16 @@ enum ClientState {
|
||||
transport: Option<PendingTransport>,
|
||||
},
|
||||
Ready {
|
||||
_process_group_guard: Option<ProcessGroupGuard>,
|
||||
service: Arc<RunningService<RoleClient, ElicitationClientService>>,
|
||||
oauth: Option<OAuthPersistor>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
const PROCESS_GROUP_TERM_GRACE_PERIOD: Duration = Duration::from_secs(2);
|
||||
|
||||
#[cfg(unix)]
|
||||
struct ProcessGroupGuard {
|
||||
process_group_id: u32,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
struct ProcessGroupGuard;
|
||||
|
||||
impl ProcessGroupGuard {
|
||||
fn new(process_group_id: u32) -> Self {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
Self { process_group_id }
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = process_group_id;
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn maybe_terminate_process_group(&self) {
|
||||
let process_group_id = self.process_group_id;
|
||||
let should_escalate =
|
||||
match codex_utils_pty::process_group::terminate_process_group(process_group_id) {
|
||||
Ok(exists) => exists,
|
||||
Err(error) => {
|
||||
warn!("Failed to terminate MCP process group {process_group_id}: {error}");
|
||||
false
|
||||
}
|
||||
};
|
||||
if should_escalate {
|
||||
std::thread::spawn(move || {
|
||||
std::thread::sleep(PROCESS_GROUP_TERM_GRACE_PERIOD);
|
||||
if let Err(error) =
|
||||
codex_utils_pty::process_group::kill_process_group(process_group_id)
|
||||
{
|
||||
warn!("Failed to kill MCP process group {process_group_id}: {error}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn maybe_terminate_process_group(&self) {}
|
||||
}
|
||||
|
||||
impl Drop for ProcessGroupGuard {
|
||||
fn drop(&mut self) {
|
||||
if cfg!(unix) {
|
||||
self.maybe_terminate_process_group();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum TransportRecipe {
|
||||
Stdio {
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<OsString, OsString>>,
|
||||
env_vars: Vec<String>,
|
||||
cwd: Option<PathBuf>,
|
||||
command: StdioServerCommand,
|
||||
launcher: Arc<dyn StdioServerLauncher>,
|
||||
},
|
||||
StreamableHttp {
|
||||
server_name: String,
|
||||
@@ -574,13 +505,11 @@ impl RmcpClient {
|
||||
env: Option<HashMap<OsString, OsString>>,
|
||||
env_vars: &[String],
|
||||
cwd: Option<PathBuf>,
|
||||
launcher: Arc<dyn StdioServerLauncher>,
|
||||
) -> io::Result<Self> {
|
||||
let transport_recipe = TransportRecipe::Stdio {
|
||||
program,
|
||||
args,
|
||||
env,
|
||||
env_vars: env_vars.to_vec(),
|
||||
cwd,
|
||||
command: StdioServerCommand::new(program, args, env, env_vars.to_vec(), cwd),
|
||||
launcher,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe)
|
||||
.await
|
||||
@@ -650,7 +579,7 @@ impl RmcpClient {
|
||||
}
|
||||
};
|
||||
|
||||
let (service, oauth_persistor, process_group_guard) =
|
||||
let (service, oauth_persistor) =
|
||||
Self::connect_pending_transport(pending_transport, client_service.clone(), timeout)
|
||||
.await?;
|
||||
|
||||
@@ -671,7 +600,6 @@ impl RmcpClient {
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
_process_group_guard: process_group_guard,
|
||||
service,
|
||||
oauth: oauth_persistor.clone(),
|
||||
};
|
||||
@@ -954,60 +882,9 @@ impl RmcpClient {
|
||||
transport_recipe: &TransportRecipe,
|
||||
) -> Result<PendingTransport> {
|
||||
match transport_recipe {
|
||||
TransportRecipe::Stdio {
|
||||
program,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let envs = create_env_for_mcp_server(env.clone(), env_vars);
|
||||
let resolved_program = program_resolver::resolve(program.clone(), &envs)?;
|
||||
|
||||
let mut command = Command::new(resolved_program);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.args(args);
|
||||
#[cfg(unix)]
|
||||
command.process_group(0);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
let process_group_guard = transport.id().map(ProcessGroupGuard::new);
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match reader.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
info!("MCP server stderr ({program_name}): {line}");
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Failed to read MCP server stderr ({program_name}): {error}"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(PendingTransport::ChildProcess {
|
||||
transport,
|
||||
process_group_guard,
|
||||
})
|
||||
TransportRecipe::Stdio { command, launcher } => {
|
||||
let transport = launcher.launch(command.clone()).await?;
|
||||
Ok(PendingTransport::Stdio { transport })
|
||||
}
|
||||
TransportRecipe::StreamableHttp {
|
||||
server_name,
|
||||
@@ -1101,21 +978,15 @@ impl RmcpClient {
|
||||
) -> Result<(
|
||||
Arc<RunningService<RoleClient, ElicitationClientService>>,
|
||||
Option<OAuthPersistor>,
|
||||
Option<ProcessGroupGuard>,
|
||||
)> {
|
||||
let (transport, oauth_persistor, process_group_guard) = match pending_transport {
|
||||
PendingTransport::ChildProcess {
|
||||
transport,
|
||||
process_group_guard,
|
||||
} => (
|
||||
let (transport, oauth_persistor) = match pending_transport {
|
||||
PendingTransport::Stdio { transport } => (
|
||||
service::serve_client(client_service, transport).boxed(),
|
||||
None,
|
||||
process_group_guard,
|
||||
),
|
||||
PendingTransport::StreamableHttp { transport } => (
|
||||
service::serve_client(client_service, transport).boxed(),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
@@ -1123,7 +994,6 @@ impl RmcpClient {
|
||||
} => (
|
||||
service::serve_client(client_service, transport).boxed(),
|
||||
Some(oauth_persistor),
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
@@ -1137,7 +1007,7 @@ impl RmcpClient {
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
|
||||
Ok((Arc::new(service), oauth_persistor, process_group_guard))
|
||||
Ok((Arc::new(service), oauth_persistor))
|
||||
}
|
||||
|
||||
async fn run_service_operation<T, F, Fut>(
|
||||
@@ -1249,7 +1119,7 @@ impl RmcpClient {
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
|
||||
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
|
||||
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
|
||||
let (service, oauth_persistor) = Self::connect_pending_transport(
|
||||
pending_transport,
|
||||
initialize_context.client_service,
|
||||
initialize_context.timeout,
|
||||
@@ -1259,7 +1129,6 @@ impl RmcpClient {
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
_process_group_guard: process_group_guard,
|
||||
service,
|
||||
oauth: oauth_persistor.clone(),
|
||||
};
|
||||
|
||||
427
codex-rs/rmcp-client/src/stdio_server_launcher.rs
Normal file
427
codex-rs/rmcp-client/src/stdio_server_launcher.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
//! Launch MCP stdio servers and return the transport rmcp should use.
|
||||
//!
|
||||
//! This module owns the "where does the server process run?" decision:
|
||||
//!
|
||||
//! - [`LocalStdioServerLauncher`] starts the configured command as a child of
|
||||
//! the orchestrator process.
|
||||
//! - [`ExecutorStdioServerLauncher`] starts the configured command through the
|
||||
//! executor process API.
|
||||
//!
|
||||
//! Both paths return [`StdioServerTransport`], so `RmcpClient` can hand the
|
||||
//! resulting byte stream to rmcp without knowing where the process lives. The
|
||||
//! executor-specific byte adaptation lives in `executor_process_transport`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
#[cfg(unix)]
|
||||
use std::thread::sleep;
|
||||
#[cfg(unix)]
|
||||
use std::thread::spawn;
|
||||
#[cfg(unix)]
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_config::types::ShellEnvironmentPolicyInherit;
|
||||
use codex_exec_server::ExecBackend;
|
||||
use codex_exec_server::ExecEnvPolicy;
|
||||
use codex_exec_server::ExecParams;
|
||||
#[cfg(unix)]
|
||||
use codex_utils_pty::process_group::kill_process_group;
|
||||
#[cfg(unix)]
|
||||
use codex_utils_pty::process_group::terminate_process_group;
|
||||
use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RxJsonRpcMessage;
|
||||
use rmcp::service::TxJsonRpcMessage;
|
||||
use rmcp::transport::Transport;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::executor_process_transport::ExecutorProcessTransport;
|
||||
use crate::program_resolver;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use crate::utils::create_env_overlay_for_remote_mcp_server;
|
||||
|
||||
// General purpose public code.
|
||||
|
||||
/// Launches an MCP stdio server and returns the transport for rmcp.
|
||||
///
|
||||
/// This trait is the boundary between MCP lifecycle code and process placement.
|
||||
/// `RmcpClient` owns MCP operations such as `initialize` and `tools/list`; the
|
||||
/// launcher owns starting the configured command and producing an rmcp
|
||||
/// [`Transport`] over the server's stdin/stdout bytes.
|
||||
pub trait StdioServerLauncher: private::Sealed + Send + Sync {
|
||||
/// Start the configured stdio server and return its rmcp-facing transport.
|
||||
fn launch(
|
||||
&self,
|
||||
command: StdioServerCommand,
|
||||
) -> BoxFuture<'static, io::Result<StdioServerTransport>>;
|
||||
}
|
||||
|
||||
/// Command-line process shape shared by stdio server launchers.
|
||||
#[derive(Clone)]
|
||||
pub struct StdioServerCommand {
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<OsString, OsString>>,
|
||||
env_vars: Vec<String>,
|
||||
cwd: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// Client-side rmcp transport for a launched MCP stdio server.
|
||||
///
|
||||
/// The concrete process placement stays private to this module. `RmcpClient`
|
||||
/// only sees the standard rmcp transport abstraction and can pass this value
|
||||
/// directly to `rmcp::service::serve_client`.
|
||||
pub struct StdioServerTransport {
|
||||
inner: StdioServerTransportInner,
|
||||
// Local child processes can leave subprocesses behind, so the local
|
||||
// variant keeps a process-group guard with the transport. Executor-backed
|
||||
// processes are owned and cleaned up by the executor, so that variant uses
|
||||
// `None`.
|
||||
_process_group_guard: Option<ProcessGroupGuard>,
|
||||
}
|
||||
|
||||
enum StdioServerTransportInner {
|
||||
Local(TokioChildProcess),
|
||||
Executor(ExecutorProcessTransport),
|
||||
}
|
||||
|
||||
impl Transport<RoleClient> for StdioServerTransport {
|
||||
type Error = io::Error;
|
||||
|
||||
fn send(
|
||||
&mut self,
|
||||
item: TxJsonRpcMessage<RoleClient>,
|
||||
) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
|
||||
// Both variants already implement rmcp's transport contract. This
|
||||
// wrapper keeps process placement private while leaving rmcp's send
|
||||
// semantics unchanged.
|
||||
match &mut self.inner {
|
||||
StdioServerTransportInner::Local(transport) => transport.send(item).boxed(),
|
||||
StdioServerTransportInner::Executor(transport) => transport.send(item).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
|
||||
// rmcp reads from the same transport shape for both placements. The
|
||||
// executor variant turns pushed process-output events back into the
|
||||
// line-delimited JSON stream expected by rmcp.
|
||||
match &mut self.inner {
|
||||
StdioServerTransportInner::Local(transport) => transport.receive().boxed(),
|
||||
StdioServerTransportInner::Executor(transport) => transport.receive().boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> std::result::Result<(), Self::Error> {
|
||||
match &mut self.inner {
|
||||
StdioServerTransportInner::Local(transport) => transport.close().await,
|
||||
StdioServerTransportInner::Executor(transport) => transport.close().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdioServerCommand {
|
||||
/// Build the stdio process parameters before choosing where the process
|
||||
/// runs.
|
||||
pub(super) fn new(
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<OsString, OsString>>,
|
||||
env_vars: Vec<String>,
|
||||
cwd: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
program,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Local public implementation.
|
||||
|
||||
/// Starts MCP stdio servers as local child processes.
|
||||
///
|
||||
/// This is the existing behavior for local MCP servers: the orchestrator
|
||||
/// process spawns the configured command and rmcp talks to the child's local
|
||||
/// stdin/stdout pipes directly.
|
||||
#[derive(Clone)]
|
||||
pub struct LocalStdioServerLauncher;
|
||||
|
||||
impl StdioServerLauncher for LocalStdioServerLauncher {
|
||||
fn launch(
|
||||
&self,
|
||||
command: StdioServerCommand,
|
||||
) -> BoxFuture<'static, io::Result<StdioServerTransport>> {
|
||||
async move { Self::launch_server(command) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
// Local private implementation.
|
||||
|
||||
#[cfg(unix)]
|
||||
const PROCESS_GROUP_TERM_GRACE_PERIOD: Duration = Duration::from_secs(2);
|
||||
|
||||
#[cfg(unix)]
|
||||
struct ProcessGroupGuard {
|
||||
process_group_id: u32,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
struct ProcessGroupGuard;
|
||||
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
impl private::Sealed for LocalStdioServerLauncher {}
|
||||
|
||||
impl LocalStdioServerLauncher {
|
||||
fn launch_server(command: StdioServerCommand) -> io::Result<StdioServerTransport> {
|
||||
let StdioServerCommand {
|
||||
program,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} = command;
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let envs = create_env_for_mcp_server(env, &env_vars);
|
||||
let resolved_program =
|
||||
program_resolver::resolve(program, &envs).map_err(io::Error::other)?;
|
||||
|
||||
let mut command = Command::new(resolved_program);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.args(args);
|
||||
#[cfg(unix)]
|
||||
command.process_group(0);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
let process_group_guard = transport.id().map(ProcessGroupGuard::new);
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match reader.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
info!("MCP server stderr ({program_name}): {line}");
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
warn!("Failed to read MCP server stderr ({program_name}): {error}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(StdioServerTransport {
|
||||
inner: StdioServerTransportInner::Local(transport),
|
||||
_process_group_guard: process_group_guard,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcessGroupGuard {
|
||||
fn new(process_group_id: u32) -> Self {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
Self { process_group_id }
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = process_group_id;
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn maybe_terminate_process_group(&self) {
|
||||
let process_group_id = self.process_group_id;
|
||||
let should_escalate = match terminate_process_group(process_group_id) {
|
||||
Ok(exists) => exists,
|
||||
Err(error) => {
|
||||
warn!("Failed to terminate MCP process group {process_group_id}: {error}");
|
||||
false
|
||||
}
|
||||
};
|
||||
if should_escalate {
|
||||
spawn(move || {
|
||||
sleep(PROCESS_GROUP_TERM_GRACE_PERIOD);
|
||||
if let Err(error) = kill_process_group(process_group_id) {
|
||||
warn!("Failed to kill MCP process group {process_group_id}: {error}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn maybe_terminate_process_group(&self) {}
|
||||
}
|
||||
|
||||
impl Drop for ProcessGroupGuard {
|
||||
fn drop(&mut self) {
|
||||
if cfg!(unix) {
|
||||
self.maybe_terminate_process_group();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remote public implementation.
|
||||
|
||||
/// Starts MCP stdio servers through the executor process API.
|
||||
///
|
||||
/// MCP framing still runs in the orchestrator. The executor only owns the
|
||||
/// child process and transports raw stdin/stdout/stderr bytes, so it does not
|
||||
/// need to know about MCP methods such as `initialize` or `tools/list`.
|
||||
#[derive(Clone)]
|
||||
pub struct ExecutorStdioServerLauncher {
|
||||
exec_backend: Arc<dyn ExecBackend>,
|
||||
default_cwd: PathBuf,
|
||||
}
|
||||
|
||||
impl ExecutorStdioServerLauncher {
|
||||
/// Creates a stdio server launcher backed by the executor process API.
|
||||
///
|
||||
/// `default_cwd` is used only when the MCP server config omits `cwd`.
|
||||
/// Executor `process/start` requires an explicit working directory, unlike
|
||||
/// local `tokio::process::Command`, which can inherit the orchestrator cwd.
|
||||
pub fn new(exec_backend: Arc<dyn ExecBackend>, default_cwd: PathBuf) -> Self {
|
||||
Self {
|
||||
exec_backend,
|
||||
default_cwd,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdioServerLauncher for ExecutorStdioServerLauncher {
|
||||
fn launch(
|
||||
&self,
|
||||
command: StdioServerCommand,
|
||||
) -> BoxFuture<'static, io::Result<StdioServerTransport>> {
|
||||
let exec_backend = Arc::clone(&self.exec_backend);
|
||||
let default_cwd = self.default_cwd.clone();
|
||||
async move { Self::launch_server(command, exec_backend, default_cwd).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
// Remote private implementation.
|
||||
|
||||
impl private::Sealed for ExecutorStdioServerLauncher {}
|
||||
|
||||
impl ExecutorStdioServerLauncher {
|
||||
async fn launch_server(
|
||||
command: StdioServerCommand,
|
||||
exec_backend: Arc<dyn ExecBackend>,
|
||||
default_cwd: PathBuf,
|
||||
) -> io::Result<StdioServerTransport> {
|
||||
let StdioServerCommand {
|
||||
program,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} = command;
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let envs = create_env_overlay_for_remote_mcp_server(env, &env_vars);
|
||||
// The executor protocol carries argv/env as UTF-8 strings. Local stdio can
|
||||
// accept arbitrary OsString values because it calls the OS directly; remote
|
||||
// stdio must reject non-Unicode command, argument, or environment data
|
||||
// before sending an executor request.
|
||||
let argv = Self::process_api_argv(&program, &args).map_err(io::Error::other)?;
|
||||
let env = Self::process_api_env(envs).map_err(io::Error::other)?;
|
||||
let process_id = ExecutorProcessTransport::next_process_id();
|
||||
// Start the MCP server process on the executor with raw pipes. `tty=false`
|
||||
// keeps stdout as a clean protocol stream, while `pipe_stdin=true` lets
|
||||
// rmcp write JSON-RPC requests after the process starts.
|
||||
let started = exec_backend
|
||||
.start(ExecParams {
|
||||
process_id,
|
||||
argv,
|
||||
cwd: cwd.unwrap_or(default_cwd),
|
||||
env_policy: Some(Self::remote_env_policy()),
|
||||
env,
|
||||
tty: false,
|
||||
pipe_stdin: true,
|
||||
arg0: None,
|
||||
})
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(StdioServerTransport {
|
||||
inner: StdioServerTransportInner::Executor(ExecutorProcessTransport::new(
|
||||
started.process,
|
||||
program_name,
|
||||
)),
|
||||
_process_group_guard: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn process_api_argv(program: &OsString, args: &[OsString]) -> Result<Vec<String>> {
|
||||
let mut argv = Vec::with_capacity(args.len() + 1);
|
||||
argv.push(Self::os_string_to_process_api_string(
|
||||
program.clone(),
|
||||
"command",
|
||||
)?);
|
||||
for arg in args {
|
||||
argv.push(Self::os_string_to_process_api_string(
|
||||
arg.clone(),
|
||||
"argument",
|
||||
)?);
|
||||
}
|
||||
Ok(argv)
|
||||
}
|
||||
|
||||
fn process_api_env(env: HashMap<OsString, OsString>) -> Result<HashMap<String, String>> {
|
||||
env.into_iter()
|
||||
.map(|(key, value)| {
|
||||
Ok((
|
||||
Self::os_string_to_process_api_string(key, "environment variable name")?,
|
||||
Self::os_string_to_process_api_string(value, "environment variable value")?,
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn os_string_to_process_api_string(value: OsString, label: &str) -> Result<String> {
|
||||
value
|
||||
.into_string()
|
||||
.map_err(|_| anyhow!("{label} must be valid Unicode for remote MCP stdio"))
|
||||
}
|
||||
|
||||
fn remote_env_policy() -> ExecEnvPolicy {
|
||||
ExecEnvPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::Core,
|
||||
ignore_default_excludes: true,
|
||||
exclude: Vec::new(),
|
||||
r#set: HashMap::new(),
|
||||
include_only: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,20 @@ pub(crate) fn create_env_for_mcp_server(
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn create_env_overlay_for_remote_mcp_server(
|
||||
extra_env: Option<HashMap<OsString, OsString>>,
|
||||
env_vars: &[String],
|
||||
) -> HashMap<OsString, OsString> {
|
||||
// Remote stdio should inherit PATH/HOME/etc. from the executor side, not
|
||||
// from the orchestrator process. Only forward variables explicitly named
|
||||
// by the MCP config plus literal env overrides from that config.
|
||||
env_vars
|
||||
.iter()
|
||||
.filter_map(|var| env::var_os(var).map(|value| (OsString::from(var), value)))
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn build_default_headers(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
@@ -197,6 +211,26 @@ mod tests {
|
||||
assert_eq!(env.get(OsStr::new(custom_var)), Some(&expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(extra_rmcp_env)]
|
||||
fn create_remote_env_overlay_only_forwards_explicit_variables() {
|
||||
let default_var = DEFAULT_ENV_VARS[0];
|
||||
let custom_var = "EXTRA_REMOTE_RMCP_ENV";
|
||||
let custom_value = OsString::from("from-env");
|
||||
let _default_guard = EnvVarGuard::set(default_var, "from-default");
|
||||
let _custom_guard = EnvVarGuard::set(custom_var, &custom_value);
|
||||
|
||||
let env = create_env_overlay_for_remote_mcp_server(
|
||||
/*extra_env*/ None,
|
||||
&[custom_var.to_string()],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
env,
|
||||
HashMap::from([(OsString::from(custom_var), custom_value)])
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
#[serial(extra_rmcp_env)]
|
||||
|
||||
@@ -4,10 +4,12 @@ use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_rmcp_client::LocalStdioServerLauncher;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
|
||||
fn process_exists(pid: u32) -> bool {
|
||||
@@ -78,6 +80,7 @@ async fn drop_kills_wrapper_process_group() -> Result<()> {
|
||||
)])),
|
||||
&[],
|
||||
/*cwd*/ None,
|
||||
Arc::new(LocalStdioServerLauncher),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::LocalStdioServerLauncher;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_utils_cargo_bin::CargoBinError;
|
||||
use futures::FutureExt as _;
|
||||
@@ -61,6 +63,7 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
|
||||
/*env*/ None,
|
||||
&[],
|
||||
/*cwd*/ None,
|
||||
Arc::new(LocalStdioServerLauncher),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user