diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 9261a59542..07a1d31eb5 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -12,7 +12,7 @@ use futures::FutureExt; use futures::future::BoxFuture; use serde_json::Value; use tokio::sync::Mutex; -use tokio::sync::OnceCell; +use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::sync::watch; @@ -25,6 +25,7 @@ use crate::client_api::ExecServerTransportParams; use crate::client_api::HttpClient; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerConnectArgs; +use crate::client_transport::ENVIRONMENT_CLIENT_NAME; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; use crate::process::ExecProcessEventLog; @@ -128,11 +129,12 @@ impl RemoteExecServerConnectArgs { } } -pub(crate) struct SessionState { +pub(crate) struct ProcessSession { wake_tx: watch::Sender, events: ExecProcessEventLog, ordered_events: StdMutex, failure: Mutex>, + disconnect_behavior: ProcessSessionDisconnectBehavior, } #[derive(Default)] @@ -145,30 +147,46 @@ struct OrderedSessionEvents { } #[derive(Clone)] -pub(crate) struct Session { - client: ExecServerClient, +pub(crate) struct ProcessSessionHandle { + control: ProcessSessionControl, process_id: ProcessId, - state: Arc, + session: Arc, +} + +#[derive(Clone)] +enum ProcessSessionControl { + #[cfg(test)] + // Direct connections are used by one-shot callers and focused client tests. + Connection(ExecServerConnection), + // Remote environments use the logical client so process sessions survive + // connection replacement across reconnect. + RemoteClient(RemoteExecServerClient), +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum ProcessSessionDisconnectBehavior { + Fail, + Preserve, } struct Inner { - client: RpcClient, + rpc_client: RpcClient, // The remote transport delivers one shared notification stream for every - // process on the connection. Keep a local process_id -> session registry so - // we can turn those connection-global notifications into process wakeups + // process on the connection. Keep a local process_id -> session route map + // so we can turn those connection-global notifications into process wakeups // without making notifications the source of truth for output delivery. - sessions: ArcSwap>>, + process_session_routes: ArcSwap>>, // ArcSwap makes reads cheap on the hot notification path, but writes still // need serialization so concurrent register/remove operations do not // overwrite each other's copy-on-write updates. - sessions_write_lock: Mutex<()>, + process_session_routes_write_lock: Mutex<()>, // Once the transport closes, every executor operation should fail quickly - // with the same canonical message. This client never reconnects, so the - // latch only moves from unset to set once. + // with the same canonical message. This connection never reconnects, so + // the latch only moves from unset to set once. disconnected: OnceLock, // Streaming HTTP responses are keyed by a client-generated request id // because they share the same connection-global notification channel as - // process output. Keep the routing table local to the client so higher + // process output. Keep the routing table local to the connection so higher // layers can consume body chunks like a normal byte stream. http_body_streams: ArcSwap>>, http_body_stream_failures: ArcSwap>, @@ -185,43 +203,273 @@ impl Drop for Inner { } #[derive(Clone)] -pub struct ExecServerClient { +pub struct ExecServerConnection { inner: Arc, } #[derive(Clone)] -pub(crate) struct LazyRemoteExecServerClient { +pub(crate) struct RemoteExecServerClient { transport_params: ExecServerTransportParams, - client: Arc>, + session: Arc>, } -impl LazyRemoteExecServerClient { +// Shared state for one logical remote exec-server client. The logical client +// owns resumable session identity and durable process session state; individual +// ExecServerConnection values only bind that state to one live transport. +struct RemoteExecServerSession { + connection: Option, + connection_attempt: Option>, + logical_session_id: Option, + terminal_error: Option, + process_sessions: HashMap>, +} + +enum RemoteExecServerConnectionAction { + Ready(ExecServerConnection), + Wait(BoxFuture<'static, ()>), + Connect { + connection_attempt: Arc, + resume_session_id: Option, + process_sessions: Vec<(ProcessId, Arc)>, + }, +} + +#[derive(Clone)] +struct TerminalReconnectError { + code: i64, + message: String, +} + +impl RemoteExecServerClient { pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self { Self { transport_params, - client: Arc::new(OnceCell::new()), + session: Arc::new(StdMutex::new(RemoteExecServerSession { + connection: None, + connection_attempt: None, + logical_session_id: None, + terminal_error: None, + process_sessions: HashMap::new(), + })), } } - pub(crate) async fn get(&self) -> Result { - self.client - // TODO: Add reconnect/disconnect handling here instead of reusing - // the first successfully initialized connection forever. - .get_or_try_init(|| { - let transport_params = self.transport_params.clone(); - async move { ExecServerClient::connect_for_transport(transport_params).await } - }) + pub(crate) async fn connection(&self) -> Result { + loop { + match self.next_connection_action()? { + RemoteExecServerConnectionAction::Ready(connection) => return Ok(connection), + RemoteExecServerConnectionAction::Wait(connection_attempt) => { + connection_attempt.await; + } + RemoteExecServerConnectionAction::Connect { + connection_attempt, + resume_session_id, + process_sessions, + } => { + let connection = self + .connect_and_rebind(resume_session_id.clone(), process_sessions) + .await; + return self.finish_connection_attempt( + connection_attempt, + resume_session_id, + connection, + ); + } + } + } + } + + pub(crate) async fn register_process_session( + &self, + process_id: &ProcessId, + ) -> Result { + let process_session = Arc::new(ProcessSession::new( + ProcessSessionDisconnectBehavior::Preserve, + )); + { + let mut session = self.lock_session(); + if session.process_sessions.contains_key(process_id) { + return Err(ExecServerError::Protocol(format!( + "session already registered for process {process_id}" + ))); + } + session + .process_sessions + .insert(process_id.clone(), Arc::clone(&process_session)); + } + + let connection = self.connection().await?; + if let Err(err) = connection + .register_process_session_route(process_id, Arc::clone(&process_session)) .await - .cloned() + { + self.unregister_process_session(process_id).await; + return Err(err); + } + + Ok(ProcessSessionHandle { + control: ProcessSessionControl::RemoteClient(self.clone()), + process_id: process_id.clone(), + session: process_session, + }) + } + + async fn read(&self, params: ReadParams) -> Result { + let connection = self.connection().await?; + match connection.read(params.clone()).await { + Ok(response) => Ok(response), + Err(err) if is_transport_closed_error(&err) && self.supports_reconnect() => { + self.connection().await?.read(params).await + } + Err(err) => Err(err), + } + } + + async fn write( + &self, + process_id: &ProcessId, + chunk: Vec, + ) -> Result { + self.connection().await?.write(process_id, chunk).await + } + + async fn terminate( + &self, + process_id: &ProcessId, + ) -> Result { + self.connection().await?.terminate(process_id).await + } + + async fn unregister_process_session(&self, process_id: &ProcessId) { + let connection = { + let mut session = self.lock_session(); + session.process_sessions.remove(process_id); + session.connection.clone() + }; + if let Some(connection) = connection { + connection.unregister_process_session(process_id).await; + } + } + + fn next_connection_action(&self) -> Result { + let mut session = self.lock_session(); + if let Some(error) = &session.terminal_error { + return Err(error.to_exec_server_error()); + } + + if let Some(connection) = &session.connection { + if let Some(error) = connection.disconnected_error() { + if !self.supports_reconnect() { + return Err(error); + } + } else { + return Ok(RemoteExecServerConnectionAction::Ready(connection.clone())); + } + } + + if let Some(connection_attempt) = &session.connection_attempt { + let connection_attempt = Arc::clone(connection_attempt).notified_owned(); + return Ok(RemoteExecServerConnectionAction::Wait( + connection_attempt.boxed(), + )); + } + + let connection_attempt = Arc::new(Notify::new()); + let resume_session_id = session.logical_session_id.clone(); + let process_sessions = session + .process_sessions + .iter() + .map(|(process_id, process_session)| (process_id.clone(), Arc::clone(process_session))) + .collect(); + session.connection_attempt = Some(Arc::clone(&connection_attempt)); + Ok(RemoteExecServerConnectionAction::Connect { + connection_attempt, + resume_session_id, + process_sessions, + }) + } + + async fn connect_and_rebind( + &self, + resume_session_id: Option, + process_sessions: Vec<(ProcessId, Arc)>, + ) -> Result { + let connection = self.connect(resume_session_id).await?; + for (process_id, process_session) in process_sessions { + connection + .register_process_session_route(&process_id, process_session) + .await?; + } + Ok(connection) + } + + fn finish_connection_attempt( + &self, + connection_attempt: Arc, + resume_session_id: Option, + connection: Result, + ) -> Result { + let mut session = self.lock_session(); + if let Err(err) = &connection { + if resume_session_id.is_some() + && let Some(terminal_error) = TerminalReconnectError::from_error(err) + { + session.terminal_error = Some(terminal_error); + } + } else if let Ok(connection) = &connection { + session.logical_session_id = connection.session_id(); + session.connection = Some(connection.clone()); + } + session.connection_attempt = None; + connection_attempt.notify_waiters(); + connection + } + + fn lock_session(&self) -> std::sync::MutexGuard<'_, RemoteExecServerSession> { + self.session + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + async fn connect( + &self, + resume_session_id: Option, + ) -> Result { + match &self.transport_params { + ExecServerTransportParams::WebSocketUrl { + websocket_url, + connect_timeout, + initialize_timeout, + } => { + ExecServerConnection::connect_websocket(RemoteExecServerConnectArgs { + websocket_url: websocket_url.clone(), + client_name: ENVIRONMENT_CLIENT_NAME.to_string(), + connect_timeout: *connect_timeout, + initialize_timeout: *initialize_timeout, + resume_session_id, + }) + .await + } + ExecServerTransportParams::StdioCommand { .. } => { + ExecServerConnection::connect_for_transport(self.transport_params.clone()).await + } + } + } + + fn supports_reconnect(&self) -> bool { + matches!( + &self.transport_params, + ExecServerTransportParams::WebSocketUrl { .. } + ) } } -impl HttpClient for LazyRemoteExecServerClient { +impl HttpClient for RemoteExecServerClient { fn http_request( &self, params: crate::HttpRequestParams, ) -> BoxFuture<'_, Result> { - async move { self.get().await?.http_request(params).await }.boxed() + async move { self.connection().await?.http_request(params).await }.boxed() } fn http_request_stream( @@ -231,7 +479,7 @@ impl HttpClient for LazyRemoteExecServerClient { '_, Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>, > { - async move { self.get().await?.http_request_stream(params).await }.boxed() + async move { self.connection().await?.http_request_stream(params).await }.boxed() } } @@ -275,7 +523,7 @@ pub enum ExecServerError { ExecutorRegistryRequest(#[from] reqwest::Error), } -impl ExecServerClient { +impl ExecServerConnection { pub async fn initialize( &self, options: ExecServerClientConnectOptions, @@ -289,7 +537,7 @@ impl ExecServerClient { timeout(initialize_timeout, async { let response: InitializeResponse = self .inner - .client + .rpc_client .call( INITIALIZE_METHOD, &InitializeParams { @@ -397,23 +645,33 @@ impl ExecServerClient { self.call(FS_COPY_METHOD, ¶ms).await } - pub(crate) async fn register_session( + #[cfg(test)] + pub(crate) async fn register_process_session( &self, process_id: &ProcessId, - ) -> Result { - let state = Arc::new(SessionState::new()); - self.inner - .insert_session(process_id, Arc::clone(&state)) + ) -> Result { + let session = Arc::new(ProcessSession::new(ProcessSessionDisconnectBehavior::Fail)); + self.register_process_session_route(process_id, Arc::clone(&session)) .await?; - Ok(Session { - client: self.clone(), + Ok(ProcessSessionHandle { + control: ProcessSessionControl::Connection(self.clone()), process_id: process_id.clone(), - state, + session, }) } - pub(crate) async fn unregister_session(&self, process_id: &ProcessId) { - self.inner.remove_session(process_id).await; + async fn register_process_session_route( + &self, + process_id: &ProcessId, + session: Arc, + ) -> Result<(), ExecServerError> { + self.inner + .insert_process_session_route(process_id, session) + .await + } + + pub(crate) async fn unregister_process_session(&self, process_id: &ProcessId) { + self.inner.remove_process_session_route(process_id).await; } pub fn session_id(&self) -> Option { @@ -424,6 +682,10 @@ impl ExecServerClient { .clone() } + fn disconnected_error(&self) -> Option { + self.inner.disconnected_error() + } + pub(crate) async fn connect( connection: JsonRpcConnection, options: ExecServerClientConnectOptions, @@ -462,9 +724,9 @@ impl ExecServerClient { }); Inner { - client: rpc_client, - sessions: ArcSwap::from_pointee(HashMap::new()), - sessions_write_lock: Mutex::new(()), + rpc_client, + process_session_routes: ArcSwap::from_pointee(HashMap::new()), + process_session_routes_write_lock: Mutex::new(()), disconnected: OnceLock::new(), http_body_streams: ArcSwap::from_pointee(HashMap::new()), http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()), @@ -475,14 +737,14 @@ impl ExecServerClient { } }); - let client = Self { inner }; - client.initialize(options).await?; - Ok(client) + let connection = Self { inner }; + connection.initialize(options).await?; + Ok(connection) } async fn notify_initialized(&self) -> Result<(), ExecServerError> { self.inner - .client + .rpc_client .notify(INITIALIZED_METHOD, &serde_json::json!({})) .await .map_err(ExecServerError::Json) @@ -500,13 +762,13 @@ impl ExecServerClient { return Err(error); } - match self.inner.client.call(method, params).await { + match self.inner.rpc_client.call(method, params).await { Ok(response) => Ok(response), Err(error) => { let error = ExecServerError::from(error); if is_transport_closed_error(&error) { // A call can race with disconnect after the preflight - // check. Only the reader task drains sessions so queued + // check. Only the reader task drains routes so queued // process notifications stay ordered before disconnect. let message = disconnected_message(/*reason*/ None); let message = record_disconnected(&self.inner, message); @@ -532,8 +794,8 @@ impl From for ExecServerError { } } -impl SessionState { - fn new() -> Self { +impl ProcessSession { + fn new(disconnect_behavior: ProcessSessionDisconnectBehavior) -> Self { let (wake_tx, _wake_rx) = watch::channel(0); Self { wake_tx, @@ -543,6 +805,7 @@ impl SessionState { ), ordered_events: StdMutex::new(OrderedSessionEvents::default()), failure: Mutex::new(None), + disconnect_behavior, } } @@ -638,17 +901,17 @@ impl SessionState { } } -impl Session { +impl ProcessSessionHandle { pub(crate) fn process_id(&self) -> &ProcessId { &self.process_id } pub(crate) fn subscribe_wake(&self) -> watch::Receiver { - self.state.subscribe() + self.session.subscribe() } pub(crate) fn subscribe_events(&self) -> ExecProcessEventReceiver { - self.state.subscribe_events() + self.session.subscribe_events() } pub(crate) async fn read( @@ -657,41 +920,104 @@ impl Session { max_bytes: Option, wait_ms: Option, ) -> Result { - if let Some(response) = self.state.failed_response().await { + if let Some(response) = self.session.failed_response().await { return Ok(response); } - match self - .client - .read(ReadParams { - process_id: self.process_id.clone(), - after_seq, - max_bytes, - wait_ms, - }) - .await - { + let params = ReadParams { + process_id: self.process_id.clone(), + after_seq, + max_bytes, + wait_ms, + }; + match self.control.read(params).await { Ok(response) => Ok(response), - Err(err) if is_transport_closed_error(&err) => { + Err(err) + if is_transport_closed_error(&err) + && self.session.disconnect_behavior + == ProcessSessionDisconnectBehavior::Fail => + { let message = disconnected_message(/*reason*/ None); - self.state.set_failure(message.clone()).await; - Ok(self.state.synthesized_failure(message)) + self.session.set_failure(message.clone()).await; + Ok(self.session.synthesized_failure(message)) } Err(err) => Err(err), } } pub(crate) async fn write(&self, chunk: Vec) -> Result { - self.client.write(&self.process_id, chunk).await + self.control.write(&self.process_id, chunk).await } pub(crate) async fn terminate(&self) -> Result<(), ExecServerError> { - self.client.terminate(&self.process_id).await?; + self.control.terminate(&self.process_id).await?; Ok(()) } pub(crate) async fn unregister(&self) { - self.client.unregister_session(&self.process_id).await; + self.control + .unregister_process_session(&self.process_id) + .await; + } +} + +impl ProcessSessionControl { + async fn read(&self, params: ReadParams) -> Result { + match self { + #[cfg(test)] + Self::Connection(connection) => connection.read(params).await, + Self::RemoteClient(client) => client.read(params).await, + } + } + + async fn write( + &self, + process_id: &ProcessId, + chunk: Vec, + ) -> Result { + match self { + #[cfg(test)] + Self::Connection(connection) => connection.write(process_id, chunk).await, + Self::RemoteClient(client) => client.write(process_id, chunk).await, + } + } + + async fn terminate( + &self, + process_id: &ProcessId, + ) -> Result { + match self { + #[cfg(test)] + Self::Connection(connection) => connection.terminate(process_id).await, + Self::RemoteClient(client) => client.terminate(process_id).await, + } + } + + async fn unregister_process_session(&self, process_id: &ProcessId) { + match self { + #[cfg(test)] + Self::Connection(connection) => connection.unregister_process_session(process_id).await, + Self::RemoteClient(client) => client.unregister_process_session(process_id).await, + } + } +} + +impl TerminalReconnectError { + fn from_error(error: &ExecServerError) -> Option { + match error { + ExecServerError::Server { code, message } if *code == -32600 => Some(Self { + code: *code, + message: message.clone(), + }), + _ => None, + } + } + + fn to_exec_server_error(&self) -> ExecServerError { + ExecServerError::Server { + code: self.code, + message: self.message.clone(), + } } } @@ -710,51 +1036,57 @@ impl Inner { } } - fn get_session(&self, process_id: &ProcessId) -> Option> { - self.sessions.load().get(process_id).cloned() + fn get_process_session_route(&self, process_id: &ProcessId) -> Option> { + self.process_session_routes.load().get(process_id).cloned() } - async fn insert_session( + async fn insert_process_session_route( &self, process_id: &ProcessId, - session: Arc, + session: Arc, ) -> Result<(), ExecServerError> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; + let _routes_write_guard = self.process_session_routes_write_lock.lock().await; // Do not register a process session that can never receive executor // notifications. Without this check, remote MCP startup could create a // dead session and wait for process output that will never arrive. if let Some(error) = self.disconnected_error() { return Err(error); } - let sessions = self.sessions.load(); - if sessions.contains_key(process_id) { + let routes = self.process_session_routes.load(); + if let Some(existing_session) = routes.get(process_id) { + if Arc::ptr_eq(existing_session, &session) { + return Ok(()); + } return Err(ExecServerError::Protocol(format!( "session already registered for process {process_id}" ))); } - let mut next_sessions = sessions.as_ref().clone(); - next_sessions.insert(process_id.clone(), session); - self.sessions.store(Arc::new(next_sessions)); + let mut next_routes = routes.as_ref().clone(); + next_routes.insert(process_id.clone(), session); + self.process_session_routes.store(Arc::new(next_routes)); Ok(()) } - async fn remove_session(&self, process_id: &ProcessId) -> Option> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; - let sessions = self.sessions.load(); - let session = sessions.get(process_id).cloned(); + async fn remove_process_session_route( + &self, + process_id: &ProcessId, + ) -> Option> { + let _routes_write_guard = self.process_session_routes_write_lock.lock().await; + let routes = self.process_session_routes.load(); + let session = routes.get(process_id).cloned(); session.as_ref()?; - let mut next_sessions = sessions.as_ref().clone(); - next_sessions.remove(process_id); - self.sessions.store(Arc::new(next_sessions)); + let mut next_routes = routes.as_ref().clone(); + next_routes.remove(process_id); + self.process_session_routes.store(Arc::new(next_routes)); session } - async fn take_all_sessions(&self) -> HashMap> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; - let sessions = self.sessions.load(); - let drained_sessions = sessions.as_ref().clone(); - self.sessions.store(Arc::new(HashMap::new())); - drained_sessions + async fn take_all_process_session_routes(&self) -> HashMap> { + let _routes_write_guard = self.process_session_routes_write_lock.lock().await; + let routes = self.process_session_routes.load(); + let drained_routes = routes.as_ref().clone(); + self.process_session_routes.store(Arc::new(HashMap::new())); + drained_routes } } @@ -779,9 +1111,9 @@ fn is_transport_closed_error(error: &ExecServerError) -> bool { } fn record_disconnected(inner: &Arc, message: String) -> String { - // The first observer records the canonical disconnect reason. Session - // draining stays with the reader task so it can preserve notification - // ordering before publishing the terminal failure. + // The first observer records the canonical disconnect reason. Process + // session route draining stays with the reader task so it can preserve + // notification ordering before publishing the terminal failure. if let Some(message) = inner.set_disconnected(message.clone()) { message } else { @@ -789,20 +1121,22 @@ fn record_disconnected(inner: &Arc, message: String) -> String { } } -async fn fail_all_sessions(inner: &Arc, message: String) { - let sessions = inner.take_all_sessions().await; +async fn fail_all_process_sessions(inner: &Arc, message: String) { + let routes = inner.take_all_process_session_routes().await; - for (_, session) in sessions { - // Sessions synthesize a closed read response and emit a pushed Failed - // event. That covers both polling consumers and streaming consumers - // such as executor-backed MCP stdio. - session.set_failure(message.clone()).await; + for (_, session) in routes { + // One-shot sessions synthesize a closed read response and emit a + // pushed Failed event. Reconnecting remote sessions keep their local + // event state so a reattached client can bind them again. + if session.disconnect_behavior == ProcessSessionDisconnectBehavior::Fail { + session.set_failure(message.clone()).await; + } } } /// Fails all in-flight work that depends on the shared JSON-RPC transport. async fn fail_all_in_flight_work(inner: &Arc, message: String) { - fail_all_sessions(inner, message.clone()).await; + fail_all_process_sessions(inner, message.clone()).await; inner.fail_all_http_body_streams(message).await; } @@ -814,7 +1148,7 @@ async fn handle_server_notification( EXEC_OUTPUT_DELTA_METHOD => { let params: ExecOutputDeltaNotification = serde_json::from_value(notification.params.unwrap_or(Value::Null))?; - if let Some(session) = inner.get_session(¶ms.process_id) { + if let Some(session) = inner.get_process_session_route(¶ms.process_id) { session.note_change(params.seq); let published_closed = session.publish_ordered_event(ExecProcessEvent::Output(ProcessOutputChunk { @@ -823,28 +1157,28 @@ async fn handle_server_notification( chunk: params.chunk, })); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_process_session_route(¶ms.process_id).await; } } } EXEC_EXITED_METHOD => { let params: ExecExitedNotification = serde_json::from_value(notification.params.unwrap_or(Value::Null))?; - if let Some(session) = inner.get_session(¶ms.process_id) { + if let Some(session) = inner.get_process_session_route(¶ms.process_id) { session.note_change(params.seq); let published_closed = session.publish_ordered_event(ExecProcessEvent::Exited { seq: params.seq, exit_code: params.exit_code, }); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_process_session_route(¶ms.process_id).await; } } } EXEC_CLOSED_METHOD => { let params: ExecClosedNotification = serde_json::from_value(notification.params.unwrap_or(Value::Null))?; - if let Some(session) = inner.get_session(¶ms.process_id) { + if let Some(session) = inner.get_process_session_route(¶ms.process_id) { session.note_change(params.seq); // Closed is terminal, but it can arrive before tail output or // exited. Keep routing this process until the ordered publisher @@ -852,7 +1186,7 @@ async fn handle_server_notification( let published_closed = session.publish_ordered_event(ExecProcessEvent::Closed { seq: params.seq }); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_process_session_route(¶ms.process_id).await; } } } @@ -891,8 +1225,8 @@ mod tests { use tokio::time::sleep; use tokio::time::timeout; - use super::ExecServerClient; use super::ExecServerClientConnectOptions; + use super::ExecServerConnection; use crate::ProcessId; #[cfg(not(windows))] use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT; @@ -940,7 +1274,7 @@ mod tests { #[cfg(not(windows))] #[tokio::test] async fn connect_stdio_command_initializes_json_rpc_client() { - let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs { command: StdioExecServerCommand { program: "sh".to_string(), args: vec![ @@ -963,7 +1297,7 @@ mod tests { #[cfg(not(windows))] #[tokio::test] async fn connect_for_transport_initializes_stdio_command() { - let client = ExecServerClient::connect_for_transport( + let client = ExecServerConnection::connect_for_transport( ExecServerTransportParams::StdioCommand { command: StdioExecServerCommand { program: "sh".to_string(), @@ -986,7 +1320,7 @@ mod tests { #[cfg(windows)] #[tokio::test] async fn connect_stdio_command_initializes_json_rpc_client_on_windows() { - let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs { command: StdioExecServerCommand { program: "powershell".to_string(), args: vec![ @@ -1024,7 +1358,7 @@ mod tests { shell_quote(child_pid_file.as_path()), ); - let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs { command: StdioExecServerCommand { program: "sh".to_string(), args: vec!["-c".to_string(), stdio_script], @@ -1067,7 +1401,7 @@ mod tests { shell_quote(pid_file.as_path()), ); - let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs { + let result = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs { command: StdioExecServerCommand { program: "sh".to_string(), args: vec!["-c".to_string(), stdio_script], @@ -1161,7 +1495,7 @@ mod tests { } }); - let client = ExecServerClient::connect( + let client = ExecServerConnection::connect( JsonRpcConnection::from_stdio( client_stdout, client_stdin, @@ -1174,7 +1508,7 @@ mod tests { let process_id = ProcessId::from("reordered"); let session = client - .register_session(&process_id) + .register_process_session(&process_id) .await .expect("session should register"); let mut events = session.subscribe_events(); @@ -1303,7 +1637,7 @@ mod tests { drop(server_writer); }); - let client = ExecServerClient::connect( + let client = ExecServerConnection::connect( JsonRpcConnection::from_stdio( client_stdout, client_stdin, @@ -1316,7 +1650,7 @@ mod tests { let process_id = ProcessId::from("disconnect"); let session = client - .register_session(&process_id) + .register_process_session(&process_id) .await .expect("session should register"); let mut events = session.subscribe_events(); @@ -1344,7 +1678,9 @@ mod tests { ); assert!(response.closed); - let new_session = client.register_session(&ProcessId::from("new")).await; + let new_session = client + .register_process_session(&ProcessId::from("new")) + .await; assert!(matches!( new_session, Err(super::ExecServerError::Disconnected(_)) @@ -1390,7 +1726,7 @@ mod tests { } }); - let client = ExecServerClient::connect( + let client = ExecServerConnection::connect( JsonRpcConnection::from_stdio( client_stdout, client_stdin, @@ -1404,11 +1740,11 @@ mod tests { let noisy_process_id = ProcessId::from("noisy"); let quiet_process_id = ProcessId::from("quiet"); let _noisy_session = client - .register_session(&noisy_process_id) + .register_process_session(&noisy_process_id) .await .expect("noisy session should register"); let quiet_session = client - .register_session(&quiet_process_id) + .register_process_session(&quiet_process_id) .await .expect("quiet session should register"); let mut quiet_wake_rx = quiet_session.subscribe_wake(); diff --git a/codex-rs/exec-server/src/client/http_client.rs b/codex-rs/exec-server/src/client/http_client.rs index cfbb3a60bb..03c9c26e80 100644 --- a/codex-rs/exec-server/src/client/http_client.rs +++ b/codex-rs/exec-server/src/client/http_client.rs @@ -3,7 +3,7 @@ //! This module is the facade for the environment-owned [`crate::HttpClient`] //! capability: //! - [`ReqwestHttpClient`] executes requests directly with `reqwest` -//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport +//! - [`ExecServerConnection`] forwards requests over the JSON-RPC transport //! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed //! remote `http/request/bodyDelta` notifications through one byte-stream API //! @@ -11,7 +11,7 @@ //! - orchestrator process: holds an `Arc` and chooses local or //! remote execution //! - remote runtime: serves the `http/request` RPC and runs the concrete local -//! HTTP request there when the orchestrator uses [`ExecServerClient`] +//! HTTP request there when the orchestrator uses [`ExecServerConnection`] #[path = "reqwest_http_client.rs"] mod reqwest_http_client; diff --git a/codex-rs/exec-server/src/client/rpc_http_client.rs b/codex-rs/exec-server/src/client/rpc_http_client.rs index d2ce842ca9..db7246c15a 100644 --- a/codex-rs/exec-server/src/client/rpc_http_client.rs +++ b/codex-rs/exec-server/src/client/rpc_http_client.rs @@ -14,7 +14,7 @@ use tokio::sync::mpsc; use super::HttpResponseBodyStream; use super::response_body_stream::HttpBodyStreamRegistration; use crate::HttpClient; -use crate::client::ExecServerClient; +use crate::client::ExecServerConnection; use crate::client::ExecServerError; use crate::protocol::HTTP_REQUEST_METHOD; use crate::protocol::HttpRequestParams; @@ -23,7 +23,7 @@ use crate::protocol::HttpRequestResponse; /// Maximum queued body frames per streamed HTTP response. const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256; -impl ExecServerClient { +impl ExecServerConnection { /// Performs an HTTP request and buffers the response body. pub async fn http_request( &self, @@ -67,14 +67,14 @@ impl ExecServerClient { } } -impl HttpClient for ExecServerClient { +impl HttpClient for ExecServerConnection { /// Orchestrator-side adapter that forwards buffered HTTP requests to the /// remote runtime over the shared JSON-RPC connection. fn http_request( &self, params: HttpRequestParams, ) -> BoxFuture<'_, Result> { - async move { ExecServerClient::http_request(self, params).await }.boxed() + async move { ExecServerConnection::http_request(self, params).await }.boxed() } /// Orchestrator-side adapter that forwards streamed HTTP requests to the @@ -83,6 +83,6 @@ impl HttpClient for ExecServerClient { &self, params: HttpRequestParams, ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { - async move { ExecServerClient::http_request_stream(self, params).await }.boxed() + async move { ExecServerConnection::http_request_stream(self, params).await }.boxed() } } diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 4bdc09a80e..0d1e519645 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -9,7 +9,7 @@ use tracing::warn; use codex_utils_rustls_provider::ensure_rustls_crypto_provider; -use crate::ExecServerClient; +use crate::ExecServerConnection; use crate::ExecServerError; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; @@ -17,9 +17,9 @@ use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; use crate::relay::harness_connection_from_websocket; -const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; +pub(crate) const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; -impl ExecServerClient { +impl ExecServerConnection { pub(crate) async fn connect_for_transport( transport_params: crate::client_api::ExecServerTransportParams, ) -> Result { diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 1d7d6e75f5..d179026ee1 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -6,7 +6,7 @@ use crate::ExecServerError; use crate::ExecServerRuntimePaths; use crate::ExecutorFileSystem; use crate::HttpClient; -use crate::client::LazyRemoteExecServerClient; +use crate::client::RemoteExecServerClient; use crate::client::http_client::ReqwestHttpClient; use crate::client_api::ExecServerTransportParams; use crate::environment_provider::DefaultEnvironmentProvider; @@ -403,7 +403,7 @@ impl Environment { } => Some(exec_server_url.clone()), ExecServerTransportParams::StdioCommand { .. } => None, }; - let client = LazyRemoteExecServerClient::new(remote_transport.clone()); + let client = RemoteExecServerClient::new(remote_transport.clone()); let exec_backend: Arc = Arc::new(RemoteProcess::new(client.clone())); let filesystem: Arc = Arc::new(RemoteFileSystem::new(client.clone())); diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index bd556638ea..55c371ddd0 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -23,8 +23,9 @@ mod runtime_paths; mod sandboxed_file_system; mod server; -pub use client::ExecServerClient; +pub use client::ExecServerConnection; pub use client::ExecServerError; +pub type ExecServerClient = ExecServerConnection; pub use client::http_client::HttpResponseBodyStream; pub use client::http_client::ReqwestHttpClient; pub use client_api::ExecServerClientConnectOptions; diff --git a/codex-rs/exec-server/src/remote_file_system.rs b/codex-rs/exec-server/src/remote_file_system.rs index 4251fa33ec..2eb955d574 100644 --- a/codex-rs/exec-server/src/remote_file_system.rs +++ b/codex-rs/exec-server/src/remote_file_system.rs @@ -14,7 +14,7 @@ use crate::FileSystemResult; use crate::FileSystemSandboxContext; use crate::ReadDirectoryEntry; use crate::RemoveOptions; -use crate::client::LazyRemoteExecServerClient; +use crate::client::RemoteExecServerClient; use crate::protocol::FsCopyParams; use crate::protocol::FsCreateDirectoryParams; use crate::protocol::FsGetMetadataParams; @@ -28,11 +28,11 @@ const NOT_FOUND_ERROR_CODE: i64 = -32004; #[derive(Clone)] pub(crate) struct RemoteFileSystem { - client: LazyRemoteExecServerClient, + client: RemoteExecServerClient, } impl RemoteFileSystem { - pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self { + pub(crate) fn new(client: RemoteExecServerClient) -> Self { trace!("remote fs new"); Self { client } } @@ -46,8 +46,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult> { trace!("remote fs read_file"); - let client = self.client.get().await.map_err(map_remote_error)?; - let response = client + let connection = self.client.connection().await.map_err(map_remote_error)?; + let response = connection .fs_read_file(FsReadFileParams { path: path.clone(), sandbox: remote_sandbox_context(sandbox), @@ -69,8 +69,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult<()> { trace!("remote fs write_file"); - let client = self.client.get().await.map_err(map_remote_error)?; - client + let connection = self.client.connection().await.map_err(map_remote_error)?; + connection .fs_write_file(FsWriteFileParams { path: path.clone(), data_base64: STANDARD.encode(contents), @@ -88,8 +88,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult<()> { trace!("remote fs create_directory"); - let client = self.client.get().await.map_err(map_remote_error)?; - client + let connection = self.client.connection().await.map_err(map_remote_error)?; + connection .fs_create_directory(FsCreateDirectoryParams { path: path.clone(), recursive: Some(options.recursive), @@ -106,8 +106,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult { trace!("remote fs get_metadata"); - let client = self.client.get().await.map_err(map_remote_error)?; - let response = client + let connection = self.client.connection().await.map_err(map_remote_error)?; + let response = connection .fs_get_metadata(FsGetMetadataParams { path: path.clone(), sandbox: remote_sandbox_context(sandbox), @@ -129,8 +129,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult> { trace!("remote fs read_directory"); - let client = self.client.get().await.map_err(map_remote_error)?; - let response = client + let connection = self.client.connection().await.map_err(map_remote_error)?; + let response = connection .fs_read_directory(FsReadDirectoryParams { path: path.clone(), sandbox: remote_sandbox_context(sandbox), @@ -155,8 +155,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult<()> { trace!("remote fs remove"); - let client = self.client.get().await.map_err(map_remote_error)?; - client + let connection = self.client.connection().await.map_err(map_remote_error)?; + connection .fs_remove(FsRemoveParams { path: path.clone(), recursive: Some(options.recursive), @@ -176,8 +176,8 @@ impl ExecutorFileSystem for RemoteFileSystem { sandbox: Option<&FileSystemSandboxContext>, ) -> FileSystemResult<()> { trace!("remote fs copy"); - let client = self.client.get().await.map_err(map_remote_error)?; - client + let connection = self.client.connection().await.map_err(map_remote_error)?; + connection .fs_copy(FsCopyParams { source_path: source_path.clone(), destination_path: destination_path.clone(), diff --git a/codex-rs/exec-server/src/remote_process.rs b/codex-rs/exec-server/src/remote_process.rs index d8d06735cd..903d3b056e 100644 --- a/codex-rs/exec-server/src/remote_process.rs +++ b/codex-rs/exec-server/src/remote_process.rs @@ -9,23 +9,23 @@ use crate::ExecProcess; use crate::ExecProcessEventReceiver; use crate::ExecServerError; use crate::StartedExecProcess; -use crate::client::LazyRemoteExecServerClient; -use crate::client::Session; +use crate::client::ProcessSessionHandle; +use crate::client::RemoteExecServerClient; use crate::protocol::ExecParams; use crate::protocol::ReadResponse; use crate::protocol::WriteResponse; #[derive(Clone)] pub(crate) struct RemoteProcess { - client: LazyRemoteExecServerClient, + client: RemoteExecServerClient, } struct RemoteExecProcess { - session: Session, + session: ProcessSessionHandle, } impl RemoteProcess { - pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self { + pub(crate) fn new(client: RemoteExecServerClient) -> Self { trace!("remote process new"); Self { client } } @@ -35,9 +35,9 @@ impl RemoteProcess { impl ExecBackend for RemoteProcess { async fn start(&self, params: ExecParams) -> Result { let process_id = params.process_id.clone(); - let client = self.client.get().await?; - let session = client.register_session(&process_id).await?; - if let Err(err) = client.exec(params).await { + let session = self.client.register_process_session(&process_id).await?; + let connection = self.client.connection().await?; + if let Err(err) = connection.exec(params).await { session.unregister().await; return Err(err); } diff --git a/codex-rs/exec-server/tests/exec_process.rs b/codex-rs/exec-server/tests/exec_process.rs index e1f330fc4e..02e9cd48c5 100644 --- a/codex-rs/exec-server/tests/exec_process.rs +++ b/codex-rs/exec-server/tests/exec_process.rs @@ -541,7 +541,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe( #[tokio::test(flavor = "multi_thread", worker_threads = 2)] // Serialize tests that launch a real exec-server process through the full CLI. #[serial_test::serial(remote_exec_server)] -async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { +async fn remote_exec_process_surfaces_reconnect_failure_after_server_shutdown() -> Result<()> { let mut context = create_process_context(/*use_remote*/ true).await?; let session = context .backend @@ -562,7 +562,6 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { .await?; let process = Arc::clone(&session.process); - let mut events = process.subscribe_events(); let process_for_pending_read = Arc::clone(&process); let pending_read = tokio::spawn(async move { process_for_pending_read @@ -579,36 +578,33 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { .expect("remote context should include exec-server harness"); server.shutdown().await?; - let event = timeout(Duration::from_secs(2), events.recv()).await??; - let ExecProcessEvent::Failed(event_message) = event else { - anyhow::bail!("expected process failure event, got {event:?}"); - }; + let pending_error = timeout(Duration::from_secs(2), pending_read) + .await + .context("timed out waiting for pending read after reconnect failure")?? + .expect_err("pending read should fail after reconnect fails"); assert!( - event_message.starts_with("exec-server transport disconnected"), - "unexpected failure event: {event_message}" + pending_error + .to_string() + .starts_with("failed to connect to exec-server websocket"), + "unexpected pending read error: {pending_error}" ); - let pending_response = timeout(Duration::from_secs(2), pending_read).await???; - let pending_message = pending_response - .failure - .expect("pending read should surface disconnect as a failure"); + let read_error = timeout( + Duration::from_secs(2), + process.read( + /*after_seq*/ None, + /*max_bytes*/ None, + /*wait_ms*/ Some(0), + ), + ) + .await + .context("timed out waiting for read after reconnect failure")? + .expect_err("read after reconnect failure should fail"); assert!( - pending_message.starts_with("exec-server transport disconnected"), - "unexpected pending failure message: {pending_message}" - ); - - let mut wake_rx = process.subscribe_wake(); - let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?; - let message = response - .failure - .expect("disconnect should surface as a failure"); - assert!( - message.starts_with("exec-server transport disconnected"), - "unexpected failure message: {message}" - ); - assert!( - response.closed, - "disconnect should close the process session" + read_error + .to_string() + .starts_with("failed to connect to exec-server websocket"), + "unexpected read error: {read_error}" ); let write_result = timeout( @@ -621,7 +617,7 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { assert!( write_error .to_string() - .starts_with("exec-server transport disconnected"), + .starts_with("failed to connect to exec-server websocket"), "unexpected write error: {write_error}" );