Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Bolin
cb6f67d284 Merge remote-tracking branch 'origin/main' into shell-process-group-timeout 2025-11-07 16:38:49 -08:00
luca
9238c58460 Kill shell tool process groups on timeout 2025-10-16 10:42:45 -05:00
4 changed files with 60 additions and 16 deletions

View File

@@ -518,6 +518,7 @@ async fn consume_truncated_output(
}
Err(_) => {
// timeout
kill_child_process_group(&mut child)?;
child.start_kill()?;
// Debatable whether `child.wait().await` should be called here.
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true)
@@ -525,6 +526,7 @@ async fn consume_truncated_output(
}
}
_ = tokio::signal::ctrl_c() => {
kill_child_process_group(&mut child)?;
child.start_kill()?;
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + SIGKILL_CODE), false)
}
@@ -621,6 +623,38 @@ fn synthetic_exit_status(code: i32) -> ExitStatus {
std::process::ExitStatus::from_raw(code as u32)
}
#[cfg(unix)]
fn kill_child_process_group(child: &mut Child) -> io::Result<()> {
use std::io::ErrorKind;
if let Some(pid) = child.id() {
let pid = pid as libc::pid_t;
let pgid = unsafe { libc::getpgid(pid) };
if pgid == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
return Ok(());
}
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
if result == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
}
}
Ok(())
}
#[cfg(not(unix))]
fn kill_child_process_group(_: &mut Child) -> io::Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -64,24 +64,31 @@ pub(crate) async fn spawn_child_async(
// any child processes that were spawned as part of a `"shell"` tool call
// to also be terminated.
// This relies on prctl(2), so it only works on Linux.
#[cfg(target_os = "linux")]
#[cfg(unix)]
unsafe {
let parent_pid = libc::getpid();
cmd.pre_exec(move || {
// This prctl call effectively requests, "deliver SIGTERM when my
// current parent dies."
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
cmd.pre_exec(|| {
if libc::setpgid(0, 0) == -1 {
return Err(std::io::Error::last_os_error());
}
// Though if there was a race condition and this pre_exec() block is
// run _after_ the parent (i.e., the Codex process) has already
// exited, then parent will be the closest configured "subreaper"
// ancestor process, or PID 1 (init). If the Codex process has exited
// already, so should the child process.
if libc::getppid() != parent_pid {
libc::raise(libc::SIGTERM);
// This relies on prctl(2), so it only works on Linux.
#[cfg(target_os = "linux")]
{
// This prctl call effectively requests, "deliver SIGTERM when my
// current parent dies."
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
return Err(std::io::Error::last_os_error());
}
// Though if there was a race condition and this pre_exec() block is
// run _after_ the parent (i.e., the Codex process) has already
// exited, then parent will be the closest configured "subreaper"
// ancestor process, or PID 1 (init). If the Codex process has exited
// already, so should the child process.
if libc::getppid() != parent_pid {
libc::raise(libc::SIGTERM);
}
}
Ok(())
});

View File

@@ -156,7 +156,8 @@ fn create_exec_command_tool() -> ToolSpec {
"yield_time_ms".to_string(),
JsonSchema::Number {
description: Some(
"How long to wait (in milliseconds) for output before yielding.".to_string(),
"Maximum time in milliseconds to wait for output after writing the input (default: 1000)."
.to_string(),
),
},
);
@@ -246,7 +247,9 @@ fn create_shell_tool() -> ToolSpec {
properties.insert(
"timeout_ms".to_string(),
JsonSchema::Number {
description: Some("The timeout for the command in milliseconds".to_string()),
description: Some(
"The timeout for the command in milliseconds (default: 1000).".to_string(),
),
},
);

View File

@@ -298,7 +298,7 @@ pub struct ShellToolCallParams {
pub command: Vec<String>,
pub workdir: Option<String>,
/// This is the maximum time in milliseconds that the command is allowed to run.
/// Maximum time in milliseconds that the command is allowed to run (defaults to 1_000 ms when omitted).
#[serde(alias = "timeout")]
pub timeout_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]