mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
Narrow stdio client lifetime handling
Keep the retained transport ownership needed for stdio child cleanup, but drop the broader AtomicBool closed-state behavior and its targeted tests from this PR. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -2,7 +2,6 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -19,6 +18,7 @@ use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
@@ -221,9 +221,10 @@ where
|
||||
pub(crate) struct RpcClient {
|
||||
write_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
|
||||
// This flips before the ordered RPC reader drains pending requests, so new
|
||||
// calls fail instead of registering work that can never complete.
|
||||
closed: Arc<AtomicBool>,
|
||||
// Shared transport state from `JsonRpcConnection`. Calls use this to fail
|
||||
// immediately when the socket closes, even if no JSON-RPC error response
|
||||
// can be delivered for their request id.
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
_transport: JsonRpcTransport,
|
||||
@@ -235,39 +236,40 @@ impl RpcClient {
|
||||
let JsonRpcConnection {
|
||||
outgoing_tx: write_tx,
|
||||
incoming_rx: mut incoming_rx,
|
||||
disconnected_rx: _,
|
||||
disconnected_rx,
|
||||
task_handles: transport_tasks,
|
||||
transport,
|
||||
} = connection;
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
let closed = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let pending_for_reader = Arc::clone(&pending);
|
||||
let closed_for_reader = Arc::clone(&closed);
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let reason = loop {
|
||||
let Some(event) = incoming_rx.recv().await else {
|
||||
break None;
|
||||
};
|
||||
|
||||
while let Some(event) = incoming_rx.recv().await {
|
||||
match event {
|
||||
JsonRpcConnectionEvent::Message(message) => {
|
||||
if let Err(err) =
|
||||
handle_server_message(&pending_for_reader, &event_tx, message).await
|
||||
{
|
||||
break Some(err);
|
||||
let _ = err;
|
||||
break;
|
||||
}
|
||||
}
|
||||
JsonRpcConnectionEvent::MalformedMessage { reason } => {
|
||||
break Some(reason);
|
||||
let _ = reason;
|
||||
break;
|
||||
}
|
||||
JsonRpcConnectionEvent::Disconnected { reason } => {
|
||||
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
|
||||
drain_pending(&pending_for_reader).await;
|
||||
return;
|
||||
}
|
||||
JsonRpcConnectionEvent::Disconnected { reason } => break reason,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
closed_for_reader.store(true, Ordering::SeqCst);
|
||||
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
|
||||
let _ = event_tx
|
||||
.send(RpcClientEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
drain_pending(&pending_for_reader).await;
|
||||
});
|
||||
|
||||
@@ -275,7 +277,7 @@ impl RpcClient {
|
||||
Self {
|
||||
write_tx,
|
||||
pending,
|
||||
closed,
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
_transport: transport,
|
||||
@@ -291,12 +293,6 @@ impl RpcClient {
|
||||
params: &P,
|
||||
) -> Result<(), serde_json::Error> {
|
||||
let params = serde_json::to_value(params)?;
|
||||
if self.closed.load(Ordering::SeqCst) {
|
||||
return Err(serde_json::Error::io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"JSON-RPC transport closed",
|
||||
)));
|
||||
}
|
||||
self.write_tx
|
||||
.send(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: method.to_string(),
|
||||
@@ -323,7 +319,7 @@ impl RpcClient {
|
||||
// Registering the pending request and checking disconnect must be
|
||||
// atomic with the reader's drain_pending path. Otherwise a call
|
||||
// can sneak in after the drain and wait forever.
|
||||
if self.closed.load(Ordering::SeqCst) {
|
||||
if *self.disconnected_rx.borrow() {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
pending.insert(request_id.clone(), response_tx);
|
||||
@@ -524,9 +520,7 @@ mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -534,7 +528,6 @@ mod tests {
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::RpcClient;
|
||||
use super::RpcClientEvent;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
|
||||
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
|
||||
@@ -638,98 +631,4 @@ mod tests {
|
||||
panic!("server task failed: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rpc_client_rejects_new_calls_after_reader_protocol_error() {
|
||||
let (client_stdin, _server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
|
||||
let (client, mut events_rx) = RpcClient::new(connection);
|
||||
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "server/request".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let event = timeout(Duration::from_secs(1), events_rx.recv())
|
||||
.await
|
||||
.expect("timed out waiting for disconnect event");
|
||||
match event {
|
||||
Some(RpcClientEvent::Disconnected { reason }) => {
|
||||
assert!(
|
||||
reason
|
||||
.as_deref()
|
||||
.is_some_and(|reason| reason.contains("unexpected JSON-RPC request")),
|
||||
"unexpected disconnect reason: {reason:?}"
|
||||
);
|
||||
}
|
||||
event => panic!("expected disconnect event, got {event:?}"),
|
||||
}
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.call::<_, serde_json::Value>("after-close", &serde_json::json!({})),
|
||||
)
|
||||
.await
|
||||
.expect("timed out waiting for closed call");
|
||||
|
||||
assert!(matches!(result, Err(super::RpcCallError::Closed)));
|
||||
assert_eq!(client.pending_request_count().await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rpc_client_drains_pending_call_on_transport_eof() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
|
||||
let (client, mut events_rx) = RpcClient::new(connection);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
let request = read_jsonrpc_line(&mut lines).await;
|
||||
match request {
|
||||
JSONRPCMessage::Request(request) if request.method == "will-close" => {}
|
||||
other => panic!("expected will-close request, got {other:?}"),
|
||||
}
|
||||
drop(server_writer);
|
||||
});
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.call::<_, serde_json::Value>("will-close", &serde_json::json!({})),
|
||||
)
|
||||
.await
|
||||
.expect("timed out waiting for closed call");
|
||||
assert!(matches!(result, Err(super::RpcCallError::Closed)));
|
||||
|
||||
let event = timeout(Duration::from_secs(1), events_rx.recv())
|
||||
.await
|
||||
.expect("timed out waiting for disconnect event");
|
||||
assert!(matches!(
|
||||
event,
|
||||
Some(RpcClientEvent::Disconnected { reason: None })
|
||||
));
|
||||
assert_eq!(client.pending_request_count().await, 0);
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.call::<_, serde_json::Value>("after-close", &serde_json::json!({})),
|
||||
)
|
||||
.await
|
||||
.expect("timed out waiting for fast closed call");
|
||||
assert!(matches!(result, Err(super::RpcCallError::Closed)));
|
||||
|
||||
let notify = client.notify("after-close", &serde_json::json!({})).await;
|
||||
assert!(notify.is_err());
|
||||
|
||||
server.await.expect("server task should finish");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user