mirror of
https://github.com/openai/codex.git
synced 2026-05-16 01:02:48 +00:00
Simplify exec-server disconnect plumbing
Keep transport shutdown responsible for stdio child cleanup, and remove the separate disconnect watch channel from the JSON-RPC connection/runtime. The RPC client now keeps a single closed flag for rejecting calls after the ordered reader exits. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -5,7 +5,6 @@ use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
@@ -75,7 +74,6 @@ impl StdioTransport {
|
||||
struct JsonRpcConnectionRuntime {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
@@ -98,11 +96,9 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
loop {
|
||||
@@ -133,18 +129,12 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -162,7 +152,6 @@ impl JsonRpcConnection {
|
||||
if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -177,7 +166,6 @@ impl JsonRpcConnection {
|
||||
runtime: Some(JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}),
|
||||
transport: JsonRpcTransport::Plain,
|
||||
@@ -190,12 +178,10 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
@@ -244,12 +230,7 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {}
|
||||
@@ -257,7 +238,6 @@ impl JsonRpcConnection {
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -266,12 +246,7 @@ impl JsonRpcConnection {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -286,7 +261,6 @@ impl JsonRpcConnection {
|
||||
{
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -298,7 +272,6 @@ impl JsonRpcConnection {
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to serialize JSON-RPC message for {connection_label}: {err}"
|
||||
)),
|
||||
@@ -314,7 +287,6 @@ impl JsonRpcConnection {
|
||||
runtime: Some(JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}),
|
||||
transport: JsonRpcTransport::Plain,
|
||||
@@ -326,16 +298,14 @@ impl JsonRpcConnection {
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
let JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles,
|
||||
} = self.take_runtime("JSON-RPC client runtime already taken");
|
||||
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
|
||||
(outgoing_tx, incoming_rx, task_handles)
|
||||
}
|
||||
|
||||
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
|
||||
@@ -348,16 +318,14 @@ impl JsonRpcConnection {
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
let JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles,
|
||||
} = self.take_runtime("JSON-RPC connection parts already taken");
|
||||
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
|
||||
(outgoing_tx, incoming_rx, task_handles)
|
||||
}
|
||||
|
||||
fn take_runtime(&mut self, message: &'static str) -> JsonRpcConnectionRuntime {
|
||||
@@ -370,10 +338,8 @@ impl JsonRpcConnection {
|
||||
|
||||
async fn send_disconnected(
|
||||
incoming_tx: &mpsc::Sender<JsonRpcConnectionEvent>,
|
||||
disconnected_tx: &watch::Sender<bool>,
|
||||
reason: Option<String>,
|
||||
) {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason })
|
||||
.await;
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -18,7 +19,6 @@ use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
@@ -223,10 +223,9 @@ where
|
||||
pub(crate) struct RpcClient {
|
||||
write_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
|
||||
// This flips when either the underlying transport closes or the RPC reader
|
||||
// exits, so new calls fail quickly after the connection is no longer usable.
|
||||
closed_rx: watch::Receiver<bool>,
|
||||
transport_disconnected_rx: watch::Receiver<bool>,
|
||||
// This flips before the ordered RPC reader drains pending requests, so new
|
||||
// calls fail instead of registering work that can never complete.
|
||||
closed: Arc<AtomicBool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
reader_task: JoinHandle<()>,
|
||||
@@ -236,13 +235,13 @@ impl RpcClient {
|
||||
pub(crate) fn new(
|
||||
connection: &mut JsonRpcConnection,
|
||||
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, transport_disconnected_rx, transport_tasks) =
|
||||
connection.take_client_runtime();
|
||||
let (write_tx, mut incoming_rx, transport_tasks) = connection.take_client_runtime();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
let (closed_tx, closed_rx) = watch::channel(*transport_disconnected_rx.borrow());
|
||||
let closed = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let pending_for_reader = Arc::clone(&pending);
|
||||
let closed_for_reader = Arc::clone(&closed);
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let reason = loop {
|
||||
let Some(event) = incoming_rx.recv().await else {
|
||||
@@ -264,7 +263,7 @@ impl RpcClient {
|
||||
}
|
||||
};
|
||||
|
||||
let _ = closed_tx.send(true);
|
||||
closed_for_reader.store(true, Ordering::SeqCst);
|
||||
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
|
||||
drain_pending(&pending_for_reader).await;
|
||||
});
|
||||
@@ -273,8 +272,7 @@ impl RpcClient {
|
||||
Self {
|
||||
write_tx,
|
||||
pending,
|
||||
closed_rx,
|
||||
transport_disconnected_rx,
|
||||
closed,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
reader_task,
|
||||
@@ -289,7 +287,7 @@ impl RpcClient {
|
||||
params: &P,
|
||||
) -> Result<(), serde_json::Error> {
|
||||
let params = serde_json::to_value(params)?;
|
||||
if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() {
|
||||
if self.closed.load(Ordering::SeqCst) {
|
||||
return Err(serde_json::Error::io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"JSON-RPC transport closed",
|
||||
@@ -321,7 +319,7 @@ impl RpcClient {
|
||||
// Registering the pending request and checking disconnect must be
|
||||
// atomic with the reader's drain_pending path. Otherwise a call
|
||||
// can sneak in after the drain and wait forever.
|
||||
if *self.closed_rx.borrow() || *self.transport_disconnected_rx.borrow() {
|
||||
if self.closed.load(Ordering::SeqCst) {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
pending.insert(request_id.clone(), response_tx);
|
||||
|
||||
@@ -47,8 +47,7 @@ async fn run_connection(
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
connection.into_parts();
|
||||
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
@@ -96,13 +95,7 @@ async fn run_connection(
|
||||
JsonRpcConnectionEvent::Message(message) => match message {
|
||||
codex_app_server_protocol::JSONRPCMessage::Request(request) => {
|
||||
if let Some(route) = router.request_route(request.method.as_str()) {
|
||||
let message = tokio::select! {
|
||||
message = route(Arc::clone(&handler), request) => message,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!("exec-server transport disconnected while handling request");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let message = route(Arc::clone(&handler), request).await;
|
||||
if let Some(message) = message
|
||||
&& outgoing_tx.send(message).await.is_err()
|
||||
{
|
||||
@@ -131,15 +124,7 @@ async fn run_connection(
|
||||
);
|
||||
break;
|
||||
};
|
||||
let result = tokio::select! {
|
||||
result = route(Arc::clone(&handler), notification) => result,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!(
|
||||
"exec-server transport disconnected while handling notification"
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let result = route(Arc::clone(&handler), notification).await;
|
||||
if let Err(err) = result {
|
||||
warn!("closing exec-server connection after protocol error: {err}");
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user