From 3b89bedc71ddfa88f2a2acfa0c03fc60324304d5 Mon Sep 17 00:00:00 2001 From: Michael Chen Date: Mon, 11 May 2026 15:39:14 -0700 Subject: [PATCH] [Core] Track prior user input in MCP metadata --- codex-rs/core/src/mcp_tool_call_tests.rs | 26 ++- .../src/tools/handlers/request_user_input.rs | 1 + .../handlers/request_user_input_tests.rs | 194 +++++++++++++++--- codex-rs/core/src/turn_metadata.rs | 18 ++ codex-rs/core/src/turn_metadata_tests.rs | 42 ++++ 5 files changed, 250 insertions(+), 31 deletions(-) diff --git a/codex-rs/core/src/mcp_tool_call_tests.rs b/codex-rs/core/src/mcp_tool_call_tests.rs index 326e0cfe31..75e4bb2160 100644 --- a/codex-rs/core/src/mcp_tool_call_tests.rs +++ b/codex-rs/core/src/mcp_tool_call_tests.rs @@ -6,6 +6,7 @@ use crate::session::tests::make_session_and_context_with_rx; use crate::state::ActiveTurn; use crate::test_support::models_manager_with_provider; use crate::turn_metadata::McpTurnMetadataContext; +use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY; use codex_config::CONFIG_TOML_FILE; use codex_config::config_toml::ConfigToml; use codex_config::types::AppConfig; @@ -2555,7 +2556,7 @@ async fn guardian_mode_mcp_denial_returns_rationale_message() { #[tokio::test] async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval() { - let (session, turn_context, _rx_event) = make_session_and_context_with_rx().await; + let (session, turn_context, rx_event) = make_session_and_context_with_rx().await; { let mut active_turn = session.active_turn.lock().await; *active_turn = Some(ActiveTurn::default()); @@ -2598,6 +2599,29 @@ async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval }) }; + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + let event = rx_event.recv().await.expect("expected event"); + if matches!( + event.msg, + EventMsg::RequestUserInput(_) | EventMsg::ElicitationRequest(_) + ) { + break; + } + } + }) + .await + .expect("MCP approval event timed out"); + + let meta = turn_context + .turn_metadata_state + .current_meta_value_for_mcp_request(mcp_turn_metadata_context(&turn_context)) + .expect("turn metadata should be present"); + assert!( + meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none(), + "current-call MCP approval should not mark prior user input metadata" + ); + assert!( tokio::time::timeout(std::time::Duration::from_millis(200), &mut approval_task) .await diff --git a/codex-rs/core/src/tools/handlers/request_user_input.rs b/codex-rs/core/src/tools/handlers/request_user_input.rs index 6d26234858..960dfa4deb 100644 --- a/codex-rs/core/src/tools/handlers/request_user_input.rs +++ b/codex-rs/core/src/tools/handlers/request_user_input.rs @@ -68,6 +68,7 @@ impl ToolHandler for RequestUserInputHandler { let args: RequestUserInputArgs = parse_arguments(&arguments)?; let args = normalize_request_user_input_args(args).map_err(FunctionCallError::RespondToModel)?; + turn.turn_metadata_state.mark_turn_user_input_requested(); let response = session .request_user_input(turn.as_ref(), call_id, args) .await diff --git a/codex-rs/core/src/tools/handlers/request_user_input_tests.rs b/codex-rs/core/src/tools/handlers/request_user_input_tests.rs index 7c577c54be..0e8c44dc9d 100644 --- a/codex-rs/core/src/tools/handlers/request_user_input_tests.rs +++ b/codex-rs/core/src/tools/handlers/request_user_input_tests.rs @@ -1,16 +1,93 @@ use super::*; use crate::session::tests::make_session_and_context; +use crate::session::tests::make_session_and_context_with_rx; +use crate::session::turn_context::TurnContext; +use crate::state::ActiveTurn; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::turn_diff_tracker::TurnDiffTracker; +use crate::turn_metadata::McpTurnMetadataContext; +use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY; use codex_protocol::ThreadId; +use codex_protocol::config_types::ModeKind; +use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; +use codex_protocol::request_user_input::RequestUserInputAnswer; +use codex_protocol::request_user_input::RequestUserInputArgs; +use codex_protocol::request_user_input::RequestUserInputQuestion; +use codex_protocol::request_user_input::RequestUserInputQuestionOption; +use codex_protocol::request_user_input::RequestUserInputResponse; use pretty_assertions::assert_eq; -use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; +fn request_user_input_args() -> RequestUserInputArgs { + RequestUserInputArgs { + questions: vec![RequestUserInputQuestion { + id: "pick_one".to_string(), + header: "Hdr".to_string(), + question: "Pick one".to_string(), + is_other: false, + is_secret: false, + options: Some(vec![ + RequestUserInputQuestionOption { + label: "A".to_string(), + description: "A".to_string(), + }, + RequestUserInputQuestionOption { + label: "B".to_string(), + description: "B".to_string(), + }, + ]), + }], + } +} + +fn request_user_input_response() -> RequestUserInputResponse { + RequestUserInputResponse { + answers: HashMap::from([( + "pick_one".to_string(), + RequestUserInputAnswer { + answers: vec!["A".to_string()], + }, + )]), + } +} + +fn prior_user_input_requested(turn: &TurnContext) -> Option { + let meta = turn + .turn_metadata_state + .current_meta_value_for_mcp_request(McpTurnMetadataContext { + model: turn.model_info.slug.as_str(), + reasoning_effort: turn.effective_reasoning_effort(), + }) + .expect("turn metadata should be present"); + meta.get(PRIOR_USER_INPUT_REQUESTED_KEY) + .and_then(serde_json::Value::as_bool) +} + +fn request_user_input_tool_invocation( + session: Arc, + turn: Arc, +) -> ToolInvocation { + ToolInvocation { + session, + turn, + cancellation_token: tokio_util::sync::CancellationToken::new(), + tracker: Arc::new(Mutex::new(TurnDiffTracker::default())), + call_id: "call-1".to_string(), + tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME), + source: crate::tools::context::ToolCallSource::Direct, + payload: ToolPayload::Function { + arguments: serde_json::to_string(&request_user_input_args()) + .expect("serialize request_user_input args"), + }, + } +} + #[tokio::test] async fn multi_agent_v2_request_user_input_rejects_subagent_threads() { let (session, mut turn) = make_session_and_context().await; @@ -25,35 +102,10 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() { let result = RequestUserInputHandler { available_modes: Vec::new(), } - .handle(ToolInvocation { - session: Arc::new(session), - turn: Arc::new(turn), - cancellation_token: tokio_util::sync::CancellationToken::new(), - tracker: Arc::new(Mutex::new(TurnDiffTracker::default())), - call_id: "call-1".to_string(), - tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME), - source: crate::tools::context::ToolCallSource::Direct, - payload: ToolPayload::Function { - arguments: json!({ - "questions": [{ - "header": "Hdr", - "question": "Pick one", - "id": "pick_one", - "options": [ - { - "label": "A", - "description": "A" - }, - { - "label": "B", - "description": "B" - } - ] - }] - }) - .to_string(), - }, - }) + .handle(request_user_input_tool_invocation( + Arc::new(session), + Arc::new(turn), + )) .await; let Err(err) = result else { @@ -66,3 +118,85 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() { ) ); } + +#[tokio::test] +async fn request_user_input_handler_marks_prior_user_input_for_mcp_metadata() { + let (session, turn, rx_event) = make_session_and_context_with_rx().await; + *session.active_turn.lock().await = Some(ActiveTurn::default()); + + assert_eq!(prior_user_input_requested(&turn), None); + + let handler_task = tokio::spawn({ + let session = Arc::clone(&session); + let turn = Arc::clone(&turn); + async move { + RequestUserInputHandler { + available_modes: vec![ModeKind::Default], + } + .handle(request_user_input_tool_invocation(session, turn)) + .await + } + }); + + let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv()) + .await + .expect("request_user_input event timed out") + .expect("expected request_user_input event"); + let EventMsg::RequestUserInput(request) = event.msg else { + panic!("expected request_user_input event"); + }; + + assert_eq!(prior_user_input_requested(&turn), Some(true)); + + session + .notify_user_input_response(&request.turn_id, request_user_input_response()) + .await; + + tokio::time::timeout(Duration::from_secs(1), handler_task) + .await + .expect("request_user_input handler timed out") + .expect("request_user_input handler task failed") + .expect("request_user_input handler should succeed"); +} + +#[tokio::test] +async fn lower_level_session_request_user_input_does_not_mark_prior_user_input() { + let (session, turn, rx_event) = make_session_and_context_with_rx().await; + *session.active_turn.lock().await = Some(ActiveTurn::default()); + + let request_task = tokio::spawn({ + let session = Arc::clone(&session); + let turn = Arc::clone(&turn); + async move { + session + .request_user_input( + turn.as_ref(), + "call-1".to_string(), + request_user_input_args(), + ) + .await + } + }); + + let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv()) + .await + .expect("request_user_input event timed out") + .expect("expected request_user_input event"); + let EventMsg::RequestUserInput(request) = event.msg else { + panic!("expected request_user_input event"); + }; + + assert_eq!(prior_user_input_requested(&turn), None); + + session + .notify_user_input_response(&request.turn_id, request_user_input_response()) + .await; + + tokio::time::timeout(Duration::from_secs(1), request_task) + .await + .expect("request_user_input timed out") + .expect("request_user_input task failed") + .expect("request_user_input should receive response"); + + assert_eq!(prior_user_input_requested(&turn), None); +} diff --git a/codex-rs/core/src/turn_metadata.rs b/codex-rs/core/src/turn_metadata.rs index 02760582f2..5e3490f8b5 100644 --- a/codex-rs/core/src/turn_metadata.rs +++ b/codex-rs/core/src/turn_metadata.rs @@ -3,6 +3,8 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use codex_utils_string::to_ascii_json_string; use serde::Serialize; @@ -21,6 +23,7 @@ use codex_protocol::protocol::ThreadSource; use codex_utils_absolute_path::AbsolutePathBuf; const MODEL_KEY: &str = "model"; +pub(crate) const PRIOR_USER_INPUT_REQUESTED_KEY: &str = "codex_prior_user_input_requested"; const REASONING_EFFORT_KEY: &str = "reasoning_effort"; const TURN_STARTED_AT_UNIX_MS_KEY: &str = "turn_started_at_unix_ms"; @@ -186,6 +189,7 @@ pub(crate) struct TurnMetadataState { enriched_header: Arc>>, turn_started_at_unix_ms: Arc>>, responsesapi_client_metadata: Arc>>>, + prior_user_input_requested: Arc, enrichment_task: Arc>>>, } @@ -231,6 +235,7 @@ impl TurnMetadataState { enriched_header: Arc::new(RwLock::new(None)), turn_started_at_unix_ms: Arc::new(RwLock::new(None)), responsesapi_client_metadata: Arc::new(RwLock::new(None)), + prior_user_input_requested: Arc::new(AtomicBool::new(false)), enrichment_task: Arc::new(Mutex::new(None)), } } @@ -285,9 +290,22 @@ impl TurnMetadataState { metadata.remove(REASONING_EFFORT_KEY); } } + if self.prior_user_input_requested.load(Ordering::Relaxed) { + metadata.insert( + PRIOR_USER_INPUT_REQUESTED_KEY.to_string(), + Value::Bool(true), + ); + } else { + metadata.remove(PRIOR_USER_INPUT_REQUESTED_KEY); + } Some(Value::Object(metadata)) } + pub(crate) fn mark_turn_user_input_requested(&self) { + self.prior_user_input_requested + .store(true, Ordering::Relaxed); + } + pub(crate) fn set_responsesapi_client_metadata( &self, responsesapi_client_metadata: HashMap, diff --git a/codex-rs/core/src/turn_metadata_tests.rs b/codex-rs/core/src/turn_metadata_tests.rs index 2a38447f86..38625d166b 100644 --- a/codex-rs/core/src/turn_metadata_tests.rs +++ b/codex-rs/core/src/turn_metadata_tests.rs @@ -213,6 +213,48 @@ fn turn_metadata_state_includes_model_and_reasoning_effort_only_in_request_meta( ); } +#[test] +fn turn_metadata_state_marks_prior_user_input_only_for_mcp_request_meta() { + let temp_dir = TempDir::new().expect("temp dir"); + let cwd = temp_dir.path().abs(); + let permission_profile = PermissionProfile::read_only(); + + let state = TurnMetadataState::new( + "session-a".to_string(), + "thread-a".to_string(), + /*thread_source*/ None, + "turn-a".to_string(), + cwd, + &permission_profile, + WindowsSandboxLevel::Disabled, + /*enforce_managed_network*/ false, + ); + + let header = state.current_header_value().expect("header"); + let header_json: Value = serde_json::from_str(&header).expect("json"); + assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none()); + + let meta = state + .current_meta_value_for_mcp_request(test_mcp_turn_metadata_context()) + .expect("turn metadata should be present"); + assert!(meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none()); + + state.mark_turn_user_input_requested(); + + let header = state.current_header_value().expect("header"); + let header_json: Value = serde_json::from_str(&header).expect("json"); + assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none()); + + let meta = state + .current_meta_value_for_mcp_request(test_mcp_turn_metadata_context()) + .expect("turn metadata should be present"); + assert_eq!( + meta.get(PRIOR_USER_INPUT_REQUESTED_KEY) + .and_then(Value::as_bool), + Some(true) + ); +} + #[test] fn turn_metadata_state_ignores_client_turn_started_at_unix_ms_before_start() { let temp_dir = TempDir::new().expect("temp dir");