diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index ba17612407..e159dade28 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -5,6 +5,7 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::process::Child; use tokio::sync::mpsc; +use tokio::sync::watch; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; use tracing::debug; @@ -74,6 +75,7 @@ impl StdioTransport { struct JsonRpcConnectionRuntime { outgoing_tx: mpsc::Sender, incoming_rx: mpsc::Receiver, + disconnected_rx: watch::Receiver, task_handles: Vec>, } @@ -100,9 +102,11 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); + let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { let mut lines = BufReader::new(reader).lines(); loop { @@ -133,12 +137,18 @@ impl JsonRpcConnection { } } Ok(None) => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } Err(err) => { send_disconnected( &incoming_tx_for_reader, + &disconnected_tx_for_reader, Some(format!( "failed to read JSON-RPC message from {reader_label}: {err}" )), @@ -156,6 +166,7 @@ impl JsonRpcConnection { if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to write JSON-RPC message to {connection_label}: {err}" )), @@ -170,6 +181,7 @@ impl JsonRpcConnection { runtime: Some(JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, + disconnected_rx, task_handles: vec![reader_task, writer_task], }), transport: JsonRpcTransport::Plain, @@ -182,10 +194,12 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); let (mut websocket_writer, mut websocket_reader) = stream.split(); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); + let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { loop { match websocket_reader.next().await { @@ -234,7 +248,12 @@ impl JsonRpcConnection { } } Some(Ok(Message::Close(_))) => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} @@ -242,6 +261,7 @@ impl JsonRpcConnection { Some(Err(err)) => { send_disconnected( &incoming_tx_for_reader, + &disconnected_tx_for_reader, Some(format!( "failed to read websocket JSON-RPC message from {reader_label}: {err}" )), @@ -250,7 +270,12 @@ impl JsonRpcConnection { break; } None => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } } @@ -265,6 +290,7 @@ impl JsonRpcConnection { { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to write websocket JSON-RPC message to {connection_label}: {err}" )), @@ -276,6 +302,7 @@ impl JsonRpcConnection { Err(err) => { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to serialize JSON-RPC message for {connection_label}: {err}" )), @@ -291,6 +318,7 @@ impl JsonRpcConnection { runtime: Some(JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, + disconnected_rx, task_handles: vec![reader_task, writer_task], }), transport: JsonRpcTransport::Plain, @@ -307,11 +335,29 @@ impl JsonRpcConnection { let JsonRpcConnectionRuntime { outgoing_tx, incoming_rx, + disconnected_rx: _, task_handles, } = self.take_runtime_or_panic("JSON-RPC connection runtime already taken"); (outgoing_tx, incoming_rx, task_handles) } + pub(crate) fn into_parts( + mut self, + ) -> ( + mpsc::Sender, + mpsc::Receiver, + watch::Receiver, + Vec>, + ) { + let JsonRpcConnectionRuntime { + outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles, + } = self.take_runtime_or_panic("JSON-RPC connection runtime already taken"); + (outgoing_tx, incoming_rx, disconnected_rx, task_handles) + } + pub(crate) fn with_child_process(mut self, child_process: Child) -> Self { self.transport = JsonRpcTransport::from_child_process(child_process); self @@ -327,8 +373,10 @@ impl JsonRpcConnection { async fn send_disconnected( incoming_tx: &mpsc::Sender, + disconnected_tx: &watch::Sender, reason: Option, ) { + let _ = disconnected_tx.send(true); let _ = incoming_tx .send(JsonRpcConnectionEvent::Disconnected { reason }) .await; diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 88a282e0d5..dc1a9b9ffe 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -42,12 +42,13 @@ impl ConnectionProcessor { } async fn run_connection( - mut connection: JsonRpcConnection, + connection: JsonRpcConnection, session_registry: Arc, runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); - let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.take_runtime(); + let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) = + connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); let notifications = RpcNotificationSender::new(outgoing_tx.clone()); @@ -95,7 +96,13 @@ async fn run_connection( JsonRpcConnectionEvent::Message(message) => match message { codex_app_server_protocol::JSONRPCMessage::Request(request) => { if let Some(route) = router.request_route(request.method.as_str()) { - let message = route(Arc::clone(&handler), request).await; + let message = tokio::select! { + message = route(Arc::clone(&handler), request) => message, + _ = disconnected_rx.changed() => { + debug!("exec-server transport disconnected while handling request"); + break; + } + }; if let Some(message) = message && outgoing_tx.send(message).await.is_err() { @@ -124,7 +131,15 @@ async fn run_connection( ); break; }; - let result = route(Arc::clone(&handler), notification).await; + let result = tokio::select! { + result = route(Arc::clone(&handler), notification) => result, + _ = disconnected_rx.changed() => { + debug!( + "exec-server transport disconnected while handling notification" + ); + break; + } + }; if let Err(err) = result { warn!("closing exec-server connection after protocol error: {err}"); break; @@ -163,3 +178,241 @@ async fn run_connection( } let _ = outbound_task.await; } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::time::Duration; + + use codex_app_server_protocol::JSONRPCMessage; + use codex_app_server_protocol::JSONRPCNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RequestId; + use serde::Serialize; + use serde::de::DeserializeOwned; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::io::DuplexStream; + use tokio::io::Lines; + use tokio::io::duplex; + use tokio::task::JoinHandle; + use tokio::time::timeout; + + use super::run_connection; + use crate::ExecServerRuntimePaths; + use crate::ProcessId; + use crate::connection::JsonRpcConnection; + use crate::protocol::EXEC_METHOD; + use crate::protocol::EXEC_READ_METHOD; + use crate::protocol::EXEC_TERMINATE_METHOD; + use crate::protocol::ExecParams; + use crate::protocol::ExecResponse; + use crate::protocol::INITIALIZE_METHOD; + use crate::protocol::INITIALIZED_METHOD; + use crate::protocol::InitializeParams; + use crate::protocol::InitializeResponse; + use crate::protocol::ReadParams; + use crate::protocol::TerminateParams; + use crate::protocol::TerminateResponse; + use crate::server::session_registry::SessionRegistry; + + #[tokio::test] + async fn transport_disconnect_detaches_session_during_in_flight_read() { + let registry = SessionRegistry::new(); + let (mut first_writer, mut first_lines, first_task) = + spawn_test_connection(Arc::clone(®istry), "first"); + + send_request( + &mut first_writer, + /*id*/ 1, + INITIALIZE_METHOD, + &InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }, + ) + .await; + let initialize_response: InitializeResponse = + read_response(&mut first_lines, /*expected_id*/ 1).await; + send_notification(&mut first_writer, INITIALIZED_METHOD, &()).await; + + let process_id = ProcessId::from("proc-long-poll"); + send_request( + &mut first_writer, + /*id*/ 2, + EXEC_METHOD, + &exec_params(process_id.clone()), + ) + .await; + let _: ExecResponse = read_response(&mut first_lines, /*expected_id*/ 2).await; + + send_request( + &mut first_writer, + /*id*/ 3, + EXEC_READ_METHOD, + &ReadParams { + process_id: process_id.clone(), + after_seq: None, + max_bytes: None, + wait_ms: Some(5_000), + }, + ) + .await; + drop(first_writer); + tokio::time::sleep(Duration::from_millis(25)).await; + + let (mut second_writer, mut second_lines, second_task) = + spawn_test_connection(Arc::clone(®istry), "second"); + send_request( + &mut second_writer, + /*id*/ 1, + INITIALIZE_METHOD, + &InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: Some(initialize_response.session_id.clone()), + }, + ) + .await; + let second_initialize_response = timeout( + Duration::from_secs(1), + read_response::(&mut second_lines, /*expected_id*/ 1), + ) + .await + .expect("resume initialize should not wait for the old read to finish"); + assert_eq!( + second_initialize_response.session_id, + initialize_response.session_id + ); + timeout(Duration::from_secs(1), first_task) + .await + .expect("first processor should exit") + .expect("first processor should join"); + send_notification(&mut second_writer, INITIALIZED_METHOD, &()).await; + + send_request( + &mut second_writer, + /*id*/ 2, + EXEC_TERMINATE_METHOD, + &TerminateParams { process_id }, + ) + .await; + let _: TerminateResponse = read_response(&mut second_lines, /*expected_id*/ 2).await; + + drop(second_writer); + drop(second_lines); + timeout(Duration::from_secs(1), second_task) + .await + .expect("second processor should exit") + .expect("second processor should join"); + } + + fn spawn_test_connection( + registry: Arc, + label: &str, + ) -> (DuplexStream, Lines>, JoinHandle<()>) { + let (client_writer, server_reader) = duplex(1 << 20); + let (server_writer, client_reader) = duplex(1 << 20); + let connection = + JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string()); + let task = tokio::spawn(run_connection(connection, registry, test_runtime_paths())); + (client_writer, BufReader::new(client_reader).lines(), task) + } + + fn test_runtime_paths() -> ExecServerRuntimePaths { + ExecServerRuntimePaths::new( + std::env::current_exe().expect("current exe"), + /*codex_linux_sandbox_exe*/ None, + ) + .expect("runtime paths") + } + + async fn send_request( + writer: &mut DuplexStream, + id: i64, + method: &str, + params: &P, + ) { + write_message( + writer, + &JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(id), + method: method.to_string(), + params: Some(serde_json::to_value(params).expect("serialize params")), + trace: None, + }), + ) + .await; + } + + async fn send_notification(writer: &mut DuplexStream, method: &str, params: &P) { + write_message( + writer, + &JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(serde_json::to_value(params).expect("serialize params")), + }), + ) + .await; + } + + async fn write_message(writer: &mut DuplexStream, message: &JSONRPCMessage) { + let encoded = serde_json::to_vec(message).expect("serialize JSON-RPC message"); + writer.write_all(&encoded).await.expect("write request"); + writer.write_all(b"\n").await.expect("write newline"); + } + + async fn read_response( + lines: &mut Lines>, + expected_id: i64, + ) -> T { + let line = lines + .next_line() + .await + .expect("read response") + .expect("response line"); + match serde_json::from_str::(&line).expect("decode JSON-RPC response") { + JSONRPCMessage::Response(JSONRPCResponse { id, result }) => { + assert_eq!(id, RequestId::Integer(expected_id)); + serde_json::from_value(result).expect("decode response result") + } + JSONRPCMessage::Error(error) => panic!("unexpected JSON-RPC error: {error:?}"), + other => panic!("expected JSON-RPC response, got {other:?}"), + } + } + + fn exec_params(process_id: ProcessId) -> ExecParams { + let mut env = HashMap::new(); + if let Some(path) = std::env::var_os("PATH") { + env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); + } + ExecParams { + process_id, + argv: sleep_then_print_argv(), + cwd: std::env::current_dir().expect("cwd"), + env_policy: None, + env, + tty: false, + pipe_stdin: false, + arg0: None, + } + } + + fn sleep_then_print_argv() -> Vec { + if cfg!(windows) { + vec![ + std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()), + "/C".to_string(), + "ping -n 3 127.0.0.1 >NUL && echo late".to_string(), + ] + } else { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 1; printf late".to_string(), + ] + } + } +}