[Core] Track prior user input in MCP metadata

This commit is contained in:
Michael Chen
2026-05-11 15:39:14 -07:00
parent 4859d80ffe
commit 3b89bedc71
5 changed files with 250 additions and 31 deletions

View File

@@ -6,6 +6,7 @@ use crate::session::tests::make_session_and_context_with_rx;
use crate::state::ActiveTurn;
use crate::test_support::models_manager_with_provider;
use crate::turn_metadata::McpTurnMetadataContext;
use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY;
use codex_config::CONFIG_TOML_FILE;
use codex_config::config_toml::ConfigToml;
use codex_config::types::AppConfig;
@@ -2555,7 +2556,7 @@ async fn guardian_mode_mcp_denial_returns_rationale_message() {
#[tokio::test]
async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval() {
let (session, turn_context, _rx_event) = make_session_and_context_with_rx().await;
let (session, turn_context, rx_event) = make_session_and_context_with_rx().await;
{
let mut active_turn = session.active_turn.lock().await;
*active_turn = Some(ActiveTurn::default());
@@ -2598,6 +2599,29 @@ async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval
})
};
tokio::time::timeout(std::time::Duration::from_secs(1), async {
loop {
let event = rx_event.recv().await.expect("expected event");
if matches!(
event.msg,
EventMsg::RequestUserInput(_) | EventMsg::ElicitationRequest(_)
) {
break;
}
}
})
.await
.expect("MCP approval event timed out");
let meta = turn_context
.turn_metadata_state
.current_meta_value_for_mcp_request(mcp_turn_metadata_context(&turn_context))
.expect("turn metadata should be present");
assert!(
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none(),
"current-call MCP approval should not mark prior user input metadata"
);
assert!(
tokio::time::timeout(std::time::Duration::from_millis(200), &mut approval_task)
.await

View File

@@ -68,6 +68,7 @@ impl ToolHandler for RequestUserInputHandler {
let args: RequestUserInputArgs = parse_arguments(&arguments)?;
let args =
normalize_request_user_input_args(args).map_err(FunctionCallError::RespondToModel)?;
turn.turn_metadata_state.mark_turn_user_input_requested();
let response = session
.request_user_input(turn.as_ref(), call_id, args)
.await

View File

@@ -1,16 +1,93 @@
use super::*;
use crate::session::tests::make_session_and_context;
use crate::session::tests::make_session_and_context_with_rx;
use crate::session::turn_context::TurnContext;
use crate::state::ActiveTurn;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::turn_metadata::McpTurnMetadataContext;
use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ModeKind;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::request_user_input::RequestUserInputAnswer;
use codex_protocol::request_user_input::RequestUserInputArgs;
use codex_protocol::request_user_input::RequestUserInputQuestion;
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
use codex_protocol::request_user_input::RequestUserInputResponse;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
fn request_user_input_args() -> RequestUserInputArgs {
RequestUserInputArgs {
questions: vec![RequestUserInputQuestion {
id: "pick_one".to_string(),
header: "Hdr".to_string(),
question: "Pick one".to_string(),
is_other: false,
is_secret: false,
options: Some(vec![
RequestUserInputQuestionOption {
label: "A".to_string(),
description: "A".to_string(),
},
RequestUserInputQuestionOption {
label: "B".to_string(),
description: "B".to_string(),
},
]),
}],
}
}
fn request_user_input_response() -> RequestUserInputResponse {
RequestUserInputResponse {
answers: HashMap::from([(
"pick_one".to_string(),
RequestUserInputAnswer {
answers: vec!["A".to_string()],
},
)]),
}
}
fn prior_user_input_requested(turn: &TurnContext) -> Option<bool> {
let meta = turn
.turn_metadata_state
.current_meta_value_for_mcp_request(McpTurnMetadataContext {
model: turn.model_info.slug.as_str(),
reasoning_effort: turn.effective_reasoning_effort(),
})
.expect("turn metadata should be present");
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY)
.and_then(serde_json::Value::as_bool)
}
fn request_user_input_tool_invocation(
session: Arc<crate::session::session::Session>,
turn: Arc<TurnContext>,
) -> ToolInvocation {
ToolInvocation {
session,
turn,
cancellation_token: tokio_util::sync::CancellationToken::new(),
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
call_id: "call-1".to_string(),
tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME),
source: crate::tools::context::ToolCallSource::Direct,
payload: ToolPayload::Function {
arguments: serde_json::to_string(&request_user_input_args())
.expect("serialize request_user_input args"),
},
}
}
#[tokio::test]
async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
let (session, mut turn) = make_session_and_context().await;
@@ -25,35 +102,10 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
let result = RequestUserInputHandler {
available_modes: Vec::new(),
}
.handle(ToolInvocation {
session: Arc::new(session),
turn: Arc::new(turn),
cancellation_token: tokio_util::sync::CancellationToken::new(),
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
call_id: "call-1".to_string(),
tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME),
source: crate::tools::context::ToolCallSource::Direct,
payload: ToolPayload::Function {
arguments: json!({
"questions": [{
"header": "Hdr",
"question": "Pick one",
"id": "pick_one",
"options": [
{
"label": "A",
"description": "A"
},
{
"label": "B",
"description": "B"
}
]
}]
})
.to_string(),
},
})
.handle(request_user_input_tool_invocation(
Arc::new(session),
Arc::new(turn),
))
.await;
let Err(err) = result else {
@@ -66,3 +118,85 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
)
);
}
#[tokio::test]
async fn request_user_input_handler_marks_prior_user_input_for_mcp_metadata() {
let (session, turn, rx_event) = make_session_and_context_with_rx().await;
*session.active_turn.lock().await = Some(ActiveTurn::default());
assert_eq!(prior_user_input_requested(&turn), None);
let handler_task = tokio::spawn({
let session = Arc::clone(&session);
let turn = Arc::clone(&turn);
async move {
RequestUserInputHandler {
available_modes: vec![ModeKind::Default],
}
.handle(request_user_input_tool_invocation(session, turn))
.await
}
});
let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv())
.await
.expect("request_user_input event timed out")
.expect("expected request_user_input event");
let EventMsg::RequestUserInput(request) = event.msg else {
panic!("expected request_user_input event");
};
assert_eq!(prior_user_input_requested(&turn), Some(true));
session
.notify_user_input_response(&request.turn_id, request_user_input_response())
.await;
tokio::time::timeout(Duration::from_secs(1), handler_task)
.await
.expect("request_user_input handler timed out")
.expect("request_user_input handler task failed")
.expect("request_user_input handler should succeed");
}
#[tokio::test]
async fn lower_level_session_request_user_input_does_not_mark_prior_user_input() {
let (session, turn, rx_event) = make_session_and_context_with_rx().await;
*session.active_turn.lock().await = Some(ActiveTurn::default());
let request_task = tokio::spawn({
let session = Arc::clone(&session);
let turn = Arc::clone(&turn);
async move {
session
.request_user_input(
turn.as_ref(),
"call-1".to_string(),
request_user_input_args(),
)
.await
}
});
let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv())
.await
.expect("request_user_input event timed out")
.expect("expected request_user_input event");
let EventMsg::RequestUserInput(request) = event.msg else {
panic!("expected request_user_input event");
};
assert_eq!(prior_user_input_requested(&turn), None);
session
.notify_user_input_response(&request.turn_id, request_user_input_response())
.await;
tokio::time::timeout(Duration::from_secs(1), request_task)
.await
.expect("request_user_input timed out")
.expect("request_user_input task failed")
.expect("request_user_input should receive response");
assert_eq!(prior_user_input_requested(&turn), None);
}

