diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 7b6e7448d4..18b811bc1d 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -446,7 +446,7 @@ struct AppServerCommand { #[derive(Debug, Parser)] struct ExecServerCommand { - /// Transport endpoint URL. Supported values: `ws://IP:PORT` (default). + /// Transport endpoint URL. Supported values: `ws://IP:PORT` (default), `stdio`, `stdio://`. #[arg( long = "listen", value_name = "URL", diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 21eac6b4c5..71f4f31059 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -8,13 +8,9 @@ use tokio::sync::watch; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; -#[cfg(test)] use tokio::io::AsyncBufReadExt; -#[cfg(test)] use tokio::io::AsyncWriteExt; -#[cfg(test)] use tokio::io::BufReader; -#[cfg(test)] use tokio::io::BufWriter; pub(crate) const CHANNEL_CAPACITY: usize = 128; @@ -34,7 +30,6 @@ pub(crate) struct JsonRpcConnection { } impl JsonRpcConnection { - #[cfg(test)] pub(crate) fn from_stdio(reader: R, writer: W, connection_label: String) -> Self where R: AsyncRead + Unpin + Send + 'static, @@ -298,7 +293,6 @@ async fn send_malformed_message( .await; } -#[cfg(test)] async fn write_jsonrpc_line_message( writer: &mut BufWriter, message: &JSONRPCMessage, diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs index b8a5a086b6..d284bf64bb 100644 --- a/codex-rs/exec-server/src/server/transport.rs +++ b/codex-rs/exec-server/src/server/transport.rs @@ -1,5 +1,8 @@ use std::io::Write as _; use std::net::SocketAddr; +use tokio::io; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio::net::TcpListener; use tokio_tungstenite::accept_async; use tracing::warn; @@ -10,6 +13,12 @@ use crate::server::processor::ConnectionProcessor; pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0"; +#[derive(Debug, Clone, Eq, PartialEq)] +pub(crate) enum ExecServerListenTransport { + WebSocket(SocketAddr), + Stdio, +} + #[derive(Debug, Clone, Eq, PartialEq)] pub enum ExecServerListenUrlParseError { UnsupportedListenUrl(String), @@ -21,7 +30,7 @@ impl std::fmt::Display for ExecServerListenUrlParseError { match self { ExecServerListenUrlParseError::UnsupportedListenUrl(listen_url) => write!( f, - "unsupported --listen URL `{listen_url}`; expected `ws://IP:PORT`" + "unsupported --listen URL `{listen_url}`; expected `ws://IP:PORT` or `stdio`" ), ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url) => write!( f, @@ -35,11 +44,18 @@ impl std::error::Error for ExecServerListenUrlParseError {} pub(crate) fn parse_listen_url( listen_url: &str, -) -> Result { +) -> Result { + if matches!(listen_url, "stdio" | "stdio://") { + return Ok(ExecServerListenTransport::Stdio); + } + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { - return socket_addr.parse::().map_err(|_| { - ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url.to_string()) - }); + return socket_addr + .parse::() + .map(ExecServerListenTransport::WebSocket) + .map_err(|_| { + ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + }); } Err(ExecServerListenUrlParseError::UnsupportedListenUrl( @@ -51,8 +67,39 @@ pub(crate) async fn run_transport( listen_url: &str, runtime_paths: ExecServerRuntimePaths, ) -> Result<(), Box> { - let bind_address = parse_listen_url(listen_url)?; - run_websocket_listener(bind_address, runtime_paths).await + match parse_listen_url(listen_url)? { + ExecServerListenTransport::WebSocket(bind_address) => { + run_websocket_listener(bind_address, runtime_paths).await + } + ExecServerListenTransport::Stdio => run_stdio_connection(runtime_paths).await, + } +} + +async fn run_stdio_connection( + runtime_paths: ExecServerRuntimePaths, +) -> Result<(), Box> { + run_stdio_connection_with_io(io::stdin(), io::stdout(), runtime_paths).await +} + +async fn run_stdio_connection_with_io( + reader: R, + writer: W, + runtime_paths: ExecServerRuntimePaths, +) -> Result<(), Box> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let processor = ConnectionProcessor::new(runtime_paths); + tracing::info!("codex-exec-server listening on stdio"); + processor + .run_connection(JsonRpcConnection::from_stdio( + reader, + writer, + "exec-server stdio".to_string(), + )) + .await; + Ok(()) } async fn run_websocket_listener( diff --git a/codex-rs/exec-server/src/server/transport_tests.rs b/codex-rs/exec-server/src/server/transport_tests.rs index bec91c936e..b9787d8a37 100644 --- a/codex-rs/exec-server/src/server/transport_tests.rs +++ b/codex-rs/exec-server/src/server/transport_tests.rs @@ -1,31 +1,127 @@ use std::net::SocketAddr; +use std::time::Duration; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; use pretty_assertions::assert_eq; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::duplex; +use tokio::time::timeout; use super::DEFAULT_LISTEN_URL; +use super::ExecServerListenTransport; use super::parse_listen_url; +use super::run_stdio_connection_with_io; +use crate::ExecServerRuntimePaths; +use crate::protocol::INITIALIZE_METHOD; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeParams; +use crate::protocol::InitializeResponse; #[test] fn parse_listen_url_accepts_default_websocket_url() { - let bind_address = - parse_listen_url(DEFAULT_LISTEN_URL).expect("default listen URL should parse"); + let transport = parse_listen_url(DEFAULT_LISTEN_URL).expect("default listen URL should parse"); assert_eq!( - bind_address, - "127.0.0.1:0" - .parse::() - .expect("valid socket address") + transport, + ExecServerListenTransport::WebSocket( + "127.0.0.1:0" + .parse::() + .expect("valid socket address") + ) ); } +#[test] +fn parse_listen_url_accepts_stdio() { + let transport = parse_listen_url("stdio").expect("stdio listen URL should parse"); + assert_eq!(transport, ExecServerListenTransport::Stdio); +} + +#[test] +fn parse_listen_url_accepts_stdio_url() { + let transport = parse_listen_url("stdio://").expect("stdio listen URL should parse"); + assert_eq!(transport, ExecServerListenTransport::Stdio); +} + +#[tokio::test] +async fn stdio_listen_transport_serves_initialize() { + let transport = parse_listen_url("stdio").expect("stdio listen URL should parse"); + let ExecServerListenTransport::Stdio = transport else { + panic!("expected stdio listen transport, got {transport:?}"); + }; + + let (mut client_writer, server_reader) = duplex(1 << 20); + let (server_writer, client_reader) = duplex(1 << 20); + let server_task = tokio::spawn(run_stdio_connection_with_io( + server_reader, + server_writer, + test_runtime_paths(), + )); + let mut client_lines = BufReader::new(client_reader).lines(); + + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: INITIALIZE_METHOD.to_string(), + params: Some( + serde_json::to_value(InitializeParams { + client_name: "exec-server-transport-test".to_string(), + resume_session_id: None, + }) + .expect("initialize params should serialize"), + ), + trace: None, + }); + write_jsonrpc_line(&mut client_writer, &initialize).await; + + let response = timeout(Duration::from_secs(1), client_lines.next_line()) + .await + .expect("initialize response should arrive") + .expect("initialize response read should succeed") + .expect("initialize response should be present"); + let response: JSONRPCMessage = + serde_json::from_str(&response).expect("initialize response should parse"); + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected initialize response, got {response:?}"); + }; + assert_eq!(id, RequestId::Integer(1)); + let initialize_response: InitializeResponse = + serde_json::from_value(result).expect("initialize response should decode"); + assert!( + !initialize_response.session_id.is_empty(), + "initialize should return a session id" + ); + + let initialized = JSONRPCMessage::Notification(JSONRPCNotification { + method: INITIALIZED_METHOD.to_string(), + params: Some(serde_json::to_value(()).expect("initialized params should serialize")), + }); + write_jsonrpc_line(&mut client_writer, &initialized).await; + + drop(client_writer); + drop(client_lines); + timeout(Duration::from_secs(1), server_task) + .await + .expect("stdio transport should finish after client disconnect") + .expect("stdio transport task should join") + .expect("stdio transport should not fail"); +} + #[test] fn parse_listen_url_accepts_websocket_url() { - let bind_address = + let transport = parse_listen_url("ws://127.0.0.1:1234").expect("websocket listen URL should parse"); assert_eq!( - bind_address, - "127.0.0.1:1234" - .parse::() - .expect("valid socket address") + transport, + ExecServerListenTransport::WebSocket( + "127.0.0.1:1234" + .parse::() + .expect("valid socket address") + ) ); } @@ -45,6 +141,26 @@ fn parse_listen_url_rejects_unsupported_url() { parse_listen_url("http://127.0.0.1:1234").expect_err("unsupported scheme should fail"); assert_eq!( err.to_string(), - "unsupported --listen URL `http://127.0.0.1:1234`; expected `ws://IP:PORT`" + "unsupported --listen URL `http://127.0.0.1:1234`; expected `ws://IP:PORT` or `stdio`" ); } + +async fn write_jsonrpc_line(writer: &mut tokio::io::DuplexStream, message: &JSONRPCMessage) { + let encoded = serde_json::to_vec(message).expect("JSON-RPC message should serialize"); + writer + .write_all(&encoded) + .await + .expect("JSON-RPC message should write"); + writer + .write_all(b"\n") + .await + .expect("JSON-RPC newline should write"); +} + +fn test_runtime_paths() -> ExecServerRuntimePaths { + ExecServerRuntimePaths::new( + std::env::current_exe().expect("current exe"), + /*codex_linux_sandbox_exe*/ None, + ) + .expect("runtime paths") +}