Simplify stdio exec-server transport ownership

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-05-05 13:17:06 -07:00
parent 52ca8fa8b8
commit 21faf08349
4 changed files with 81 additions and 123 deletions

View File

@@ -12,6 +12,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::mpsc;
use tokio::sync::watch;
@@ -152,6 +153,9 @@ pub(crate) struct Session {
struct Inner {
client: RpcClient,
// Keep the connection alive for any transport-specific owned state such as
// the stdio child process. RpcClient only takes the runtime channels/tasks.
_connection: JsonRpcConnection,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> session registry so
// we can turn those connection-global notifications into process wakeups
@@ -191,28 +195,25 @@ pub struct ExecServerClient {
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
transport: ExecServerTransport,
client: Arc<Mutex<Option<ExecServerClient>>>,
client: Arc<OnceCell<ExecServerClient>>,
}
impl LazyRemoteExecServerClient {
pub(crate) fn new(transport: ExecServerTransport) -> Self {
Self {
transport,
client: Arc::new(Mutex::new(None)),
client: Arc::new(OnceCell::new()),
}
}
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
let mut client = self.client.lock().await;
if let Some(client) = client.as_ref()
&& !client.is_disconnected()
{
return Ok(client.clone());
}
let connected = ExecServerClient::connect_for_environment(self.transport.clone()).await?;
*client = Some(connected.clone());
Ok(connected)
self.client
.get_or_try_init(|| {
let transport = self.transport.clone();
async move { ExecServerClient::connect_for_environment(transport).await }
})
.await
.cloned()
}
}
@@ -276,10 +277,6 @@ pub enum ExecServerError {
}
impl ExecServerClient {
fn is_disconnected(&self) -> bool {
self.inner.disconnected_error().is_some() || self.inner.client.is_disconnected()
}
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
@@ -429,10 +426,10 @@ impl ExecServerClient {
}
pub(crate) async fn connect(
connection: JsonRpcConnection,
mut connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
) -> Result<Self, ExecServerError> {
let (rpc_client, mut events_rx) = RpcClient::new(connection);
let (rpc_client, mut events_rx) = RpcClient::new(&mut connection);
let inner = Arc::new_cyclic(|weak| {
let weak = weak.clone();
let reader_task = tokio::spawn(async move {
@@ -467,6 +464,7 @@ impl ExecServerClient {
Inner {
client: rpc_client,
_connection: connection,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
disconnected: OnceLock::new(),

View File

@@ -3,9 +3,7 @@ use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
@@ -79,7 +77,6 @@ impl ExecServerClient {
args: StdioExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let mut child = stdio_command_process(&args.command)
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@@ -110,60 +107,13 @@ impl ExecServerClient {
Self::connect(
JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string())
.with_transport_lifetime(Box::new(StdioChildGuard::spawn(child))),
.with_stdio_child(child),
args.into(),
)
.await
}
}
struct StdioChildGuard {
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl StdioChildGuard {
fn spawn(child: Child) -> Self {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
tokio::spawn(supervise_stdio_child(child, shutdown_rx));
Self {
shutdown_tx: Some(shutdown_tx),
}
}
}
impl Drop for StdioChildGuard {
fn drop(&mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
}
}
async fn supervise_stdio_child(mut child: Child, shutdown_rx: oneshot::Receiver<()>) {
let shutdown_requested = tokio::select! {
result = child.wait() => {
if let Err(err) = result {
debug!("failed to wait for exec-server stdio child: {err}");
}
false
}
_ = shutdown_rx => true,
};
if shutdown_requested {
kill_stdio_child(&mut child);
if let Err(err) = child.wait().await {
debug!("failed to wait for exec-server stdio child after shutdown: {err}");
}
}
}
fn kill_stdio_child(child: &mut Child) {
if let Err(err) = child.start_kill() {
debug!("failed to terminate exec-server stdio child: {err}");
}
}
fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
let mut command = Command::new(&stdio_command.program);
command.args(&stdio_command.args);

View File

@@ -3,10 +3,12 @@ use futures::SinkExt;
use futures::StreamExt;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::process::Child;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
@@ -15,16 +17,6 @@ use tokio::io::BufWriter;
pub(crate) const CHANNEL_CAPACITY: usize = 128;
pub(crate) type JsonRpcTransportLifetime = Box<dyn Send>;
pub(crate) struct JsonRpcConnectionParts {
pub(crate) outgoing_tx: mpsc::Sender<JSONRPCMessage>,
pub(crate) incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
pub(crate) disconnected_rx: watch::Receiver<bool>,
pub(crate) task_handles: Vec<tokio::task::JoinHandle<()>>,
pub(crate) transport_lifetime: Option<JsonRpcTransportLifetime>,
}
#[derive(Debug)]
pub(crate) enum JsonRpcConnectionEvent {
Message(JSONRPCMessage),
@@ -32,12 +24,24 @@ pub(crate) enum JsonRpcConnectionEvent {
Disconnected { reason: Option<String> },
}
struct StdioTransport {
child: Child,
}
impl Drop for StdioTransport {
fn drop(&mut self) {
if let Err(err) = self.child.start_kill() {
debug!("failed to terminate exec-server stdio child: {err}");
}
}
}
pub(crate) struct JsonRpcConnection {
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
disconnected_rx: watch::Receiver<bool>,
outgoing_tx: Option<mpsc::Sender<JSONRPCMessage>>,
incoming_rx: Option<mpsc::Receiver<JsonRpcConnectionEvent>>,
disconnected_rx: Option<watch::Receiver<bool>>,
task_handles: Vec<tokio::task::JoinHandle<()>>,
transport_lifetime: Option<JsonRpcTransportLifetime>,
_stdio_transport: Option<StdioTransport>,
}
impl JsonRpcConnection {
@@ -124,11 +128,11 @@ impl JsonRpcConnection {
});
Self {
outgoing_tx,
incoming_rx,
disconnected_rx,
outgoing_tx: Some(outgoing_tx),
incoming_rx: Some(incoming_rx),
disconnected_rx: Some(disconnected_rx),
task_handles: vec![reader_task, writer_task],
transport_lifetime: None,
_stdio_transport: None,
}
}
@@ -259,16 +263,38 @@ impl JsonRpcConnection {
});
Self {
outgoing_tx,
incoming_rx,
disconnected_rx,
outgoing_tx: Some(outgoing_tx),
incoming_rx: Some(incoming_rx),
disconnected_rx: Some(disconnected_rx),
task_handles: vec![reader_task, writer_task],
transport_lifetime: None,
_stdio_transport: None,
}
}
pub(crate) fn with_transport_lifetime(mut self, lifetime: JsonRpcTransportLifetime) -> Self {
self.transport_lifetime = Some(lifetime);
pub(crate) fn take_client_runtime(
&mut self,
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
(
self.outgoing_tx
.take()
.expect("JSON-RPC client runtime already taken"),
self.incoming_rx
.take()
.expect("JSON-RPC client runtime already taken"),
self.disconnected_rx
.take()
.expect("JSON-RPC client runtime already taken"),
std::mem::take(&mut self.task_handles),
)
}
pub(crate) fn with_stdio_child(mut self, child: Child) -> Self {
self._stdio_transport = Some(StdioTransport { child });
self
}
@@ -281,22 +307,15 @@ impl JsonRpcConnection {
Vec<tokio::task::JoinHandle<()>>,
) {
(
self.outgoing_tx,
self.incoming_rx,
self.disconnected_rx,
self.outgoing_tx
.expect("JSON-RPC connection parts already taken"),
self.incoming_rx
.expect("JSON-RPC connection parts already taken"),
self.disconnected_rx
.expect("JSON-RPC connection parts already taken"),
self.task_handles,
)
}
pub(crate) fn into_parts_with_lifetime(self) -> JsonRpcConnectionParts {
JsonRpcConnectionParts {
outgoing_tx: self.outgoing_tx,
incoming_rx: self.incoming_rx,
disconnected_rx: self.disconnected_rx,
task_handles: self.task_handles,
transport_lifetime: self.transport_lifetime,
}
}
}
async fn send_disconnected(

View File

@@ -233,19 +233,16 @@ pub(crate) struct RpcClient {
}
impl RpcClient {
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let connection_parts = connection.into_parts_with_lifetime();
let write_tx = connection_parts.outgoing_tx;
let mut incoming_rx = connection_parts.incoming_rx;
let disconnected_rx = connection_parts.disconnected_rx;
let transport_tasks = connection_parts.task_handles;
let transport_lifetime = connection_parts.transport_lifetime;
pub(crate) fn new(
connection: &mut JsonRpcConnection,
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) =
connection.take_client_runtime();
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
let pending_for_reader = Arc::clone(&pending);
let reader_task = tokio::spawn(async move {
let _transport_lifetime = transport_lifetime;
while let Some(event) = incoming_rx.recv().await {
match event {
JsonRpcConnectionEvent::Message(message) => {
@@ -307,10 +304,6 @@ impl RpcClient {
})
}
pub(crate) fn is_disconnected(&self) -> bool {
*self.disconnected_rx.borrow()
}
pub(crate) async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, RpcCallError>
where
P: Serialize,
@@ -575,11 +568,9 @@ mod tests {
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
"test-rpc".to_string(),
));
let mut connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, _events_rx) = RpcClient::new(&mut connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();