diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 908a9e1b05..c49e9cdb98 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -114,7 +114,7 @@ impl ExecServerClient { stdin, format!("exec-server stdio command `{shell_command}`"), ) - .with_lifetime_guard(Box::new(StdioChildGuard { child: Some(child) })), + .with_transport_lifetime(Box::new(StdioChildGuard { child: Some(child) })), args.into(), ) .await diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index d76cbcae86..0ec4c6bc5f 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -15,13 +15,13 @@ use tokio::io::BufWriter; pub(crate) const CHANNEL_CAPACITY: usize = 128; -pub(crate) type JsonRpcConnectionLifetimeGuard = Box; +pub(crate) type JsonRpcTransportLifetime = Box; pub(crate) type JsonRpcConnectionParts = ( mpsc::Sender, mpsc::Receiver, watch::Receiver, Vec>, - Option, + Option, ); #[derive(Debug)] @@ -36,7 +36,7 @@ pub(crate) struct JsonRpcConnection { incoming_rx: mpsc::Receiver, disconnected_rx: watch::Receiver, task_handles: Vec>, - lifetime_guard: Option, + transport_lifetime: Option, } impl JsonRpcConnection { @@ -127,7 +127,7 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![reader_task, writer_task], - lifetime_guard: None, + transport_lifetime: None, } } @@ -262,12 +262,12 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![reader_task, writer_task], - lifetime_guard: None, + transport_lifetime: None, } } - pub(crate) fn with_lifetime_guard(mut self, guard: JsonRpcConnectionLifetimeGuard) -> Self { - self.lifetime_guard = Some(guard); + pub(crate) fn with_transport_lifetime(mut self, lifetime: JsonRpcTransportLifetime) -> Self { + self.transport_lifetime = Some(lifetime); self } @@ -277,7 +277,7 @@ impl JsonRpcConnection { self.incoming_rx, self.disconnected_rx, self.task_handles, - self.lifetime_guard, + self.transport_lifetime, ) } } diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 8985849ec9..d9d8fbbf72 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -24,7 +24,7 @@ use tokio::task::JoinHandle; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; -use crate::connection::JsonRpcConnectionLifetimeGuard; +use crate::connection::JsonRpcTransportLifetime; #[derive(Debug)] pub(crate) enum RpcCallError { @@ -231,13 +231,19 @@ pub(crate) struct RpcClient { disconnected_rx: watch::Receiver, next_request_id: AtomicI64, transport_tasks: Vec>, - _transport_lifetime_guard: Option>, + _transport_lifetime: Option, reader_task: JoinHandle<()>, } +// Holds transport-owned resources, such as a stdio child process, for as long +// as the RPC client owns the underlying connection. +struct TransportLifetime { + _guard: StdMutex, +} + impl RpcClient { pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks, lifetime_guard) = + let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks, transport_lifetime) = connection.into_parts(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); @@ -279,7 +285,9 @@ impl RpcClient { disconnected_rx, next_request_id: AtomicI64::new(1), transport_tasks, - _transport_lifetime_guard: lifetime_guard.map(StdMutex::new), + _transport_lifetime: transport_lifetime.map(|lifetime| TransportLifetime { + _guard: StdMutex::new(lifetime), + }), reader_task, }, event_rx, diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 50907f7e42..b7e1a03bd5 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -47,8 +47,13 @@ async fn run_connection( runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); - let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks, _lifetime_guard) = - connection.into_parts(); + let ( + json_outgoing_tx, + mut incoming_rx, + mut disconnected_rx, + connection_tasks, + _transport_lifetime, + ) = connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); let notifications = RpcNotificationSender::new(outgoing_tx.clone());