feat: use process group to kill the PTY (#12688)

Use the process group kill logic to kill the PTY
This commit is contained in:
jif-oai
2026-02-24 16:55:23 +00:00
committed by GitHub
parent 97d0068658
commit 9a8adbf6e5
2 changed files with 130 additions and 1 deletions

View File

@@ -35,10 +35,27 @@ pub fn conpty_supported() -> bool {
struct PtyChildTerminator {
killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
#[cfg(unix)]
process_group_id: Option<u32>,
}
impl ChildTerminator for PtyChildTerminator {
fn kill(&mut self) -> std::io::Result<()> {
#[cfg(unix)]
if let Some(process_group_id) = self.process_group_id {
// Match the pipe backend's hard-kill behavior so descendant
// processes from interactive shells/REPLs do not survive shutdown.
// Also try the direct child killer in case the cached PGID is stale.
let process_group_kill_result =
crate::process_group::kill_process_group(process_group_id);
let child_kill_result = self.killer.kill();
return match child_kill_result {
Ok(()) => Ok(()),
Err(err) if err.kind() == ErrorKind::NotFound => process_group_kill_result,
Err(err) => process_group_kill_result.or(Err(err)),
};
}
self.killer.kill()
}
}
@@ -86,6 +103,11 @@ pub async fn spawn_process(
}
let mut child = pair.slave.spawn_command(command_builder)?;
#[cfg(unix)]
// portable-pty establishes the spawned PTY child as a new session leader on
// Unix, so PID == PGID and we can reuse the pipe backend's process-group
// hard-kill semantics for descendants.
let process_group_id = child.process_id();
let killer = child.clone_killer();
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
@@ -156,7 +178,11 @@ pub async fn spawn_process(
writer_tx,
output_tx,
initial_output_rx,
Box::new(PtyChildTerminator { killer }),
Box::new(PtyChildTerminator {
killer,
#[cfg(unix)]
process_group_id,
}),
reader_handle,
Vec::new(),
writer_handle,

View File

@@ -144,6 +144,73 @@ async fn wait_for_python_repl_ready(
);
}
#[cfg(unix)]
fn process_exists(pid: i32) -> anyhow::Result<bool> {
let result = unsafe { libc::kill(pid, 0) };
if result == 0 {
return Ok(true);
}
let err = std::io::Error::last_os_error();
match err.raw_os_error() {
Some(libc::ESRCH) => Ok(false),
Some(libc::EPERM) => Ok(true),
_ => Err(err.into()),
}
}
#[cfg(unix)]
async fn wait_for_marker_pid(
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
marker: &str,
timeout_ms: u64,
) -> anyhow::Result<i32> {
let mut collected = Vec::new();
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
anyhow::bail!(
"timed out waiting for marker {marker:?} in PTY output: {:?}",
String::from_utf8_lossy(&collected)
);
}
let remaining = deadline.saturating_duration_since(now);
let chunk = tokio::time::timeout(remaining, output_rx.recv())
.await
.map_err(|_| anyhow::anyhow!("timeout waiting for PTY output"))??;
collected.extend_from_slice(&chunk);
let text = String::from_utf8_lossy(&collected);
if let Some(marker_idx) = text.find(marker) {
let suffix = &text[marker_idx + marker.len()..];
let digits: String = suffix
.chars()
.skip_while(|ch| !ch.is_ascii_digit())
.take_while(char::is_ascii_digit)
.collect();
if !digits.is_empty() {
return Ok(digits.parse()?);
}
}
}
}
#[cfg(unix)]
async fn wait_for_process_exit(pid: i32, timeout_ms: u64) -> anyhow::Result<bool> {
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
loop {
if !process_exists(pid)? {
return Ok(true);
}
if tokio::time::Instant::now() >= deadline {
return Ok(false);
}
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_python_repl_emits_output_and_exits() -> anyhow::Result<()> {
let Some(python) = find_python() else {
@@ -341,3 +408,39 @@ async fn pipe_terminate_aborts_detached_readers() -> anyhow::Result<()> {
),
}
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_terminate_kills_background_children_in_same_process_group() -> anyhow::Result<()> {
let env_map: HashMap<String, String> = std::env::vars().collect();
let marker = "__codex_bg_pid:";
let script = format!("sleep 1000 & bg=$!; echo {marker}$bg; wait");
let (program, args) = shell_command(&script);
let mut spawned = spawn_pty_process(&program, &args, Path::new("."), &env_map, &None).await?;
let bg_pid = match wait_for_marker_pid(&mut spawned.output_rx, marker, 2_000).await {
Ok(pid) => pid,
Err(err) => {
spawned.session.terminate();
return Err(err);
}
};
assert!(
process_exists(bg_pid)?,
"expected background child pid {bg_pid} to exist before terminate"
);
spawned.session.terminate();
let exited = wait_for_process_exit(bg_pid, 3_000).await?;
if !exited {
let _ = unsafe { libc::kill(bg_pid, libc::SIGKILL) };
}
assert!(
exited,
"background child pid {bg_pid} survived PTY terminate()"
);
Ok(())
}