Fix websocket backpressure regression tests

This commit is contained in:
starr-openai
2026-05-18 15:36:41 -07:00
parent 9215e15ee3
commit 6d6cdeb128
2 changed files with 67 additions and 4 deletions

View File

@@ -722,12 +722,17 @@ mod tests {
#[tokio::test]
async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) = ControlledWebSocket::new(false);
let mut connection =
JsonRpcConnection::from_websocket_stream(websocket, "test".into(), None);
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ false);
let mut connection = JsonRpcConnection::from_websocket_stream(
websocket,
"test".into(),
/*ping_interval*/ None,
);
let message = test_jsonrpc_message();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Ping(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
@@ -777,12 +782,16 @@ mod tests {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
@@ -797,17 +806,23 @@ mod tests {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
@@ -826,6 +841,22 @@ mod tests {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
@@ -835,6 +866,8 @@ mod tests {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}

View File

@@ -622,7 +622,8 @@ mod tests {
#[tokio::test]
async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) = ControlledWebSocket::new(true);
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ true);
let mut connection = harness_connection_from_websocket(websocket, "test".to_string());
let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
.await?
@@ -635,6 +636,7 @@ mod tests {
control.set_write_blocked();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Ping(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
@@ -722,12 +724,16 @@ mod tests {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
@@ -742,17 +748,23 @@ mod tests {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
@@ -775,6 +787,22 @@ mod tests {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
@@ -784,6 +812,8 @@ mod tests {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}