Compare commits

...

1 Commits

Author SHA1 Message Date
Andrey Mishchenko
365134b7bc Add keepalive in app-server WS client talking to exec-server 2026-05-17 14:49:48 -07:00

View File

@@ -17,6 +17,7 @@ use tokio::io::AsyncWrite;
use tokio::process::Child;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::time::MissedTickBehavior;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
@@ -30,6 +31,10 @@ use tokio::io::BufWriter;
pub(crate) const CHANNEL_CAPACITY: usize = 128;
const STDIO_TERMINATION_GRACE_PERIOD: Duration = Duration::from_secs(2);
#[cfg(test)]
const WEBSOCKET_CLIENT_KEEPALIVE_INTERVAL: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const WEBSOCKET_CLIENT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
#[derive(Debug)]
pub(crate) enum JsonRpcConnectionEvent {
@@ -320,18 +325,29 @@ impl JsonRpcConnection {
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label)
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
Some(WEBSOCKET_CLIENT_KEEPALIVE_INTERVAL),
)
}
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label)
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
/*keepalive_interval*/ None,
)
}
fn from_websocket_parts<W, R, M, E>(
mut websocket_writer: W,
mut websocket_reader: R,
connection_label: String,
keepalive_interval: Option<Duration>,
) -> Self
where
W: Sink<M, Error = E> + Unpin + Send + 'static,
@@ -404,30 +420,54 @@ impl JsonRpcConnection {
});
let writer_task = tokio::spawn(async move {
while let Some(message) = outgoing_rx.recv().await {
match serialize_jsonrpc_message(&message) {
Ok(encoded) => {
if let Err(err) = websocket_writer.send(M::from_text(encoded)).await {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
)),
if let Some(keepalive_interval) = keepalive_interval {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + keepalive_interval,
keepalive_interval,
);
keepalive.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await;
break;
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
_ = keepalive.tick() => {
if let Err(err) = websocket_writer.send(M::ping()).await {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
)),
)
.await;
break;
}
}
}
Err(err) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to serialize JSON-RPC message for {connection_label}: {err}"
)),
)
.await;
}
} else {
while let Some(message) = outgoing_rx.recv().await {
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
@@ -449,6 +489,29 @@ impl JsonRpcConnection {
}
}
async fn send_websocket_jsonrpc_message<W, M, E>(
websocket_writer: &mut W,
connection_label: &str,
message: &JSONRPCMessage,
) -> Result<(), String>
where
W: Sink<M, Error = E> + Unpin,
M: JsonRpcWebSocketMessage,
E: std::fmt::Display,
{
match serialize_jsonrpc_message(message) {
Ok(encoded) => websocket_writer
.send(M::from_text(encoded))
.await
.map_err(|err| {
format!("failed to write websocket JSON-RPC message to {connection_label}: {err}")
}),
Err(err) => Err(format!(
"failed to serialize JSON-RPC message for {connection_label}: {err}"
)),
}
}
enum JsonRpcWebSocketFrame {
Message(JSONRPCMessage),
Close,
@@ -458,6 +521,7 @@ enum JsonRpcWebSocketFrame {
trait JsonRpcWebSocketMessage: Send + 'static {
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error>;
fn from_text(text: String) -> Self;
fn ping() -> Self;
}
impl JsonRpcWebSocketMessage for Message {
@@ -479,6 +543,10 @@ impl JsonRpcWebSocketMessage for Message {
fn from_text(text: String) -> Self {
Self::Text(text.into())
}
fn ping() -> Self {
Self::Ping(Vec::new().into())
}
}
impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
@@ -500,6 +568,10 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
fn from_text(text: String) -> Self {
Self::Text(text.into())
}
fn ping() -> Self {
Self::Ping(Vec::new().into())
}
}
async fn send_disconnected(
@@ -541,3 +613,46 @@ where
fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_json::Error> {
serde_json::to_string(message)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures::StreamExt;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
use super::JsonRpcConnection;
#[tokio::test]
async fn websocket_client_connection_sends_keepalive_ping() -> anyhow::Result<()> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let websocket_url = format!("ws://{}", listener.local_addr()?);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
let mut websocket = accept_async(stream).await?;
loop {
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
anyhow::bail!("client websocket closed before keepalive ping");
};
match message? {
Message::Ping(_) => return Ok::<(), anyhow::Error>(()),
Message::Text(_) | Message::Binary(_) | Message::Pong(_) => continue,
other => anyhow::bail!("expected keepalive ping, got {other:?}"),
}
}
});
let (client_websocket, _) = connect_async(websocket_url).await?;
let connection =
JsonRpcConnection::from_websocket(client_websocket, "test websocket".to_string());
server_task.await??;
drop(connection);
Ok(())
}
}