codex: address stdio transport review feedback (#20664)

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-05-07 16:06:20 -07:00
parent 9f125d25cb
commit 26899a2d5b
4 changed files with 209 additions and 32 deletions

View File

@@ -894,6 +894,7 @@ mod tests {
use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use crate::ProcessId;
use crate::client_api::ExecServerTransportParams;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
@@ -956,6 +957,26 @@ mod tests {
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(not(windows))]
#[tokio::test]
async fn connect_for_transport_initializes_stdio_command() {
let client = ExecServerClient::connect_for_transport(
ExecServerTransportParams::StdioCommand(StdioExecServerCommand {
program: "sh".to_string(),
args: vec![
"-c".to_string(),
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
],
env: HashMap::new(),
cwd: None,
}),
)
.await
.expect("stdio transport should connect");
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(windows)]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
@@ -985,13 +1006,16 @@ mod tests {
async fn dropping_stdio_client_terminates_spawned_process() {
let tempdir = tempfile::tempdir().expect("tempdir should be created");
let pid_file = tempdir.path().join("server.pid");
let child_pid_file = tempdir.path().join("server-child.pid");
let stdio_script = format!(
"read _line; \
echo \"$$\" > {}; \
sleep 60 >/dev/null 2>&1 & echo \"$!\" > {}; \
printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \
read _line; \
sleep 60",
wait",
shell_quote(pid_file.as_path()),
shell_quote(child_pid_file.as_path()),
);
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
@@ -1008,14 +1032,20 @@ mod tests {
.await
.expect("stdio client should connect");
let server_pid = read_pid_file(pid_file.as_path()).await;
let child_pid = read_pid_file(child_pid_file.as_path()).await;
assert!(
process_exists(server_pid),
"spawned stdio process should be running before client drop"
);
assert!(
process_exists(child_pid),
"spawned stdio child process should be running before client drop"
);
drop(client);
wait_for_process_exit(server_pid).await;
wait_for_process_exit(child_pid).await;
}
#[cfg(unix)]

View File

@@ -121,5 +121,7 @@ fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
if let Some(cwd) = &stdio_command.cwd {
command.current_dir(cwd);
}
#[cfg(unix)]
command.process_group(0);
command
}

View File

@@ -1,3 +1,8 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCMessage;
use futures::SinkExt;
use futures::StreamExt;
@@ -6,9 +11,11 @@ use tokio::io::AsyncWrite;
use tokio::process::Child;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
use tracing::warn;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
@@ -16,6 +23,7 @@ use tokio::io::BufReader;
use tokio::io::BufWriter;
pub(crate) const CHANNEL_CAPACITY: usize = 128;
const STDIO_TERMINATION_GRACE_PERIOD: Duration = Duration::from_secs(2);
#[derive(Debug)]
pub(crate) enum JsonRpcConnectionEvent {
@@ -24,46 +32,177 @@ pub(crate) enum JsonRpcConnectionEvent {
Disconnected { reason: Option<String> },
}
#[derive(Clone)]
pub(crate) enum JsonRpcTransport {
Plain,
Stdio { _transport: Box<StdioTransport> },
Stdio { transport: StdioTransport },
}
impl JsonRpcTransport {
fn from_child_process(child_process: Child) -> Self {
Self::Stdio {
_transport: Box::new(StdioTransport {
child_process: Some(child_process),
}),
transport: StdioTransport::spawn(child_process),
}
}
pub(crate) fn terminate(&self) {
match self {
Self::Plain => {}
Self::Stdio { transport } => transport.terminate(),
}
}
}
#[derive(Clone)]
pub(crate) struct StdioTransport {
child_process: Option<Child>,
handle: Arc<StdioTransportHandle>,
}
impl Drop for StdioTransport {
fn drop(&mut self) {
let Some(mut child_process) = self.child_process.take() else {
return;
};
struct StdioTransportHandle {
terminate_tx: watch::Sender<bool>,
terminate_requested: AtomicBool,
}
if let Err(err) = child_process.start_kill() {
debug!("failed to terminate exec-server stdio child: {err}");
impl StdioTransport {
fn spawn(child_process: Child) -> Self {
let (terminate_tx, terminate_rx) = watch::channel(false);
let handle = Arc::new(StdioTransportHandle {
terminate_tx,
terminate_requested: AtomicBool::new(false),
});
spawn_stdio_child_supervisor(child_process, terminate_rx);
Self { handle }
}
fn terminate(&self) {
self.handle.terminate();
}
}
impl StdioTransportHandle {
fn terminate(&self) {
if !self.terminate_requested.swap(true, Ordering::AcqRel) {
let _ = self.terminate_tx.send(true);
}
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
handle.spawn(async move {
if let Err(err) = child_process.wait().await {
debug!("failed to wait for exec-server stdio child: {err}");
}
});
}
}
impl Drop for StdioTransportHandle {
fn drop(&mut self) {
self.terminate();
}
}
fn spawn_stdio_child_supervisor(mut child_process: Child, mut terminate_rx: watch::Receiver<bool>) {
let process_group_id = child_process.id();
tokio::spawn(async move {
tokio::select! {
result = child_process.wait() => {
log_stdio_child_wait_result(result);
kill_process_tree(&mut child_process, process_group_id);
}
Err(err) => {
debug!("failed to wait for exec-server stdio child without a Tokio runtime: {err}");
() = wait_for_stdio_termination(&mut terminate_rx) => {
terminate_stdio_child(&mut child_process, process_group_id).await;
}
}
});
}
async fn wait_for_stdio_termination(terminate_rx: &mut watch::Receiver<bool>) {
loop {
if *terminate_rx.borrow() {
return;
}
if terminate_rx.changed().await.is_err() {
return;
}
}
}
async fn terminate_stdio_child(child_process: &mut Child, process_group_id: Option<u32>) {
terminate_process_tree(child_process, process_group_id);
match timeout(STDIO_TERMINATION_GRACE_PERIOD, child_process.wait()).await {
Ok(result) => {
log_stdio_child_wait_result(result);
}
Err(_) => {
kill_process_tree(child_process, process_group_id);
log_stdio_child_wait_result(child_process.wait().await);
}
}
}
fn terminate_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
let Some(process_group_id) = process_group_id else {
kill_direct_child(child_process, "terminate");
return;
};
#[cfg(unix)]
if let Err(err) = codex_utils_pty::process_group::terminate_process_group(process_group_id) {
warn!("failed to terminate exec-server stdio process group {process_group_id}: {err}");
kill_direct_child(child_process, "terminate");
}
#[cfg(windows)]
if !kill_windows_process_tree(process_group_id) {
kill_direct_child(child_process, "terminate");
}
#[cfg(not(any(unix, windows)))]
{
let _ = process_group_id;
kill_direct_child(child_process, "terminate");
}
}
fn kill_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
let Some(process_group_id) = process_group_id else {
kill_direct_child(child_process, "kill");
return;
};
#[cfg(unix)]
if let Err(err) = codex_utils_pty::process_group::kill_process_group(process_group_id) {
warn!("failed to kill exec-server stdio process group {process_group_id}: {err}");
}
#[cfg(windows)]
if !kill_windows_process_tree(process_group_id) {
kill_direct_child(child_process, "kill");
}
#[cfg(not(any(unix, windows)))]
{
let _ = process_group_id;
kill_direct_child(child_process, "kill");
}
}
fn kill_direct_child(child_process: &mut Child, action: &str) {
if let Err(err) = child_process.start_kill() {
debug!("failed to {action} exec-server stdio child: {err}");
}
}
#[cfg(windows)]
fn kill_windows_process_tree(pid: u32) -> bool {
let pid = pid.to_string();
match std::process::Command::new("taskkill")
.args(["/PID", pid.as_str(), "/T", "/F"])
.status()
{
Ok(status) => status.success(),
Err(err) => {
warn!("failed to run taskkill for exec-server stdio process tree {pid}: {err}");
false
}
}
}
fn log_stdio_child_wait_result(result: std::io::Result<std::process::ExitStatus>) {
if let Err(err) = result {
debug!("failed to wait for exec-server stdio child: {err}");
}
}

View File

@@ -227,7 +227,7 @@ pub(crate) struct RpcClient {
disconnected_rx: watch::Receiver<bool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
_transport: JsonRpcTransport,
transport: JsonRpcTransport,
reader_task: JoinHandle<()>,
}
@@ -244,33 +244,38 @@ impl RpcClient {
let (event_tx, event_rx) = mpsc::channel(128);
let pending_for_reader = Arc::clone(&pending);
let transport_for_reader = transport.clone();
let reader_task = tokio::spawn(async move {
while let Some(event) = incoming_rx.recv().await {
let disconnect_reason = loop {
let Some(event) = incoming_rx.recv().await else {
break None;
};
match event {
JsonRpcConnectionEvent::Message(message) => {
if let Err(err) =
handle_server_message(&pending_for_reader, &event_tx, message).await
{
let _ = err;
break;
break None;
}
}
JsonRpcConnectionEvent::MalformedMessage { reason } => {
let _ = reason;
break;
break None;
}
JsonRpcConnectionEvent::Disconnected { reason } => {
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
drain_pending(&pending_for_reader).await;
return;
break reason;
}
}
}
};
let _ = event_tx
.send(RpcClientEvent::Disconnected { reason: None })
.send(RpcClientEvent::Disconnected {
reason: disconnect_reason,
})
.await;
drain_pending(&pending_for_reader).await;
transport_for_reader.terminate();
});
(
@@ -280,7 +285,7 @@ impl RpcClient {
disconnected_rx,
next_request_id: AtomicI64::new(1),
transport_tasks,
_transport: transport,
transport,
reader_task,
},
event_rx,
@@ -370,6 +375,7 @@ impl RpcClient {
impl Drop for RpcClient {
fn drop(&mut self) {
self.transport.terminate();
for task in &self.transport_tasks {
task.abort();
}