mirror of
https://github.com/openai/codex.git
synced 2026-05-28 23:10:20 +00:00
fix(linux-sandbox): preserve shell cleanup on interruption (#22729)
## Why Interrupted `shell_command` calls can race with the outer tool-dispatch cancellation path. When that happens, the runtime future may be dropped before the spawned process gets a chance to run `SIGTERM` cleanup. For bwrapd-backed Linux sandbox commands, that can leave synthetic protected-path mount bookkeeping such as `.git/.codex` registrations under `/tmp` behind after a TUI interruption. The relevant cancellation points are the outer dispatch race in [`core/src/tools/parallel.rs`](bd184ba847/codex-rs/core/src/tools/parallel.rs (L91-L132)) and the process shutdown logic in [`core/src/exec.rs`](bd184ba847/codex-rs/core/src/exec.rs (L1367-L1393)). ## What changed - Keep `shell_command` dispatch alive long enough for the runtime to finish cancellation cleanup instead of immediately returning the synthetic aborted response. - Fold shell-turn cancellation into the existing `ExecExpiration` path in [`core/src/tools/runtimes/shell.rs`](bd184ba847/codex-rs/core/src/tools/runtimes/shell.rs (L267-L274)), so cancellation and timeout behavior stay centralized. - On cancellation, send `SIGTERM` first, wait briefly for cleanup to run, then hard-kill any remaining descendants in the original process group. - Treat `ESRCH` as an already-gone process-group cleanup case in `codex-utils-pty`, which keeps best-effort teardown from surfacing a stale-process race as an error. ## Verification - `cargo test -p codex-core cancellation` - Added regression coverage for: - `shell_tool_cancellation_waits_for_runtime_cleanup` - `process_exec_tool_call_cancellation_allows_sigterm_cleanup`
This commit is contained in:
@@ -56,6 +56,7 @@ const SIGKILL_CODE: i32 = 9;
|
||||
const TIMEOUT_CODE: i32 = 64;
|
||||
const EXIT_CODE_SIGNAL_BASE: i32 = 128; // conventional shell: 128 + signal
|
||||
const EXEC_TIMEOUT_EXIT_CODE: i32 = 124; // conventional timeout exit code
|
||||
const CANCELLATION_TERMINATION_GRACE_PERIOD: Duration = Duration::from_millis(50);
|
||||
|
||||
// I/O buffer sizing
|
||||
const READ_CHUNK_SIZE: usize = 8192; // bytes per read
|
||||
@@ -1358,15 +1359,49 @@ async fn consume_output(
|
||||
(exit_status, false)
|
||||
}
|
||||
outcome = &mut expiration_wait => {
|
||||
kill_child_process_group(&mut child)?;
|
||||
child.start_kill()?;
|
||||
let timed_out = matches!(outcome, Some(ExecExpirationOutcome::TimedOut));
|
||||
let exit_status = if timed_out {
|
||||
synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE)
|
||||
} else {
|
||||
synthetic_exit_status_for_code(/*code*/ 1)
|
||||
};
|
||||
(exit_status, timed_out)
|
||||
match outcome {
|
||||
Some(ExecExpirationOutcome::TimedOut) => {
|
||||
kill_child_process_group(&mut child)?;
|
||||
child.start_kill()?;
|
||||
(
|
||||
synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE),
|
||||
true,
|
||||
)
|
||||
}
|
||||
Some(ExecExpirationOutcome::Cancelled) => {
|
||||
// Let TERM-aware processes run cleanup briefly, then kill any
|
||||
// remaining members of the original process group.
|
||||
let process_group_id = child.id();
|
||||
let should_escalate = if let Some(process_group_id) = process_group_id {
|
||||
codex_utils_pty::process_group::terminate_process_group(process_group_id)?
|
||||
} else {
|
||||
false
|
||||
};
|
||||
match tokio::time::timeout(
|
||||
CANCELLATION_TERMINATION_GRACE_PERIOD,
|
||||
child.wait(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(status) => {
|
||||
status?;
|
||||
if should_escalate
|
||||
&& let Some(process_group_id) = process_group_id
|
||||
{
|
||||
codex_utils_pty::process_group::kill_process_group(
|
||||
process_group_id,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
kill_child_process_group(&mut child)?;
|
||||
child.start_kill()?;
|
||||
}
|
||||
}
|
||||
(synthetic_exit_status_for_code(/*code*/ 1), false)
|
||||
}
|
||||
None => unreachable!("expiration wait only resolves while expiration is active"),
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
kill_child_process_group(&mut child)?;
|
||||
|
||||
@@ -1128,6 +1128,115 @@ async fn process_exec_tool_call_respects_cancellation_token() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn process_exec_tool_call_cancellation_allows_sigterm_cleanup() -> Result<()> {
|
||||
let temp_dir = tempfile::TempDir::new()?;
|
||||
let ready_marker = temp_dir.path().join("ready");
|
||||
let cleanup_marker = temp_dir.path().join("cleanup");
|
||||
let descendant_pid_marker = temp_dir.path().join("descendant-pid");
|
||||
// The parent handles TERM and records cleanup, while a TERM-ignoring child
|
||||
// proves cancellation still escalates any survivors in the process group.
|
||||
let command = vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
r#"(trap '' TERM; sleep 60) &
|
||||
printf '%s' "$!" > "$DESCENDANT_PID_MARKER"
|
||||
trap 'printf cleaned > "$CLEANUP_MARKER"; exit 0' TERM
|
||||
printf ready > "$READY_MARKER"
|
||||
while :; do sleep 1; done"#
|
||||
.to_string(),
|
||||
];
|
||||
let cwd = codex_utils_absolute_path::AbsolutePathBuf::current_dir()?;
|
||||
let mut env: HashMap<String, String> = std::env::vars().collect();
|
||||
env.insert(
|
||||
"READY_MARKER".to_string(),
|
||||
ready_marker.to_string_lossy().into_owned(),
|
||||
);
|
||||
env.insert(
|
||||
"CLEANUP_MARKER".to_string(),
|
||||
cleanup_marker.to_string_lossy().into_owned(),
|
||||
);
|
||||
env.insert(
|
||||
"DESCENDANT_PID_MARKER".to_string(),
|
||||
descendant_pid_marker.to_string_lossy().into_owned(),
|
||||
);
|
||||
let cancel_token = CancellationToken::new();
|
||||
let cancel_tx = cancel_token.clone();
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..50 {
|
||||
if ready_marker.exists() {
|
||||
cancel_tx.cancel();
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
cancel_tx.cancel();
|
||||
});
|
||||
let params = ExecParams {
|
||||
command,
|
||||
cwd: cwd.clone(),
|
||||
expiration: ExecExpiration::DefaultTimeout.with_cancellation(cancel_token),
|
||||
capture_policy: ExecCapturePolicy::ShellTool,
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: SandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled,
|
||||
windows_sandbox_private_desktop: false,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
};
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_secs(5),
|
||||
process_exec_tool_call(
|
||||
params,
|
||||
&PermissionProfile::Disabled,
|
||||
&cwd,
|
||||
&None,
|
||||
/*use_legacy_landlock*/ false,
|
||||
/*stdout_stream*/ None,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("cancellation should stop the process promptly");
|
||||
let output = result.expect("cancellation should return a non-timeout exec result");
|
||||
assert!(!output.timed_out);
|
||||
assert_eq!(
|
||||
std::fs::read_to_string(cleanup_marker)?,
|
||||
"cleaned",
|
||||
"SIGTERM cleanup trap should run before cancellation falls back to a hard kill"
|
||||
);
|
||||
let descendant_pid = std::fs::read_to_string(descendant_pid_marker)?
|
||||
.parse::<i32>()
|
||||
.map_err(|error| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("failed to parse descendant pid: {error}"),
|
||||
)
|
||||
})?;
|
||||
let mut killed = false;
|
||||
for _ in 0..20 {
|
||||
if unsafe { libc::kill(descendant_pid, 0) } == -1
|
||||
&& let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error()
|
||||
{
|
||||
killed = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
if !killed {
|
||||
unsafe {
|
||||
libc::kill(descendant_pid, libc::SIGKILL);
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
killed,
|
||||
"TERM-ignoring descendant process with pid {descendant_pid} is still alive"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn long_running_command() -> Vec<String> {
|
||||
vec![
|
||||
|
||||
@@ -10131,6 +10131,76 @@ async fn rejects_escalated_permissions_when_policy_not_on_request() {
|
||||
ExecApprovalRequirement::Skip { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn shell_tool_cancellation_waits_for_runtime_cleanup() -> anyhow::Result<()> {
|
||||
let session = make_session_with_config(|config| {
|
||||
let cwd = config.cwd.clone();
|
||||
config
|
||||
.permissions
|
||||
.set_legacy_sandbox_policy(SandboxPolicy::DangerFullAccess, cwd.as_path())
|
||||
.expect("test setup should allow sandbox policy");
|
||||
})
|
||||
.await?;
|
||||
let turn_context = session.new_default_turn().await;
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let temp_dir = tempfile::TempDir::new()?;
|
||||
let ready_marker = temp_dir.path().join("ready");
|
||||
let cleanup_marker = temp_dir.path().join("cleanup");
|
||||
// Interrupt after the shell starts, then verify dispatch waits for its TERM cleanup trap.
|
||||
let command = format!(
|
||||
r#"trap 'printf cleaned > "{}"; exit 0' TERM
|
||||
printf ready > "{}"
|
||||
while :; do sleep 1; done"#,
|
||||
cleanup_marker.display(),
|
||||
ready_marker.display(),
|
||||
);
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "shell_command".to_string(),
|
||||
namespace: None,
|
||||
arguments: serde_json::json!({
|
||||
"command": command,
|
||||
"timeout_ms": 60_000,
|
||||
})
|
||||
.to_string(),
|
||||
call_id: "shell-cleanup-call".to_string(),
|
||||
};
|
||||
let call = ToolRouter::build_tool_call(item)?
|
||||
.expect("shell command response item should build a tool call");
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let cancellation_tx = cancellation_token.clone();
|
||||
let handle = tokio::spawn(
|
||||
test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context))
|
||||
.handle_tool_call(call, cancellation_token),
|
||||
);
|
||||
|
||||
let mut ready = false;
|
||||
for _ in 0..50 {
|
||||
if ready_marker.exists() {
|
||||
ready = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
if !ready {
|
||||
cancellation_tx.cancel();
|
||||
let _ = timeout(Duration::from_secs(5), handle).await;
|
||||
anyhow::bail!("shell command should reach the ready marker");
|
||||
}
|
||||
|
||||
cancellation_tx.cancel();
|
||||
timeout(Duration::from_secs(5), handle)
|
||||
.await
|
||||
.expect("cancelled shell tool should finish promptly")
|
||||
.expect("shell tool task should join")
|
||||
.expect("cancelled shell tool should return a response item");
|
||||
assert_eq!(std::fs::read_to_string(cleanup_marker)?, "cleaned");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() {
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
|
||||
@@ -2,6 +2,7 @@ use codex_features::Feature;
|
||||
use codex_protocol::models::ShellCommandToolCallParams;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec_policy::ExecApprovalRequest;
|
||||
@@ -44,6 +45,7 @@ fn shell_command_payload_command(payload: &ToolPayload) -> Option<String> {
|
||||
struct RunExecLikeArgs {
|
||||
tool_name: ToolName,
|
||||
exec_params: ExecParams,
|
||||
cancellation_token: CancellationToken,
|
||||
hook_command: String,
|
||||
shell_type: Option<ShellType>,
|
||||
additional_permissions: Option<AdditionalPermissionProfile>,
|
||||
@@ -59,6 +61,7 @@ async fn run_exec_like(args: RunExecLikeArgs) -> Result<FunctionToolOutput, Func
|
||||
let RunExecLikeArgs {
|
||||
tool_name,
|
||||
exec_params,
|
||||
cancellation_token,
|
||||
hook_command,
|
||||
shell_type,
|
||||
additional_permissions,
|
||||
@@ -183,6 +186,7 @@ async fn run_exec_like(args: RunExecLikeArgs) -> Result<FunctionToolOutput, Func
|
||||
hook_command,
|
||||
cwd: exec_params.cwd.clone(),
|
||||
timeout_ms: exec_params.expiration.timeout_ms(),
|
||||
cancellation_token,
|
||||
env: exec_params.env.clone(),
|
||||
explicit_env_overrides,
|
||||
network: exec_params.network.clone(),
|
||||
|
||||
@@ -148,6 +148,7 @@ impl ToolExecutor<ToolInvocation> for ShellCommandHandler {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
cancellation_token,
|
||||
tracker,
|
||||
call_id,
|
||||
payload,
|
||||
@@ -185,6 +186,7 @@ impl ToolExecutor<ToolInvocation> for ShellCommandHandler {
|
||||
run_exec_like(RunExecLikeArgs {
|
||||
tool_name,
|
||||
exec_params,
|
||||
cancellation_token,
|
||||
hook_command: params.command,
|
||||
shell_type,
|
||||
additional_permissions: params.additional_permissions.clone(),
|
||||
@@ -205,6 +207,10 @@ impl CoreToolRuntime for ShellCommandHandler {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
fn waits_for_runtime_cancellation(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
||||
shell_command_payload_command(&invocation.payload).map(|command| PreToolUsePayload {
|
||||
tool_name: HookToolName::bash(),
|
||||
|
||||
@@ -92,6 +92,7 @@ impl ToolCallRuntime {
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let invocation_cancellation_token = cancellation_token.clone();
|
||||
let wait_for_runtime_cancellation = self.router.tool_waits_for_runtime_cancellation(&call);
|
||||
let started = Instant::now();
|
||||
let abort_session = Arc::clone(&session);
|
||||
let abort_source = source.clone();
|
||||
@@ -140,23 +141,35 @@ impl ToolCallRuntime {
|
||||
} else {
|
||||
let secs = started.elapsed().as_secs_f32().max(0.1);
|
||||
abort_dispatch_span.record("aborted", true);
|
||||
handle.abort();
|
||||
match handle.await {
|
||||
Ok(result) => result,
|
||||
Err(err) if err.is_cancelled() => {
|
||||
let response = Self::aborted_response(&call, secs);
|
||||
notify_tool_aborted(
|
||||
abort_session.as_ref(),
|
||||
abort_turn.as_ref(),
|
||||
call.call_id.as_str(),
|
||||
&call.tool_name,
|
||||
abort_source,
|
||||
)
|
||||
.await;
|
||||
Ok(response)
|
||||
if wait_for_runtime_cancellation {
|
||||
if terminal_outcome_reached.swap(true, Ordering::AcqRel) {
|
||||
return handle.await.map_err(Self::tool_task_join_error)?;
|
||||
}
|
||||
// The abort owns the terminal outcome; await only so
|
||||
// the runtime can finish process teardown.
|
||||
match handle.await {
|
||||
Ok(_) => {}
|
||||
Err(err) if err.is_cancelled() => {}
|
||||
Err(err) => return Err(Self::tool_task_join_error(err)),
|
||||
}
|
||||
} else {
|
||||
handle.abort();
|
||||
match handle.await {
|
||||
Ok(result) => return result,
|
||||
Err(err) if err.is_cancelled() => {}
|
||||
Err(err) => return Err(Self::tool_task_join_error(err)),
|
||||
}
|
||||
Err(err) => Err(Self::tool_task_join_error(err)),
|
||||
}
|
||||
let response = Self::aborted_response(&call, secs);
|
||||
notify_tool_aborted(
|
||||
abort_session.as_ref(),
|
||||
abort_turn.as_ref(),
|
||||
call.call_id.as_str(),
|
||||
&call.tool_name,
|
||||
abort_source,
|
||||
)
|
||||
.await;
|
||||
Ok(response)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -274,6 +287,85 @@ mod tests {
|
||||
|
||||
impl CoreToolRuntime for ImmediateHandler {}
|
||||
|
||||
struct CancellationCleanupHandler {
|
||||
tool_name: codex_tools::ToolName,
|
||||
started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
|
||||
cleanup_started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
|
||||
allow_cleanup: Arc<Notify>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ToolExecutor<ToolInvocation> for CancellationCleanupHandler {
|
||||
fn tool_name(&self) -> codex_tools::ToolName {
|
||||
self.tool_name.clone()
|
||||
}
|
||||
|
||||
fn spec(&self) -> codex_tools::ToolSpec {
|
||||
codex_tools::ToolSpec::Function(codex_tools::ResponsesApiTool {
|
||||
name: self.tool_name.name.clone(),
|
||||
description: "Cancellation cleanup test tool.".to_string(),
|
||||
strict: false,
|
||||
defer_loading: None,
|
||||
parameters: codex_tools::JsonSchema::default(),
|
||||
output_schema: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
|
||||
let started = self
|
||||
.started
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take();
|
||||
if let Some(started) = started {
|
||||
let _ = started.send(());
|
||||
}
|
||||
invocation.cancellation_token.cancelled().await;
|
||||
let cleanup_started = self
|
||||
.cleanup_started
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take();
|
||||
if let Some(cleanup_started) = cleanup_started {
|
||||
let _ = cleanup_started.send(());
|
||||
}
|
||||
self.allow_cleanup.notified().await;
|
||||
Ok(Box::new(FunctionToolOutput::from_text(
|
||||
"cleanup complete".to_string(),
|
||||
Some(false),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl CoreToolRuntime for CancellationCleanupHandler {
|
||||
fn waits_for_runtime_cancellation(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
struct FinishRecorder {
|
||||
records: Arc<std::sync::Mutex<Vec<ToolCallOutcome>>>,
|
||||
}
|
||||
|
||||
impl codex_extension_api::ToolLifecycleContributor for FinishRecorder {
|
||||
fn on_tool_finish<'a>(
|
||||
&'a self,
|
||||
input: codex_extension_api::ToolFinishInput<'a>,
|
||||
) -> codex_extension_api::ToolLifecycleFuture<'a> {
|
||||
let records = Arc::clone(&self.records);
|
||||
let outcome = input.outcome;
|
||||
Box::pin(async move {
|
||||
records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(outcome);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct BlockingFinishContributor {
|
||||
records: Arc<std::sync::Mutex<Vec<ToolCallOutcome>>>,
|
||||
finish_started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
|
||||
@@ -375,4 +467,75 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_waiting_for_runtime_cleanup_emits_only_aborted_lifecycle()
|
||||
-> anyhow::Result<()> {
|
||||
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
|
||||
let records = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let mut builder =
|
||||
codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
|
||||
builder.tool_lifecycle_contributor(Arc::new(FinishRecorder {
|
||||
records: Arc::clone(&records),
|
||||
}));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let tool_name = codex_tools::ToolName::plain("cleanup_tool");
|
||||
let (started_tx, started_rx) = oneshot::channel();
|
||||
let (cleanup_started_tx, cleanup_started_rx) = oneshot::channel();
|
||||
let allow_cleanup = Arc::new(Notify::new());
|
||||
let handler = Arc::new(CancellationCleanupHandler {
|
||||
tool_name: tool_name.clone(),
|
||||
started: std::sync::Mutex::new(Some(started_tx)),
|
||||
cleanup_started: std::sync::Mutex::new(Some(cleanup_started_tx)),
|
||||
allow_cleanup: Arc::clone(&allow_cleanup),
|
||||
}) as Arc<dyn CoreToolRuntime>;
|
||||
let router = Arc::new(ToolRouter::from_parts(
|
||||
ToolRegistry::from_tools([handler]),
|
||||
Vec::new(),
|
||||
));
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let runtime = ToolCallRuntime::new(router, session, turn_context, tracker);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let call = ToolCall {
|
||||
tool_name,
|
||||
call_id: "call-1".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let response_task =
|
||||
tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone()));
|
||||
started_rx.await.expect("handler should start");
|
||||
cancellation_token.cancel();
|
||||
cleanup_started_rx
|
||||
.await
|
||||
.expect("handler should start cleanup");
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
allow_cleanup.notify_one();
|
||||
|
||||
let response = tokio::time::timeout(Duration::from_secs(1), response_task)
|
||||
.await
|
||||
.expect("timed out waiting for tool response")
|
||||
.expect("tool response task should join")?;
|
||||
let ResponseInputItem::FunctionCallOutput { output, .. } = response else {
|
||||
anyhow::bail!("cancelled tool should return function output");
|
||||
};
|
||||
let FunctionCallOutputBody::Text(text) = output.body else {
|
||||
anyhow::bail!("cancelled tool output should be text");
|
||||
};
|
||||
assert!(text.contains("aborted by user"));
|
||||
|
||||
let actual = records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(vec![ToolCallOutcome::Aborted], actual);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,12 @@ pub(crate) trait CoreToolRuntime: ToolExecutor<ToolInvocation> {
|
||||
)
|
||||
}
|
||||
|
||||
/// Whether cancellation should let the handler finish teardown before the
|
||||
/// host returns an aborted tool response.
|
||||
fn waits_for_runtime_cancellation(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn telemetry_tags<'a>(
|
||||
&'a self,
|
||||
_invocation: &'a ToolInvocation,
|
||||
@@ -284,6 +290,10 @@ impl CoreToolRuntime for ExposureOverride {
|
||||
self.handler.matches_kind(payload)
|
||||
}
|
||||
|
||||
fn waits_for_runtime_cancellation(&self) -> bool {
|
||||
self.handler.waits_for_runtime_cancellation()
|
||||
}
|
||||
|
||||
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
||||
self.handler.pre_tool_use_payload(invocation)
|
||||
}
|
||||
@@ -381,6 +391,11 @@ impl ToolRegistry {
|
||||
Some(tool.supports_parallel_tool_calls())
|
||||
}
|
||||
|
||||
pub(crate) fn waits_for_runtime_cancellation(&self, name: &ToolName) -> Option<bool> {
|
||||
let tool = self.tool(name)?;
|
||||
Some(tool.waits_for_runtime_cancellation())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn dispatch_any(
|
||||
&self,
|
||||
@@ -497,10 +512,12 @@ impl ToolRegistry {
|
||||
PreToolUseHookResult::Blocked(message) => {
|
||||
let err = FunctionCallError::RespondToModel(message);
|
||||
dispatch_trace.record_failed(&err);
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(&invocation, ToolCallOutcome::Blocked).await;
|
||||
notify_tool_finish_if_unclaimed(
|
||||
&invocation,
|
||||
terminal_outcome_reached.as_deref(),
|
||||
ToolCallOutcome::Blocked,
|
||||
)
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
PreToolUseHookResult::Continue {
|
||||
@@ -511,11 +528,9 @@ impl ToolRegistry {
|
||||
}
|
||||
Err(err) => {
|
||||
dispatch_trace.record_failed(&err);
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(
|
||||
notify_tool_finish_if_unclaimed(
|
||||
&invocation,
|
||||
terminal_outcome_reached.as_deref(),
|
||||
ToolCallOutcome::Failed {
|
||||
handler_executed: false,
|
||||
},
|
||||
@@ -638,18 +653,21 @@ impl ToolRegistry {
|
||||
handler_executed: true,
|
||||
},
|
||||
};
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(&invocation, lifecycle_outcome).await;
|
||||
let finished = notify_tool_finish_if_unclaimed(
|
||||
&invocation,
|
||||
terminal_outcome_reached.as_deref(),
|
||||
lifecycle_outcome,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = invocation
|
||||
.session
|
||||
.goal_runtime_apply(GoalRuntimeEvent::ToolCompleted {
|
||||
turn_context: invocation.turn.as_ref(),
|
||||
tool_name: tool_name.name.as_str(),
|
||||
})
|
||||
.await
|
||||
if finished
|
||||
&& let Err(err) = invocation
|
||||
.session
|
||||
.goal_runtime_apply(GoalRuntimeEvent::ToolCompleted {
|
||||
turn_context: invocation.turn.as_ref(),
|
||||
tool_name: tool_name.name.as_str(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to account thread goal progress after tool call: {err}");
|
||||
}
|
||||
@@ -676,6 +694,19 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_tool_finish_if_unclaimed(
|
||||
invocation: &ToolInvocation,
|
||||
terminal_outcome_reached: Option<&AtomicBool>,
|
||||
outcome: ToolCallOutcome,
|
||||
) -> bool {
|
||||
if terminal_outcome_reached.is_some_and(|reached| reached.swap(true, Ordering::AcqRel)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
notify_tool_finish(invocation, outcome).await;
|
||||
true
|
||||
}
|
||||
|
||||
async fn handle_any_tool(
|
||||
tool: &dyn CoreToolRuntime,
|
||||
invocation: ToolInvocation,
|
||||
|
||||
@@ -86,6 +86,12 @@ impl ToolRouter {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn tool_waits_for_runtime_cancellation(&self, call: &ToolCall) -> bool {
|
||||
self.registry
|
||||
.waits_for_runtime_cancellation(&call.tool_name)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
pub fn build_tool_call(item: ResponseItem) -> Result<Option<ToolCall>, FunctionCallError> {
|
||||
match item {
|
||||
|
||||
@@ -44,6 +44,7 @@ use codex_shell_command::powershell::prefix_powershell_script_with_utf8;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use futures::future::BoxFuture;
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ShellRequest {
|
||||
@@ -52,6 +53,7 @@ pub struct ShellRequest {
|
||||
pub hook_command: String,
|
||||
pub cwd: AbsolutePathBuf,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub cancellation_token: CancellationToken,
|
||||
pub env: HashMap<String, String>,
|
||||
pub explicit_env_overrides: HashMap<String, String>,
|
||||
pub network: Option<NetworkProxy>,
|
||||
@@ -265,6 +267,7 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
let command =
|
||||
build_sandbox_command(&command, &req.cwd, &env, req.additional_permissions.clone())?;
|
||||
let mut expiration: crate::exec::ExecExpiration = req.timeout_ms.into();
|
||||
expiration = expiration.with_cancellation(req.cancellation_token.clone());
|
||||
if let Some(cancellation) = attempt.network_denial_cancellation_token.clone() {
|
||||
expiration = expiration.with_cancellation(cancellation);
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ pub fn kill_process_group_by_pid(pid: u32) -> io::Result<()> {
|
||||
let pgid = unsafe { libc::getpgid(pid) };
|
||||
if pgid == -1 {
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
if err.kind() != ErrorKind::NotFound && err.raw_os_error() != Some(libc::ESRCH) {
|
||||
return Err(err);
|
||||
}
|
||||
return Ok(());
|
||||
@@ -103,7 +103,7 @@ pub fn kill_process_group_by_pid(pid: u32) -> io::Result<()> {
|
||||
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
|
||||
if result == -1 {
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
if err.kind() != ErrorKind::NotFound && err.raw_os_error() != Some(libc::ESRCH) {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ pub fn terminate_process_group(process_group_id: u32) -> io::Result<bool> {
|
||||
let result = unsafe { libc::killpg(pgid, libc::SIGTERM) };
|
||||
if result == -1 {
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() == ErrorKind::NotFound {
|
||||
if err.kind() == ErrorKind::NotFound || err.raw_os_error() == Some(libc::ESRCH) {
|
||||
return Ok(false);
|
||||
}
|
||||
return Err(err);
|
||||
@@ -153,7 +153,7 @@ pub fn kill_process_group(process_group_id: u32) -> io::Result<()> {
|
||||
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
|
||||
if result == -1 {
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
if err.kind() != ErrorKind::NotFound && err.raw_os_error() != Some(libc::ESRCH) {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user