Refactor exec-server websocket pump

This commit is contained in:
starr-openai
2026-05-18 10:04:34 -07:00
parent 82061660ae
commit 8a9300e92a
2 changed files with 308 additions and 340 deletions

View File

@@ -323,37 +323,20 @@ impl JsonRpcConnection {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
)
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
}
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
// own keepalive pings so one side does not accidentally create redundant traffic.
/*keepalive_interval*/
None,
)
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
}
fn from_websocket_parts<W, R, M, E>(
mut websocket_writer: W,
mut websocket_reader: R,
fn from_websocket_stream<T, M, E>(
mut websocket: T,
connection_label: String,
keepalive_interval: Option<Duration>,
ping_interval: Option<Duration>,
) -> Self
where
W: Sink<M, Error = E> + Unpin + Send + 'static,
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
M: JsonRpcWebSocketMessage,
E: std::fmt::Display + Send + 'static,
{
@@ -361,118 +344,119 @@ impl JsonRpcConnection {
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut ping_interval = ping_interval.map(|ping_interval| {
let mut interval = tokio::time::interval_at(
tokio::time::Instant::now() + ping_interval,
ping_interval,
);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval
});
loop {
match websocket_reader.next().await {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
Err(err) => {
send_malformed_message(
&incoming_tx_for_reader,
Some(format!(
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
}
_ = async {
match ping_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
Ok(JsonRpcWebSocketFrame::Close) => {
} => {
if let Err(err) = websocket.send(M::ping()).await {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
)),
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
if let Some(keepalive_interval) = keepalive_interval {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + keepalive_interval,
keepalive_interval,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
_ = keepalive.tick() => {
if let Err(err) = websocket_writer.send(M::ping()).await {
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Ok(JsonRpcWebSocketFrame::Ping(payload)) => {
if let Err(err) = websocket.send(M::pong(payload)).await {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket pong to {connection_label}: {err}"
)),
)
.await;
break;
}
}
Ok(JsonRpcWebSocketFrame::Close) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
Err(err) => {
send_malformed_message(
&incoming_tx,
Some(format!(
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
}
}
}
} else {
while let Some(message) = outgoing_rx.recv().await {
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
}
});
@@ -480,7 +464,7 @@ impl JsonRpcConnection {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -493,6 +477,7 @@ impl JsonRpcConnection {
enum JsonRpcWebSocketFrame {
Message(JSONRPCMessage),
Ping(bytes::Bytes),
Close,
Ignore,
}
@@ -501,6 +486,7 @@ trait JsonRpcWebSocketMessage: Send + 'static {
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error>;
fn from_text(text: String) -> Self;
fn ping() -> Self;
fn pong(payload: bytes::Bytes) -> Self;
}
impl JsonRpcWebSocketMessage for Message {
@@ -513,9 +499,8 @@ impl JsonRpcWebSocketMessage for Message {
serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message)
}
Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
Ok(JsonRpcWebSocketFrame::Ignore)
}
Message::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)),
Message::Pong(_) | Message::Frame(_) => Ok(JsonRpcWebSocketFrame::Ignore),
}
}
@@ -526,6 +511,10 @@ impl JsonRpcWebSocketMessage for Message {
fn ping() -> Self {
Self::Ping(Vec::new().into())
}
fn pong(payload: bytes::Bytes) -> Self {
Self::Pong(payload)
}
}
impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
@@ -538,9 +527,8 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message)
}
AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
AxumWebSocketMessage::Ping(_) | AxumWebSocketMessage::Pong(_) => {
Ok(JsonRpcWebSocketFrame::Ignore)
}
AxumWebSocketMessage::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)),
AxumWebSocketMessage::Pong(_) => Ok(JsonRpcWebSocketFrame::Ignore),
}
}
@@ -551,6 +539,10 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
fn ping() -> Self {
Self::Ping(Vec::new().into())
}
fn pong(payload: bytes::Bytes) -> Self {
Self::Pong(payload)
}
}
async fn send_disconnected(
@@ -618,71 +610,60 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
#[cfg(test)]
mod tests {
use std::pin::Pin;
use futures::channel::mpsc as futures_mpsc;
use futures::stream;
use futures::task::Context;
use futures::task::Poll;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::connect_async;
use super::*;
struct TestWebSocketSink {
message_tx: futures_mpsc::UnboundedSender<Message>,
}
#[tokio::test]
async fn websocket_connection_pongs_server_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
impl Sink<Message> for TestWebSocketSink {
type Error = std::convert::Infallible;
server_websocket
.send(Message::Ping(b"check".to_vec().into()))
.await?;
let message = timeout(Duration::from_secs(1), server_websocket.next())
.await?
.expect("websocket should stay open")?;
assert_eq!(message, Message::Pong(b"check".to_vec().into()));
fn poll_ready(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut()
.message_tx
.unbounded_send(item)
.expect("test websocket receiver should stay open");
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_sends_keepalive_ping() {
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
let websocket_writer = TestWebSocketSink { message_tx };
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
let connection = JsonRpcConnection::from_websocket_parts(
websocket_writer,
websocket_reader,
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = JsonRpcConnection::from_websocket_stream(
client_websocket,
"test".into(),
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
);
let message = timeout(Duration::from_secs(1), message_rx.next())
.await
.expect("keepalive ping should arrive before timeout")
.expect("keepalive ping should be sent");
let message = timeout(Duration::from_secs(1), server_websocket.next())
.await?
.expect("websocket should stay open")?;
assert!(matches!(message, Message::Ping(_)));
drop(connection);
Ok(())
}
async fn websocket_pair() -> anyhow::Result<(
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
WebSocketStream<tokio::net::TcpStream>,
)> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let websocket_url = format!("ws://{}", listener.local_addr()?);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
accept_async(stream).await.map_err(anyhow::Error::from)
});
let (client_websocket, _) = connect_async(websocket_url).await?;
let server_websocket = server_task.await??;
Ok((client_websocket, server_websocket))
}
}

View File

@@ -19,7 +19,6 @@ use crate::connection::CHANNEL_CAPACITY;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransport;
use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL;
use crate::relay_proto::RelayData;
use crate::relay_proto::RelayMessageFrame;
use crate::relay_proto::RelayResume;
@@ -148,113 +147,16 @@ where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let stream_id = Uuid::new_v4().to_string();
let (mut websocket_writer, mut websocket_reader) = stream.split();
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let incoming_tx_for_reader = incoming_tx;
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
loop {
match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut websocket = stream;
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let resume = RelayMessageFrame::resume(stream_id.clone());
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
.await
.is_err()
@@ -263,11 +165,6 @@ where
return;
}
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut next_seq = 0u32;
loop {
tokio::select! {
@@ -284,7 +181,7 @@ where
};
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
next_seq = next_seq.wrapping_add(1);
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
.await
.is_err()
@@ -293,10 +190,102 @@ where
break;
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
let _ = disconnected_tx.send(true);
break;
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Ping(payload))) => {
if websocket.send(Message::Pong(payload)).await.is_err() {
let _ = disconnected_tx.send(true);
break;
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
}
@@ -307,7 +296,7 @@ where
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -318,59 +307,48 @@ pub(crate) async fn run_multiplexed_executor<S>(
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (mut websocket_writer, mut websocket_reader) = stream.split();
let mut websocket = stream;
let (physical_outgoing_tx, mut physical_outgoing_rx) =
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
let writer_task = tokio::spawn(async move {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket_writer
.send(Message::Binary(encoded.into()))
.await
.is_err()
{
break;
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
break;
}
}
}
}
});
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
loop {
let frame = match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
let frame = tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
incoming_message = websocket.next() => match incoming_message {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
}
}
Some(Ok(Message::Ping(payload))) => {
if websocket.send(Message::Pong(payload)).await.is_err() {
break;
}
continue;
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
}
}
};
@@ -423,7 +401,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
stream.disconnect(/*reason*/ None).await;
}
drop(physical_outgoing_tx);
let _ = writer_task.await;
}
struct VirtualStream {
@@ -511,14 +488,17 @@ mod tests {
}
#[tokio::test]
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
async fn multiplexed_executor_pongs_server_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let executor_task = tokio::spawn(run_multiplexed_executor(
client_websocket,
ConnectionProcessor::new(test_runtime_paths()?),
));
read_keepalive_ping(&mut server_websocket).await?;
server_websocket
.send(Message::Ping(b"check".to_vec().into()))
.await?;
read_pong(&mut server_websocket).await?;
executor_task.abort();
let _ = executor_task.await;
@@ -526,11 +506,14 @@ mod tests {
}
#[tokio::test]
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
async fn harness_connection_pongs_server_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
read_keepalive_ping(&mut server_websocket).await?;
server_websocket
.send(Message::Ping(b"check".to_vec().into()))
.await?;
read_pong(&mut server_websocket).await?;
drop(connection);
Ok(())
@@ -551,17 +534,21 @@ mod tests {
Ok((client_websocket, server_websocket))
}
async fn read_keepalive_ping(
async fn read_pong(
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
) -> anyhow::Result<()> {
loop {
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
anyhow::bail!("websocket closed before keepalive ping");
anyhow::bail!("websocket closed before pong");
};
match message? {
Message::Ping(_) => return Ok(()),
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
Message::Pong(payload) if payload.as_ref() == b"check" => return Ok(()),
Message::Binary(_)
| Message::Text(_)
| Message::Ping(_)
| Message::Pong(_)
| Message::Frame(_) => {}
Message::Close(_) => anyhow::bail!("websocket closed before pong"),
}
}
}