diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 8a1178c93d..346b7d841b 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -154,8 +154,7 @@ pub(crate) struct Session { struct Inner { // Keep the underlying transport connection alive and drop it before the RPC // client starts tearing down its channel/task handles. - #[allow(dead_code)] - connection: JsonRpcConnection, + connection: Option, client: RpcClient, // The remote transport delivers one shared notification stream for every // process on the connection. Keep a local process_id -> session registry so @@ -185,6 +184,7 @@ struct Inner { impl Drop for Inner { fn drop(&mut self) { self.reader_task.abort(); + drop(self.connection.take()); } } @@ -466,7 +466,7 @@ impl ExecServerClient { }); Inner { - connection, + connection: Some(connection), client: rpc_client, sessions: ArcSwap::from_pointee(HashMap::new()), sessions_write_lock: Mutex::new(()), @@ -890,6 +890,7 @@ mod tests { use tokio::io::BufReader; use tokio::io::duplex; use tokio::sync::mpsc; + use tokio::sync::oneshot; use tokio::time::Duration; #[cfg(unix)] use tokio::time::sleep; @@ -1235,6 +1236,92 @@ mod tests { server.await.expect("server task should finish"); } + #[tokio::test] + async fn transport_disconnect_fails_sessions_and_rejects_new_sessions() { + let (client_stdin, server_reader) = duplex(1 << 20); + let (mut server_writer, client_stdout) = duplex(1 << 20); + let (disconnect_tx, disconnect_rx) = oneshot::channel(); + let server = tokio::spawn(async move { + let mut lines = BufReader::new(server_reader).lines(); + let initialize = read_jsonrpc_line(&mut lines).await; + let request = match initialize { + JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request, + other => panic!("expected initialize request, got {other:?}"), + }; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Response(JSONRPCResponse { + id: request.id, + result: serde_json::to_value(InitializeResponse { + session_id: "session-1".to_string(), + }) + .expect("initialize response should serialize"), + }), + ) + .await; + + let initialized = read_jsonrpc_line(&mut lines).await; + match initialized { + JSONRPCMessage::Notification(notification) + if notification.method == INITIALIZED_METHOD => {} + other => panic!("expected initialized notification, got {other:?}"), + } + + let _ = disconnect_rx.await; + drop(server_writer); + }); + + let client = ExecServerClient::connect( + JsonRpcConnection::from_stdio( + client_stdout, + client_stdin, + "test-exec-server-client".to_string(), + ), + ExecServerClientConnectOptions::default(), + ) + .await + .expect("client should connect"); + + let process_id = ProcessId::from("disconnect"); + let session = client + .register_session(&process_id) + .await + .expect("session should register"); + let mut events = session.subscribe_events(); + + disconnect_tx.send(()).expect("disconnect should signal"); + + let event = timeout(Duration::from_secs(1), events.recv()) + .await + .expect("session failure should not time out") + .expect("session event stream should stay open"); + let ExecProcessEvent::Failed(message) = event else { + panic!("expected session failure after disconnect, got {event:?}"); + }; + assert_eq!(message, "exec-server transport disconnected"); + + let response = session + .read( + /*after_seq*/ None, /*max_bytes*/ None, /*wait_ms*/ None, + ) + .await + .expect("disconnected session read should synthesize a response"); + assert_eq!( + response.failure.as_deref(), + Some("exec-server transport disconnected") + ); + assert!(response.closed); + + let new_session = client.register_session(&ProcessId::from("new")).await; + assert!(matches!( + new_session, + Err(super::ExecServerError::Disconnected(_)) + )); + + drop(client); + server.await.expect("server task should finish"); + } + #[tokio::test] async fn wake_notifications_do_not_block_other_sessions() { let (client_stdin, server_reader) = duplex(1 << 20); diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 20520f002e..9320efac30 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -49,7 +49,6 @@ pub(crate) struct StdioExecServerCommand { #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum ExecServerTransportParams { WebSocketUrl(String), - StdioCommand(StdioExecServerCommand), } /// Sends HTTP requests through a runtime-selected transport. diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 560630e3db..6f4492853c 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -24,27 +24,16 @@ impl ExecServerClient { pub(crate) async fn connect_for_transport( transport_params: crate::client_api::ExecServerTransportParams, ) -> Result { - match transport_params { - crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) => { - Self::connect_websocket(RemoteExecServerConnectArgs { - websocket_url, - client_name: ENVIRONMENT_CLIENT_NAME.to_string(), - connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT, - initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, - resume_session_id: None, - }) - .await - } - crate::client_api::ExecServerTransportParams::StdioCommand(command) => { - Self::connect_stdio_command(StdioExecServerConnectArgs { - command, - client_name: ENVIRONMENT_CLIENT_NAME.to_string(), - initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, - resume_session_id: None, - }) - .await - } - } + let crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) = + transport_params; + Self::connect_websocket(RemoteExecServerConnectArgs { + websocket_url, + client_name: ENVIRONMENT_CLIENT_NAME.to_string(), + connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT, + initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT, + resume_session_id: None, + }) + .await } pub async fn connect_websocket( diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 05f8c94f2f..367d83f15d 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -334,10 +334,7 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles, - } = self - .runtime - .take() - .expect("JSON-RPC client runtime already taken"); + } = self.take_runtime("JSON-RPC client runtime already taken"); (outgoing_tx, incoming_rx, disconnected_rx, task_handles) } @@ -359,12 +356,16 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles, - } = self - .runtime - .take() - .expect("JSON-RPC connection parts already taken"); + } = self.take_runtime("JSON-RPC connection parts already taken"); (outgoing_tx, incoming_rx, disconnected_rx, task_handles) } + + fn take_runtime(&mut self, message: &'static str) -> JsonRpcConnectionRuntime { + match self.runtime.take() { + Some(runtime) => runtime, + None => panic!("{message}"), + } + } } async fn send_disconnected( diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 8eb9445077..82948b920c 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -682,4 +682,53 @@ mod tests { assert!(matches!(result, Err(super::RpcCallError::Closed))); assert_eq!(client.pending_request_count().await, 0); } + + #[tokio::test] + async fn rpc_client_drains_pending_call_on_transport_eof() { + let (client_stdin, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_stdout) = tokio::io::duplex(4096); + let mut connection = + JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string()); + let (client, mut events_rx) = RpcClient::new(&mut connection); + + let server = tokio::spawn(async move { + let mut lines = BufReader::new(server_reader).lines(); + let request = read_jsonrpc_line(&mut lines).await; + match request { + JSONRPCMessage::Request(request) if request.method == "will-close" => {} + other => panic!("expected will-close request, got {other:?}"), + } + drop(server_writer); + }); + + let result = timeout( + Duration::from_secs(1), + client.call::<_, serde_json::Value>("will-close", &serde_json::json!({})), + ) + .await + .expect("timed out waiting for closed call"); + assert!(matches!(result, Err(super::RpcCallError::Closed))); + + let event = timeout(Duration::from_secs(1), events_rx.recv()) + .await + .expect("timed out waiting for disconnect event"); + assert!(matches!( + event, + Some(RpcClientEvent::Disconnected { reason: None }) + )); + assert_eq!(client.pending_request_count().await, 0); + + let result = timeout( + Duration::from_secs(1), + client.call::<_, serde_json::Value>("after-close", &serde_json::json!({})), + ) + .await + .expect("timed out waiting for fast closed call"); + assert!(matches!(result, Err(super::RpcCallError::Closed))); + + let notify = client.notify("after-close", &serde_json::json!({})).await; + assert!(notify.is_err()); + + server.await.expect("server task should finish"); + } } diff --git a/codex-rs/tui/src/resume_picker.rs b/codex-rs/tui/src/resume_picker.rs index e06ccfd118..06ad0a61a7 100644 --- a/codex-rs/tui/src/resume_picker.rs +++ b/codex-rs/tui/src/resume_picker.rs @@ -5753,7 +5753,6 @@ session_picker_view = "dense" text: String::from("1. Do the thing"), }, ], - items_view: codex_app_server_protocol::TurnItemsView::Full, status: codex_app_server_protocol::TurnStatus::Completed, error: None, started_at: None, @@ -5805,7 +5804,6 @@ session_picker_view = "dense" summary: Vec::new(), content: vec![String::from("private raw chain of thought")], }], - items_view: codex_app_server_protocol::TurnItemsView::Full, status: codex_app_server_protocol::TurnStatus::Completed, error: None, started_at: None, @@ -5861,7 +5859,6 @@ session_picker_view = "dense" summary: vec![String::from("public summary")], content: vec![String::from("raw reasoning content")], }], - items_view: codex_app_server_protocol::TurnItemsView::Full, status: codex_app_server_protocol::TurnStatus::Completed, error: None, started_at: None,