diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ec654051e4..26b32f67c8 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2741,6 +2741,7 @@ dependencies = [ "anyhow", "arc-swap", "async-trait", + "axum", "base64 0.22.1", "bytes", "codex-app-server-protocol", @@ -2751,6 +2752,7 @@ dependencies = [ "codex-test-binary-support", "codex-utils-absolute-path", "codex-utils-pty", + "codex-utils-rustls-provider", "ctor 0.6.3", "futures", "pretty_assertions", diff --git a/codex-rs/app-server-client/src/lib.rs b/codex-rs/app-server-client/src/lib.rs index 3b386ff3ce..3a5be188df 100644 --- a/codex-rs/app-server-client/src/lib.rs +++ b/codex-rs/app-server-client/src/lib.rs @@ -2094,7 +2094,7 @@ mod tests { let config = Arc::new(build_test_config().await); let environment_manager = Arc::new( EnvironmentManager::create_for_tests( - Some("ws://127.0.0.1:8765".to_string()), + Some("ws://127.0.0.1:8765/ws".to_string()), ExecServerRuntimePaths::new( std::env::current_exe().expect("current exe"), /*codex_linux_sandbox_exe*/ None, diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index ae00b08b73..5caa066f08 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -1808,7 +1808,7 @@ mod tests { request_id: request_id(), params: v2::EnvironmentAddParams { environment_id: "remote-a".to_string(), - exec_server_url: "ws://127.0.0.1:8765".to_string(), + exec_server_url: "ws://127.0.0.1:8765/ws".to_string(), }, }; assert_eq!( @@ -2603,7 +2603,7 @@ mod tests { request_id: RequestId::Integer(9), params: v2::EnvironmentAddParams { environment_id: "remote-a".to_string(), - exec_server_url: "ws://127.0.0.1:8765".to_string(), + exec_server_url: "ws://127.0.0.1:8765/ws".to_string(), }, }; assert_eq!( @@ -2612,7 +2612,7 @@ mod tests { "id": 9, "params": { "environmentId": "remote-a", - "execServerUrl": "ws://127.0.0.1:8765" + "execServerUrl": "ws://127.0.0.1:8765/ws" } }), serde_json::to_value(&request)?, @@ -2898,7 +2898,7 @@ mod tests { request_id: RequestId::Integer(1), params: v2::EnvironmentAddParams { environment_id: "remote-a".to_string(), - exec_server_url: "ws://127.0.0.1:8765".to_string(), + exec_server_url: "ws://127.0.0.1:8765/ws".to_string(), }, }; let reason = crate::experimental_api::ExperimentalApi::experimental_reason(&request); diff --git a/codex-rs/app-server/tests/suite/v2/turn_start.rs b/codex-rs/app-server/tests/suite/v2/turn_start.rs index 524b795b81..e152babd10 100644 --- a/codex-rs/app-server/tests/suite/v2/turn_start.rs +++ b/codex-rs/app-server/tests/suite/v2/turn_start.rs @@ -2005,7 +2005,7 @@ async fn turn_start_resolves_sticky_thread_local_environment_and_turn_overrides( r#" [[environments]] id = "remote" -url = "ws://127.0.0.1:1" +url = "ws://127.0.0.1:1/ws" "#, )?; diff --git a/codex-rs/core/src/environment_selection.rs b/codex-rs/core/src/environment_selection.rs index 89808c27ee..247e8f507c 100644 --- a/codex-rs/core/src/environment_selection.rs +++ b/codex-rs/core/src/environment_selection.rs @@ -105,7 +105,7 @@ mod tests { async fn default_thread_environment_selections_use_manager_default_id() { let cwd = AbsolutePathBuf::current_dir().expect("cwd"); let manager = EnvironmentManager::create_for_tests( - Some("ws://127.0.0.1:8765".to_string()), + Some("ws://127.0.0.1:8765/ws".to_string()), test_runtime_paths(), ) .await; @@ -127,7 +127,7 @@ mod tests { r#" [[environments]] id = "remote" -url = "ws://127.0.0.1:8765" +url = "ws://127.0.0.1:8765/ws" "#, ) .expect("write environments.toml"); diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 9fbdd91117..936fa412f1 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] arc-swap = { workspace = true } async-trait = { workspace = true } +axum = { workspace = true, features = ["http1", "tokio", "ws"] } base64 = { workspace = true } bytes = { workspace = true } codex-app-server-protocol = { workspace = true } @@ -22,6 +23,7 @@ codex-protocol = { workspace = true } codex-sandboxing = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-pty = { workspace = true } +codex-utils-rustls-provider = { workspace = true } futures = { workspace = true } reqwest = { workspace = true, features = ["json", "rustls-tls", "stream"] } serde = { workspace = true, features = ["derive"] } diff --git a/codex-rs/exec-server/README.md b/codex-rs/exec-server/README.md index 81664eaca0..e4d8f58932 100644 --- a/codex-rs/exec-server/README.md +++ b/codex-rs/exec-server/README.md @@ -21,7 +21,7 @@ the wire. The CLI entrypoint supports: -- `ws://IP:PORT` (default) +- `ws://IP:PORT` (default bind address; the listener advertises `ws://IP:PORT/ws`) - `--remote URL --executor-id ID [--name NAME]` Remote mode registers the local exec-server with the executor registry, diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 23dc0bc7b3..8ca1eb0280 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -7,6 +7,8 @@ use tokio_tungstenite::connect_async; use tracing::debug; use tracing::warn; +use codex_utils_rustls_provider::ensure_rustls_crypto_provider; + use crate::ExecServerClient; use crate::ExecServerError; use crate::client_api::RemoteExecServerConnectArgs; @@ -53,6 +55,7 @@ impl ExecServerClient { pub async fn connect_websocket( args: RemoteExecServerConnectArgs, ) -> Result { + ensure_rustls_crypto_provider(); let websocket_url = args.websocket_url.clone(); let connect_timeout = args.connect_timeout; let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str())) diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index c990c89338..a1d666b0f5 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -3,8 +3,12 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::time::Duration; +use axum::extract::ws::Message as AxumWebSocketMessage; +use axum::extract::ws::WebSocket as AxumWebSocket; use codex_app_server_protocol::JSONRPCMessage; +use futures::Sink; use futures::SinkExt; +use futures::Stream; use futures::StreamExt; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; @@ -309,11 +313,30 @@ impl JsonRpcConnection { pub(crate) fn from_websocket(stream: WebSocketStream, connection_label: String) -> Self where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (websocket_writer, websocket_reader) = stream.split(); + Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label) + } + + pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self { + let (websocket_writer, websocket_reader) = stream.split(); + Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label) + } + + fn from_websocket_parts( + mut websocket_writer: W, + mut websocket_reader: R, + connection_label: String, + ) -> Self + where + W: Sink + Unpin + Send + 'static, + R: Stream> + Unpin + Send + 'static, + M: JsonRpcWebSocketMessage, + E: std::fmt::Display + Send + 'static, { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let (mut websocket_writer, mut websocket_reader) = stream.split(); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); @@ -321,61 +344,36 @@ impl JsonRpcConnection { let reader_task = tokio::spawn(async move { loop { match websocket_reader.next().await { - Some(Ok(Message::Text(text))) => { - match serde_json::from_str::(text.as_ref()) { - Ok(message) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } - } - Err(err) => { - send_malformed_message( - &incoming_tx_for_reader, - Some(format!( - "failed to parse websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; + Some(Ok(message)) => match message.parse_jsonrpc_frame() { + Ok(JsonRpcWebSocketFrame::Message(message)) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; } } - } - Some(Ok(Message::Binary(bytes))) => { - match serde_json::from_slice::(bytes.as_ref()) { - Ok(message) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } - } - Err(err) => { - send_malformed_message( - &incoming_tx_for_reader, - Some(format!( - "failed to parse websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; - } + Err(err) => { + send_malformed_message( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; } - } - Some(Ok(Message::Close(_))) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; - break; - } - Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} - Some(Ok(_)) => {} + Ok(JsonRpcWebSocketFrame::Close) => { + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; + break; + } + Ok(JsonRpcWebSocketFrame::Ignore) => {} + }, Some(Err(err)) => { send_disconnected( &incoming_tx_for_reader, @@ -404,8 +402,7 @@ impl JsonRpcConnection { while let Some(message) = outgoing_rx.recv().await { match serialize_jsonrpc_message(&message) { Ok(encoded) => { - if let Err(err) = websocket_writer.send(Message::Text(encoded.into())).await - { + if let Err(err) = websocket_writer.send(M::from_text(encoded)).await { send_disconnected( &incoming_tx, &disconnected_tx, @@ -447,6 +444,59 @@ impl JsonRpcConnection { } } +enum JsonRpcWebSocketFrame { + Message(JSONRPCMessage), + Close, + Ignore, +} + +trait JsonRpcWebSocketMessage: Send + 'static { + fn parse_jsonrpc_frame(self) -> Result; + fn from_text(text: String) -> Self; +} + +impl JsonRpcWebSocketMessage for Message { + fn parse_jsonrpc_frame(self) -> Result { + match self { + Message::Text(text) => { + serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + Message::Binary(bytes) => { + serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close), + Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => { + Ok(JsonRpcWebSocketFrame::Ignore) + } + } + } + + fn from_text(text: String) -> Self { + Self::Text(text.into()) + } +} + +impl JsonRpcWebSocketMessage for AxumWebSocketMessage { + fn parse_jsonrpc_frame(self) -> Result { + match self { + AxumWebSocketMessage::Text(text) => { + serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + AxumWebSocketMessage::Binary(bytes) => { + serde_json::from_slice(bytes.as_ref()).map(JsonRpcWebSocketFrame::Message) + } + AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close), + AxumWebSocketMessage::Ping(_) | AxumWebSocketMessage::Pong(_) => { + Ok(JsonRpcWebSocketFrame::Ignore) + } + } + } + + fn from_text(text: String) -> Self { + Self::Text(text.into()) + } +} + async fn send_disconnected( incoming_tx: &mpsc::Sender, disconnected_tx: &watch::Sender, diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 7e4a3fb056..dd207fc6ea 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -506,7 +506,7 @@ mod tests { #[tokio::test] async fn environment_manager_reports_remote_url() { let manager = EnvironmentManager::create_for_tests( - Some("ws://127.0.0.1:8765".to_string()), + Some("ws://127.0.0.1:8765/ws".to_string()), test_runtime_paths(), ) .await; @@ -517,7 +517,10 @@ mod tests { Some(REMOTE_ENVIRONMENT_ID) ); assert!(environment.is_remote()); - assert_eq!(environment.exec_server_url(), Some("ws://127.0.0.1:8765")); + assert_eq!( + environment.exec_server_url(), + Some("ws://127.0.0.1:8765/ws") + ); assert!(Arc::ptr_eq( &environment, &manager @@ -548,7 +551,7 @@ mod tests { snapshot: EnvironmentProviderSnapshot { environments: vec![( REMOTE_ENVIRONMENT_ID.to_string(), - Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string())) + Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string())) .expect("remote environment"), )], default: EnvironmentDefault::EnvironmentId(REMOTE_ENVIRONMENT_ID.to_string()), @@ -620,7 +623,7 @@ mod tests { snapshot: EnvironmentProviderSnapshot { environments: vec![( "devbox".to_string(), - Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string())) + Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string())) .expect("remote environment"), )], default: EnvironmentDefault::EnvironmentId("devbox".to_string()), @@ -645,7 +648,7 @@ mod tests { snapshot: EnvironmentProviderSnapshot { environments: vec![( "devbox".to_string(), - Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string())) + Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string())) .expect("remote environment"), )], default: EnvironmentDefault::Disabled, @@ -672,7 +675,7 @@ mod tests { snapshot: EnvironmentProviderSnapshot { environments: vec![( "devbox".to_string(), - Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string())) + Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string())) .expect("remote environment"), )], default: EnvironmentDefault::EnvironmentId("missing".to_string()), @@ -766,23 +769,29 @@ mod tests { let manager = EnvironmentManager::disabled_for_tests(test_runtime_paths()); manager - .upsert_environment("executor-a".to_string(), "ws://127.0.0.1:8765".to_string()) + .upsert_environment( + "executor-a".to_string(), + "ws://127.0.0.1:8765/ws".to_string(), + ) .expect("remote environment"); let first = manager .get_environment("executor-a") .expect("first remote environment"); assert!(first.is_remote()); - assert_eq!(first.exec_server_url(), Some("ws://127.0.0.1:8765")); + assert_eq!(first.exec_server_url(), Some("ws://127.0.0.1:8765/ws")); assert_eq!(manager.default_environment_id(), None); manager - .upsert_environment("executor-a".to_string(), "ws://127.0.0.1:9876".to_string()) + .upsert_environment( + "executor-a".to_string(), + "ws://127.0.0.1:9876/ws".to_string(), + ) .expect("updated remote environment"); let second = manager .get_environment("executor-a") .expect("second remote environment"); assert!(second.is_remote()); - assert_eq!(second.exec_server_url(), Some("ws://127.0.0.1:9876")); + assert_eq!(second.exec_server_url(), Some("ws://127.0.0.1:9876/ws")); assert!(!Arc::ptr_eq(&first, &second)); } diff --git a/codex-rs/exec-server/src/environment_provider.rs b/codex-rs/exec-server/src/environment_provider.rs index 7e132ee2b4..47f0e82b94 100644 --- a/codex-rs/exec-server/src/environment_provider.rs +++ b/codex-rs/exec-server/src/environment_provider.rs @@ -162,7 +162,7 @@ mod tests { #[tokio::test] async fn default_provider_adds_remote_environment_for_websocket_url() { - let provider = DefaultEnvironmentProvider::new(Some("ws://127.0.0.1:8765".to_string())); + let provider = DefaultEnvironmentProvider::new(Some("ws://127.0.0.1:8765/ws".to_string())); let snapshot = provider.snapshot().await.expect("environments"); let EnvironmentProviderSnapshot { environments, @@ -177,7 +177,7 @@ mod tests { assert!(remote_environment.is_remote()); assert_eq!( remote_environment.exec_server_url(), - Some("ws://127.0.0.1:8765") + Some("ws://127.0.0.1:8765/ws") ); assert_eq!( default, @@ -187,13 +187,14 @@ mod tests { #[tokio::test] async fn default_provider_normalizes_exec_server_url() { - let provider = DefaultEnvironmentProvider::new(Some(" ws://127.0.0.1:8765 ".to_string())); + let provider = + DefaultEnvironmentProvider::new(Some(" ws://127.0.0.1:8765/ws ".to_string())); let snapshot = provider.snapshot().await.expect("environments"); let environments: HashMap<_, _> = snapshot.environments.into_iter().collect(); assert_eq!( environments[REMOTE_ENVIRONMENT_ID].exec_server_url(), - Some("ws://127.0.0.1:8765") + Some("ws://127.0.0.1:8765/ws") ); } } diff --git a/codex-rs/exec-server/src/environment_toml.rs b/codex-rs/exec-server/src/environment_toml.rs index 90f4c78262..5615960eda 100644 --- a/codex-rs/exec-server/src/environment_toml.rs +++ b/codex-rs/exec-server/src/environment_toml.rs @@ -333,7 +333,7 @@ mod tests { environments: vec![ EnvironmentToml { id: "devbox".to_string(), - url: Some(" ws://127.0.0.1:8765 ".to_string()), + url: Some(" ws://127.0.0.1:8765/ws ".to_string()), ..Default::default() }, EnvironmentToml { @@ -370,7 +370,7 @@ mod tests { assert!(!environments.contains_key(LOCAL_ENVIRONMENT_ID)); assert_eq!( environments["devbox"].exec_server_url(), - Some("ws://127.0.0.1:8765") + Some("ws://127.0.0.1:8765/ws") ); assert!(environments["ssh-dev"].is_remote()); assert_eq!(environments["ssh-dev"].exec_server_url(), None); @@ -411,7 +411,7 @@ mod tests { ( EnvironmentToml { id: "local".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), ..Default::default() }, "environment id `local` is reserved", @@ -419,7 +419,7 @@ mod tests { ( EnvironmentToml { id: " devbox ".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), ..Default::default() }, "environment id ` devbox ` must not contain surrounding whitespace", @@ -427,7 +427,7 @@ mod tests { ( EnvironmentToml { id: "dev box".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), ..Default::default() }, "environment id `dev box` must contain only ASCII letters, numbers, '-' or '_'", @@ -443,7 +443,7 @@ mod tests { ( EnvironmentToml { id: "devbox".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), program: Some("codex".to_string()), ..Default::default() }, @@ -528,7 +528,7 @@ mod tests { environments: vec![ EnvironmentToml { id: "devbox".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), connect_timeout_sec: Some(Duration::from_secs(12)), initialize_timeout_sec: Some(Duration::from_secs(34)), ..Default::default() @@ -546,7 +546,7 @@ mod tests { assert_eq!( provider.environments[0].1, ExecServerTransportParams::WebSocketUrl { - websocket_url: "ws://127.0.0.1:8765".to_string(), + websocket_url: "ws://127.0.0.1:8765/ws".to_string(), connect_timeout: Duration::from_secs(12), initialize_timeout: Duration::from_secs(34), } @@ -591,7 +591,7 @@ mod tests { environments: vec![ EnvironmentToml { id: "devbox".to_string(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), ..Default::default() }, EnvironmentToml { @@ -616,7 +616,7 @@ mod tests { default: None, environments: vec![EnvironmentToml { id: id.clone(), - url: Some("ws://127.0.0.1:8765".to_string()), + url: Some("ws://127.0.0.1:8765/ws".to_string()), ..Default::default() }], }) @@ -655,7 +655,7 @@ default = "ssh-dev" [[environments]] id = "devbox" -url = "ws://127.0.0.1:4512" +url = "ws://127.0.0.1:4512/ws" connect_timeout_sec = 12.0 initialize_timeout_sec = 34.0 @@ -678,7 +678,7 @@ CODEX_LOG = "debug" environments.environments[0], EnvironmentToml { id: "devbox".to_string(), - url: Some("ws://127.0.0.1:4512".to_string()), + url: Some("ws://127.0.0.1:4512/ws".to_string()), connect_timeout_sec: Some(Duration::from_secs(12)), initialize_timeout_sec: Some(Duration::from_secs(34)), ..Default::default() @@ -712,7 +712,7 @@ CODEX_LOG = "debug" r#" [[environments]] id = "devbox" -url = "ws://127.0.0.1:4512" +url = "ws://127.0.0.1:4512/ws" unknown = true "#, "unknown field `unknown`", diff --git a/codex-rs/exec-server/src/remote.rs b/codex-rs/exec-server/src/remote.rs index 43c424a142..bb22105c19 100644 --- a/codex-rs/exec-server/src/remote.rs +++ b/codex-rs/exec-server/src/remote.rs @@ -7,6 +7,8 @@ use tokio::time::sleep; use tokio_tungstenite::connect_async; use tracing::warn; +use codex_utils_rustls_provider::ensure_rustls_crypto_provider; + use crate::ExecServerError; use crate::ExecServerRuntimePaths; use crate::connection::JsonRpcConnection; @@ -133,6 +135,7 @@ pub async fn run_remote_executor( config: RemoteExecutorConfig, runtime_paths: ExecServerRuntimePaths, ) -> Result<(), ExecServerError> { + ensure_rustls_crypto_provider(); let client = ExecutorRegistryClient::new(config.base_url.clone(), config.bearer_token.clone())?; let processor = ConnectionProcessor::new(runtime_paths); let mut backoff = Duration::from_secs(1); diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs index d284bf64bb..a82962c0d6 100644 --- a/codex-rs/exec-server/src/server/transport.rs +++ b/codex-rs/exec-server/src/server/transport.rs @@ -1,11 +1,18 @@ +use axum::Router; +use axum::extract::ConnectInfo; +use axum::extract::State; +use axum::extract::ws::WebSocketUpgrade; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::any; +use axum::routing::get; use std::io::Write as _; use std::net::SocketAddr; use tokio::io; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::net::TcpListener; -use tokio_tungstenite::accept_async; -use tracing::warn; +use tracing::info; use crate::ExecServerRuntimePaths; use crate::connection::JsonRpcConnection; @@ -109,31 +116,48 @@ async fn run_websocket_listener( let listener = TcpListener::bind(bind_address).await?; let local_addr = listener.local_addr()?; let processor = ConnectionProcessor::new(runtime_paths); - tracing::info!("codex-exec-server listening on ws://{local_addr}"); - println!("ws://{local_addr}"); + info!("codex-exec-server listening on ws://{local_addr}/ws"); + println!("ws://{local_addr}/ws"); std::io::stdout().flush()?; - loop { - let (stream, peer_addr) = listener.accept().await?; - let processor = processor.clone(); - tokio::spawn(async move { - match accept_async(stream).await { - Ok(websocket) => { - processor - .run_connection(JsonRpcConnection::from_websocket( - websocket, - format!("exec-server websocket {peer_addr}"), - )) - .await; - } - Err(err) => { - warn!( - "failed to accept exec-server websocket connection from {peer_addr}: {err}" - ); - } - } - }); - } + let router = Router::new() + .route("/", get(health_check_handler)) + .route("/readyz", get(health_check_handler)) + .route("/healthz", get(health_check_handler)) + .route("/ws", any(websocket_upgrade_handler)) + .with_state(ExecServerWebSocketState { processor }); + axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ) + .await?; + Ok(()) +} + +#[derive(Clone)] +struct ExecServerWebSocketState { + processor: ConnectionProcessor, +} + +async fn health_check_handler() -> StatusCode { + StatusCode::OK +} + +async fn websocket_upgrade_handler( + websocket: WebSocketUpgrade, + ConnectInfo(peer_addr): ConnectInfo, + State(state): State, +) -> impl IntoResponse { + info!(%peer_addr, "exec-server websocket client connected"); + websocket.on_upgrade(move |stream| async move { + state + .processor + .run_connection(JsonRpcConnection::from_axum_websocket( + stream, + format!("exec-server websocket {peer_addr}"), + )) + .await; + }) } #[cfg(test)] diff --git a/codex-rs/exec-server/tests/common/exec_server.rs b/codex-rs/exec-server/tests/common/exec_server.rs index 4ff7408715..f1bf03d25b 100644 --- a/codex-rs/exec-server/tests/common/exec_server.rs +++ b/codex-rs/exec-server/tests/common/exec_server.rs @@ -142,6 +142,11 @@ impl ExecServerHarness { Ok(()) } + pub(crate) async fn send_raw_binary(&mut self, bytes: Vec) -> anyhow::Result<()> { + self.websocket.send(Message::Binary(bytes.into())).await?; + Ok(()) + } + pub(crate) async fn next_event(&mut self) -> anyhow::Result { self.next_event_with_timeout(EVENT_TIMEOUT).await } diff --git a/codex-rs/exec-server/tests/health.rs b/codex-rs/exec-server/tests/health.rs new file mode 100644 index 0000000000..b4c2aa59f2 --- /dev/null +++ b/codex-rs/exec-server/tests/health.rs @@ -0,0 +1,24 @@ +#![cfg(unix)] + +mod common; + +use common::exec_server::exec_server; +use pretty_assertions::assert_eq; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_serves_health_checks_alongside_websocket_endpoint() -> anyhow::Result<()> { + let mut server = exec_server().await?; + let http_base_url = server + .websocket_url() + .strip_prefix("ws://") + .and_then(|url| url.strip_suffix("/ws")) + .expect("websocket URL should use ws://.../ws"); + + for path in ["/", "/readyz", "/healthz"] { + let response = reqwest::get(format!("http://{http_base_url}{path}")).await?; + assert_eq!(response.status(), reqwest::StatusCode::OK); + } + + server.shutdown().await?; + Ok(()) +} diff --git a/codex-rs/exec-server/tests/websocket.rs b/codex-rs/exec-server/tests/websocket.rs index 64c9438b81..03b2cabfcd 100644 --- a/codex-rs/exec-server/tests/websocket.rs +++ b/codex-rs/exec-server/tests/websocket.rs @@ -60,3 +60,39 @@ async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> any server.shutdown().await?; Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_accepts_binary_websocket_json() -> anyhow::Result<()> { + let mut server = exec_server().await?; + let initialize_id = codex_app_server_protocol::RequestId::Integer(1); + let initialize = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: initialize_id.clone(), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-binary-test".to_string(), + resume_session_id: None, + })?), + trace: None, + }); + server + .send_raw_binary(serde_json::to_vec(&initialize)?) + .await?; + + let response = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id + ) + }) + .await?; + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected initialize response for binary input"); + }; + assert_eq!(id, initialize_id); + let initialize_response: InitializeResponse = serde_json::from_value(result)?; + Uuid::parse_str(&initialize_response.session_id)?; + + server.shutdown().await?; + Ok(()) +} diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index ac5b489af1..86ffc1a073 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -2145,7 +2145,7 @@ mod tests { }; let target = AppServerTarget::Embedded; let environment_manager = EnvironmentManager::create_for_tests( - Some("ws://127.0.0.1:8765".to_string()), + Some("ws://127.0.0.1:8765/ws".to_string()), ExecServerRuntimePaths::new( std::env::current_exe().expect("current exe"), /*codex_linux_sandbox_exe*/ None, diff --git a/scripts/start-codex-exec.sh b/scripts/start-codex-exec.sh index 8835d4c095..41b76b0317 100755 --- a/scripts/start-codex-exec.sh +++ b/scripts/start-codex-exec.sh @@ -161,6 +161,7 @@ if [[ -z "${remote_exec_server_pid}" || -z "${listen_url}" || -z "${remote_repo_ fi remote_exec_server_port="${listen_url##*:}" +remote_exec_server_port="${remote_exec_server_port%%/*}" if [[ -z "${remote_exec_server_port}" || "${remote_exec_server_port}" == "${listen_url}" ]]; then echo "failed to parse remote exec server port from ${listen_url}" >&2 exit 1 @@ -170,7 +171,7 @@ echo "Remote exec server: ${listen_url}" echo "Remote exec server log: ${remote_exec_server_log_path}" echo "Press Ctrl-C to stop the SSH tunnel and remote exec server." echo "Start codex via: " -printf ' CODEX_EXEC_SERVER_URL=ws://127.0.0.1:%s codex -C %q\n' \ +printf ' CODEX_EXEC_SERVER_URL=ws://127.0.0.1:%s/ws codex -C %q\n' \ "${local_exec_server_port}" \ "${remote_repo_root}" diff --git a/scripts/test-remote-env.sh b/scripts/test-remote-env.sh index 96743616a2..c7c7b51e9a 100755 --- a/scripts/test-remote-env.sh +++ b/scripts/test-remote-env.sh @@ -85,7 +85,7 @@ setup_remote_env() { return 1 fi export CODEX_TEST_REMOTE_EXEC_SERVER_PID="${remote_exec_server_pid}" - export CODEX_TEST_REMOTE_EXEC_SERVER_URL="ws://${container_ip}:${remote_exec_server_port}" +export CODEX_TEST_REMOTE_EXEC_SERVER_URL="ws://${container_ip}:${remote_exec_server_port}/ws" fi export CODEX_TEST_REMOTE_ENV="${container_name}"