mirror of
https://github.com/openai/codex.git
synced 2026-05-23 20:44:50 +00:00
codex: address stdio transport review feedback (#20664)
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -894,6 +894,7 @@ mod tests {
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::ExecServerTransportParams;
|
||||
use crate::client_api::StdioExecServerCommand;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
@@ -956,6 +957,26 @@ mod tests {
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_for_transport_initializes_stdio_command() {
|
||||
let client = ExecServerClient::connect_for_transport(
|
||||
ExecServerTransportParams::StdioCommand(StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec![
|
||||
"-c".to_string(),
|
||||
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.expect("stdio transport should connect");
|
||||
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
|
||||
@@ -985,13 +1006,16 @@ mod tests {
|
||||
async fn dropping_stdio_client_terminates_spawned_process() {
|
||||
let tempdir = tempfile::tempdir().expect("tempdir should be created");
|
||||
let pid_file = tempdir.path().join("server.pid");
|
||||
let child_pid_file = tempdir.path().join("server-child.pid");
|
||||
let stdio_script = format!(
|
||||
"read _line; \
|
||||
echo \"$$\" > {}; \
|
||||
sleep 60 >/dev/null 2>&1 & echo \"$!\" > {}; \
|
||||
printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \
|
||||
read _line; \
|
||||
sleep 60",
|
||||
wait",
|
||||
shell_quote(pid_file.as_path()),
|
||||
shell_quote(child_pid_file.as_path()),
|
||||
);
|
||||
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
@@ -1008,14 +1032,20 @@ mod tests {
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
let server_pid = read_pid_file(pid_file.as_path()).await;
|
||||
let child_pid = read_pid_file(child_pid_file.as_path()).await;
|
||||
assert!(
|
||||
process_exists(server_pid),
|
||||
"spawned stdio process should be running before client drop"
|
||||
);
|
||||
assert!(
|
||||
process_exists(child_pid),
|
||||
"spawned stdio child process should be running before client drop"
|
||||
);
|
||||
|
||||
drop(client);
|
||||
|
||||
wait_for_process_exit(server_pid).await;
|
||||
wait_for_process_exit(child_pid).await;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
|
||||
@@ -121,5 +121,7 @@ fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
|
||||
if let Some(cwd) = &stdio_command.cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
#[cfg(unix)]
|
||||
command.process_group(0);
|
||||
command
|
||||
}
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
@@ -6,9 +11,11 @@ use tokio::io::AsyncWrite;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -16,6 +23,7 @@ use tokio::io::BufReader;
|
||||
use tokio::io::BufWriter;
|
||||
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
const STDIO_TERMINATION_GRACE_PERIOD: Duration = Duration::from_secs(2);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum JsonRpcConnectionEvent {
|
||||
@@ -24,46 +32,177 @@ pub(crate) enum JsonRpcConnectionEvent {
|
||||
Disconnected { reason: Option<String> },
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum JsonRpcTransport {
|
||||
Plain,
|
||||
Stdio { _transport: Box<StdioTransport> },
|
||||
Stdio { transport: StdioTransport },
|
||||
}
|
||||
|
||||
impl JsonRpcTransport {
|
||||
fn from_child_process(child_process: Child) -> Self {
|
||||
Self::Stdio {
|
||||
_transport: Box::new(StdioTransport {
|
||||
child_process: Some(child_process),
|
||||
}),
|
||||
transport: StdioTransport::spawn(child_process),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn terminate(&self) {
|
||||
match self {
|
||||
Self::Plain => {}
|
||||
Self::Stdio { transport } => transport.terminate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct StdioTransport {
|
||||
child_process: Option<Child>,
|
||||
handle: Arc<StdioTransportHandle>,
|
||||
}
|
||||
|
||||
impl Drop for StdioTransport {
|
||||
fn drop(&mut self) {
|
||||
let Some(mut child_process) = self.child_process.take() else {
|
||||
return;
|
||||
};
|
||||
struct StdioTransportHandle {
|
||||
terminate_tx: watch::Sender<bool>,
|
||||
terminate_requested: AtomicBool,
|
||||
}
|
||||
|
||||
if let Err(err) = child_process.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
impl StdioTransport {
|
||||
fn spawn(child_process: Child) -> Self {
|
||||
let (terminate_tx, terminate_rx) = watch::channel(false);
|
||||
let handle = Arc::new(StdioTransportHandle {
|
||||
terminate_tx,
|
||||
terminate_requested: AtomicBool::new(false),
|
||||
});
|
||||
spawn_stdio_child_supervisor(child_process, terminate_rx);
|
||||
Self { handle }
|
||||
}
|
||||
|
||||
fn terminate(&self) {
|
||||
self.handle.terminate();
|
||||
}
|
||||
}
|
||||
|
||||
impl StdioTransportHandle {
|
||||
fn terminate(&self) {
|
||||
if !self.terminate_requested.swap(true, Ordering::AcqRel) {
|
||||
let _ = self.terminate_tx.send(true);
|
||||
}
|
||||
match tokio::runtime::Handle::try_current() {
|
||||
Ok(handle) => {
|
||||
handle.spawn(async move {
|
||||
if let Err(err) = child_process.wait().await {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StdioTransportHandle {
|
||||
fn drop(&mut self) {
|
||||
self.terminate();
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_stdio_child_supervisor(mut child_process: Child, mut terminate_rx: watch::Receiver<bool>) {
|
||||
let process_group_id = child_process.id();
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
result = child_process.wait() => {
|
||||
log_stdio_child_wait_result(result);
|
||||
kill_process_tree(&mut child_process, process_group_id);
|
||||
}
|
||||
Err(err) => {
|
||||
debug!("failed to wait for exec-server stdio child without a Tokio runtime: {err}");
|
||||
() = wait_for_stdio_termination(&mut terminate_rx) => {
|
||||
terminate_stdio_child(&mut child_process, process_group_id).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn wait_for_stdio_termination(terminate_rx: &mut watch::Receiver<bool>) {
|
||||
loop {
|
||||
if *terminate_rx.borrow() {
|
||||
return;
|
||||
}
|
||||
if terminate_rx.changed().await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn terminate_stdio_child(child_process: &mut Child, process_group_id: Option<u32>) {
|
||||
terminate_process_tree(child_process, process_group_id);
|
||||
match timeout(STDIO_TERMINATION_GRACE_PERIOD, child_process.wait()).await {
|
||||
Ok(result) => {
|
||||
log_stdio_child_wait_result(result);
|
||||
}
|
||||
Err(_) => {
|
||||
kill_process_tree(child_process, process_group_id);
|
||||
log_stdio_child_wait_result(child_process.wait().await);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
|
||||
let Some(process_group_id) = process_group_id else {
|
||||
kill_direct_child(child_process, "terminate");
|
||||
return;
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
if let Err(err) = codex_utils_pty::process_group::terminate_process_group(process_group_id) {
|
||||
warn!("failed to terminate exec-server stdio process group {process_group_id}: {err}");
|
||||
kill_direct_child(child_process, "terminate");
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
if !kill_windows_process_tree(process_group_id) {
|
||||
kill_direct_child(child_process, "terminate");
|
||||
}
|
||||
|
||||
#[cfg(not(any(unix, windows)))]
|
||||
{
|
||||
let _ = process_group_id;
|
||||
kill_direct_child(child_process, "terminate");
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
|
||||
let Some(process_group_id) = process_group_id else {
|
||||
kill_direct_child(child_process, "kill");
|
||||
return;
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
if let Err(err) = codex_utils_pty::process_group::kill_process_group(process_group_id) {
|
||||
warn!("failed to kill exec-server stdio process group {process_group_id}: {err}");
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
if !kill_windows_process_tree(process_group_id) {
|
||||
kill_direct_child(child_process, "kill");
|
||||
}
|
||||
|
||||
#[cfg(not(any(unix, windows)))]
|
||||
{
|
||||
let _ = process_group_id;
|
||||
kill_direct_child(child_process, "kill");
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_direct_child(child_process: &mut Child, action: &str) {
|
||||
if let Err(err) = child_process.start_kill() {
|
||||
debug!("failed to {action} exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn kill_windows_process_tree(pid: u32) -> bool {
|
||||
let pid = pid.to_string();
|
||||
match std::process::Command::new("taskkill")
|
||||
.args(["/PID", pid.as_str(), "/T", "/F"])
|
||||
.status()
|
||||
{
|
||||
Ok(status) => status.success(),
|
||||
Err(err) => {
|
||||
warn!("failed to run taskkill for exec-server stdio process tree {pid}: {err}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn log_stdio_child_wait_result(result: std::io::Result<std::process::ExitStatus>) {
|
||||
if let Err(err) = result {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ pub(crate) struct RpcClient {
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
_transport: JsonRpcTransport,
|
||||
transport: JsonRpcTransport,
|
||||
reader_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
@@ -244,33 +244,38 @@ impl RpcClient {
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
let pending_for_reader = Arc::clone(&pending);
|
||||
let transport_for_reader = transport.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
while let Some(event) = incoming_rx.recv().await {
|
||||
let disconnect_reason = loop {
|
||||
let Some(event) = incoming_rx.recv().await else {
|
||||
break None;
|
||||
};
|
||||
match event {
|
||||
JsonRpcConnectionEvent::Message(message) => {
|
||||
if let Err(err) =
|
||||
handle_server_message(&pending_for_reader, &event_tx, message).await
|
||||
{
|
||||
let _ = err;
|
||||
break;
|
||||
break None;
|
||||
}
|
||||
}
|
||||
JsonRpcConnectionEvent::MalformedMessage { reason } => {
|
||||
let _ = reason;
|
||||
break;
|
||||
break None;
|
||||
}
|
||||
JsonRpcConnectionEvent::Disconnected { reason } => {
|
||||
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
|
||||
drain_pending(&pending_for_reader).await;
|
||||
return;
|
||||
break reason;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let _ = event_tx
|
||||
.send(RpcClientEvent::Disconnected { reason: None })
|
||||
.send(RpcClientEvent::Disconnected {
|
||||
reason: disconnect_reason,
|
||||
})
|
||||
.await;
|
||||
drain_pending(&pending_for_reader).await;
|
||||
transport_for_reader.terminate();
|
||||
});
|
||||
|
||||
(
|
||||
@@ -280,7 +285,7 @@ impl RpcClient {
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
_transport: transport,
|
||||
transport,
|
||||
reader_task,
|
||||
},
|
||||
event_rx,
|
||||
@@ -370,6 +375,7 @@ impl RpcClient {
|
||||
|
||||
impl Drop for RpcClient {
|
||||
fn drop(&mut self) {
|
||||
self.transport.terminate();
|
||||
for task in &self.transport_tasks {
|
||||
task.abort();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user