Simplify exec-server disconnect plumbing

Keep transport shutdown responsible for stdio child cleanup, and remove the separate disconnect watch channel from the JSON-RPC connection/runtime. The RPC client now keeps a single closed flag for rejecting calls after the ordered reader exits.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-05-05 15:37:07 -07:00
parent fb93315b4b
commit bc34e376f7
3 changed files with 19 additions and 70 deletions

View File

@@ -5,7 +5,6 @@ use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::process::Child;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
@@ -75,7 +74,6 @@ impl StdioTransport {
struct JsonRpcConnectionRuntime {
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
disconnected_rx: watch::Receiver<bool>,
task_handles: Vec<tokio::task::JoinHandle<()>>,
}
@@ -98,11 +96,9 @@ impl JsonRpcConnection {
{
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
let mut lines = BufReader::new(reader).lines();
loop {
@@ -133,18 +129,12 @@ impl JsonRpcConnection {
}
}
Ok(None) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
break;
}
Err(err) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read JSON-RPC message from {reader_label}: {err}"
)),
@@ -162,7 +152,6 @@ impl JsonRpcConnection {
if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write JSON-RPC message to {connection_label}: {err}"
)),
@@ -177,7 +166,6 @@ impl JsonRpcConnection {
runtime: Some(JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
}),
transport: JsonRpcTransport::Plain,
@@ -190,12 +178,10 @@ impl JsonRpcConnection {
{
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let (mut websocket_writer, mut websocket_reader) = stream.split();
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
loop {
match websocket_reader.next().await {
@@ -244,12 +230,7 @@ impl JsonRpcConnection {
}
}
Some(Ok(Message::Close(_))) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
break;
}
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {}
@@ -257,7 +238,6 @@ impl JsonRpcConnection {
Some(Err(err)) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
)),
@@ -266,12 +246,7 @@ impl JsonRpcConnection {
break;
}
None => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
break;
}
}
@@ -286,7 +261,6 @@ impl JsonRpcConnection {
{
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
)),
@@ -298,7 +272,6 @@ impl JsonRpcConnection {
Err(err) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to serialize JSON-RPC message for {connection_label}: {err}"
)),
@@ -314,7 +287,6 @@ impl JsonRpcConnection {
runtime: Some(JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
}),
transport: JsonRpcTransport::Plain,
@@ -326,16 +298,14 @@ impl JsonRpcConnection {
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
let JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles,
} = self.take_runtime("JSON-RPC client runtime already taken");
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
(outgoing_tx, incoming_rx, task_handles)
}
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
@@ -348,16 +318,14 @@ impl JsonRpcConnection {
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
let JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles,
} = self.take_runtime("JSON-RPC connection parts already taken");
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
(outgoing_tx, incoming_rx, task_handles)
}
fn take_runtime(&mut self, message: &'static str) -> JsonRpcConnectionRuntime {
@@ -370,10 +338,8 @@ impl JsonRpcConnection {
async fn send_disconnected(
incoming_tx: &mpsc::Sender<JsonRpcConnectionEvent>,
disconnected_tx: &watch::Sender<bool>,
reason: Option<String>,
) {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected { reason })
.await;

View File

@@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
@@ -18,7 +19,6 @@ use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
@@ -223,10 +223,9 @@ where
pub(crate) struct RpcClient {
write_tx: mpsc::Sender<JSONRPCMessage>,
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
// This flips when either the underlying transport closes or the RPC reader
// exits, so new calls fail quickly after the connection is no longer usable.
closed_rx: watch::Receiver<bool>,
transport_disconnected_rx: watch::Receiver<bool>,
// This flips before the ordered RPC reader drains pending requests, so new
// calls fail instead of registering work that can never complete.
closed: Arc<AtomicBool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
reader_task: JoinHandle<()>,
@@ -236,13 +235,13 @@ impl RpcClient {
pub(crate) fn new(
connection: &mut JsonRpcConnection,
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let (write_tx, mut incoming_rx, transport_disconnected_rx, transport_tasks) =
connection.take_client_runtime();
let (write_tx, mut incoming_rx, transport_tasks) = connection.take_client_runtime();
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
let (closed_tx, closed_rx) = watch::channel(*transport_disconnected_rx.borrow());
let closed = Arc::new(AtomicBool::new(false));
let pending_for_reader = Arc::clone(&pending);
let closed_for_reader = Arc::clone(&closed);
let reader_task = tokio::spawn(async move {
let reason = loop {
let Some(event) = incoming_rx.recv().await else {
@@ -264,7 +263,7 @@ impl RpcClient {
}
};
let _ = closed_tx.send(true);
closed_for_reader.store(true, Ordering::SeqCst);
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
drain_pending(&pending_for_reader).await;
});
@@ -273,8 +272,7 @@ impl RpcClient {
Self {
write_tx,
pending,
closed_rx,
transport_disconnected_rx,
closed,
next_request_id: AtomicI64::new(1),
transport_tasks,
reader_task,
@@ -289,7 +287,7 @@ impl RpcClient {
params: &P,
) -> Result<(), serde_json::Error> {
let params = serde_json::to_value(params)?;
if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() {
if self.closed.load(Ordering::SeqCst) {
return Err(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"JSON-RPC transport closed",
@@ -321,7 +319,7 @@ impl RpcClient {
// Registering the pending request and checking disconnect must be
// atomic with the reader's drain_pending path. Otherwise a call
// can sneak in after the drain and wait forever.
if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() {
if self.closed.load(Ordering::SeqCst) {
return Err(RpcCallError::Closed);
}
pending.insert(request_id.clone(), response_tx);

View File

@@ -47,8 +47,7 @@ async fn run_connection(
runtime_paths: ExecServerRuntimePaths,
) {
let router = Arc::new(build_router());
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
connection.into_parts();
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts();
let (outgoing_tx, mut outgoing_rx) =
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
@@ -96,13 +95,7 @@ async fn run_connection(
JsonRpcConnectionEvent::Message(message) => match message {
codex_app_server_protocol::JSONRPCMessage::Request(request) => {
if let Some(route) = router.request_route(request.method.as_str()) {
let message = tokio::select! {
message = route(Arc::clone(&handler), request) => message,
_ = disconnected_rx.changed() => {
debug!("exec-server transport disconnected while handling request");
break;
}
};
let message = route(Arc::clone(&handler), request).await;
if let Some(message) = message
&& outgoing_tx.send(message).await.is_err()
{
@@ -131,15 +124,7 @@ async fn run_connection(
);
break;
};
let result = tokio::select! {
result = route(Arc::clone(&handler), notification) => result,
_ = disconnected_rx.changed() => {
debug!(
"exec-server transport disconnected while handling notification"
);
break;
}
};
let result = route(Arc::clone(&handler), notification).await;
if let Err(err) = result {
warn!("closing exec-server connection after protocol error: {err}");
break;