diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 367d83f15d..bd9c9ccce3 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -5,7 +5,6 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::process::Child; use tokio::sync::mpsc; -use tokio::sync::watch; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; use tracing::debug; @@ -75,7 +74,6 @@ impl StdioTransport { struct JsonRpcConnectionRuntime { outgoing_tx: mpsc::Sender, incoming_rx: mpsc::Receiver, - disconnected_rx: watch::Receiver, task_handles: Vec>, } @@ -98,11 +96,9 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); - let (disconnected_tx, disconnected_rx) = watch::channel(false); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); - let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { let mut lines = BufReader::new(reader).lines(); loop { @@ -133,18 +129,12 @@ impl JsonRpcConnection { } } Ok(None) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; break; } Err(err) => { send_disconnected( &incoming_tx_for_reader, - &disconnected_tx_for_reader, Some(format!( "failed to read JSON-RPC message from {reader_label}: {err}" )), @@ -162,7 +152,6 @@ impl JsonRpcConnection { if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await { send_disconnected( &incoming_tx, - &disconnected_tx, Some(format!( "failed to write JSON-RPC message to {connection_label}: {err}" )), @@ -177,7 +166,6 @@ impl JsonRpcConnection { runtime: Some(JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, - disconnected_rx, task_handles: vec![reader_task, writer_task], }), transport: JsonRpcTransport::Plain, @@ -190,12 +178,10 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); - let (disconnected_tx, disconnected_rx) = watch::channel(false); let (mut websocket_writer, mut websocket_reader) = stream.split(); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); - let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { loop { match websocket_reader.next().await { @@ -244,12 +230,7 @@ impl JsonRpcConnection { } } Some(Ok(Message::Close(_))) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; break; } Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} @@ -257,7 +238,6 @@ impl JsonRpcConnection { Some(Err(err)) => { send_disconnected( &incoming_tx_for_reader, - &disconnected_tx_for_reader, Some(format!( "failed to read websocket JSON-RPC message from {reader_label}: {err}" )), @@ -266,12 +246,7 @@ impl JsonRpcConnection { break; } None => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; break; } } @@ -286,7 +261,6 @@ impl JsonRpcConnection { { send_disconnected( &incoming_tx, - &disconnected_tx, Some(format!( "failed to write websocket JSON-RPC message to {connection_label}: {err}" )), @@ -298,7 +272,6 @@ impl JsonRpcConnection { Err(err) => { send_disconnected( &incoming_tx, - &disconnected_tx, Some(format!( "failed to serialize JSON-RPC message for {connection_label}: {err}" )), @@ -314,7 +287,6 @@ impl JsonRpcConnection { runtime: Some(JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, - disconnected_rx, task_handles: vec![reader_task, writer_task], }), transport: JsonRpcTransport::Plain, @@ -326,16 +298,14 @@ impl JsonRpcConnection { ) -> ( mpsc::Sender, mpsc::Receiver, - watch::Receiver, Vec>, ) { let JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, - disconnected_rx, task_handles, } = self.take_runtime("JSON-RPC client runtime already taken"); - (outgoing_tx, incoming_rx, disconnected_rx, task_handles) + (outgoing_tx, incoming_rx, task_handles) } pub(crate) fn with_child_process(mut self, child_process: Child) -> Self { @@ -348,16 +318,14 @@ impl JsonRpcConnection { ) -> ( mpsc::Sender, mpsc::Receiver, - watch::Receiver, Vec>, ) { let JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, - disconnected_rx, task_handles, } = self.take_runtime("JSON-RPC connection parts already taken"); - (outgoing_tx, incoming_rx, disconnected_rx, task_handles) + (outgoing_tx, incoming_rx, task_handles) } fn take_runtime(&mut self, message: &'static str) -> JsonRpcConnectionRuntime { @@ -370,10 +338,8 @@ impl JsonRpcConnection { async fn send_disconnected( incoming_tx: &mpsc::Sender, - disconnected_tx: &watch::Sender, reason: Option, ) { - let _ = disconnected_tx.send(true); let _ = incoming_tx .send(JsonRpcConnectionEvent::Disconnected { reason }) .await; diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 82948b920c..65cc363aa7 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -18,7 +19,6 @@ use serde_json::Value; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::sync::oneshot; -use tokio::sync::watch; use tokio::task::JoinHandle; use crate::connection::JsonRpcConnection; @@ -223,10 +223,9 @@ where pub(crate) struct RpcClient { write_tx: mpsc::Sender, pending: Arc>>, - // This flips when either the underlying transport closes or the RPC reader - // exits, so new calls fail quickly after the connection is no longer usable. - closed_rx: watch::Receiver, - transport_disconnected_rx: watch::Receiver, + // This flips before the ordered RPC reader drains pending requests, so new + // calls fail instead of registering work that can never complete. + closed: Arc, next_request_id: AtomicI64, transport_tasks: Vec>, reader_task: JoinHandle<()>, @@ -236,13 +235,13 @@ impl RpcClient { pub(crate) fn new( connection: &mut JsonRpcConnection, ) -> (Self, mpsc::Receiver) { - let (write_tx, mut incoming_rx, transport_disconnected_rx, transport_tasks) = - connection.take_client_runtime(); + let (write_tx, mut incoming_rx, transport_tasks) = connection.take_client_runtime(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); - let (closed_tx, closed_rx) = watch::channel(*transport_disconnected_rx.borrow()); + let closed = Arc::new(AtomicBool::new(false)); let pending_for_reader = Arc::clone(&pending); + let closed_for_reader = Arc::clone(&closed); let reader_task = tokio::spawn(async move { let reason = loop { let Some(event) = incoming_rx.recv().await else { @@ -264,7 +263,7 @@ impl RpcClient { } }; - let _ = closed_tx.send(true); + closed_for_reader.store(true, Ordering::SeqCst); let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await; drain_pending(&pending_for_reader).await; }); @@ -273,8 +272,7 @@ impl RpcClient { Self { write_tx, pending, - closed_rx, - transport_disconnected_rx, + closed, next_request_id: AtomicI64::new(1), transport_tasks, reader_task, @@ -289,7 +287,7 @@ impl RpcClient { params: &P, ) -> Result<(), serde_json::Error> { let params = serde_json::to_value(params)?; - if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() { + if self.closed.load(Ordering::SeqCst) { return Err(serde_json::Error::io(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "JSON-RPC transport closed", @@ -321,7 +319,7 @@ impl RpcClient { // Registering the pending request and checking disconnect must be // atomic with the reader's drain_pending path. Otherwise a call // can sneak in after the drain and wait forever. - if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() { + if self.closed.load(Ordering::SeqCst) { return Err(RpcCallError::Closed); } pending.insert(request_id.clone(), response_tx); diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index dc1a9b9ffe..132de63921 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -47,8 +47,7 @@ async fn run_connection( runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); - let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) = - connection.into_parts(); + let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); let notifications = RpcNotificationSender::new(outgoing_tx.clone()); @@ -96,13 +95,7 @@ async fn run_connection( JsonRpcConnectionEvent::Message(message) => match message { codex_app_server_protocol::JSONRPCMessage::Request(request) => { if let Some(route) = router.request_route(request.method.as_str()) { - let message = tokio::select! { - message = route(Arc::clone(&handler), request) => message, - _ = disconnected_rx.changed() => { - debug!("exec-server transport disconnected while handling request"); - break; - } - }; + let message = route(Arc::clone(&handler), request).await; if let Some(message) = message && outgoing_tx.send(message).await.is_err() { @@ -131,15 +124,7 @@ async fn run_connection( ); break; }; - let result = tokio::select! { - result = route(Arc::clone(&handler), notification) => result, - _ = disconnected_rx.changed() => { - debug!( - "exec-server transport disconnected while handling notification" - ); - break; - } - }; + let result = route(Arc::clone(&handler), notification).await; if let Err(err) = result { warn!("closing exec-server connection after protocol error: {err}"); break;