mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
feat: use process group to kill the PTY (#12688)
Use the process group kill logic to kill the PTY
This commit is contained in:
@@ -35,10 +35,27 @@ pub fn conpty_supported() -> bool {
|
||||
|
||||
struct PtyChildTerminator {
|
||||
killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
|
||||
#[cfg(unix)]
|
||||
process_group_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl ChildTerminator for PtyChildTerminator {
|
||||
fn kill(&mut self) -> std::io::Result<()> {
|
||||
#[cfg(unix)]
|
||||
if let Some(process_group_id) = self.process_group_id {
|
||||
// Match the pipe backend's hard-kill behavior so descendant
|
||||
// processes from interactive shells/REPLs do not survive shutdown.
|
||||
// Also try the direct child killer in case the cached PGID is stale.
|
||||
let process_group_kill_result =
|
||||
crate::process_group::kill_process_group(process_group_id);
|
||||
let child_kill_result = self.killer.kill();
|
||||
return match child_kill_result {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => process_group_kill_result,
|
||||
Err(err) => process_group_kill_result.or(Err(err)),
|
||||
};
|
||||
}
|
||||
|
||||
self.killer.kill()
|
||||
}
|
||||
}
|
||||
@@ -86,6 +103,11 @@ pub async fn spawn_process(
|
||||
}
|
||||
|
||||
let mut child = pair.slave.spawn_command(command_builder)?;
|
||||
#[cfg(unix)]
|
||||
// portable-pty establishes the spawned PTY child as a new session leader on
|
||||
// Unix, so PID == PGID and we can reuse the pipe backend's process-group
|
||||
// hard-kill semantics for descendants.
|
||||
let process_group_id = child.process_id();
|
||||
let killer = child.clone_killer();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
@@ -156,7 +178,11 @@ pub async fn spawn_process(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
initial_output_rx,
|
||||
Box::new(PtyChildTerminator { killer }),
|
||||
Box::new(PtyChildTerminator {
|
||||
killer,
|
||||
#[cfg(unix)]
|
||||
process_group_id,
|
||||
}),
|
||||
reader_handle,
|
||||
Vec::new(),
|
||||
writer_handle,
|
||||
|
||||
@@ -144,6 +144,73 @@ async fn wait_for_python_repl_ready(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn process_exists(pid: i32) -> anyhow::Result<bool> {
|
||||
let result = unsafe { libc::kill(pid, 0) };
|
||||
if result == 0 {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let err = std::io::Error::last_os_error();
|
||||
match err.raw_os_error() {
|
||||
Some(libc::ESRCH) => Ok(false),
|
||||
Some(libc::EPERM) => Ok(true),
|
||||
_ => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn wait_for_marker_pid(
|
||||
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
marker: &str,
|
||||
timeout_ms: u64,
|
||||
) -> anyhow::Result<i32> {
|
||||
let mut collected = Vec::new();
|
||||
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
|
||||
loop {
|
||||
let now = tokio::time::Instant::now();
|
||||
if now >= deadline {
|
||||
anyhow::bail!(
|
||||
"timed out waiting for marker {marker:?} in PTY output: {:?}",
|
||||
String::from_utf8_lossy(&collected)
|
||||
);
|
||||
}
|
||||
|
||||
let remaining = deadline.saturating_duration_since(now);
|
||||
let chunk = tokio::time::timeout(remaining, output_rx.recv())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("timeout waiting for PTY output"))??;
|
||||
collected.extend_from_slice(&chunk);
|
||||
|
||||
let text = String::from_utf8_lossy(&collected);
|
||||
if let Some(marker_idx) = text.find(marker) {
|
||||
let suffix = &text[marker_idx + marker.len()..];
|
||||
let digits: String = suffix
|
||||
.chars()
|
||||
.skip_while(|ch| !ch.is_ascii_digit())
|
||||
.take_while(char::is_ascii_digit)
|
||||
.collect();
|
||||
if !digits.is_empty() {
|
||||
return Ok(digits.parse()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn wait_for_process_exit(pid: i32, timeout_ms: u64) -> anyhow::Result<bool> {
|
||||
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
|
||||
loop {
|
||||
if !process_exists(pid)? {
|
||||
return Ok(true);
|
||||
}
|
||||
if tokio::time::Instant::now() >= deadline {
|
||||
return Ok(false);
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_python_repl_emits_output_and_exits() -> anyhow::Result<()> {
|
||||
let Some(python) = find_python() else {
|
||||
@@ -341,3 +408,39 @@ async fn pipe_terminate_aborts_detached_readers() -> anyhow::Result<()> {
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_terminate_kills_background_children_in_same_process_group() -> anyhow::Result<()> {
|
||||
let env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
let marker = "__codex_bg_pid:";
|
||||
let script = format!("sleep 1000 & bg=$!; echo {marker}$bg; wait");
|
||||
let (program, args) = shell_command(&script);
|
||||
let mut spawned = spawn_pty_process(&program, &args, Path::new("."), &env_map, &None).await?;
|
||||
|
||||
let bg_pid = match wait_for_marker_pid(&mut spawned.output_rx, marker, 2_000).await {
|
||||
Ok(pid) => pid,
|
||||
Err(err) => {
|
||||
spawned.session.terminate();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
assert!(
|
||||
process_exists(bg_pid)?,
|
||||
"expected background child pid {bg_pid} to exist before terminate"
|
||||
);
|
||||
|
||||
spawned.session.terminate();
|
||||
|
||||
let exited = wait_for_process_exit(bg_pid, 3_000).await?;
|
||||
if !exited {
|
||||
let _ = unsafe { libc::kill(bg_pid, libc::SIGKILL) };
|
||||
}
|
||||
|
||||
assert!(
|
||||
exited,
|
||||
"background child pid {bg_pid} survived PTY terminate()"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user