diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 55cd99b6c9..88fce6f4ea 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -894,6 +894,7 @@ mod tests { use super::ExecServerClient; use super::ExecServerClientConnectOptions; use crate::ProcessId; + use crate::client_api::ExecServerTransportParams; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; @@ -956,6 +957,26 @@ mod tests { assert_eq!(client.session_id().as_deref(), Some("stdio-test")); } + #[cfg(not(windows))] + #[tokio::test] + async fn connect_for_transport_initializes_stdio_command() { + let client = ExecServerClient::connect_for_transport( + ExecServerTransportParams::StdioCommand(StdioExecServerCommand { + program: "sh".to_string(), + args: vec![ + "-c".to_string(), + "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(), + ], + env: HashMap::new(), + cwd: None, + }), + ) + .await + .expect("stdio transport should connect"); + + assert_eq!(client.session_id().as_deref(), Some("stdio-test")); + } + #[cfg(windows)] #[tokio::test] async fn connect_stdio_command_initializes_json_rpc_client_on_windows() { @@ -985,13 +1006,16 @@ mod tests { async fn dropping_stdio_client_terminates_spawned_process() { let tempdir = tempfile::tempdir().expect("tempdir should be created"); let pid_file = tempdir.path().join("server.pid"); + let child_pid_file = tempdir.path().join("server-child.pid"); let stdio_script = format!( "read _line; \ echo \"$$\" > {}; \ + sleep 60 >/dev/null 2>&1 & echo \"$!\" > {}; \ printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \ read _line; \ - sleep 60", + wait", shell_quote(pid_file.as_path()), + shell_quote(child_pid_file.as_path()), ); let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { @@ -1008,14 +1032,20 @@ mod tests { .await .expect("stdio client should connect"); let server_pid = read_pid_file(pid_file.as_path()).await; + let child_pid = read_pid_file(child_pid_file.as_path()).await; assert!( process_exists(server_pid), "spawned stdio process should be running before client drop" ); + assert!( + process_exists(child_pid), + "spawned stdio child process should be running before client drop" + ); drop(client); wait_for_process_exit(server_pid).await; + wait_for_process_exit(child_pid).await; } #[cfg(unix)] diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 560630e3db..3fccfa25c5 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -121,5 +121,7 @@ fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command { if let Some(cwd) = &stdio_command.cwd { command.current_dir(cwd); } + #[cfg(unix)] + command.process_group(0); command } diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index f1e65e321a..c990c89338 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -1,3 +1,8 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; + use codex_app_server_protocol::JSONRPCMessage; use futures::SinkExt; use futures::StreamExt; @@ -6,9 +11,11 @@ use tokio::io::AsyncWrite; use tokio::process::Child; use tokio::sync::mpsc; use tokio::sync::watch; +use tokio::time::timeout; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; use tracing::debug; +use tracing::warn; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; @@ -16,6 +23,7 @@ use tokio::io::BufReader; use tokio::io::BufWriter; pub(crate) const CHANNEL_CAPACITY: usize = 128; +const STDIO_TERMINATION_GRACE_PERIOD: Duration = Duration::from_secs(2); #[derive(Debug)] pub(crate) enum JsonRpcConnectionEvent { @@ -24,46 +32,177 @@ pub(crate) enum JsonRpcConnectionEvent { Disconnected { reason: Option }, } +#[derive(Clone)] pub(crate) enum JsonRpcTransport { Plain, - Stdio { _transport: Box }, + Stdio { transport: StdioTransport }, } impl JsonRpcTransport { fn from_child_process(child_process: Child) -> Self { Self::Stdio { - _transport: Box::new(StdioTransport { - child_process: Some(child_process), - }), + transport: StdioTransport::spawn(child_process), + } + } + + pub(crate) fn terminate(&self) { + match self { + Self::Plain => {} + Self::Stdio { transport } => transport.terminate(), } } } +#[derive(Clone)] pub(crate) struct StdioTransport { - child_process: Option, + handle: Arc, } -impl Drop for StdioTransport { - fn drop(&mut self) { - let Some(mut child_process) = self.child_process.take() else { - return; - }; +struct StdioTransportHandle { + terminate_tx: watch::Sender, + terminate_requested: AtomicBool, +} - if let Err(err) = child_process.start_kill() { - debug!("failed to terminate exec-server stdio child: {err}"); +impl StdioTransport { + fn spawn(child_process: Child) -> Self { + let (terminate_tx, terminate_rx) = watch::channel(false); + let handle = Arc::new(StdioTransportHandle { + terminate_tx, + terminate_requested: AtomicBool::new(false), + }); + spawn_stdio_child_supervisor(child_process, terminate_rx); + Self { handle } + } + + fn terminate(&self) { + self.handle.terminate(); + } +} + +impl StdioTransportHandle { + fn terminate(&self) { + if !self.terminate_requested.swap(true, Ordering::AcqRel) { + let _ = self.terminate_tx.send(true); } - match tokio::runtime::Handle::try_current() { - Ok(handle) => { - handle.spawn(async move { - if let Err(err) = child_process.wait().await { - debug!("failed to wait for exec-server stdio child: {err}"); - } - }); + } +} + +impl Drop for StdioTransportHandle { + fn drop(&mut self) { + self.terminate(); + } +} + +fn spawn_stdio_child_supervisor(mut child_process: Child, mut terminate_rx: watch::Receiver) { + let process_group_id = child_process.id(); + tokio::spawn(async move { + tokio::select! { + result = child_process.wait() => { + log_stdio_child_wait_result(result); + kill_process_tree(&mut child_process, process_group_id); } - Err(err) => { - debug!("failed to wait for exec-server stdio child without a Tokio runtime: {err}"); + () = wait_for_stdio_termination(&mut terminate_rx) => { + terminate_stdio_child(&mut child_process, process_group_id).await; } } + }); +} + +async fn wait_for_stdio_termination(terminate_rx: &mut watch::Receiver) { + loop { + if *terminate_rx.borrow() { + return; + } + if terminate_rx.changed().await.is_err() { + return; + } + } +} + +async fn terminate_stdio_child(child_process: &mut Child, process_group_id: Option) { + terminate_process_tree(child_process, process_group_id); + match timeout(STDIO_TERMINATION_GRACE_PERIOD, child_process.wait()).await { + Ok(result) => { + log_stdio_child_wait_result(result); + } + Err(_) => { + kill_process_tree(child_process, process_group_id); + log_stdio_child_wait_result(child_process.wait().await); + } + } +} + +fn terminate_process_tree(child_process: &mut Child, process_group_id: Option) { + let Some(process_group_id) = process_group_id else { + kill_direct_child(child_process, "terminate"); + return; + }; + + #[cfg(unix)] + if let Err(err) = codex_utils_pty::process_group::terminate_process_group(process_group_id) { + warn!("failed to terminate exec-server stdio process group {process_group_id}: {err}"); + kill_direct_child(child_process, "terminate"); + } + + #[cfg(windows)] + if !kill_windows_process_tree(process_group_id) { + kill_direct_child(child_process, "terminate"); + } + + #[cfg(not(any(unix, windows)))] + { + let _ = process_group_id; + kill_direct_child(child_process, "terminate"); + } +} + +fn kill_process_tree(child_process: &mut Child, process_group_id: Option) { + let Some(process_group_id) = process_group_id else { + kill_direct_child(child_process, "kill"); + return; + }; + + #[cfg(unix)] + if let Err(err) = codex_utils_pty::process_group::kill_process_group(process_group_id) { + warn!("failed to kill exec-server stdio process group {process_group_id}: {err}"); + } + + #[cfg(windows)] + if !kill_windows_process_tree(process_group_id) { + kill_direct_child(child_process, "kill"); + } + + #[cfg(not(any(unix, windows)))] + { + let _ = process_group_id; + kill_direct_child(child_process, "kill"); + } +} + +fn kill_direct_child(child_process: &mut Child, action: &str) { + if let Err(err) = child_process.start_kill() { + debug!("failed to {action} exec-server stdio child: {err}"); + } +} + +#[cfg(windows)] +fn kill_windows_process_tree(pid: u32) -> bool { + let pid = pid.to_string(); + match std::process::Command::new("taskkill") + .args(["/PID", pid.as_str(), "/T", "/F"]) + .status() + { + Ok(status) => status.success(), + Err(err) => { + warn!("failed to run taskkill for exec-server stdio process tree {pid}: {err}"); + false + } + } +} + +fn log_stdio_child_wait_result(result: std::io::Result) { + if let Err(err) = result { + debug!("failed to wait for exec-server stdio child: {err}"); } } diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 9ea41f3854..e4f2ff554a 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -227,7 +227,7 @@ pub(crate) struct RpcClient { disconnected_rx: watch::Receiver, next_request_id: AtomicI64, transport_tasks: Vec>, - _transport: JsonRpcTransport, + transport: JsonRpcTransport, reader_task: JoinHandle<()>, } @@ -244,33 +244,38 @@ impl RpcClient { let (event_tx, event_rx) = mpsc::channel(128); let pending_for_reader = Arc::clone(&pending); + let transport_for_reader = transport.clone(); let reader_task = tokio::spawn(async move { - while let Some(event) = incoming_rx.recv().await { + let disconnect_reason = loop { + let Some(event) = incoming_rx.recv().await else { + break None; + }; match event { JsonRpcConnectionEvent::Message(message) => { if let Err(err) = handle_server_message(&pending_for_reader, &event_tx, message).await { let _ = err; - break; + break None; } } JsonRpcConnectionEvent::MalformedMessage { reason } => { let _ = reason; - break; + break None; } JsonRpcConnectionEvent::Disconnected { reason } => { - let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await; - drain_pending(&pending_for_reader).await; - return; + break reason; } } - } + }; let _ = event_tx - .send(RpcClientEvent::Disconnected { reason: None }) + .send(RpcClientEvent::Disconnected { + reason: disconnect_reason, + }) .await; drain_pending(&pending_for_reader).await; + transport_for_reader.terminate(); }); ( @@ -280,7 +285,7 @@ impl RpcClient { disconnected_rx, next_request_id: AtomicI64::new(1), transport_tasks, - _transport: transport, + transport, reader_task, }, event_rx, @@ -370,6 +375,7 @@ impl RpcClient { impl Drop for RpcClient { fn drop(&mut self) { + self.transport.terminate(); for task in &self.transport_tasks { task.abort(); }