Address stdio exec-server review feedback

Spawn stdio exec-server commands directly from structured argv/env/cwd instead of wrapping a shell string, redact the connection label, and tie the stdio child guard to transport disconnect.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-05-05 10:08:36 -07:00
parent 74e96987b8
commit c00a36e727
7 changed files with 185 additions and 142 deletions

View File

@@ -12,7 +12,6 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::mpsc;
use tokio::sync::watch;
@@ -192,25 +191,28 @@ pub struct ExecServerClient {
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
transport: ExecServerTransport,
client: Arc<OnceCell<ExecServerClient>>,
client: Arc<Mutex<Option<ExecServerClient>>>,
}
impl LazyRemoteExecServerClient {
pub(crate) fn new(transport: ExecServerTransport) -> Self {
Self {
transport,
client: Arc::new(OnceCell::new()),
client: Arc::new(Mutex::new(None)),
}
}
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
self.client
.get_or_try_init(|| {
let transport = self.transport.clone();
async move { ExecServerClient::connect_for_environment(transport).await }
})
.await
.cloned()
let mut client = self.client.lock().await;
if let Some(client) = client.as_ref()
&& !client.is_disconnected()
{
return Ok(client.clone());
}
let connected = ExecServerClient::connect_for_environment(self.transport.clone()).await?;
*client = Some(connected.clone());
Ok(connected)
}
}
@@ -274,6 +276,10 @@ pub enum ExecServerError {
}
impl ExecServerClient {
fn is_disconnected(&self) -> bool {
self.inner.disconnected_error().is_some() || self.inner.client.is_disconnected()
}
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
@@ -872,6 +878,7 @@ mod tests {
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
#[cfg(unix)]
use std::path::Path;
#[cfg(unix)]
@@ -890,7 +897,7 @@ mod tests {
use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use crate::ProcessId;
#[cfg(not(windows))]
use crate::StdioExecServerCommand;
use crate::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
@@ -933,7 +940,38 @@ mod tests {
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
shell_command: "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec![
"-c".to_string(),
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await
.expect("stdio client should connect");
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(windows)]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "cmd".to_string(),
args: vec![
"/C".to_string(),
"set /p _line= & echo {\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}} & set /p _line= & ping -n 60 127.0.0.1 >nul".to_string(),
],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
@@ -946,43 +984,71 @@ mod tests {
#[cfg(unix)]
#[tokio::test]
async fn dropping_stdio_client_terminates_shell_process_group() {
async fn dropping_stdio_client_terminates_spawned_process() {
let tempdir = tempfile::tempdir().expect("tempdir should be created");
let pid_file = tempdir.path().join("child.pid");
let shell_command = format!(
let pid_file = tempdir.path().join("server.pid");
let stdio_script = format!(
"read _line; \
(trap 'exit 0' TERM; while true; do sleep 1; done) & \
child=$!; \
echo \"$child\" > {}; \
echo \"$$\" > {}; \
printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \
read _line; \
wait \"$child\"",
sleep 60",
shell_quote(pid_file.as_path()),
);
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
shell_command,
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await
.expect("stdio client should connect");
let child_pid = read_pid_file(pid_file.as_path()).await;
let server_pid = read_pid_file(pid_file.as_path()).await;
assert!(
process_exists(child_pid),
"wrapper child process should be running before client drop"
process_exists(server_pid),
"spawned stdio process should be running before client drop"
);
drop(client);
for _ in 0..20 {
if !process_exists(child_pid) {
return;
}
sleep(Duration::from_millis(100)).await;
}
panic!("wrapper child process {child_pid} should exit after client drop");
wait_for_process_exit(server_pid).await;
}
#[cfg(unix)]
#[tokio::test]
async fn malformed_stdio_message_terminates_spawned_process() {
let tempdir = tempfile::tempdir().expect("tempdir should be created");
let pid_file = tempdir.path().join("server.pid");
let stdio_script = format!(
"read _line; \
echo \"$$\" > {}; \
printf '%s\\n' 'not-json'; \
sleep 60",
shell_quote(pid_file.as_path()),
);
let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await;
assert!(result.is_err(), "malformed stdio server should not connect");
let server_pid = read_pid_file(pid_file.as_path()).await;
wait_for_process_exit(server_pid).await;
}
#[cfg(unix)]
@@ -999,6 +1065,17 @@ mod tests {
panic!("pid file {} should be written", path.display());
}
#[cfg(unix)]
async fn wait_for_process_exit(pid: u32) {
for _ in 0..20 {
if !process_exists(pid) {
return;
}
sleep(Duration::from_millis(100)).await;
}
panic!("process {pid} should exit");
}
#[cfg(unix)]
fn process_exists(pid: u32) -> bool {
Command::new("kill")

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use futures::future::BoxFuture;
@@ -28,17 +30,26 @@ pub struct RemoteExecServerConnectArgs {
/// Stdio connection arguments for a command-backed exec-server.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StdioExecServerConnectArgs {
pub shell_command: String,
pub command: StdioExecServerCommand,
pub client_name: String,
pub initialize_timeout: Duration,
pub resume_session_id: Option<String>,
}
/// Structured process command used to start an exec-server over stdio.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StdioExecServerCommand {
pub program: String,
pub args: Vec<String>,
pub env: HashMap<String, String>,
pub cwd: Option<PathBuf>,
}
/// Transport used to connect to a remote exec-server environment.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecServerTransport {
WebSocketUrl(String),
StdioShellCommand(String),
StdioCommand(StdioExecServerCommand),
}
/// Sends HTTP requests through a runtime-selected transport.

View File

@@ -1,19 +1,11 @@
use std::process::Stdio;
#[cfg(unix)]
use std::thread::sleep;
#[cfg(unix)]
use std::thread::spawn;
use std::time::Duration;
#[cfg(unix)]
use codex_utils_pty::process_group::kill_process_group;
#[cfg(unix)]
use codex_utils_pty::process_group::terminate_process_group;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
@@ -22,14 +14,13 @@ use tracing::warn;
use crate::ExecServerClient;
use crate::ExecServerError;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(unix)]
const STDIO_CHILD_TERM_GRACE_PERIOD: Duration = Duration::from_millis(500);
impl ExecServerClient {
pub(crate) async fn connect_for_environment(
@@ -46,9 +37,9 @@ impl ExecServerClient {
})
.await
}
crate::client_api::ExecServerTransport::StdioShellCommand(shell_command) => {
crate::client_api::ExecServerTransport::StdioCommand(command) => {
Self::connect_stdio_command(StdioExecServerConnectArgs {
shell_command,
command,
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
resume_session_id: None,
@@ -87,15 +78,13 @@ impl ExecServerClient {
pub async fn connect_stdio_command(
args: StdioExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let shell_command = args.shell_command.clone();
let mut child = shell_command_process(&shell_command)
let mut child = stdio_command_process(&args.command)
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(ExecServerError::Spawn)?;
let process_id = child.id();
let stdin = child.stdin.take().ok_or_else(|| {
ExecServerError::Protocol("spawned exec-server command has no stdin".to_string())
@@ -120,15 +109,8 @@ impl ExecServerClient {
}
Self::connect(
JsonRpcConnection::from_stdio(
stdout,
stdin,
format!("exec-server stdio command `{shell_command}`"),
)
.with_transport_lifetime(Box::new(StdioChildGuard {
child: Some(child),
process_id,
})),
JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string())
.with_transport_lifetime(Box::new(StdioChildGuard::spawn(child))),
args.into(),
)
.await
@@ -136,70 +118,44 @@ impl ExecServerClient {
}
struct StdioChildGuard {
child: Option<Child>,
process_id: Option<u32>,
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl StdioChildGuard {
fn spawn(child: Child) -> Self {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
tokio::spawn(supervise_stdio_child(child, shutdown_rx));
Self {
shutdown_tx: Some(shutdown_tx),
}
}
}
impl Drop for StdioChildGuard {
fn drop(&mut self) {
let Some(mut child) = self.child.take() else {
return;
};
terminate_stdio_child_process(self.process_id, &mut child);
if let Ok(handle) = Handle::try_current() {
let _wait_task = handle.spawn(wait_stdio_child(child));
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
}
}
async fn wait_stdio_child(mut child: Child) {
if let Err(err) = child.wait().await {
debug!("failed to wait for exec-server stdio child: {err}");
}
}
#[cfg(unix)]
fn terminate_stdio_child_process(process_group_id: Option<u32>, child: &mut Child) {
let Some(process_group_id) = process_group_id else {
kill_stdio_child(child);
return;
};
let should_escalate = match terminate_process_group(process_group_id) {
Ok(exists) => exists,
Err(err) => {
debug!("failed to terminate exec-server stdio process group {process_group_id}: {err}");
async fn supervise_stdio_child(mut child: Child, shutdown_rx: oneshot::Receiver<()>) {
let shutdown_requested = tokio::select! {
result = child.wait() => {
if let Err(err) = result {
debug!("failed to wait for exec-server stdio child: {err}");
}
false
}
_ = shutdown_rx => true,
};
if should_escalate {
spawn(move || {
sleep(STDIO_CHILD_TERM_GRACE_PERIOD);
if let Err(err) = kill_process_group(process_group_id) {
debug!("failed to kill exec-server stdio process group {process_group_id}: {err}");
}
});
}
}
#[cfg(windows)]
fn terminate_stdio_child_process(process_id: Option<u32>, child: &mut Child) {
if let Some(process_id) = process_id {
let _ = std::process::Command::new("taskkill")
.arg("/PID")
.arg(process_id.to_string())
.arg("/T")
.arg("/F")
.output();
if shutdown_requested {
kill_stdio_child(&mut child);
if let Err(err) = child.wait().await {
debug!("failed to wait for exec-server stdio child after shutdown: {err}");
}
}
kill_stdio_child(child);
}
#[cfg(not(any(unix, windows)))]
fn terminate_stdio_child_process(_process_id: Option<u32>, child: &mut Child) {
kill_stdio_child(child);
}
fn kill_stdio_child(child: &mut Child) {
@@ -208,19 +164,12 @@ fn kill_stdio_child(child: &mut Child) {
}
}
fn shell_command_process(shell_command: &str) -> Command {
#[cfg(windows)]
{
let mut command = Command::new("cmd");
command.arg("/C").arg(shell_command);
command
}
#[cfg(not(windows))]
{
let mut command = Command::new("sh");
command.arg("-lc").arg(shell_command);
command.process_group(0);
command
fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
let mut command = Command::new(&stdio_command.program);
command.args(&stdio_command.args);
command.envs(&stdio_command.env);
if let Some(cwd) = &stdio_command.cwd {
command.current_dir(cwd);
}
command
}

View File

@@ -272,7 +272,23 @@ impl JsonRpcConnection {
self
}
pub(crate) fn into_parts(self) -> JsonRpcConnectionParts {
pub(crate) fn into_parts(
self,
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
(
self.outgoing_tx,
self.incoming_rx,
self.disconnected_rx,
self.task_handles,
)
}
pub(crate) fn into_parts_with_lifetime(self) -> JsonRpcConnectionParts {
JsonRpcConnectionParts {
outgoing_tx: self.outgoing_tx,
incoming_rx: self.incoming_rx,

View File

@@ -28,6 +28,7 @@ pub use client_api::ExecServerClientConnectOptions;
pub use client_api::ExecServerTransport;
pub use client_api::HttpClient;
pub use client_api::RemoteExecServerConnectArgs;
pub use client_api::StdioExecServerCommand;
pub use client_api::StdioExecServerConnectArgs;
pub use codex_file_system::CopyOptions;
pub use codex_file_system::CreateDirectoryOptions;

View File

@@ -2,7 +2,6 @@ use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
@@ -24,7 +23,6 @@ use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransportLifetime;
#[derive(Debug)]
pub(crate) enum RpcCallError {
@@ -231,19 +229,12 @@ pub(crate) struct RpcClient {
disconnected_rx: watch::Receiver<bool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
_transport_lifetime: Option<TransportLifetime>,
reader_task: JoinHandle<()>,
}
// Holds transport-owned resources, such as a stdio child process, for as long
// as the RPC client owns the underlying connection.
struct TransportLifetime {
_guard: StdMutex<JsonRpcTransportLifetime>,
}
impl RpcClient {
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let connection_parts = connection.into_parts();
let connection_parts = connection.into_parts_with_lifetime();
let write_tx = connection_parts.outgoing_tx;
let mut incoming_rx = connection_parts.incoming_rx;
let disconnected_rx = connection_parts.disconnected_rx;
@@ -254,6 +245,7 @@ impl RpcClient {
let pending_for_reader = Arc::clone(&pending);
let reader_task = tokio::spawn(async move {
let _transport_lifetime = transport_lifetime;
while let Some(event) = incoming_rx.recv().await {
match event {
JsonRpcConnectionEvent::Message(message) => {
@@ -289,9 +281,6 @@ impl RpcClient {
disconnected_rx,
next_request_id: AtomicI64::new(1),
transport_tasks,
_transport_lifetime: transport_lifetime.map(|lifetime| TransportLifetime {
_guard: StdMutex::new(lifetime),
}),
reader_task,
},
event_rx,
@@ -318,6 +307,10 @@ impl RpcClient {
})
}
pub(crate) fn is_disconnected(&self) -> bool {
*self.disconnected_rx.borrow()
}
pub(crate) async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, RpcCallError>
where
P: Serialize,

View File

@@ -47,12 +47,8 @@ async fn run_connection(
runtime_paths: ExecServerRuntimePaths,
) {
let router = Arc::new(build_router());
let connection_parts = connection.into_parts();
let json_outgoing_tx = connection_parts.outgoing_tx;
let mut incoming_rx = connection_parts.incoming_rx;
let mut disconnected_rx = connection_parts.disconnected_rx;
let connection_tasks = connection_parts.task_handles;
let _transport_lifetime = connection_parts.transport_lifetime;
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
connection.into_parts();
let (outgoing_tx, mut outgoing_rx) =
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
let notifications = RpcNotificationSender::new(outgoing_tx.clone());