This commit is contained in:
Ahmed Ibrahim
2026-01-11 15:22:00 -08:00
parent a8d69094ad
commit fcde109c7f
11 changed files with 119 additions and 213 deletions

View File

@@ -131,6 +131,7 @@ use codex_core::config::edit::ConfigEditsBuilder;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::default_client::get_codex_user_agent;
use codex_core::error::CodexErr;
use codex_core::exec::ExecExpiration;
use codex_core::exec::ExecParams;
use codex_core::exec_env::create_env;
use codex_core::features::Feature;
@@ -177,6 +178,7 @@ use std::time::Duration;
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use toml::Value as TomlValue;
use tracing::error;
use tracing::info;
@@ -1204,10 +1206,11 @@ impl CodexMessageProcessor {
let timeout_ms = params
.timeout_ms
.and_then(|timeout_ms| u64::try_from(timeout_ms).ok());
let cancellation_token = CancellationToken::new();
let exec_params = ExecParams {
command: params.command,
cwd,
expiration: timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(timeout_ms, cancellation_token),
env,
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,

View File

@@ -3984,6 +3984,7 @@ mod tests {
#[tokio::test]
async fn rejects_escalated_permissions_when_policy_not_on_request() {
use crate::exec::ExecExpiration;
use crate::exec::ExecParams;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
@@ -4014,7 +4015,7 @@ mod tests {
]
},
cwd: turn_context.cwd.clone(),
expiration: timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(Some(timeout_ms), CancellationToken::new()),
env: HashMap::new(),
sandbox_permissions,
justification: Some("test".to_string()),
@@ -4025,7 +4026,7 @@ mod tests {
sandbox_permissions: SandboxPermissions::UseDefault,
command: params.command.clone(),
cwd: params.cwd.clone(),
expiration: timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(Some(timeout_ms), CancellationToken::new()),
env: HashMap::new(),
justification: params.justification.clone(),
arg0: None,

View File

@@ -64,94 +64,68 @@ pub struct ExecParams {
/// Mechanism to terminate an exec invocation before it finishes naturally.
#[derive(Debug)]
pub enum ExecExpiration {
Timeout(Duration),
DefaultTimeout,
Cancellation(CancellationToken),
TimeoutOrCancellation {
timeout: Duration,
cancellation: CancellationToken,
},
DefaultTimeoutOrCancellation(CancellationToken),
pub struct ExecExpiration {
pub timeout: TimeoutSpec,
pub cancellation: CancellationToken,
}
impl From<Option<u64>> for ExecExpiration {
fn from(timeout_ms: Option<u64>) -> Self {
timeout_ms.map_or(ExecExpiration::DefaultTimeout, |timeout_ms| {
ExecExpiration::Timeout(Duration::from_millis(timeout_ms))
})
}
}
impl From<u64> for ExecExpiration {
fn from(timeout_ms: u64) -> Self {
ExecExpiration::Timeout(Duration::from_millis(timeout_ms))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TimeoutSpec {
Default,
Explicit(Duration),
None,
}
impl ExecExpiration {
pub fn new(timeout: TimeoutSpec, cancellation: CancellationToken) -> Self {
Self {
timeout,
cancellation,
}
}
pub fn default(cancellation: CancellationToken) -> Self {
Self::new(TimeoutSpec::Default, cancellation)
}
pub fn from_timeout(timeout: Duration, cancellation: CancellationToken) -> Self {
Self::new(TimeoutSpec::Explicit(timeout), cancellation)
}
pub fn from_timeout_ms(timeout_ms: Option<u64>, cancellation: CancellationToken) -> Self {
let timeout = timeout_ms.map_or(TimeoutSpec::Default, |timeout_ms| {
TimeoutSpec::Explicit(Duration::from_millis(timeout_ms))
});
Self::new(timeout, cancellation)
}
pub fn cancel_only(cancellation: CancellationToken) -> Self {
Self::new(TimeoutSpec::None, cancellation)
}
async fn wait(self) -> ExecExpirationOutcome {
match self {
ExecExpiration::Timeout(duration) => {
tokio::time::sleep(duration).await;
ExecExpirationOutcome::TimedOut
}
ExecExpiration::DefaultTimeout => {
tokio::time::sleep(Duration::from_millis(DEFAULT_EXEC_COMMAND_TIMEOUT_MS)).await;
ExecExpirationOutcome::TimedOut
}
ExecExpiration::Cancellation(cancel) => {
cancel.cancelled().await;
match self.timeout {
TimeoutSpec::Explicit(duration) => tokio::select! {
_ = tokio::time::sleep(duration) => ExecExpirationOutcome::TimedOut,
_ = self.cancellation.cancelled() => ExecExpirationOutcome::Cancelled,
},
TimeoutSpec::Default => tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(DEFAULT_EXEC_COMMAND_TIMEOUT_MS)) => ExecExpirationOutcome::TimedOut,
_ = self.cancellation.cancelled() => ExecExpirationOutcome::Cancelled,
},
TimeoutSpec::None => {
self.cancellation.cancelled().await;
ExecExpirationOutcome::Cancelled
}
ExecExpiration::TimeoutOrCancellation {
timeout,
cancellation,
} => tokio::select! {
_ = tokio::time::sleep(timeout) => ExecExpirationOutcome::TimedOut,
_ = cancellation.cancelled() => ExecExpirationOutcome::Cancelled,
},
ExecExpiration::DefaultTimeoutOrCancellation(cancellation) => tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(DEFAULT_EXEC_COMMAND_TIMEOUT_MS)) => ExecExpirationOutcome::TimedOut,
_ = cancellation.cancelled() => ExecExpirationOutcome::Cancelled,
},
}
}
/// If ExecExpiration is a timeout, returns the timeout in milliseconds.
pub(crate) fn timeout_ms(&self) -> Option<u64> {
match self {
ExecExpiration::Timeout(duration) => Some(duration.as_millis() as u64),
ExecExpiration::DefaultTimeout => Some(DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
ExecExpiration::Cancellation(_) => None,
ExecExpiration::TimeoutOrCancellation { timeout, .. } => {
Some(timeout.as_millis() as u64)
}
ExecExpiration::DefaultTimeoutOrCancellation(_) => {
Some(DEFAULT_EXEC_COMMAND_TIMEOUT_MS)
}
}
}
pub(crate) fn with_cancellation(self, cancellation: CancellationToken) -> Self {
match self {
ExecExpiration::Timeout(timeout) => ExecExpiration::TimeoutOrCancellation {
timeout,
cancellation,
},
ExecExpiration::DefaultTimeout => {
ExecExpiration::DefaultTimeoutOrCancellation(cancellation)
}
ExecExpiration::Cancellation(_) => ExecExpiration::Cancellation(cancellation),
ExecExpiration::TimeoutOrCancellation { timeout, .. } => {
ExecExpiration::TimeoutOrCancellation {
timeout,
cancellation,
}
}
ExecExpiration::DefaultTimeoutOrCancellation(_) => {
ExecExpiration::DefaultTimeoutOrCancellation(cancellation)
}
match self.timeout {
TimeoutSpec::Explicit(duration) => Some(duration.as_millis() as u64),
TimeoutSpec::Default => Some(DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
TimeoutSpec::None => None,
}
}
}
@@ -289,8 +263,7 @@ async fn exec_windows_sandbox(
expiration,
..
} = params;
// TODO(iceweasel-oai): run_windows_sandbox_capture should support all
// variants of ExecExpiration, not just timeout.
// TODO(iceweasel-oai): run_windows_sandbox_capture should respect cancellation tokens.
let timeout_ms = expiration.timeout_ms();
let policy_str = serde_json::to_string(sandbox_policy).map_err(|err| {
@@ -950,7 +923,7 @@ mod tests {
let params = ExecParams {
command,
cwd: std::env::current_dir()?,
expiration: 500.into(),
expiration: ExecExpiration::from_timeout_ms(Some(500), CancellationToken::new()),
env,
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,
@@ -995,7 +968,7 @@ mod tests {
let params = ExecParams {
command,
cwd: cwd.clone(),
expiration: ExecExpiration::Cancellation(cancel_token),
expiration: ExecExpiration::cancel_only(cancel_token),
env,
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,

View File

@@ -2,14 +2,13 @@ use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use codex_async_utils::CancelErr;
use codex_async_utils::OrCancelExt;
use codex_protocol::user_input::UserInput;
use tokio_util::sync::CancellationToken;
use tracing::error;
use uuid::Uuid;
use crate::codex::TurnContext;
use crate::exec::ExecExpiration;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::StdoutStream;
@@ -103,9 +102,12 @@ impl SessionTask for UserShellCommandTask {
command: command.clone(),
cwd: cwd.clone(),
env: create_env(&turn_context.shell_environment_policy),
// TODO(zhao-oai): Now that we have ExecExpiration::Cancellation, we
// should use that instead of an "arbitrarily large" timeout here.
expiration: USER_SHELL_TIMEOUT_MS.into(),
// TODO(zhao-oai): consider whether the user shell should use a shorter
// default timeout now that cancellation is wired through ExecExpiration.
expiration: ExecExpiration::from_timeout_ms(
Some(USER_SHELL_TIMEOUT_MS),
cancellation_token.clone(),
),
sandbox: SandboxType::None,
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,
@@ -119,52 +121,10 @@ impl SessionTask for UserShellCommandTask {
});
let sandbox_policy = SandboxPolicy::DangerFullAccess;
let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream)
.or_cancel(&cancellation_token)
.await;
let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream).await;
match exec_result {
Err(CancelErr::Cancelled) => {
let aborted_message = "command aborted by user".to_string();
let exec_output = ExecToolCallOutput {
exit_code: -1,
stdout: StreamOutput::new(String::new()),
stderr: StreamOutput::new(aborted_message.clone()),
aggregated_output: StreamOutput::new(aborted_message.clone()),
duration: Duration::ZERO,
timed_out: false,
};
let output_items = [user_shell_command_record_item(
&raw_command,
&exec_output,
&turn_context,
)];
session
.record_conversation_items(turn_context.as_ref(), &output_items)
.await;
session
.send_event(
turn_context.as_ref(),
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id,
process_id: None,
turn_id: turn_context.sub_id.clone(),
command: command.clone(),
cwd: cwd.clone(),
parsed_cmd: parsed_cmd.clone(),
source: ExecCommandSource::UserShell,
interaction_input: None,
stdout: String::new(),
stderr: aborted_message.clone(),
aggregated_output: aborted_message.clone(),
exit_code: -1,
duration: Duration::ZERO,
formatted_output: aborted_message,
}),
)
.await;
}
Ok(Ok(output)) => {
Ok(output) => {
session
.send_event(
turn_context.as_ref(),
@@ -199,7 +159,7 @@ impl SessionTask for UserShellCommandTask {
.record_conversation_items(turn_context.as_ref(), &output_items)
.await;
}
Ok(Err(err)) => {
Err(err) => {
error!("user shell command failed: {err:?}");
let message = format!("execution error: {err:?}");
let exec_output = ExecToolCallOutput {

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext;
use crate::exec::ExecExpiration;
use crate::exec::ExecParams;
use crate::exec_env::create_env;
use crate::function_tool::FunctionCallError;
@@ -30,11 +31,15 @@ pub struct ShellHandler;
pub struct ShellCommandHandler;
impl ShellHandler {
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
fn to_exec_params(
params: ShellToolCallParams,
turn_context: &TurnContext,
cancellation_token: CancellationToken,
) -> ExecParams {
ExecParams {
command: params.command,
cwd: turn_context.resolve_path(params.workdir.clone()),
expiration: params.timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(params.timeout_ms, cancellation_token),
env: create_env(&turn_context.shell_environment_policy),
sandbox_permissions: params.sandbox_permissions.unwrap_or_default(),
justification: params.justification,
@@ -53,6 +58,7 @@ impl ShellCommandHandler {
params: ShellCommandToolCallParams,
session: &crate::codex::Session,
turn_context: &TurnContext,
cancellation_token: CancellationToken,
) -> ExecParams {
let shell = session.user_shell();
let command = Self::base_command(shell.as_ref(), &params.command, params.login);
@@ -60,7 +66,7 @@ impl ShellCommandHandler {
ExecParams {
command,
cwd: turn_context.resolve_path(params.workdir.clone()),
expiration: params.timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(params.timeout_ms, cancellation_token),
env: create_env(&turn_context.shell_environment_policy),
sandbox_permissions: params.sandbox_permissions.unwrap_or_default(),
justification: params.justification,
@@ -108,7 +114,8 @@ impl ToolHandler for ShellHandler {
match payload {
ToolPayload::Function { arguments } => {
let params: ShellToolCallParams = parse_arguments(&arguments)?;
let exec_params = Self::to_exec_params(params, turn.as_ref());
let exec_params =
Self::to_exec_params(params, turn.as_ref(), cancellation_token.clone());
Self::run_exec_like(
tool_name.as_str(),
exec_params,
@@ -118,7 +125,8 @@ impl ToolHandler for ShellHandler {
.await
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn.as_ref());
let exec_params =
Self::to_exec_params(params, turn.as_ref(), cancellation_token.clone());
Self::run_exec_like(
tool_name.as_str(),
exec_params,
@@ -176,7 +184,12 @@ impl ToolHandler for ShellCommandHandler {
};
let params: ShellCommandToolCallParams = parse_arguments(&arguments)?;
let exec_params = Self::to_exec_params(params, session.as_ref(), turn.as_ref());
let exec_params = Self::to_exec_params(
params,
session.as_ref(),
turn.as_ref(),
cancellation_token.clone(),
);
ShellHandler::run_exec_like(
tool_name.as_str(),
exec_params,
@@ -320,6 +333,7 @@ mod tests {
use codex_protocol::models::ShellCommandToolCallParams;
use pretty_assertions::assert_eq;
use tokio_util::sync::CancellationToken;
use crate::codex::make_session_and_context;
use crate::exec_env::create_env;
@@ -403,7 +417,12 @@ mod tests {
justification: justification.clone(),
};
let exec_params = ShellCommandHandler::to_exec_params(params, &session, &turn_context);
let exec_params = ShellCommandHandler::to_exec_params(
params,
&session,
&turn_context,
CancellationToken::new(),
);
// ExecParams cannot derive Eq due to the CancellationToken field, so we manually compare the fields.
assert_eq!(exec_params.command, expected_command);

View File

@@ -1,7 +1,4 @@
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio_util::either::Either;
use tokio_util::sync::CancellationToken;
@@ -13,13 +10,10 @@ use tracing::trace_span;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::exec::EXEC_ABORTED_MESSAGE;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
#[derive(Clone)]
@@ -54,15 +48,12 @@ impl ToolCallRuntime {
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
let supports_partial_output = Self::supports_partial_output_on_cancel(&call.tool_name);
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution);
let started = Instant::now();
let dispatch_span = trace_span!(
"dispatch_tool_call",
otel.name = call.tool_name.as_str(),
@@ -93,21 +84,16 @@ impl ToolCallRuntime {
.await
});
tokio::select! {
_ = cancellation_token.cancelled() => {
const CANCEL_OUTPUT_GRACE: Duration = Duration::from_millis(2_000);
let secs = started.elapsed().as_secs_f32().max(0.1);
let outcome = tokio::select! {
res = &mut dispatch_future => Some(res),
_ = cancellation_token.cancelled() => None,
};
match outcome {
Some(res) => res,
None => {
dispatch_span.record("aborted", true);
if supports_partial_output {
match tokio::time::timeout(CANCEL_OUTPUT_GRACE, &mut dispatch_future).await {
Ok(res) => res,
Err(_) => Ok(Self::aborted_response(&call, secs)),
}
} else {
Ok(Self::aborted_response(&call, secs))
}
},
res = &mut dispatch_future => res,
dispatch_future.await
}
}
}));
@@ -125,47 +111,4 @@ impl ToolCallRuntime {
}
}
impl ToolCallRuntime {
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
match &call.payload {
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
output: Self::abort_message(call, secs),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call.call_id.clone(),
result: Err(Self::abort_message(call, secs)),
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
content: Self::abort_message(call, secs),
..Default::default()
},
},
}
}
fn abort_message(call: &ToolCall, secs: f32) -> String {
match call.tool_name.as_str() {
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec"
| "exec_command" | "write_stdin" => {
format!("Wall time: {secs:.1} seconds\n{EXEC_ABORTED_MESSAGE}")
}
_ => format!("aborted by user after {secs:.1}s"),
}
}
fn supports_partial_output_on_cancel(tool_name: &str) -> bool {
matches!(
tool_name,
"shell"
| "container.exec"
| "local_shell"
| "shell_command"
| "unified_exec"
| "exec_command"
| "write_stdin"
)
}
}
impl ToolCallRuntime {}

View File

@@ -5,6 +5,7 @@
//! `codex --codex-run-as-apply-patch`, and runs under the current
//! `SandboxAttempt` with a minimal environment.
use crate::CODEX_APPLY_PATCH_ARG1;
use crate::exec::ExecExpiration;
use crate::exec::ExecToolCallOutput;
use crate::sandboxing::CommandSpec;
use crate::sandboxing::SandboxPermissions;
@@ -27,6 +28,7 @@ use codex_utils_absolute_path::AbsolutePathBuf;
use futures::future::BoxFuture;
use std::collections::HashMap;
use std::path::PathBuf;
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
pub struct ApplyPatchRequest {
@@ -46,7 +48,10 @@ impl ApplyPatchRuntime {
Self
}
fn build_command_spec(req: &ApplyPatchRequest) -> Result<CommandSpec, ToolError> {
fn build_command_spec(
req: &ApplyPatchRequest,
cancellation_token: CancellationToken,
) -> Result<CommandSpec, ToolError> {
use std::env;
let exe = if let Some(path) = &req.codex_exe {
path.clone()
@@ -59,7 +64,7 @@ impl ApplyPatchRuntime {
program,
args: vec![CODEX_APPLY_PATCH_ARG1.to_string(), req.action.patch.clone()],
cwd: req.action.cwd.clone(),
expiration: req.timeout_ms.into(),
expiration: ExecExpiration::from_timeout_ms(req.timeout_ms, cancellation_token),
// Run apply_patch with a minimal environment for determinism and to avoid leaks.
env: HashMap::new(),
sandbox_permissions: SandboxPermissions::UseDefault,
@@ -149,7 +154,7 @@ impl ToolRuntime<ApplyPatchRequest, ExecToolCallOutput> for ApplyPatchRuntime {
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
) -> Result<ExecToolCallOutput, ToolError> {
let spec = Self::build_command_spec(req)?;
let spec = Self::build_command_spec(req, ctx.cancellation_token.clone())?;
let env = attempt
.env_for(spec)
.map_err(|err| ToolError::Codex(err.into()))?;

View File

@@ -157,7 +157,7 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
};
let expiration =
ExecExpiration::from(req.timeout_ms).with_cancellation(ctx.cancellation_token.clone());
ExecExpiration::from_timeout_ms(req.timeout_ms, ctx.cancellation_token.clone());
let spec = build_command_spec(
&command,
&req.cwd,

View File

@@ -180,7 +180,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
&command,
&req.cwd,
&req.env,
ExecExpiration::DefaultTimeout,
ExecExpiration::default(ctx.cancellation_token.clone()),
req.sandbox_permissions,
req.justification.clone(),
)

View File

@@ -3,6 +3,7 @@
use std::collections::HashMap;
use std::string::ToString;
use codex_core::exec::ExecExpiration;
use codex_core::exec::ExecParams;
use codex_core::exec::ExecToolCallOutput;
use codex_core::exec::SandboxType;
@@ -15,6 +16,7 @@ use tempfile::TempDir;
use codex_core::error::Result;
use codex_core::get_platform_sandbox;
use tokio_util::sync::CancellationToken;
fn skip_test() -> bool {
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
@@ -33,7 +35,7 @@ async fn run_test_cmd(tmp: TempDir, cmd: Vec<&str>) -> Result<ExecToolCallOutput
let params = ExecParams {
command: cmd.iter().map(ToString::to_string).collect(),
cwd: tmp.path().to_path_buf(),
expiration: 1000.into(),
expiration: ExecExpiration::from_timeout_ms(Some(1000), CancellationToken::new()),
env: HashMap::new(),
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,

View File

@@ -84,7 +84,7 @@ impl EscalateServer {
command,
],
cwd: PathBuf::from(&workdir),
expiration: ExecExpiration::Cancellation(cancel_rx),
expiration: ExecExpiration::cancel_only(cancel_rx),
env,
sandbox_permissions: SandboxPermissions::UseDefault,
justification: None,