diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 4c1ea63062..9d88a41a0c 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -722,12 +722,17 @@ mod tests { #[tokio::test] async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured() -> anyhow::Result<()> { - let (websocket, control, mut outbound_rx) = ControlledWebSocket::new(false); - let mut connection = - JsonRpcConnection::from_websocket_stream(websocket, "test".into(), None); + let (websocket, control, mut outbound_rx) = + ControlledWebSocket::new(/*write_ready*/ false); + let mut connection = JsonRpcConnection::from_websocket_stream( + websocket, + "test".into(), + /*ping_interval*/ None, + ); let message = test_jsonrpc_message(); connection.outgoing_tx.send(message.clone()).await?; + control.wait_for_blocked_write().await?; control.send_inbound(Message::Ping(b"check".to_vec().into()))?; assert!( timeout(Duration::from_millis(50), connection.incoming_rx.recv()) @@ -777,12 +782,16 @@ mod tests { inbound_rx: futures_mpsc::UnboundedReceiver>, outbound_tx: futures_mpsc::UnboundedSender, write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, write_waker: Arc, } struct ControlledWebSocketHandle { inbound_tx: futures_mpsc::UnboundedSender>, write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, write_waker: Arc, } @@ -797,17 +806,23 @@ mod tests { let (inbound_tx, inbound_rx) = futures_mpsc::unbounded(); let (outbound_tx, outbound_rx) = futures_mpsc::unbounded(); let write_ready = Arc::new(AtomicBool::new(write_ready)); + let write_blocked = Arc::new(AtomicBool::new(false)); + let write_blocked_waker = Arc::new(AtomicWaker::new()); let write_waker = Arc::new(AtomicWaker::new()); ( Self { inbound_rx, outbound_tx, write_ready: Arc::clone(&write_ready), + write_blocked: Arc::clone(&write_blocked), + write_blocked_waker: Arc::clone(&write_blocked_waker), write_waker: Arc::clone(&write_waker), }, ControlledWebSocketHandle { inbound_tx, write_ready, + write_blocked, + write_blocked_waker, write_waker, }, outbound_rx, @@ -826,6 +841,22 @@ mod tests { self.write_ready.store(true, Ordering::Release); self.write_waker.wake(); } + + async fn wait_for_blocked_write(&self) -> anyhow::Result<()> { + timeout( + Duration::from_secs(1), + futures::future::poll_fn(|cx| { + if self.write_blocked.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + self.write_blocked_waker.register(cx.waker()); + Poll::Pending + } + }), + ) + .await?; + Ok(()) + } } impl Sink for ControlledWebSocket { @@ -835,6 +866,8 @@ mod tests { if self.write_ready.load(Ordering::Acquire) { Poll::Ready(Ok(())) } else { + self.write_blocked.store(true, Ordering::Release); + self.write_blocked_waker.wake(); self.write_waker.register(cx.waker()); Poll::Pending } diff --git a/codex-rs/exec-server/src/relay.rs b/codex-rs/exec-server/src/relay.rs index 7d3751c67d..4dd607a305 100644 --- a/codex-rs/exec-server/src/relay.rs +++ b/codex-rs/exec-server/src/relay.rs @@ -622,7 +622,8 @@ mod tests { #[tokio::test] async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured() -> anyhow::Result<()> { - let (websocket, control, mut outbound_rx) = ControlledWebSocket::new(true); + let (websocket, control, mut outbound_rx) = + ControlledWebSocket::new(/*write_ready*/ true); let mut connection = harness_connection_from_websocket(websocket, "test".to_string()); let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next()) .await? @@ -635,6 +636,7 @@ mod tests { control.set_write_blocked(); connection.outgoing_tx.send(message.clone()).await?; + control.wait_for_blocked_write().await?; control.send_inbound(Message::Ping(b"check".to_vec().into()))?; assert!( timeout(Duration::from_millis(50), connection.incoming_rx.recv()) @@ -722,12 +724,16 @@ mod tests { inbound_rx: futures_mpsc::UnboundedReceiver>, outbound_tx: futures_mpsc::UnboundedSender, write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, write_waker: Arc, } struct ControlledWebSocketHandle { inbound_tx: futures_mpsc::UnboundedSender>, write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, write_waker: Arc, } @@ -742,17 +748,23 @@ mod tests { let (inbound_tx, inbound_rx) = futures_mpsc::unbounded(); let (outbound_tx, outbound_rx) = futures_mpsc::unbounded(); let write_ready = Arc::new(AtomicBool::new(write_ready)); + let write_blocked = Arc::new(AtomicBool::new(false)); + let write_blocked_waker = Arc::new(AtomicWaker::new()); let write_waker = Arc::new(AtomicWaker::new()); ( Self { inbound_rx, outbound_tx, write_ready: Arc::clone(&write_ready), + write_blocked: Arc::clone(&write_blocked), + write_blocked_waker: Arc::clone(&write_blocked_waker), write_waker: Arc::clone(&write_waker), }, ControlledWebSocketHandle { inbound_tx, write_ready, + write_blocked, + write_blocked_waker, write_waker, }, outbound_rx, @@ -775,6 +787,22 @@ mod tests { self.write_ready.store(true, Ordering::Release); self.write_waker.wake(); } + + async fn wait_for_blocked_write(&self) -> anyhow::Result<()> { + timeout( + Duration::from_secs(1), + futures::future::poll_fn(|cx| { + if self.write_blocked.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + self.write_blocked_waker.register(cx.waker()); + Poll::Pending + } + }), + ) + .await?; + Ok(()) + } } impl Sink for ControlledWebSocket { @@ -784,6 +812,8 @@ mod tests { if self.write_ready.load(Ordering::Acquire) { Poll::Ready(Ok(())) } else { + self.write_blocked.store(true, Ordering::Release); + self.write_blocked_waker.wake(); self.write_waker.register(cx.waker()); Poll::Pending }