diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 2842838ce4..37f3e2455a 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -12,6 +12,7 @@ use futures::FutureExt; use futures::future::BoxFuture; use serde_json::Value; use tokio::sync::Mutex; +use tokio::sync::OnceCell; use tokio::sync::mpsc; use tokio::sync::watch; @@ -152,6 +153,9 @@ pub(crate) struct Session { struct Inner { client: RpcClient, + // Keep the connection alive for any transport-specific owned state such as + // the stdio child process. RpcClient only takes the runtime channels/tasks. + _connection: JsonRpcConnection, // The remote transport delivers one shared notification stream for every // process on the connection. Keep a local process_id -> session registry so // we can turn those connection-global notifications into process wakeups @@ -191,28 +195,25 @@ pub struct ExecServerClient { #[derive(Clone)] pub(crate) struct LazyRemoteExecServerClient { transport: ExecServerTransport, - client: Arc>>, + client: Arc>, } impl LazyRemoteExecServerClient { pub(crate) fn new(transport: ExecServerTransport) -> Self { Self { transport, - client: Arc::new(Mutex::new(None)), + client: Arc::new(OnceCell::new()), } } pub(crate) async fn get(&self) -> Result { - let mut client = self.client.lock().await; - if let Some(client) = client.as_ref() - && !client.is_disconnected() - { - return Ok(client.clone()); - } - - let connected = ExecServerClient::connect_for_environment(self.transport.clone()).await?; - *client = Some(connected.clone()); - Ok(connected) + self.client + .get_or_try_init(|| { + let transport = self.transport.clone(); + async move { ExecServerClient::connect_for_environment(transport).await } + }) + .await + .cloned() } } @@ -276,10 +277,6 @@ pub enum ExecServerError { } impl ExecServerClient { - fn is_disconnected(&self) -> bool { - self.inner.disconnected_error().is_some() || self.inner.client.is_disconnected() - } - pub async fn initialize( &self, options: ExecServerClientConnectOptions, @@ -429,10 +426,10 @@ impl ExecServerClient { } pub(crate) async fn connect( - connection: JsonRpcConnection, + mut connection: JsonRpcConnection, options: ExecServerClientConnectOptions, ) -> Result { - let (rpc_client, mut events_rx) = RpcClient::new(connection); + let (rpc_client, mut events_rx) = RpcClient::new(&mut connection); let inner = Arc::new_cyclic(|weak| { let weak = weak.clone(); let reader_task = tokio::spawn(async move { @@ -467,6 +464,7 @@ impl ExecServerClient { Inner { client: rpc_client, + _connection: connection, sessions: ArcSwap::from_pointee(HashMap::new()), sessions_write_lock: Mutex::new(()), disconnected: OnceLock::new(), diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index d6d1c92deb..df9d84beab 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -3,9 +3,7 @@ use std::time::Duration; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; -use tokio::process::Child; use tokio::process::Command; -use tokio::sync::oneshot; use tokio::time::timeout; use tokio_tungstenite::connect_async; use tracing::debug; @@ -79,7 +77,6 @@ impl ExecServerClient { args: StdioExecServerConnectArgs, ) -> Result { let mut child = stdio_command_process(&args.command) - .kill_on_drop(true) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -110,60 +107,13 @@ impl ExecServerClient { Self::connect( JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string()) - .with_transport_lifetime(Box::new(StdioChildGuard::spawn(child))), + .with_stdio_child(child), args.into(), ) .await } } -struct StdioChildGuard { - shutdown_tx: Option>, -} - -impl StdioChildGuard { - fn spawn(child: Child) -> Self { - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - tokio::spawn(supervise_stdio_child(child, shutdown_rx)); - Self { - shutdown_tx: Some(shutdown_tx), - } - } -} - -impl Drop for StdioChildGuard { - fn drop(&mut self) { - if let Some(shutdown_tx) = self.shutdown_tx.take() { - let _ = shutdown_tx.send(()); - } - } -} - -async fn supervise_stdio_child(mut child: Child, shutdown_rx: oneshot::Receiver<()>) { - let shutdown_requested = tokio::select! { - result = child.wait() => { - if let Err(err) = result { - debug!("failed to wait for exec-server stdio child: {err}"); - } - false - } - _ = shutdown_rx => true, - }; - - if shutdown_requested { - kill_stdio_child(&mut child); - if let Err(err) = child.wait().await { - debug!("failed to wait for exec-server stdio child after shutdown: {err}"); - } - } -} - -fn kill_stdio_child(child: &mut Child) { - if let Err(err) = child.start_kill() { - debug!("failed to terminate exec-server stdio child: {err}"); - } -} - fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command { let mut command = Command::new(&stdio_command.program); command.args(&stdio_command.args); diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index c0a4a42a54..f7832e5a34 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -3,10 +3,12 @@ use futures::SinkExt; use futures::StreamExt; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; +use tokio::process::Child; use tokio::sync::mpsc; use tokio::sync::watch; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; +use tracing::debug; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; @@ -15,16 +17,6 @@ use tokio::io::BufWriter; pub(crate) const CHANNEL_CAPACITY: usize = 128; -pub(crate) type JsonRpcTransportLifetime = Box; - -pub(crate) struct JsonRpcConnectionParts { - pub(crate) outgoing_tx: mpsc::Sender, - pub(crate) incoming_rx: mpsc::Receiver, - pub(crate) disconnected_rx: watch::Receiver, - pub(crate) task_handles: Vec>, - pub(crate) transport_lifetime: Option, -} - #[derive(Debug)] pub(crate) enum JsonRpcConnectionEvent { Message(JSONRPCMessage), @@ -32,12 +24,24 @@ pub(crate) enum JsonRpcConnectionEvent { Disconnected { reason: Option }, } +struct StdioTransport { + child: Child, +} + +impl Drop for StdioTransport { + fn drop(&mut self) { + if let Err(err) = self.child.start_kill() { + debug!("failed to terminate exec-server stdio child: {err}"); + } + } +} + pub(crate) struct JsonRpcConnection { - outgoing_tx: mpsc::Sender, - incoming_rx: mpsc::Receiver, - disconnected_rx: watch::Receiver, + outgoing_tx: Option>, + incoming_rx: Option>, + disconnected_rx: Option>, task_handles: Vec>, - transport_lifetime: Option, + _stdio_transport: Option, } impl JsonRpcConnection { @@ -124,11 +128,11 @@ impl JsonRpcConnection { }); Self { - outgoing_tx, - incoming_rx, - disconnected_rx, + outgoing_tx: Some(outgoing_tx), + incoming_rx: Some(incoming_rx), + disconnected_rx: Some(disconnected_rx), task_handles: vec![reader_task, writer_task], - transport_lifetime: None, + _stdio_transport: None, } } @@ -259,16 +263,38 @@ impl JsonRpcConnection { }); Self { - outgoing_tx, - incoming_rx, - disconnected_rx, + outgoing_tx: Some(outgoing_tx), + incoming_rx: Some(incoming_rx), + disconnected_rx: Some(disconnected_rx), task_handles: vec![reader_task, writer_task], - transport_lifetime: None, + _stdio_transport: None, } } - pub(crate) fn with_transport_lifetime(mut self, lifetime: JsonRpcTransportLifetime) -> Self { - self.transport_lifetime = Some(lifetime); + pub(crate) fn take_client_runtime( + &mut self, + ) -> ( + mpsc::Sender, + mpsc::Receiver, + watch::Receiver, + Vec>, + ) { + ( + self.outgoing_tx + .take() + .expect("JSON-RPC client runtime already taken"), + self.incoming_rx + .take() + .expect("JSON-RPC client runtime already taken"), + self.disconnected_rx + .take() + .expect("JSON-RPC client runtime already taken"), + std::mem::take(&mut self.task_handles), + ) + } + + pub(crate) fn with_stdio_child(mut self, child: Child) -> Self { + self._stdio_transport = Some(StdioTransport { child }); self } @@ -281,22 +307,15 @@ impl JsonRpcConnection { Vec>, ) { ( - self.outgoing_tx, - self.incoming_rx, - self.disconnected_rx, + self.outgoing_tx + .expect("JSON-RPC connection parts already taken"), + self.incoming_rx + .expect("JSON-RPC connection parts already taken"), + self.disconnected_rx + .expect("JSON-RPC connection parts already taken"), self.task_handles, ) } - - pub(crate) fn into_parts_with_lifetime(self) -> JsonRpcConnectionParts { - JsonRpcConnectionParts { - outgoing_tx: self.outgoing_tx, - incoming_rx: self.incoming_rx, - disconnected_rx: self.disconnected_rx, - task_handles: self.task_handles, - transport_lifetime: self.transport_lifetime, - } - } } async fn send_disconnected( diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 49438b4fd8..a9c77d549a 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -233,19 +233,16 @@ pub(crate) struct RpcClient { } impl RpcClient { - pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let connection_parts = connection.into_parts_with_lifetime(); - let write_tx = connection_parts.outgoing_tx; - let mut incoming_rx = connection_parts.incoming_rx; - let disconnected_rx = connection_parts.disconnected_rx; - let transport_tasks = connection_parts.task_handles; - let transport_lifetime = connection_parts.transport_lifetime; + pub(crate) fn new( + connection: &mut JsonRpcConnection, + ) -> (Self, mpsc::Receiver) { + let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = + connection.take_client_runtime(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); let pending_for_reader = Arc::clone(&pending); let reader_task = tokio::spawn(async move { - let _transport_lifetime = transport_lifetime; while let Some(event) = incoming_rx.recv().await { match event { JsonRpcConnectionEvent::Message(message) => { @@ -307,10 +304,6 @@ impl RpcClient { }) } - pub(crate) fn is_disconnected(&self) -> bool { - *self.disconnected_rx.borrow() - } - pub(crate) async fn call(&self, method: &str, params: &P) -> Result where P: Serialize, @@ -575,11 +568,9 @@ mod tests { async fn rpc_client_matches_out_of_order_responses_by_request_id() { let (client_stdin, server_reader) = tokio::io::duplex(4096); let (mut server_writer, client_stdout) = tokio::io::duplex(4096); - let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio( - client_stdout, - client_stdin, - "test-rpc".to_string(), - )); + let mut connection = + JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string()); + let (client, _events_rx) = RpcClient::new(&mut connection); let server = tokio::spawn(async move { let mut lines = BufReader::new(server_reader).lines();