Simplify exec-server connection ownership

Remove the runtime extraction helpers and make JsonRpcConnection ownership explicit at the destructuring sites. Let the stdio transport clean up through Drop so ExecServerClient no longer needs to call an explicit shutdown hook.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-05-05 16:30:33 -07:00
parent 7557a7307a
commit c72f484068
4 changed files with 49 additions and 90 deletions

View File

@@ -152,8 +152,6 @@ pub(crate) struct Session {
}
struct Inner {
// Keep the underlying transport connection alive for the client lifetime.
connection: JsonRpcConnection,
client: RpcClient,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> session registry so
@@ -183,7 +181,6 @@ struct Inner {
impl Drop for Inner {
fn drop(&mut self) {
self.reader_task.abort();
self.connection.shutdown();
}
}
@@ -428,10 +425,10 @@ impl ExecServerClient {
}
pub(crate) async fn connect(
mut connection: JsonRpcConnection,
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
) -> Result<Self, ExecServerError> {
let (rpc_client, mut events_rx) = RpcClient::new(&mut connection);
let (rpc_client, mut events_rx) = RpcClient::new(connection);
let inner = Arc::new_cyclic(|weak| {
let weak = weak.clone();
let reader_task = tokio::spawn(async move {
@@ -465,7 +462,6 @@ impl ExecServerClient {
});
Inner {
connection,
client: rpc_client,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),

View File

@@ -24,7 +24,7 @@ pub(crate) enum JsonRpcConnectionEvent {
Disconnected { reason: Option<String> },
}
enum JsonRpcTransport {
pub(crate) enum JsonRpcTransport {
Plain,
Stdio(StdioTransport),
}
@@ -35,21 +35,14 @@ impl JsonRpcTransport {
child_process: Some(child_process),
})
}
fn shutdown(&mut self) {
match self {
Self::Plain => {}
Self::Stdio(transport) => transport.shutdown(),
}
}
}
struct StdioTransport {
pub(crate) struct StdioTransport {
child_process: Option<Child>,
}
impl StdioTransport {
fn shutdown(&mut self) {
impl Drop for StdioTransport {
fn drop(&mut self) {
let Some(mut child_process) = self.child_process.take() else {
return;
};
@@ -72,29 +65,19 @@ impl StdioTransport {
}
}
struct JsonRpcConnectionRuntime {
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
disconnected_rx: watch::Receiver<bool>,
task_handles: Vec<tokio::task::JoinHandle<()>>,
pub(crate) struct JsonRpcConnectionRuntime {
pub(crate) outgoing_tx: mpsc::Sender<JSONRPCMessage>,
pub(crate) incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
pub(crate) disconnected_rx: watch::Receiver<bool>,
pub(crate) task_handles: Vec<tokio::task::JoinHandle<()>>,
}
pub(crate) struct JsonRpcConnection {
runtime: Option<JsonRpcConnectionRuntime>,
transport: JsonRpcTransport,
}
impl Drop for JsonRpcConnection {
fn drop(&mut self) {
self.shutdown();
}
pub(crate) runtime: JsonRpcConnectionRuntime,
pub(crate) transport: JsonRpcTransport,
}
impl JsonRpcConnection {
pub(crate) fn shutdown(&mut self) {
self.transport.shutdown();
}
pub(crate) fn from_stdio<R, W>(reader: R, writer: W, connection_label: String) -> Self
where
R: AsyncRead + Unpin + Send + 'static,
@@ -178,12 +161,12 @@ impl JsonRpcConnection {
});
Self {
runtime: Some(JsonRpcConnectionRuntime {
runtime: JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
}),
},
transport: JsonRpcTransport::Plain,
}
}
@@ -315,60 +298,20 @@ impl JsonRpcConnection {
});
Self {
runtime: Some(JsonRpcConnectionRuntime {
runtime: JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
}),
},
transport: JsonRpcTransport::Plain,
}
}
pub(crate) fn take_runtime(
&mut self,
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
Vec<tokio::task::JoinHandle<()>>,
) {
let JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx: _,
task_handles,
} = self.take_runtime_or_panic("JSON-RPC connection runtime already taken");
(outgoing_tx, incoming_rx, task_handles)
}
pub(crate) fn into_parts(
mut self,
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
let JsonRpcConnectionRuntime {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles,
} = self.take_runtime_or_panic("JSON-RPC connection runtime already taken");
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
}
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
self.transport = JsonRpcTransport::from_child_process(child_process);
self
}
fn take_runtime_or_panic(&mut self, message: &'static str) -> JsonRpcConnectionRuntime {
match self.runtime.take() {
Some(runtime) => runtime,
None => panic!("{message}"),
}
}
}
async fn send_disconnected(

View File

@@ -23,6 +23,8 @@ use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcConnectionRuntime;
use crate::connection::JsonRpcTransport;
#[derive(Debug)]
pub(crate) enum RpcCallError {
@@ -225,14 +227,22 @@ pub(crate) struct RpcClient {
closed: Arc<AtomicBool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
_transport: JsonRpcTransport,
reader_task: JoinHandle<()>,
}
impl RpcClient {
pub(crate) fn new(
connection: &mut JsonRpcConnection,
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let (write_tx, mut incoming_rx, transport_tasks) = connection.take_runtime();
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let JsonRpcConnection {
runtime:
JsonRpcConnectionRuntime {
outgoing_tx: write_tx,
incoming_rx: mut incoming_rx,
disconnected_rx: _,
task_handles: transport_tasks,
},
transport,
} = connection;
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
let closed = Arc::new(AtomicBool::new(false));
@@ -272,6 +282,7 @@ impl RpcClient {
closed,
next_request_id: AtomicI64::new(1),
transport_tasks,
_transport: transport,
reader_task,
},
event_rx,
@@ -570,9 +581,9 @@ mod tests {
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let mut connection =
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, _events_rx) = RpcClient::new(&mut connection);
let (client, _events_rx) = RpcClient::new(connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
@@ -636,9 +647,9 @@ mod tests {
async fn rpc_client_rejects_new_calls_after_reader_protocol_error() {
let (client_stdin, _server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let mut connection =
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, mut events_rx) = RpcClient::new(&mut connection);
let (client, mut events_rx) = RpcClient::new(connection);
write_jsonrpc_line(
&mut server_writer,
@@ -681,9 +692,9 @@ mod tests {
async fn rpc_client_drains_pending_call_on_transport_eof() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (server_writer, client_stdout) = tokio::io::duplex(4096);
let mut connection =
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, mut events_rx) = RpcClient::new(&mut connection);
let (client, mut events_rx) = RpcClient::new(connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();

View File

@@ -8,6 +8,7 @@ use crate::ExecServerRuntimePaths;
use crate::connection::CHANNEL_CAPACITY;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcConnectionRuntime;
use crate::rpc::RpcNotificationSender;
use crate::rpc::RpcServerOutboundMessage;
use crate::rpc::encode_server_message;
@@ -47,8 +48,16 @@ async fn run_connection(
runtime_paths: ExecServerRuntimePaths,
) {
let router = Arc::new(build_router());
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
connection.into_parts();
let JsonRpcConnection {
runtime:
JsonRpcConnectionRuntime {
outgoing_tx: json_outgoing_tx,
incoming_rx: mut incoming_rx,
disconnected_rx: mut disconnected_rx,
task_handles: connection_tasks,
},
transport: _transport,
} = connection;
let (outgoing_tx, mut outgoing_rx) =
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
let notifications = RpcNotificationSender::new(outgoing_tx.clone());