From 8a9300e92ad191d0268a4c946bf97e2a5f2cea5f Mon Sep 17 00:00:00 2001 From: starr-openai Date: Mon, 18 May 2026 10:04:34 -0700 Subject: [PATCH] Refactor exec-server websocket pump --- codex-rs/exec-server/src/connection.rs | 323 ++++++++++++------------ codex-rs/exec-server/src/relay.rs | 325 ++++++++++++------------- 2 files changed, 308 insertions(+), 340 deletions(-) diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index cf504bbea7..0974a21c3a 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -323,37 +323,20 @@ impl JsonRpcConnection { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (websocket_writer, websocket_reader) = stream.split(); - Self::from_websocket_parts( - websocket_writer, - websocket_reader, - connection_label, - Some(WEBSOCKET_KEEPALIVE_INTERVAL), - ) + Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None) } pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self { - let (websocket_writer, websocket_reader) = stream.split(); - Self::from_websocket_parts( - websocket_writer, - websocket_reader, - connection_label, - // Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients - // own keepalive pings so one side does not accidentally create redundant traffic. - /*keepalive_interval*/ - None, - ) + Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL)) } - fn from_websocket_parts( - mut websocket_writer: W, - mut websocket_reader: R, + fn from_websocket_stream( + mut websocket: T, connection_label: String, - keepalive_interval: Option, + ping_interval: Option, ) -> Self where - W: Sink + Unpin + Send + 'static, - R: Stream> + Unpin + Send + 'static, + T: Sink + Stream> + Unpin + Send + 'static, M: JsonRpcWebSocketMessage, E: std::fmt::Display + Send + 'static, { @@ -361,118 +344,119 @@ impl JsonRpcConnection { let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let reader_label = connection_label.clone(); - let incoming_tx_for_reader = incoming_tx.clone(); - let disconnected_tx_for_reader = disconnected_tx.clone(); - let reader_task = tokio::spawn(async move { + let websocket_task = tokio::spawn(async move { + let mut ping_interval = ping_interval.map(|ping_interval| { + let mut interval = tokio::time::interval_at( + tokio::time::Instant::now() + ping_interval, + ping_interval, + ); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + interval + }); + loop { - match websocket_reader.next().await { - Some(Ok(message)) => match message.parse_jsonrpc_frame() { - Ok(JsonRpcWebSocketFrame::Message(message)) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } + tokio::select! { + maybe_message = outgoing_rx.recv() => { + let Some(message) = maybe_message else { + break; + }; + if let Err(reason) = send_websocket_jsonrpc_message( + &mut websocket, + &connection_label, + &message, + ) + .await + { + send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; + break; } - Err(err) => { - send_malformed_message( - &incoming_tx_for_reader, - Some(format!( - "failed to parse websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; + } + _ = async { + match ping_interval.as_mut() { + Some(interval) => interval.tick().await, + None => std::future::pending().await, } - Ok(JsonRpcWebSocketFrame::Close) => { + } => { + if let Err(err) = websocket.send(M::ping()).await { send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, + &incoming_tx, + &disconnected_tx, + Some(format!( + "failed to write websocket ping to {connection_label}: {err}" + )), ) .await; break; } - Ok(JsonRpcWebSocketFrame::Ignore) => {} - }, - Some(Err(err)) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - Some(format!( - "failed to read websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; - break; } - None => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; - break; - } - } - } - }); - - let writer_task = tokio::spawn(async move { - if let Some(keepalive_interval) = keepalive_interval { - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + keepalive_interval, - keepalive_interval, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - maybe_message = outgoing_rx.recv() => { - let Some(message) = maybe_message else { - break; - }; - if let Err(reason) = send_websocket_jsonrpc_message( - &mut websocket_writer, - &connection_label, - &message, - ) - .await - { - send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; - break; - } - } - _ = keepalive.tick() => { - if let Err(err) = websocket_writer.send(M::ping()).await { + incoming_message = websocket.next() => { + match incoming_message { + Some(Ok(message)) => match message.parse_jsonrpc_frame() { + Ok(JsonRpcWebSocketFrame::Message(message)) => { + if incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Ok(JsonRpcWebSocketFrame::Ping(payload)) => { + if let Err(err) = websocket.send(M::pong(payload)).await { + send_disconnected( + &incoming_tx, + &disconnected_tx, + Some(format!( + "failed to write websocket pong to {connection_label}: {err}" + )), + ) + .await; + break; + } + } + Ok(JsonRpcWebSocketFrame::Close) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + /*reason*/ None, + ) + .await; + break; + } + Ok(JsonRpcWebSocketFrame::Ignore) => {} + Err(err) => { + send_malformed_message( + &incoming_tx, + Some(format!( + "failed to parse websocket JSON-RPC message from {connection_label}: {err}" + )), + ) + .await; + } + }, + Some(Err(err)) => { send_disconnected( &incoming_tx, &disconnected_tx, Some(format!( - "failed to write websocket ping to {connection_label}: {err}" + "failed to read websocket JSON-RPC message from {connection_label}: {err}" )), ) .await; break; } + None => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + /*reason*/ None, + ) + .await; + break; + } } } } - } else { - while let Some(message) = outgoing_rx.recv().await { - if let Err(reason) = send_websocket_jsonrpc_message( - &mut websocket_writer, - &connection_label, - &message, - ) - .await - { - send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; - break; - } - } } }); @@ -480,7 +464,7 @@ impl JsonRpcConnection { outgoing_tx, incoming_rx, disconnected_rx, - task_handles: vec![reader_task, writer_task], + task_handles: vec![websocket_task], transport: JsonRpcTransport::Plain, } } @@ -493,6 +477,7 @@ impl JsonRpcConnection { enum JsonRpcWebSocketFrame { Message(JSONRPCMessage), + Ping(bytes::Bytes), Close, Ignore, } @@ -501,6 +486,7 @@ trait JsonRpcWebSocketMessage: Send + 'static { fn parse_jsonrpc_frame(self) -> Result; fn from_text(text: String) -> Self; fn ping() -> Self; + fn pong(payload: bytes::Bytes) -> Self; } impl JsonRpcWebSocketMessage for Message { @@ -513,9 +499,8 @@ impl JsonRpcWebSocketMessage for Message { serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message) } Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close), - Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => { - Ok(JsonRpcWebSocketFrame::Ignore) - } + Message::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)), + Message::Pong(_) | Message::Frame(_) => Ok(JsonRpcWebSocketFrame::Ignore), } } @@ -526,6 +511,10 @@ impl JsonRpcWebSocketMessage for Message { fn ping() -> Self { Self::Ping(Vec::new().into()) } + + fn pong(payload: bytes::Bytes) -> Self { + Self::Pong(payload) + } } impl JsonRpcWebSocketMessage for AxumWebSocketMessage { @@ -538,9 +527,8 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage { serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message) } AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close), - AxumWebSocketMessage::Ping(_) | AxumWebSocketMessage::Pong(_) => { - Ok(JsonRpcWebSocketFrame::Ignore) - } + AxumWebSocketMessage::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)), + AxumWebSocketMessage::Pong(_) => Ok(JsonRpcWebSocketFrame::Ignore), } } @@ -551,6 +539,10 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage { fn ping() -> Self { Self::Ping(Vec::new().into()) } + + fn pong(payload: bytes::Bytes) -> Self { + Self::Pong(payload) + } } async fn send_disconnected( @@ -618,71 +610,60 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result, - } + #[tokio::test] + async fn websocket_connection_pongs_server_ping() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let connection = JsonRpcConnection::from_websocket(client_websocket, "test".into()); - impl Sink for TestWebSocketSink { - type Error = std::convert::Infallible; + server_websocket + .send(Message::Ping(b"check".to_vec().into())) + .await?; + let message = timeout(Duration::from_secs(1), server_websocket.next()) + .await? + .expect("websocket should stay open")?; + assert_eq!(message, Message::Pong(b"check".to_vec().into())); - fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.get_mut() - .message_tx - .unbounded_send(item) - .expect("test websocket receiver should stay open"); - Ok(()) - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } + drop(connection); + Ok(()) } #[tokio::test] - async fn websocket_connection_sends_keepalive_ping() { - let (message_tx, mut message_rx) = futures_mpsc::unbounded::(); - let websocket_writer = TestWebSocketSink { message_tx }; - let websocket_reader = stream::pending::>(); - let connection = JsonRpcConnection::from_websocket_parts( - websocket_writer, - websocket_reader, + async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let connection = JsonRpcConnection::from_websocket_stream( + client_websocket, "test".into(), Some(WEBSOCKET_KEEPALIVE_INTERVAL), ); - let message = timeout(Duration::from_secs(1), message_rx.next()) - .await - .expect("keepalive ping should arrive before timeout") - .expect("keepalive ping should be sent"); + let message = timeout(Duration::from_secs(1), server_websocket.next()) + .await? + .expect("websocket should stay open")?; assert!(matches!(message, Message::Ping(_))); drop(connection); + Ok(()) + } + + async fn websocket_pair() -> anyhow::Result<( + WebSocketStream>, + WebSocketStream, + )> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let websocket_url = format!("ws://{}", listener.local_addr()?); + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + accept_async(stream).await.map_err(anyhow::Error::from) + }); + let (client_websocket, _) = connect_async(websocket_url).await?; + let server_websocket = server_task.await??; + Ok((client_websocket, server_websocket)) } } diff --git a/codex-rs/exec-server/src/relay.rs b/codex-rs/exec-server/src/relay.rs index 7470a6290c..471a2e658e 100644 --- a/codex-rs/exec-server/src/relay.rs +++ b/codex-rs/exec-server/src/relay.rs @@ -19,7 +19,6 @@ use crate::connection::CHANNEL_CAPACITY; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; use crate::connection::JsonRpcTransport; -use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL; use crate::relay_proto::RelayData; use crate::relay_proto::RelayMessageFrame; use crate::relay_proto::RelayResume; @@ -148,113 +147,16 @@ where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let stream_id = Uuid::new_v4().to_string(); - let (mut websocket_writer, mut websocket_reader) = stream.split(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let reader_label = connection_label; - let reader_stream_id = stream_id.clone(); - let incoming_tx_for_reader = incoming_tx; - let disconnected_tx_for_reader = disconnected_tx.clone(); - let reader_task = tokio::spawn(async move { - loop { - match websocket_reader.next().await { - Some(Ok(Message::Binary(payload))) => { - let frame = match decode_relay_message_frame(payload.as_ref()) { - Ok(frame) => frame, - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: format!( - "failed to parse relay message frame from {reader_label}: {err}" - ), - }) - .await; - continue; - } - }; - if frame.stream_id != reader_stream_id { - continue; - } - let kind = match frame.validate() { - Ok(kind) => kind, - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: err.to_string(), - }) - .await; - continue; - } - }; - match kind { - RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() { - Ok(message) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } - } - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: err.to_string(), - }) - .await; - } - }, - RelayFrameBodyKind::Reset => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { - reason: frame.into_reset_reason(), - }) - .await; - break; - } - RelayFrameBodyKind::Ack - | RelayFrameBodyKind::Resume - | RelayFrameBodyKind::Heartbeat => {} - } - } - Some(Ok(Message::Close(_))) | None => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { reason: None }) - .await; - break; - } - Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {} - Some(Ok(Message::Text(_))) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: "relay exec-server transport expects binary protobuf frames" - .to_string(), - }) - .await; - } - Some(Err(err)) => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { - reason: Some(format!( - "failed to read relay websocket frame from {reader_label}: {err}" - )), - }) - .await; - break; - } - } - } - }); - - let writer_task = tokio::spawn(async move { + let websocket_task = tokio::spawn(async move { + let mut websocket = stream; + let reader_label = connection_label; + let reader_stream_id = stream_id.clone(); let resume = RelayMessageFrame::resume(stream_id.clone()); - if websocket_writer + if websocket .send(Message::Binary(encode_relay_message_frame(&resume).into())) .await .is_err() @@ -263,11 +165,6 @@ where return; } - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL, - WEBSOCKET_KEEPALIVE_INTERVAL, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut next_seq = 0u32; loop { tokio::select! { @@ -284,7 +181,7 @@ where }; let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload); next_seq = next_seq.wrapping_add(1); - if websocket_writer + if websocket .send(Message::Binary(encode_relay_message_frame(&frame).into())) .await .is_err() @@ -293,10 +190,102 @@ where break; } } - _ = keepalive.tick() => { - if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() { - let _ = disconnected_tx.send(true); - break; + incoming_message = websocket.next() => { + match incoming_message { + Some(Ok(Message::Binary(payload))) => { + let frame = match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: format!( + "failed to parse relay message frame from {reader_label}: {err}" + ), + }) + .await; + continue; + } + }; + if frame.stream_id != reader_stream_id { + continue; + } + let kind = match frame.validate() { + Ok(kind) => kind, + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + continue; + } + }; + match kind { + RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() { + Ok(message) => { + if incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + } + }, + RelayFrameBodyKind::Reset => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { + reason: frame.into_reset_reason(), + }) + .await; + break; + } + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat => {} + } + } + Some(Ok(Message::Ping(payload))) => { + if websocket.send(Message::Pong(payload)).await.is_err() { + let _ = disconnected_tx.send(true); + break; + } + } + Some(Ok(Message::Close(_))) | None => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason: None }) + .await; + break; + } + Some(Ok(Message::Pong(_) | Message::Frame(_))) => {} + Some(Ok(Message::Text(_))) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: "relay exec-server transport expects binary protobuf frames" + .to_string(), + }) + .await; + } + Some(Err(err)) => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { + reason: Some(format!( + "failed to read relay websocket frame from {reader_label}: {err}" + )), + }) + .await; + break; + } } } } @@ -307,7 +296,7 @@ where outgoing_tx, incoming_rx, disconnected_rx, - task_handles: vec![reader_task, writer_task], + task_handles: vec![websocket_task], transport: JsonRpcTransport::Plain, } } @@ -318,59 +307,48 @@ pub(crate) async fn run_multiplexed_executor( ) where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut websocket_writer, mut websocket_reader) = stream.split(); + let mut websocket = stream; let (physical_outgoing_tx, mut physical_outgoing_rx) = mpsc::channel::>(CHANNEL_CAPACITY); - let writer_task = tokio::spawn(async move { - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL, - WEBSOCKET_KEEPALIVE_INTERVAL, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - maybe_encoded = physical_outgoing_rx.recv() => { - let Some(encoded) = maybe_encoded else { - break; - }; - if websocket_writer - .send(Message::Binary(encoded.into())) - .await - .is_err() - { - break; - } - } - _ = keepalive.tick() => { - if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() { - break; - } - } - } - } - }); let mut streams: HashMap = HashMap::new(); loop { - let frame = match websocket_reader.next().await { - Some(Ok(Message::Binary(payload))) => { - match decode_relay_message_frame(payload.as_ref()) { - Ok(frame) => frame, - Err(err) => { - warn!("dropping malformed relay message frame from harness: {err}"); - continue; - } + let frame = tokio::select! { + maybe_encoded = physical_outgoing_rx.recv() => { + let Some(encoded) = maybe_encoded else { + break; + }; + if websocket.send(Message::Binary(encoded.into())).await.is_err() { + break; } - } - Some(Ok(Message::Close(_))) | None => break, - Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, - Some(Ok(Message::Text(_))) => { - warn!("dropping non-binary relay message frame from harness"); continue; } - Some(Err(err)) => { - debug!("multiplexed executor websocket read failed: {err}"); - break; + incoming_message = websocket.next() => match incoming_message { + Some(Ok(Message::Binary(payload))) => { + match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + warn!("dropping malformed relay message frame from harness: {err}"); + continue; + } + } + } + Some(Ok(Message::Ping(payload))) => { + if websocket.send(Message::Pong(payload)).await.is_err() { + break; + } + continue; + } + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(Message::Pong(_) | Message::Frame(_))) => continue, + Some(Ok(Message::Text(_))) => { + warn!("dropping non-binary relay message frame from harness"); + continue; + } + Some(Err(err)) => { + debug!("multiplexed executor websocket read failed: {err}"); + break; + } } }; @@ -423,7 +401,6 @@ pub(crate) async fn run_multiplexed_executor( stream.disconnect(/*reason*/ None).await; } drop(physical_outgoing_tx); - let _ = writer_task.await; } struct VirtualStream { @@ -511,14 +488,17 @@ mod tests { } #[tokio::test] - async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> { + async fn multiplexed_executor_pongs_server_ping() -> anyhow::Result<()> { let (client_websocket, mut server_websocket) = websocket_pair().await?; let executor_task = tokio::spawn(run_multiplexed_executor( client_websocket, ConnectionProcessor::new(test_runtime_paths()?), )); - read_keepalive_ping(&mut server_websocket).await?; + server_websocket + .send(Message::Ping(b"check".to_vec().into())) + .await?; + read_pong(&mut server_websocket).await?; executor_task.abort(); let _ = executor_task.await; @@ -526,11 +506,14 @@ mod tests { } #[tokio::test] - async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> { + async fn harness_connection_pongs_server_ping() -> anyhow::Result<()> { let (client_websocket, mut server_websocket) = websocket_pair().await?; let connection = harness_connection_from_websocket(client_websocket, "test".to_string()); - read_keepalive_ping(&mut server_websocket).await?; + server_websocket + .send(Message::Ping(b"check".to_vec().into())) + .await?; + read_pong(&mut server_websocket).await?; drop(connection); Ok(()) @@ -551,17 +534,21 @@ mod tests { Ok((client_websocket, server_websocket)) } - async fn read_keepalive_ping( + async fn read_pong( websocket: &mut WebSocketStream, ) -> anyhow::Result<()> { loop { let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else { - anyhow::bail!("websocket closed before keepalive ping"); + anyhow::bail!("websocket closed before pong"); }; match message? { - Message::Ping(_) => return Ok(()), - Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {} - Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"), + Message::Pong(payload) if payload.as_ref() == b"check" => return Ok(()), + Message::Binary(_) + | Message::Text(_) + | Message::Ping(_) + | Message::Pong(_) + | Message::Frame(_) => {} + Message::Close(_) => anyhow::bail!("websocket closed before pong"), } } }