diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index cd0e72297f..104d09f57d 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -5,8 +5,12 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::time::Duration; +use axum::extract::ws::Message as AxumWebSocketMessage; +use axum::extract::ws::WebSocket as AxumWebSocket; use codex_app_server_protocol::JSONRPCMessage; +use futures::Sink; use futures::SinkExt; +use futures::Stream; use futures::StreamExt; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; @@ -314,11 +318,30 @@ impl JsonRpcConnection { pub(crate) fn from_websocket(stream: WebSocketStream, connection_label: String) -> Self where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (websocket_writer, websocket_reader) = stream.split(); + Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label) + } + + pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self { + let (websocket_writer, websocket_reader) = stream.split(); + Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label) + } + + fn from_websocket_parts( + mut websocket_writer: W, + mut websocket_reader: R, + connection_label: String, + ) -> Self + where + W: Sink + Unpin + Send + 'static, + R: Stream> + Unpin + Send + 'static, + M: JsonRpcWebSocketMessage, + E: std::fmt::Display + Send + 'static, { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let (mut websocket_writer, mut websocket_reader) = stream.split(); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); @@ -326,41 +349,36 @@ impl JsonRpcConnection { let reader_task = tokio::spawn(async move { loop { match websocket_reader.next().await { - Some(Ok(Message::Text(text))) => { - match serde_json::from_str::(text.as_ref()) { - Ok(message) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } - } - Err(err) => { - send_malformed_message( - &incoming_tx_for_reader, - Some(format!( - "failed to parse websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; + Some(Ok(message)) => match message.parse_jsonrpc_frame() { + Ok(JsonRpcWebSocketFrame::Message(message)) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; } } - } - Some(Ok(Message::Close(_))) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; - break; - } - Some(Ok(Message::Binary(_))) - | Some(Ok(Message::Ping(_))) - | Some(Ok(Message::Pong(_))) - | Some(Ok(Message::Frame(_))) => {} + Err(err) => { + send_malformed_message( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + } + Ok(JsonRpcWebSocketFrame::Close) => { + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; + break; + } + Ok(JsonRpcWebSocketFrame::Ignore) => {} + }, Some(Err(err)) => { send_disconnected( &incoming_tx_for_reader, @@ -389,8 +407,7 @@ impl JsonRpcConnection { while let Some(message) = outgoing_rx.recv().await { match serialize_jsonrpc_message(&message) { Ok(encoded) => { - if let Err(err) = websocket_writer.send(Message::Text(encoded.into())).await - { + if let Err(err) = websocket_writer.send(M::from_text(encoded)).await { send_disconnected( &incoming_tx, &disconnected_tx, @@ -432,6 +449,53 @@ impl JsonRpcConnection { } } +enum JsonRpcWebSocketFrame { + Message(JSONRPCMessage), + Close, + Ignore, +} + +trait JsonRpcWebSocketMessage: Send + 'static { + fn parse_jsonrpc_frame(self) -> Result; + fn from_text(text: String) -> Self; +} + +impl JsonRpcWebSocketMessage for Message { + fn parse_jsonrpc_frame(self) -> Result { + match self { + Message::Text(text) => { + serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close), + Message::Binary(_) | Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => { + Ok(JsonRpcWebSocketFrame::Ignore) + } + } + } + + fn from_text(text: String) -> Self { + Self::Text(text.into()) + } +} + +impl JsonRpcWebSocketMessage for AxumWebSocketMessage { + fn parse_jsonrpc_frame(self) -> Result { + match self { + AxumWebSocketMessage::Text(text) => { + serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close), + AxumWebSocketMessage::Binary(_) + | AxumWebSocketMessage::Ping(_) + | AxumWebSocketMessage::Pong(_) => Ok(JsonRpcWebSocketFrame::Ignore), + } + } + + fn from_text(text: String) -> Self { + Self::Text(text.into()) + } +} + async fn send_disconnected( incoming_tx: &mpsc::Sender, disconnected_tx: &watch::Sender, diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs index 9deafa9b24..92b2c77aaa 100644 --- a/codex-rs/exec-server/src/server/transport.rs +++ b/codex-rs/exec-server/src/server/transport.rs @@ -1,12 +1,18 @@ +use axum::Router; +use axum::extract::ConnectInfo; +use axum::extract::State; +use axum::extract::ws::WebSocketUpgrade; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::any; +use axum::routing::get; use std::io::Write as _; use std::net::SocketAddr; use tokio::io; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::net::TcpListener; -use tokio_tungstenite::accept_async; use tracing::info; -use tracing::warn; use crate::ExecServerRuntimePaths; use crate::connection::JsonRpcConnection; @@ -114,27 +120,42 @@ async fn run_websocket_listener( println!("ws://{local_addr}"); std::io::stdout().flush()?; - loop { - let (stream, peer_addr) = listener.accept().await?; - let processor = processor.clone(); - tokio::spawn(async move { - match accept_async(stream).await { - Ok(websocket) => { - processor - .run_connection(JsonRpcConnection::from_websocket( - websocket, - format!("exec-server websocket {peer_addr}"), - )) - .await; - } - Err(err) => { - warn!( - "failed to accept exec-server websocket connection from {peer_addr}: {err}" - ); - } - } - }); - } + let router = Router::new() + .route("/", any(websocket_upgrade_handler)) + .route("/readyz", get(readiness_handler)) + .with_state(ExecServerWebSocketState { processor }); + axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ) + .await?; + Ok(()) +} + +#[derive(Clone)] +struct ExecServerWebSocketState { + processor: ConnectionProcessor, +} + +async fn readiness_handler() -> StatusCode { + StatusCode::OK +} + +async fn websocket_upgrade_handler( + websocket: WebSocketUpgrade, + ConnectInfo(peer_addr): ConnectInfo, + State(state): State, +) -> impl IntoResponse { + info!(%peer_addr, "exec-server websocket client connected"); + websocket.on_upgrade(move |stream| async move { + state + .processor + .run_connection(JsonRpcConnection::from_axum_websocket( + stream, + format!("exec-server websocket {peer_addr}"), + )) + .await; + }) } #[cfg(test)] diff --git a/codex-rs/exec-server/tests/health.rs b/codex-rs/exec-server/tests/health.rs new file mode 100644 index 0000000000..91b3806a22 --- /dev/null +++ b/codex-rs/exec-server/tests/health.rs @@ -0,0 +1,21 @@ +#![cfg(unix)] + +mod common; + +use common::exec_server::exec_server; +use pretty_assertions::assert_eq; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_serves_readyz_alongside_websocket_endpoint() -> anyhow::Result<()> { + let mut server = exec_server().await?; + let http_base_url = server + .websocket_url() + .strip_prefix("ws://") + .expect("websocket URL should use ws://"); + + let response = reqwest::get(format!("http://{http_base_url}/readyz")).await?; + assert_eq!(response.status(), reqwest::StatusCode::OK); + + server.shutdown().await?; + Ok(()) +}