mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
feat: waiting for an elicitation should not count against a shell tool timeout (#6973)
Previously, we were running into an issue where we would run the `shell` tool call with a timeout of 10s, but it fired an elicitation asking for user approval, the time the user took to respond to the elicitation was counted agains the 10s timeout, so the `shell` tool call would fail with a timeout error unless the user is very fast! This PR addresses this issue by introducing a "stopwatch" abstraction that is used to manage the timeout. The idea is: - `Stopwatch::new()` is called with the _real_ timeout of the `shell` tool call. - `process_exec_tool_call()` is called with the `Cancellation` variant of `ExecExpiration` because it should not manage its own timeout in this case - the `Stopwatch` expiration is wired up to the `cancel_rx` passed to `process_exec_tool_call()` - when an elicitation for the `shell` tool call is received, the `Stopwatch` pauses - because it is possible for multiple elicitations to arrive concurrently, it keeps track of the number of "active pauses" and does not resume until that counter goes down to zero I verified that I can test the MCP server using `@modelcontextprotocol/inspector` and specify `git status` as the `command` with a timeout of 500ms and that the elicitation pops up and I have all the time in the world to respond whereas previous to this PR, that would not have been possible. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/6973). * #7005 * __->__ #6973 * #6972
This commit is contained in:
@@ -49,6 +49,7 @@ tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ mod escalation_policy;
|
||||
mod mcp;
|
||||
mod mcp_escalation_policy;
|
||||
mod socket;
|
||||
mod stopwatch;
|
||||
|
||||
/// Default value of --execve option relative to the current executable.
|
||||
/// Note this must match the name of the binary as specified in Cargo.toml.
|
||||
|
||||
@@ -13,6 +13,7 @@ use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use tokio::process::Command;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
@@ -24,6 +25,7 @@ use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::escalation_policy::EscalationPolicy;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
use codex_core::exec::ExecExpiration;
|
||||
|
||||
pub(crate) struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
@@ -48,7 +50,7 @@ impl EscalateServer {
|
||||
command: String,
|
||||
env: HashMap<String, String>,
|
||||
workdir: PathBuf,
|
||||
timeout_ms: Option<u64>,
|
||||
cancel_rx: CancellationToken,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
@@ -79,7 +81,7 @@ impl EscalateServer {
|
||||
command,
|
||||
],
|
||||
cwd: PathBuf::from(&workdir),
|
||||
expiration: timeout_ms.into(),
|
||||
expiration: ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
|
||||
@@ -22,6 +22,7 @@ use crate::posix::escalate_server::EscalateServer;
|
||||
use crate::posix::escalate_server::{self};
|
||||
use crate::posix::mcp_escalation_policy::ExecPolicy;
|
||||
use crate::posix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
|
||||
/// Path to our patched bash.
|
||||
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
|
||||
@@ -87,10 +88,17 @@ impl ExecTool {
|
||||
context: RequestContext<RoleServer>,
|
||||
Parameters(params): Parameters<ExecParams>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let effective_timeout = Duration::from_millis(
|
||||
params
|
||||
.timeout_ms
|
||||
.unwrap_or(codex_core::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
|
||||
);
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let escalate_server = EscalateServer::new(
|
||||
self.bash_path.clone(),
|
||||
self.execve_wrapper.clone(),
|
||||
McpEscalationPolicy::new(self.policy, context),
|
||||
McpEscalationPolicy::new(self.policy, context, stopwatch.clone()),
|
||||
);
|
||||
let result = escalate_server
|
||||
.exec(
|
||||
@@ -98,7 +106,7 @@ impl ExecTool {
|
||||
// TODO: use ShellEnvironmentPolicy
|
||||
std::env::vars().collect(),
|
||||
PathBuf::from(¶ms.workdir),
|
||||
params.timeout_ms,
|
||||
cancel_token,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
|
||||
@@ -10,6 +10,7 @@ use rmcp::service::RequestContext;
|
||||
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalation_policy::EscalationPolicy;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
|
||||
/// This is the policy which decides how to handle an exec() call.
|
||||
///
|
||||
@@ -34,11 +35,20 @@ pub(crate) enum ExecPolicyOutcome {
|
||||
pub(crate) struct McpEscalationPolicy {
|
||||
policy: ExecPolicy,
|
||||
context: RequestContext<RoleServer>,
|
||||
stopwatch: Stopwatch,
|
||||
}
|
||||
|
||||
impl McpEscalationPolicy {
|
||||
pub(crate) fn new(policy: ExecPolicy, context: RequestContext<RoleServer>) -> Self {
|
||||
Self { policy, context }
|
||||
pub(crate) fn new(
|
||||
policy: ExecPolicy,
|
||||
context: RequestContext<RoleServer>,
|
||||
stopwatch: Stopwatch,
|
||||
) -> Self {
|
||||
Self {
|
||||
policy,
|
||||
context,
|
||||
stopwatch,
|
||||
}
|
||||
}
|
||||
|
||||
async fn prompt(
|
||||
@@ -54,25 +64,34 @@ impl McpEscalationPolicy {
|
||||
} else {
|
||||
format!("{} {}", file.display(), args)
|
||||
};
|
||||
context
|
||||
.peer
|
||||
.create_elicitation(CreateElicitationRequestParam {
|
||||
message: format!("Allow agent to run `{command}` in `{}`?", workdir.display()),
|
||||
requested_schema: ElicitationSchema::builder()
|
||||
.title("Execution Permission Request")
|
||||
.optional_string_with("reason", |schema| {
|
||||
schema.description("Optional reason for allowing or denying execution")
|
||||
self.stopwatch
|
||||
.pause_for(async {
|
||||
context
|
||||
.peer
|
||||
.create_elicitation(CreateElicitationRequestParam {
|
||||
message: format!(
|
||||
"Allow agent to run `{command}` in `{}`?",
|
||||
workdir.display()
|
||||
),
|
||||
requested_schema: ElicitationSchema::builder()
|
||||
.title("Execution Permission Request")
|
||||
.optional_string_with("reason", |schema| {
|
||||
schema.description(
|
||||
"Optional reason for allowing or denying execution",
|
||||
)
|
||||
})
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
McpError::internal_error(
|
||||
format!("failed to build elicitation schema: {e}"),
|
||||
None,
|
||||
)
|
||||
})?,
|
||||
})
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
McpError::internal_error(
|
||||
format!("failed to build elicitation schema: {e}"),
|
||||
None,
|
||||
)
|
||||
})?,
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
211
codex-rs/exec-server/src/posix/stopwatch.rs
Normal file
211
codex-rs/exec-server/src/posix/stopwatch.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Stopwatch {
|
||||
limit: Duration,
|
||||
inner: Arc<Mutex<StopwatchState>>,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StopwatchState {
|
||||
elapsed: Duration,
|
||||
running_since: Option<Instant>,
|
||||
active_pauses: u32,
|
||||
}
|
||||
|
||||
impl Stopwatch {
|
||||
pub(crate) fn new(limit: Duration) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(StopwatchState {
|
||||
elapsed: Duration::ZERO,
|
||||
running_since: Some(Instant::now()),
|
||||
active_pauses: 0,
|
||||
})),
|
||||
notify: Arc::new(Notify::new()),
|
||||
limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cancellation_token(&self) -> CancellationToken {
|
||||
let limit = self.limit;
|
||||
let token = CancellationToken::new();
|
||||
let cancel = token.clone();
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let notify = Arc::clone(&self.notify);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (remaining, running) = {
|
||||
let guard = inner.lock().await;
|
||||
let elapsed = guard.elapsed
|
||||
+ guard
|
||||
.running_since
|
||||
.map(|since| since.elapsed())
|
||||
.unwrap_or_default();
|
||||
if elapsed >= limit {
|
||||
break;
|
||||
}
|
||||
(limit - elapsed, guard.running_since.is_some())
|
||||
};
|
||||
|
||||
if !running {
|
||||
notify.notified().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let sleep = tokio::time::sleep(remaining);
|
||||
tokio::pin!(sleep);
|
||||
tokio::select! {
|
||||
_ = &mut sleep => {
|
||||
break;
|
||||
}
|
||||
_ = notify.notified() => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel.cancel();
|
||||
});
|
||||
token
|
||||
}
|
||||
|
||||
/// Runs `fut`, pausing the stopwatch while the future is pending. The clock
|
||||
/// resumes automatically when the future completes. Nested/overlapping
|
||||
/// calls are reference-counted so the stopwatch only resumes when every
|
||||
/// pause is lifted.
|
||||
pub(crate) async fn pause_for<F, T>(&self, fut: F) -> T
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
self.pause().await;
|
||||
let result = fut.await;
|
||||
self.resume().await;
|
||||
result
|
||||
}
|
||||
|
||||
async fn pause(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
guard.active_pauses += 1;
|
||||
if guard.active_pauses == 1
|
||||
&& let Some(since) = guard.running_since.take()
|
||||
{
|
||||
guard.elapsed += since.elapsed();
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
async fn resume(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
if guard.active_pauses == 0 {
|
||||
return;
|
||||
}
|
||||
guard.active_pauses -= 1;
|
||||
if guard.active_pauses == 0 && guard.running_since.is_none() {
|
||||
guard.running_since = Some(Instant::now());
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Stopwatch;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_receiver_fires_after_limit() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
let start = Instant::now();
|
||||
token.cancelled().await;
|
||||
assert!(start.elapsed() >= Duration::from_millis(50));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pause_prevents_timeout_until_resumed() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
let pause_handle = tokio::spawn({
|
||||
let stopwatch = stopwatch.clone();
|
||||
async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause_handle.await.expect("pause task should finish");
|
||||
|
||||
token.cancelled().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlapping_pauses_only_resume_once() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
// First pause.
|
||||
let pause1 = {
|
||||
let stopwatch = stopwatch.clone();
|
||||
tokio::spawn(async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(80)).await;
|
||||
})
|
||||
.await;
|
||||
})
|
||||
};
|
||||
|
||||
// Overlapping pause that ends sooner.
|
||||
let pause2 = {
|
||||
let stopwatch = stopwatch.clone();
|
||||
tokio::spawn(async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(30)).await;
|
||||
})
|
||||
.await;
|
||||
})
|
||||
};
|
||||
|
||||
// While both pauses are active, the cancellation should not fire.
|
||||
assert!(
|
||||
timeout(Duration::from_millis(40), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause2.await.expect("short pause should complete");
|
||||
|
||||
// Still paused because the long pause is active.
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause1.await.expect("long pause should complete");
|
||||
|
||||
// Now the stopwatch should resume and hit the limit shortly after.
|
||||
token.cancelled().await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user