diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 54a67fea3c..2842838ce4 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -12,7 +12,6 @@ use futures::FutureExt; use futures::future::BoxFuture; use serde_json::Value; use tokio::sync::Mutex; -use tokio::sync::OnceCell; use tokio::sync::mpsc; use tokio::sync::watch; @@ -192,25 +191,28 @@ pub struct ExecServerClient { #[derive(Clone)] pub(crate) struct LazyRemoteExecServerClient { transport: ExecServerTransport, - client: Arc>, + client: Arc>>, } impl LazyRemoteExecServerClient { pub(crate) fn new(transport: ExecServerTransport) -> Self { Self { transport, - client: Arc::new(OnceCell::new()), + client: Arc::new(Mutex::new(None)), } } pub(crate) async fn get(&self) -> Result { - self.client - .get_or_try_init(|| { - let transport = self.transport.clone(); - async move { ExecServerClient::connect_for_environment(transport).await } - }) - .await - .cloned() + let mut client = self.client.lock().await; + if let Some(client) = client.as_ref() + && !client.is_disconnected() + { + return Ok(client.clone()); + } + + let connected = ExecServerClient::connect_for_environment(self.transport.clone()).await?; + *client = Some(connected.clone()); + Ok(connected) } } @@ -274,6 +276,10 @@ pub enum ExecServerError { } impl ExecServerClient { + fn is_disconnected(&self) -> bool { + self.inner.disconnected_error().is_some() || self.inner.client.is_disconnected() + } + pub async fn initialize( &self, options: ExecServerClientConnectOptions, @@ -872,6 +878,7 @@ mod tests { use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCResponse; use pretty_assertions::assert_eq; + use std::collections::HashMap; #[cfg(unix)] use std::path::Path; #[cfg(unix)] @@ -890,7 +897,7 @@ mod tests { use super::ExecServerClient; use super::ExecServerClientConnectOptions; use crate::ProcessId; - #[cfg(not(windows))] + use crate::StdioExecServerCommand; use crate::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; @@ -933,7 +940,38 @@ mod tests { #[tokio::test] async fn connect_stdio_command_initializes_json_rpc_client() { let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { - shell_command: "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(), + command: StdioExecServerCommand { + program: "sh".to_string(), + args: vec![ + "-c".to_string(), + "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(), + ], + env: HashMap::new(), + cwd: None, + }, + client_name: "stdio-test-client".to_string(), + initialize_timeout: Duration::from_secs(1), + resume_session_id: None, + }) + .await + .expect("stdio client should connect"); + + assert_eq!(client.session_id().as_deref(), Some("stdio-test")); + } + + #[cfg(windows)] + #[tokio::test] + async fn connect_stdio_command_initializes_json_rpc_client_on_windows() { + let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + command: StdioExecServerCommand { + program: "cmd".to_string(), + args: vec![ + "/C".to_string(), + "set /p _line= & echo {\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}} & set /p _line= & ping -n 60 127.0.0.1 >nul".to_string(), + ], + env: HashMap::new(), + cwd: None, + }, client_name: "stdio-test-client".to_string(), initialize_timeout: Duration::from_secs(1), resume_session_id: None, @@ -946,43 +984,71 @@ mod tests { #[cfg(unix)] #[tokio::test] - async fn dropping_stdio_client_terminates_shell_process_group() { + async fn dropping_stdio_client_terminates_spawned_process() { let tempdir = tempfile::tempdir().expect("tempdir should be created"); - let pid_file = tempdir.path().join("child.pid"); - let shell_command = format!( + let pid_file = tempdir.path().join("server.pid"); + let stdio_script = format!( "read _line; \ - (trap 'exit 0' TERM; while true; do sleep 1; done) & \ - child=$!; \ - echo \"$child\" > {}; \ + echo \"$$\" > {}; \ printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \ read _line; \ - wait \"$child\"", + sleep 60", shell_quote(pid_file.as_path()), ); let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { - shell_command, + command: StdioExecServerCommand { + program: "sh".to_string(), + args: vec!["-c".to_string(), stdio_script], + env: HashMap::new(), + cwd: None, + }, client_name: "stdio-test-client".to_string(), initialize_timeout: Duration::from_secs(1), resume_session_id: None, }) .await .expect("stdio client should connect"); - let child_pid = read_pid_file(pid_file.as_path()).await; + let server_pid = read_pid_file(pid_file.as_path()).await; assert!( - process_exists(child_pid), - "wrapper child process should be running before client drop" + process_exists(server_pid), + "spawned stdio process should be running before client drop" ); drop(client); - for _ in 0..20 { - if !process_exists(child_pid) { - return; - } - sleep(Duration::from_millis(100)).await; - } - panic!("wrapper child process {child_pid} should exit after client drop"); + wait_for_process_exit(server_pid).await; + } + + #[cfg(unix)] + #[tokio::test] + async fn malformed_stdio_message_terminates_spawned_process() { + let tempdir = tempfile::tempdir().expect("tempdir should be created"); + let pid_file = tempdir.path().join("server.pid"); + let stdio_script = format!( + "read _line; \ + echo \"$$\" > {}; \ + printf '%s\\n' 'not-json'; \ + sleep 60", + shell_quote(pid_file.as_path()), + ); + + let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + command: StdioExecServerCommand { + program: "sh".to_string(), + args: vec!["-c".to_string(), stdio_script], + env: HashMap::new(), + cwd: None, + }, + client_name: "stdio-test-client".to_string(), + initialize_timeout: Duration::from_secs(1), + resume_session_id: None, + }) + .await; + assert!(result.is_err(), "malformed stdio server should not connect"); + + let server_pid = read_pid_file(pid_file.as_path()).await; + wait_for_process_exit(server_pid).await; } #[cfg(unix)] @@ -999,6 +1065,17 @@ mod tests { panic!("pid file {} should be written", path.display()); } + #[cfg(unix)] + async fn wait_for_process_exit(pid: u32) { + for _ in 0..20 { + if !process_exists(pid) { + return; + } + sleep(Duration::from_millis(100)).await; + } + panic!("process {pid} should exit"); + } + #[cfg(unix)] fn process_exists(pid: u32) -> bool { Command::new("kill") diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 95ed053476..8110ee24e5 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; +use std::path::PathBuf; use std::time::Duration; use futures::future::BoxFuture; @@ -28,17 +30,26 @@ pub struct RemoteExecServerConnectArgs { /// Stdio connection arguments for a command-backed exec-server. #[derive(Debug, Clone, PartialEq, Eq)] pub struct StdioExecServerConnectArgs { - pub shell_command: String, + pub command: StdioExecServerCommand, pub client_name: String, pub initialize_timeout: Duration, pub resume_session_id: Option, } +/// Structured process command used to start an exec-server over stdio. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StdioExecServerCommand { + pub program: String, + pub args: Vec, + pub env: HashMap, + pub cwd: Option, +} + /// Transport used to connect to a remote exec-server environment. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExecServerTransport { WebSocketUrl(String), - StdioShellCommand(String), + StdioCommand(StdioExecServerCommand), } /// Sends HTTP requests through a runtime-selected transport. diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 00ecc2bd2e..d6d1c92deb 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -1,19 +1,11 @@ use std::process::Stdio; -#[cfg(unix)] -use std::thread::sleep; -#[cfg(unix)] -use std::thread::spawn; use std::time::Duration; -#[cfg(unix)] -use codex_utils_pty::process_group::kill_process_group; -#[cfg(unix)] -use codex_utils_pty::process_group::terminate_process_group; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Child; use tokio::process::Command; -use tokio::runtime::Handle; +use tokio::sync::oneshot; use tokio::time::timeout; use tokio_tungstenite::connect_async; use tracing::debug; @@ -22,14 +14,13 @@ use tracing::warn; use crate::ExecServerClient; use crate::ExecServerError; use crate::client_api::RemoteExecServerConnectArgs; +use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5); -#[cfg(unix)] -const STDIO_CHILD_TERM_GRACE_PERIOD: Duration = Duration::from_millis(500); impl ExecServerClient { pub(crate) async fn connect_for_environment( @@ -46,9 +37,9 @@ impl ExecServerClient { }) .await } - crate::client_api::ExecServerTransport::StdioShellCommand(shell_command) => { + crate::client_api::ExecServerTransport::StdioCommand(command) => { Self::connect_stdio_command(StdioExecServerConnectArgs { - shell_command, + command, client_name: ENVIRONMENT_CLIENT_NAME.to_string(), initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, resume_session_id: None, @@ -87,15 +78,13 @@ impl ExecServerClient { pub async fn connect_stdio_command( args: StdioExecServerConnectArgs, ) -> Result { - let shell_command = args.shell_command.clone(); - let mut child = shell_command_process(&shell_command) + let mut child = stdio_command_process(&args.command) .kill_on_drop(true) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() .map_err(ExecServerError::Spawn)?; - let process_id = child.id(); let stdin = child.stdin.take().ok_or_else(|| { ExecServerError::Protocol("spawned exec-server command has no stdin".to_string()) @@ -120,15 +109,8 @@ impl ExecServerClient { } Self::connect( - JsonRpcConnection::from_stdio( - stdout, - stdin, - format!("exec-server stdio command `{shell_command}`"), - ) - .with_transport_lifetime(Box::new(StdioChildGuard { - child: Some(child), - process_id, - })), + JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string()) + .with_transport_lifetime(Box::new(StdioChildGuard::spawn(child))), args.into(), ) .await @@ -136,70 +118,44 @@ impl ExecServerClient { } struct StdioChildGuard { - child: Option, - process_id: Option, + shutdown_tx: Option>, +} + +impl StdioChildGuard { + fn spawn(child: Child) -> Self { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + tokio::spawn(supervise_stdio_child(child, shutdown_rx)); + Self { + shutdown_tx: Some(shutdown_tx), + } + } } impl Drop for StdioChildGuard { fn drop(&mut self) { - let Some(mut child) = self.child.take() else { - return; - }; - - terminate_stdio_child_process(self.process_id, &mut child); - - if let Ok(handle) = Handle::try_current() { - let _wait_task = handle.spawn(wait_stdio_child(child)); + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); } } } -async fn wait_stdio_child(mut child: Child) { - if let Err(err) = child.wait().await { - debug!("failed to wait for exec-server stdio child: {err}"); - } -} - -#[cfg(unix)] -fn terminate_stdio_child_process(process_group_id: Option, child: &mut Child) { - let Some(process_group_id) = process_group_id else { - kill_stdio_child(child); - return; - }; - - let should_escalate = match terminate_process_group(process_group_id) { - Ok(exists) => exists, - Err(err) => { - debug!("failed to terminate exec-server stdio process group {process_group_id}: {err}"); +async fn supervise_stdio_child(mut child: Child, shutdown_rx: oneshot::Receiver<()>) { + let shutdown_requested = tokio::select! { + result = child.wait() => { + if let Err(err) = result { + debug!("failed to wait for exec-server stdio child: {err}"); + } false } + _ = shutdown_rx => true, }; - if should_escalate { - spawn(move || { - sleep(STDIO_CHILD_TERM_GRACE_PERIOD); - if let Err(err) = kill_process_group(process_group_id) { - debug!("failed to kill exec-server stdio process group {process_group_id}: {err}"); - } - }); - } -} -#[cfg(windows)] -fn terminate_stdio_child_process(process_id: Option, child: &mut Child) { - if let Some(process_id) = process_id { - let _ = std::process::Command::new("taskkill") - .arg("/PID") - .arg(process_id.to_string()) - .arg("/T") - .arg("/F") - .output(); + if shutdown_requested { + kill_stdio_child(&mut child); + if let Err(err) = child.wait().await { + debug!("failed to wait for exec-server stdio child after shutdown: {err}"); + } } - kill_stdio_child(child); -} - -#[cfg(not(any(unix, windows)))] -fn terminate_stdio_child_process(_process_id: Option, child: &mut Child) { - kill_stdio_child(child); } fn kill_stdio_child(child: &mut Child) { @@ -208,19 +164,12 @@ fn kill_stdio_child(child: &mut Child) { } } -fn shell_command_process(shell_command: &str) -> Command { - #[cfg(windows)] - { - let mut command = Command::new("cmd"); - command.arg("/C").arg(shell_command); - command - } - - #[cfg(not(windows))] - { - let mut command = Command::new("sh"); - command.arg("-lc").arg(shell_command); - command.process_group(0); - command +fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command { + let mut command = Command::new(&stdio_command.program); + command.args(&stdio_command.args); + command.envs(&stdio_command.env); + if let Some(cwd) = &stdio_command.cwd { + command.current_dir(cwd); } + command } diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 9d06f0841d..c0a4a42a54 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -272,7 +272,23 @@ impl JsonRpcConnection { self } - pub(crate) fn into_parts(self) -> JsonRpcConnectionParts { + pub(crate) fn into_parts( + self, + ) -> ( + mpsc::Sender, + mpsc::Receiver, + watch::Receiver, + Vec>, + ) { + ( + self.outgoing_tx, + self.incoming_rx, + self.disconnected_rx, + self.task_handles, + ) + } + + pub(crate) fn into_parts_with_lifetime(self) -> JsonRpcConnectionParts { JsonRpcConnectionParts { outgoing_tx: self.outgoing_tx, incoming_rx: self.incoming_rx, diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index 9bec4ed1da..29897c5285 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -28,6 +28,7 @@ pub use client_api::ExecServerClientConnectOptions; pub use client_api::ExecServerTransport; pub use client_api::HttpClient; pub use client_api::RemoteExecServerConnectArgs; +pub use client_api::StdioExecServerCommand; pub use client_api::StdioExecServerConnectArgs; pub use codex_file_system::CopyOptions; pub use codex_file_system::CreateDirectoryOptions; diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 3b155c08a2..49438b4fd8 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -24,7 +23,6 @@ use tokio::task::JoinHandle; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; -use crate::connection::JsonRpcTransportLifetime; #[derive(Debug)] pub(crate) enum RpcCallError { @@ -231,19 +229,12 @@ pub(crate) struct RpcClient { disconnected_rx: watch::Receiver, next_request_id: AtomicI64, transport_tasks: Vec>, - _transport_lifetime: Option, reader_task: JoinHandle<()>, } -// Holds transport-owned resources, such as a stdio child process, for as long -// as the RPC client owns the underlying connection. -struct TransportLifetime { - _guard: StdMutex, -} - impl RpcClient { pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let connection_parts = connection.into_parts(); + let connection_parts = connection.into_parts_with_lifetime(); let write_tx = connection_parts.outgoing_tx; let mut incoming_rx = connection_parts.incoming_rx; let disconnected_rx = connection_parts.disconnected_rx; @@ -254,6 +245,7 @@ impl RpcClient { let pending_for_reader = Arc::clone(&pending); let reader_task = tokio::spawn(async move { + let _transport_lifetime = transport_lifetime; while let Some(event) = incoming_rx.recv().await { match event { JsonRpcConnectionEvent::Message(message) => { @@ -289,9 +281,6 @@ impl RpcClient { disconnected_rx, next_request_id: AtomicI64::new(1), transport_tasks, - _transport_lifetime: transport_lifetime.map(|lifetime| TransportLifetime { - _guard: StdMutex::new(lifetime), - }), reader_task, }, event_rx, @@ -318,6 +307,10 @@ impl RpcClient { }) } + pub(crate) fn is_disconnected(&self) -> bool { + *self.disconnected_rx.borrow() + } + pub(crate) async fn call(&self, method: &str, params: &P) -> Result where P: Serialize, diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 11472dc626..dc1a9b9ffe 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -47,12 +47,8 @@ async fn run_connection( runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); - let connection_parts = connection.into_parts(); - let json_outgoing_tx = connection_parts.outgoing_tx; - let mut incoming_rx = connection_parts.incoming_rx; - let mut disconnected_rx = connection_parts.disconnected_rx; - let connection_tasks = connection_parts.task_handles; - let _transport_lifetime = connection_parts.transport_lifetime; + let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) = + connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); let notifications = RpcNotificationSender::new(outgoing_tx.clone());