diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 47359393d3..0729be151d 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -17,13 +17,14 @@ use tokio::sync::mpsc; use tokio::sync::watch; use tokio::time::timeout; -use tokio_tungstenite::connect_async; use tracing::debug; use crate::ProcessId; use crate::client_api::ExecServerClientConnectOptions; +use crate::client_api::ExecServerTransport; use crate::client_api::HttpClient; use crate::client_api::RemoteExecServerConnectArgs; +use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; use crate::process::ExecProcessEventLog; @@ -105,6 +106,16 @@ impl From for ExecServerClientConnectOptions { } } +impl From for ExecServerClientConnectOptions { + fn from(value: StdioExecServerConnectArgs) -> Self { + Self { + client_name: value.client_name, + initialize_timeout: value.initialize_timeout, + resume_session_id: value.resume_session_id, + } + } +} + impl RemoteExecServerConnectArgs { pub fn new(websocket_url: String, client_name: String) -> Self { Self { @@ -180,29 +191,23 @@ pub struct ExecServerClient { #[derive(Clone)] pub(crate) struct LazyRemoteExecServerClient { - websocket_url: String, + transport: ExecServerTransport, client: Arc>, } impl LazyRemoteExecServerClient { - pub(crate) fn new(websocket_url: String) -> Self { + pub(crate) fn new(transport: ExecServerTransport) -> Self { Self { - websocket_url, + transport, client: Arc::new(OnceCell::new()), } } pub(crate) async fn get(&self) -> Result { self.client - .get_or_try_init(|| async { - ExecServerClient::connect_websocket(RemoteExecServerConnectArgs { - websocket_url: self.websocket_url.clone(), - client_name: "codex-environment".to_string(), - connect_timeout: Duration::from_secs(5), - initialize_timeout: Duration::from_secs(5), - resume_session_id: None, - }) - .await + .get_or_try_init(|| { + let transport = self.transport.clone(); + async move { transport.connect_for_environment().await } }) .await .cloned() @@ -269,32 +274,6 @@ pub enum ExecServerError { } impl ExecServerClient { - pub async fn connect_websocket( - args: RemoteExecServerConnectArgs, - ) -> Result { - let websocket_url = args.websocket_url.clone(); - let connect_timeout = args.connect_timeout; - let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str())) - .await - .map_err(|_| ExecServerError::WebSocketConnectTimeout { - url: websocket_url.clone(), - timeout: connect_timeout, - })? - .map_err(|source| ExecServerError::WebSocketConnect { - url: websocket_url.clone(), - source, - })?; - - Self::connect( - JsonRpcConnection::from_websocket( - stream, - format!("exec-server websocket {websocket_url}"), - ), - args.into(), - ) - .await - } - pub async fn initialize( &self, options: ExecServerClientConnectOptions, @@ -443,7 +422,7 @@ impl ExecServerClient { .clone() } - async fn connect( + pub(crate) async fn connect( connection: JsonRpcConnection, options: ExecServerClientConnectOptions, ) -> Result { @@ -905,6 +884,7 @@ mod tests { use super::ExecServerClient; use super::ExecServerClientConnectOptions; use crate::ProcessId; + use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; use crate::protocol::EXEC_CLOSED_METHOD; @@ -942,6 +922,21 @@ mod tests { .expect("json-rpc line should write"); } + #[cfg(not(windows))] + #[tokio::test] + async fn connect_stdio_command_initializes_json_rpc_client() { + let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + shell_command: "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(), + client_name: "stdio-test-client".to_string(), + initialize_timeout: Duration::from_secs(1), + resume_session_id: None, + }) + .await + .expect("stdio client should connect"); + + assert_eq!(client.session_id().as_deref(), Some("stdio-test")); + } + #[tokio::test] async fn process_events_are_delivered_in_seq_order_when_notifications_are_reordered() { let (client_stdin, server_reader) = duplex(1 << 20); diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index b1761b69f1..95ed053476 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -25,6 +25,22 @@ pub struct RemoteExecServerConnectArgs { pub resume_session_id: Option, } +/// Stdio connection arguments for a command-backed exec-server. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StdioExecServerConnectArgs { + pub shell_command: String, + pub client_name: String, + pub initialize_timeout: Duration, + pub resume_session_id: Option, +} + +/// Transport used to connect to a remote exec-server environment. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExecServerTransport { + WebSocketUrl(String), + StdioShellCommand(String), +} + /// Sends HTTP requests through a runtime-selected transport. /// /// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs new file mode 100644 index 0000000000..908a9e1b05 --- /dev/null +++ b/codex-rs/exec-server/src/client_transport.rs @@ -0,0 +1,176 @@ +use std::process::Stdio; +use std::time::Duration; + +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::process::Child; +use tokio::process::Command; +use tokio::runtime::Handle; +use tokio::time::timeout; +use tokio_tungstenite::connect_async; +use tracing::debug; +use tracing::warn; + +use crate::ExecServerClient; +use crate::ExecServerError; +use crate::client_api::ExecServerTransport; +use crate::client_api::RemoteExecServerConnectArgs; +use crate::client_api::StdioExecServerConnectArgs; +use crate::connection::JsonRpcConnection; + +const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; +const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5); + +impl ExecServerTransport { + pub(crate) async fn connect_for_environment(self) -> Result { + match self { + ExecServerTransport::WebSocketUrl(websocket_url) => { + ExecServerClient::connect_websocket(RemoteExecServerConnectArgs { + websocket_url, + client_name: ENVIRONMENT_CLIENT_NAME.to_string(), + connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT, + initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, + resume_session_id: None, + }) + .await + } + ExecServerTransport::StdioShellCommand(shell_command) => { + ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + shell_command, + client_name: ENVIRONMENT_CLIENT_NAME.to_string(), + initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, + resume_session_id: None, + }) + .await + } + } + } +} + +impl ExecServerClient { + pub async fn connect_websocket( + args: RemoteExecServerConnectArgs, + ) -> Result { + let websocket_url = args.websocket_url.clone(); + let connect_timeout = args.connect_timeout; + let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str())) + .await + .map_err(|_| ExecServerError::WebSocketConnectTimeout { + url: websocket_url.clone(), + timeout: connect_timeout, + })? + .map_err(|source| ExecServerError::WebSocketConnect { + url: websocket_url.clone(), + source, + })?; + + Self::connect( + JsonRpcConnection::from_websocket( + stream, + format!("exec-server websocket {websocket_url}"), + ), + args.into(), + ) + .await + } + + pub async fn connect_stdio_command( + args: StdioExecServerConnectArgs, + ) -> Result { + let shell_command = args.shell_command.clone(); + let mut child = shell_command_process(&shell_command) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(ExecServerError::Spawn)?; + + let stdin = child.stdin.take().ok_or_else(|| { + ExecServerError::Protocol("spawned exec-server command has no stdin".to_string()) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + ExecServerError::Protocol("spawned exec-server command has no stdout".to_string()) + })?; + if let Some(stderr) = child.stderr.take() { + tokio::spawn(async move { + let mut lines = BufReader::new(stderr).lines(); + loop { + match lines.next_line().await { + Ok(Some(line)) => debug!("exec-server stdio stderr: {line}"), + Ok(None) => break, + Err(err) => { + warn!("failed to read exec-server stdio stderr: {err}"); + break; + } + } + } + }); + } + + Self::connect( + JsonRpcConnection::from_stdio( + stdout, + stdin, + format!("exec-server stdio command `{shell_command}`"), + ) + .with_lifetime_guard(Box::new(StdioChildGuard { child: Some(child) })), + args.into(), + ) + .await + } +} + +struct StdioChildGuard { + child: Option, +} + +impl Drop for StdioChildGuard { + fn drop(&mut self) { + let Some(child) = self.child.take() else { + return; + }; + + match Handle::try_current() { + Ok(handle) => { + let _terminate_task = handle.spawn(terminate_stdio_child(child)); + } + Err(_) => { + terminate_stdio_child_now(child); + } + } + } +} + +async fn terminate_stdio_child(mut child: Child) { + kill_stdio_child(&mut child); + if let Err(err) = child.wait().await { + debug!("failed to wait for exec-server stdio child: {err}"); + } +} + +fn terminate_stdio_child_now(mut child: Child) { + kill_stdio_child(&mut child); +} + +fn kill_stdio_child(child: &mut Child) { + if let Err(err) = child.start_kill() { + debug!("failed to terminate exec-server stdio child: {err}"); + } +} + +fn shell_command_process(shell_command: &str) -> Command { + #[cfg(windows)] + { + let mut command = Command::new("cmd"); + command.arg("/C").arg(shell_command); + command + } + + #[cfg(not(windows))] + { + let mut command = Command::new("sh"); + command.arg("-lc").arg(shell_command); + command + } +} diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 71f4f31059..f79ac3da57 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -15,6 +15,8 @@ use tokio::io::BufWriter; pub(crate) const CHANNEL_CAPACITY: usize = 128; +pub(crate) type JsonRpcConnectionLifetimeGuard = Box; + #[derive(Debug)] pub(crate) enum JsonRpcConnectionEvent { Message(JSONRPCMessage), @@ -27,6 +29,7 @@ pub(crate) struct JsonRpcConnection { incoming_rx: mpsc::Receiver, disconnected_rx: watch::Receiver, task_handles: Vec>, + lifetime_guard: Option, } impl JsonRpcConnection { @@ -117,6 +120,7 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![reader_task, writer_task], + lifetime_guard: None, } } @@ -251,9 +255,15 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![reader_task, writer_task], + lifetime_guard: None, } } + pub(crate) fn with_lifetime_guard(mut self, guard: JsonRpcConnectionLifetimeGuard) -> Self { + self.lifetime_guard = Some(guard); + self + } + pub(crate) fn into_parts( self, ) -> ( @@ -261,12 +271,14 @@ impl JsonRpcConnection { mpsc::Receiver, watch::Receiver, Vec>, + Option, ) { ( self.outgoing_tx, self.incoming_rx, self.disconnected_rx, self.task_handles, + self.lifetime_guard, ) } } diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 855989dafb..3764b29fe5 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -7,6 +7,7 @@ use crate::ExecutorFileSystem; use crate::HttpClient; use crate::client::LazyRemoteExecServerClient; use crate::client::http_client::ReqwestHttpClient; +use crate::client_api::ExecServerTransport; use crate::environment_provider::DefaultEnvironmentProvider; use crate::environment_provider::EnvironmentProvider; use crate::environment_provider::normalize_exec_server_url; @@ -274,7 +275,9 @@ impl Environment { exec_server_url: String, local_runtime_paths: Option, ) -> Self { - let client = LazyRemoteExecServerClient::new(exec_server_url.clone()); + let client = LazyRemoteExecServerClient::new(ExecServerTransport::WebSocketUrl( + exec_server_url.clone(), + )); let exec_backend: Arc = Arc::new(RemoteProcess::new(client.clone())); let filesystem: Arc = Arc::new(RemoteFileSystem::new(client.clone())); diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index d860d59aba..9bec4ed1da 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -1,5 +1,6 @@ mod client; mod client_api; +mod client_transport; mod connection; mod environment; mod environment_provider; @@ -24,8 +25,10 @@ pub use client::ExecServerError; pub use client::http_client::HttpResponseBodyStream; pub use client::http_client::ReqwestHttpClient; pub use client_api::ExecServerClientConnectOptions; +pub use client_api::ExecServerTransport; pub use client_api::HttpClient; pub use client_api::RemoteExecServerConnectArgs; +pub use client_api::StdioExecServerConnectArgs; pub use codex_file_system::CopyOptions; pub use codex_file_system::CreateDirectoryOptions; pub use codex_file_system::ExecutorFileSystem; diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 723b99f502..2cce5c0400 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -23,6 +23,7 @@ use tokio::task::JoinHandle; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; +use crate::connection::JsonRpcConnectionLifetimeGuard; #[derive(Debug)] pub(crate) enum RpcCallError { @@ -229,12 +230,14 @@ pub(crate) struct RpcClient { disconnected_rx: watch::Receiver, next_request_id: AtomicI64, transport_tasks: Vec>, + _transport_lifetime_guard: Option, reader_task: JoinHandle<()>, } impl RpcClient { pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts(); + let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks, lifetime_guard) = + connection.into_parts(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); @@ -275,6 +278,7 @@ impl RpcClient { disconnected_rx, next_request_id: AtomicI64::new(1), transport_tasks, + _transport_lifetime_guard: lifetime_guard, reader_task, }, event_rx, diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index dc1a9b9ffe..50907f7e42 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -47,7 +47,7 @@ async fn run_connection( runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); - let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) = + let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks, _lifetime_guard) = connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY);