mirror of
https://github.com/openai/codex.git
synced 2026-05-24 13:04:29 +00:00
Refactor exec-server websocket pump
This commit is contained in:
@@ -323,37 +323,20 @@ impl JsonRpcConnection {
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
|
||||
}
|
||||
|
||||
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
|
||||
// own keepalive pings so one side does not accidentally create redundant traffic.
|
||||
/*keepalive_interval*/
|
||||
None,
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
|
||||
}
|
||||
|
||||
fn from_websocket_parts<W, R, M, E>(
|
||||
mut websocket_writer: W,
|
||||
mut websocket_reader: R,
|
||||
fn from_websocket_stream<T, M, E>(
|
||||
mut websocket: T,
|
||||
connection_label: String,
|
||||
keepalive_interval: Option<Duration>,
|
||||
ping_interval: Option<Duration>,
|
||||
) -> Self
|
||||
where
|
||||
W: Sink<M, Error = E> + Unpin + Send + 'static,
|
||||
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
M: JsonRpcWebSocketMessage,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
@@ -361,118 +344,119 @@ impl JsonRpcConnection {
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut ping_interval = ping_interval.map(|ping_interval| {
|
||||
let mut interval = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + ping_interval,
|
||||
ping_interval,
|
||||
);
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
interval
|
||||
});
|
||||
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ = async {
|
||||
match ping_interval.as_mut() {
|
||||
Some(interval) => interval.tick().await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
} => {
|
||||
if let Err(err) = websocket.send(M::ping()).await {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
if let Some(keepalive_interval) = keepalive_interval {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + keepalive_interval,
|
||||
keepalive_interval,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if let Err(err) = websocket_writer.send(M::ping()).await {
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ping(payload)) => {
|
||||
if let Err(err) = websocket.send(M::pong(payload)).await {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket pong to {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -480,7 +464,7 @@ impl JsonRpcConnection {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -493,6 +477,7 @@ impl JsonRpcConnection {
|
||||
|
||||
enum JsonRpcWebSocketFrame {
|
||||
Message(JSONRPCMessage),
|
||||
Ping(bytes::Bytes),
|
||||
Close,
|
||||
Ignore,
|
||||
}
|
||||
@@ -501,6 +486,7 @@ trait JsonRpcWebSocketMessage: Send + 'static {
|
||||
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error>;
|
||||
fn from_text(text: String) -> Self;
|
||||
fn ping() -> Self;
|
||||
fn pong(payload: bytes::Bytes) -> Self;
|
||||
}
|
||||
|
||||
impl JsonRpcWebSocketMessage for Message {
|
||||
@@ -513,9 +499,8 @@ impl JsonRpcWebSocketMessage for Message {
|
||||
serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message)
|
||||
}
|
||||
Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
|
||||
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
|
||||
Ok(JsonRpcWebSocketFrame::Ignore)
|
||||
}
|
||||
Message::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)),
|
||||
Message::Pong(_) | Message::Frame(_) => Ok(JsonRpcWebSocketFrame::Ignore),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,6 +511,10 @@ impl JsonRpcWebSocketMessage for Message {
|
||||
fn ping() -> Self {
|
||||
Self::Ping(Vec::new().into())
|
||||
}
|
||||
|
||||
fn pong(payload: bytes::Bytes) -> Self {
|
||||
Self::Pong(payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
|
||||
@@ -538,9 +527,8 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
|
||||
serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message)
|
||||
}
|
||||
AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
|
||||
AxumWebSocketMessage::Ping(_) | AxumWebSocketMessage::Pong(_) => {
|
||||
Ok(JsonRpcWebSocketFrame::Ignore)
|
||||
}
|
||||
AxumWebSocketMessage::Ping(payload) => Ok(JsonRpcWebSocketFrame::Ping(payload)),
|
||||
AxumWebSocketMessage::Pong(_) => Ok(JsonRpcWebSocketFrame::Ignore),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,6 +539,10 @@ impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
|
||||
fn ping() -> Self {
|
||||
Self::Ping(Vec::new().into())
|
||||
}
|
||||
|
||||
fn pong(payload: bytes::Bytes) -> Self {
|
||||
Self::Pong(payload)
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_disconnected(
|
||||
@@ -618,71 +610,60 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
|
||||
use futures::channel::mpsc as futures_mpsc;
|
||||
use futures::stream;
|
||||
use futures::task::Context;
|
||||
use futures::task::Poll;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct TestWebSocketSink {
|
||||
message_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_pongs_server_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
|
||||
impl Sink<Message> for TestWebSocketSink {
|
||||
type Error = std::convert::Infallible;
|
||||
server_websocket
|
||||
.send(Message::Ping(b"check".to_vec().into()))
|
||||
.await?;
|
||||
let message = timeout(Duration::from_secs(1), server_websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
assert_eq!(message, Message::Pong(b"check".to_vec().into()));
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.get_mut()
|
||||
.message_tx
|
||||
.unbounded_send(item)
|
||||
.expect("test websocket receiver should stay open");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_sends_keepalive_ping() {
|
||||
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
|
||||
let websocket_writer = TestWebSocketSink { message_tx };
|
||||
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
|
||||
let connection = JsonRpcConnection::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = JsonRpcConnection::from_websocket_stream(
|
||||
client_websocket,
|
||||
"test".into(),
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
);
|
||||
|
||||
let message = timeout(Duration::from_secs(1), message_rx.next())
|
||||
.await
|
||||
.expect("keepalive ping should arrive before timeout")
|
||||
.expect("keepalive ping should be sent");
|
||||
let message = timeout(Duration::from_secs(1), server_websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
assert!(matches!(message, Message::Ping(_)));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn websocket_pair() -> anyhow::Result<(
|
||||
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||||
WebSocketStream<tokio::net::TcpStream>,
|
||||
)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let websocket_url = format!("ws://{}", listener.local_addr()?);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
accept_async(stream).await.map_err(anyhow::Error::from)
|
||||
});
|
||||
let (client_websocket, _) = connect_async(websocket_url).await?;
|
||||
let server_websocket = server_task.await??;
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ use crate::connection::CHANNEL_CAPACITY;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::connection::JsonRpcTransport;
|
||||
use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL;
|
||||
use crate::relay_proto::RelayData;
|
||||
use crate::relay_proto::RelayMessageFrame;
|
||||
use crate::relay_proto::RelayResume;
|
||||
@@ -148,113 +147,16 @@ where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let stream_id = Uuid::new_v4().to_string();
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let incoming_tx_for_reader = incoming_tx;
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut websocket = stream;
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let resume = RelayMessageFrame::resume(stream_id.clone());
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -263,11 +165,6 @@ where
|
||||
return;
|
||||
}
|
||||
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mut next_seq = 0u32;
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -284,7 +181,7 @@ where
|
||||
};
|
||||
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -293,10 +190,102 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
let _ = disconnected_tx.send(true);
|
||||
break;
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if websocket.send(Message::Pong(payload)).await.is_err() {
|
||||
let _ = disconnected_tx.send(true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -307,7 +296,7 @@ where
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -318,59 +307,48 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let mut websocket = stream;
|
||||
let (physical_outgoing_tx, mut physical_outgoing_rx) =
|
||||
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer
|
||||
.send(Message::Binary(encoded.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
|
||||
loop {
|
||||
let frame = match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
let frame = tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
incoming_message = websocket.next() => match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if websocket.send(Message::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -423,7 +401,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
stream.disconnect(/*reason*/ None).await;
|
||||
}
|
||||
drop(physical_outgoing_tx);
|
||||
let _ = writer_task.await;
|
||||
}
|
||||
|
||||
struct VirtualStream {
|
||||
@@ -511,14 +488,17 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn multiplexed_executor_pongs_server_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let executor_task = tokio::spawn(run_multiplexed_executor(
|
||||
client_websocket,
|
||||
ConnectionProcessor::new(test_runtime_paths()?),
|
||||
));
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
server_websocket
|
||||
.send(Message::Ping(b"check".to_vec().into()))
|
||||
.await?;
|
||||
read_pong(&mut server_websocket).await?;
|
||||
|
||||
executor_task.abort();
|
||||
let _ = executor_task.await;
|
||||
@@ -526,11 +506,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn harness_connection_pongs_server_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
server_websocket
|
||||
.send(Message::Ping(b"check".to_vec().into()))
|
||||
.await?;
|
||||
read_pong(&mut server_websocket).await?;
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
@@ -551,17 +534,21 @@ mod tests {
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
|
||||
async fn read_keepalive_ping(
|
||||
async fn read_pong(
|
||||
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
|
||||
anyhow::bail!("websocket closed before keepalive ping");
|
||||
anyhow::bail!("websocket closed before pong");
|
||||
};
|
||||
match message? {
|
||||
Message::Ping(_) => return Ok(()),
|
||||
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
|
||||
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
|
||||
Message::Pong(payload) if payload.as_ref() == b"check" => return Ok(()),
|
||||
Message::Binary(_)
|
||||
| Message::Text(_)
|
||||
| Message::Ping(_)
|
||||
| Message::Pong(_)
|
||||
| Message::Frame(_) => {}
|
||||
Message::Close(_) => anyhow::bail!("websocket closed before pong"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user