View File

@@ -3,6 +3,8 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use codex_utils_string::to_ascii_json_string;
use serde::Serialize;
@@ -21,6 +23,7 @@ use codex_protocol::protocol::ThreadSource;
use codex_utils_absolute_path::AbsolutePathBuf;
const MODEL_KEY: &str = "model";
pub(crate) const PRIOR_USER_INPUT_REQUESTED_KEY: &str = "codex_prior_user_input_requested";
const REASONING_EFFORT_KEY: &str = "reasoning_effort";
const TURN_STARTED_AT_UNIX_MS_KEY: &str = "turn_started_at_unix_ms";
@@ -186,6 +189,7 @@ pub(crate) struct TurnMetadataState {
enriched_header: Arc<RwLock<Option<String>>>,
turn_started_at_unix_ms: Arc<RwLock<Option<i64>>>,
responsesapi_client_metadata: Arc<RwLock<Option<HashMap<String, String>>>>,
prior_user_input_requested: Arc<AtomicBool>,
enrichment_task: Arc<Mutex<Option<JoinHandle<()>>>>,
}
@@ -231,6 +235,7 @@ impl TurnMetadataState {
enriched_header: Arc::new(RwLock::new(None)),
turn_started_at_unix_ms: Arc::new(RwLock::new(None)),
responsesapi_client_metadata: Arc::new(RwLock::new(None)),
prior_user_input_requested: Arc::new(AtomicBool::new(false)),
enrichment_task: Arc::new(Mutex::new(None)),
}
}
@@ -285,9 +290,22 @@ impl TurnMetadataState {
metadata.remove(REASONING_EFFORT_KEY);
}
}
if self.prior_user_input_requested.load(Ordering::Relaxed) {
metadata.insert(
PRIOR_USER_INPUT_REQUESTED_KEY.to_string(),
Value::Bool(true),
);
} else {
metadata.remove(PRIOR_USER_INPUT_REQUESTED_KEY);
}
Some(Value::Object(metadata))
}
pub(crate) fn mark_turn_user_input_requested(&self) {
self.prior_user_input_requested
.store(true, Ordering::Relaxed);
}
pub(crate) fn set_responsesapi_client_metadata(
&self,
responsesapi_client_metadata: HashMap<String, String>,

View File

@@ -213,6 +213,48 @@ fn turn_metadata_state_includes_model_and_reasoning_effort_only_in_request_meta(
);
}
#[test]
fn turn_metadata_state_marks_prior_user_input_only_for_mcp_request_meta() {
let temp_dir = TempDir::new().expect("temp dir");
let cwd = temp_dir.path().abs();
let permission_profile = PermissionProfile::read_only();
let state = TurnMetadataState::new(
"session-a".to_string(),
"thread-a".to_string(),
/*thread_source*/ None,
"turn-a".to_string(),
cwd,
&permission_profile,
WindowsSandboxLevel::Disabled,
/*enforce_managed_network*/ false,
);
let header = state.current_header_value().expect("header");
let header_json: Value = serde_json::from_str(&header).expect("json");
assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
let meta = state
.current_meta_value_for_mcp_request(test_mcp_turn_metadata_context())
.expect("turn metadata should be present");
assert!(meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
state.mark_turn_user_input_requested();
let header = state.current_header_value().expect("header");
let header_json: Value = serde_json::from_str(&header).expect("json");
assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
let meta = state
.current_meta_value_for_mcp_request(test_mcp_turn_metadata_context())
.expect("turn metadata should be present");
assert_eq!(
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY)
.and_then(Value::as_bool),
Some(true)
);
}
#[test]
fn turn_metadata_state_ignores_client_turn_started_at_unix_ms_before_start() {
let temp_dir = TempDir::new().expect("temp dir");