mirror of
https://github.com/openai/codex.git
synced 2026-05-16 01:02:48 +00:00
[Core] Track prior user input in MCP metadata
This commit is contained in:
@@ -6,6 +6,7 @@ use crate::session::tests::make_session_and_context_with_rx;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::test_support::models_manager_with_provider;
|
||||
use crate::turn_metadata::McpTurnMetadataContext;
|
||||
use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY;
|
||||
use codex_config::CONFIG_TOML_FILE;
|
||||
use codex_config::config_toml::ConfigToml;
|
||||
use codex_config::types::AppConfig;
|
||||
@@ -2555,7 +2556,7 @@ async fn guardian_mode_mcp_denial_returns_rationale_message() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval() {
|
||||
let (session, turn_context, _rx_event) = make_session_and_context_with_rx().await;
|
||||
let (session, turn_context, rx_event) = make_session_and_context_with_rx().await;
|
||||
{
|
||||
let mut active_turn = session.active_turn.lock().await;
|
||||
*active_turn = Some(ActiveTurn::default());
|
||||
@@ -2598,6 +2599,29 @@ async fn prompt_mode_waits_for_approval_when_annotations_do_not_require_approval
|
||||
})
|
||||
};
|
||||
|
||||
tokio::time::timeout(std::time::Duration::from_secs(1), async {
|
||||
loop {
|
||||
let event = rx_event.recv().await.expect("expected event");
|
||||
if matches!(
|
||||
event.msg,
|
||||
EventMsg::RequestUserInput(_) | EventMsg::ElicitationRequest(_)
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("MCP approval event timed out");
|
||||
|
||||
let meta = turn_context
|
||||
.turn_metadata_state
|
||||
.current_meta_value_for_mcp_request(mcp_turn_metadata_context(&turn_context))
|
||||
.expect("turn metadata should be present");
|
||||
assert!(
|
||||
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none(),
|
||||
"current-call MCP approval should not mark prior user input metadata"
|
||||
);
|
||||
|
||||
assert!(
|
||||
tokio::time::timeout(std::time::Duration::from_millis(200), &mut approval_task)
|
||||
.await
|
||||
|
||||
@@ -68,6 +68,7 @@ impl ToolHandler for RequestUserInputHandler {
|
||||
let args: RequestUserInputArgs = parse_arguments(&arguments)?;
|
||||
let args =
|
||||
normalize_request_user_input_args(args).map_err(FunctionCallError::RespondToModel)?;
|
||||
turn.turn_metadata_state.mark_turn_user_input_requested();
|
||||
let response = session
|
||||
.request_user_input(turn.as_ref(), call_id, args)
|
||||
.await
|
||||
|
||||
@@ -1,16 +1,93 @@
|
||||
use super::*;
|
||||
use crate::session::tests::make_session_and_context;
|
||||
use crate::session::tests::make_session_and_context_with_rx;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::turn_metadata::McpTurnMetadataContext;
|
||||
use crate::turn_metadata::PRIOR_USER_INPUT_REQUESTED_KEY;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::request_user_input::RequestUserInputAnswer;
|
||||
use codex_protocol::request_user_input::RequestUserInputArgs;
|
||||
use codex_protocol::request_user_input::RequestUserInputQuestion;
|
||||
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
fn request_user_input_args() -> RequestUserInputArgs {
|
||||
RequestUserInputArgs {
|
||||
questions: vec![RequestUserInputQuestion {
|
||||
id: "pick_one".to_string(),
|
||||
header: "Hdr".to_string(),
|
||||
question: "Pick one".to_string(),
|
||||
is_other: false,
|
||||
is_secret: false,
|
||||
options: Some(vec![
|
||||
RequestUserInputQuestionOption {
|
||||
label: "A".to_string(),
|
||||
description: "A".to_string(),
|
||||
},
|
||||
RequestUserInputQuestionOption {
|
||||
label: "B".to_string(),
|
||||
description: "B".to_string(),
|
||||
},
|
||||
]),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn request_user_input_response() -> RequestUserInputResponse {
|
||||
RequestUserInputResponse {
|
||||
answers: HashMap::from([(
|
||||
"pick_one".to_string(),
|
||||
RequestUserInputAnswer {
|
||||
answers: vec!["A".to_string()],
|
||||
},
|
||||
)]),
|
||||
}
|
||||
}
|
||||
|
||||
fn prior_user_input_requested(turn: &TurnContext) -> Option<bool> {
|
||||
let meta = turn
|
||||
.turn_metadata_state
|
||||
.current_meta_value_for_mcp_request(McpTurnMetadataContext {
|
||||
model: turn.model_info.slug.as_str(),
|
||||
reasoning_effort: turn.effective_reasoning_effort(),
|
||||
})
|
||||
.expect("turn metadata should be present");
|
||||
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY)
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
}
|
||||
|
||||
fn request_user_input_tool_invocation(
|
||||
session: Arc<crate::session::session::Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
) -> ToolInvocation {
|
||||
ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
|
||||
call_id: "call-1".to_string(),
|
||||
tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME),
|
||||
source: crate::tools::context::ToolCallSource::Direct,
|
||||
payload: ToolPayload::Function {
|
||||
arguments: serde_json::to_string(&request_user_input_args())
|
||||
.expect("serialize request_user_input args"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
|
||||
let (session, mut turn) = make_session_and_context().await;
|
||||
@@ -25,35 +102,10 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
|
||||
let result = RequestUserInputHandler {
|
||||
available_modes: Vec::new(),
|
||||
}
|
||||
.handle(ToolInvocation {
|
||||
session: Arc::new(session),
|
||||
turn: Arc::new(turn),
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
|
||||
call_id: "call-1".to_string(),
|
||||
tool_name: codex_tools::ToolName::plain(REQUEST_USER_INPUT_TOOL_NAME),
|
||||
source: crate::tools::context::ToolCallSource::Direct,
|
||||
payload: ToolPayload::Function {
|
||||
arguments: json!({
|
||||
"questions": [{
|
||||
"header": "Hdr",
|
||||
"question": "Pick one",
|
||||
"id": "pick_one",
|
||||
"options": [
|
||||
{
|
||||
"label": "A",
|
||||
"description": "A"
|
||||
},
|
||||
{
|
||||
"label": "B",
|
||||
"description": "B"
|
||||
}
|
||||
]
|
||||
}]
|
||||
})
|
||||
.to_string(),
|
||||
},
|
||||
})
|
||||
.handle(request_user_input_tool_invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
))
|
||||
.await;
|
||||
|
||||
let Err(err) = result else {
|
||||
@@ -66,3 +118,85 @@ async fn multi_agent_v2_request_user_input_rejects_subagent_threads() {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn request_user_input_handler_marks_prior_user_input_for_mcp_metadata() {
|
||||
let (session, turn, rx_event) = make_session_and_context_with_rx().await;
|
||||
*session.active_turn.lock().await = Some(ActiveTurn::default());
|
||||
|
||||
assert_eq!(prior_user_input_requested(&turn), None);
|
||||
|
||||
let handler_task = tokio::spawn({
|
||||
let session = Arc::clone(&session);
|
||||
let turn = Arc::clone(&turn);
|
||||
async move {
|
||||
RequestUserInputHandler {
|
||||
available_modes: vec![ModeKind::Default],
|
||||
}
|
||||
.handle(request_user_input_tool_invocation(session, turn))
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv())
|
||||
.await
|
||||
.expect("request_user_input event timed out")
|
||||
.expect("expected request_user_input event");
|
||||
let EventMsg::RequestUserInput(request) = event.msg else {
|
||||
panic!("expected request_user_input event");
|
||||
};
|
||||
|
||||
assert_eq!(prior_user_input_requested(&turn), Some(true));
|
||||
|
||||
session
|
||||
.notify_user_input_response(&request.turn_id, request_user_input_response())
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(1), handler_task)
|
||||
.await
|
||||
.expect("request_user_input handler timed out")
|
||||
.expect("request_user_input handler task failed")
|
||||
.expect("request_user_input handler should succeed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lower_level_session_request_user_input_does_not_mark_prior_user_input() {
|
||||
let (session, turn, rx_event) = make_session_and_context_with_rx().await;
|
||||
*session.active_turn.lock().await = Some(ActiveTurn::default());
|
||||
|
||||
let request_task = tokio::spawn({
|
||||
let session = Arc::clone(&session);
|
||||
let turn = Arc::clone(&turn);
|
||||
async move {
|
||||
session
|
||||
.request_user_input(
|
||||
turn.as_ref(),
|
||||
"call-1".to_string(),
|
||||
request_user_input_args(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
let event = tokio::time::timeout(Duration::from_secs(1), rx_event.recv())
|
||||
.await
|
||||
.expect("request_user_input event timed out")
|
||||
.expect("expected request_user_input event");
|
||||
let EventMsg::RequestUserInput(request) = event.msg else {
|
||||
panic!("expected request_user_input event");
|
||||
};
|
||||
|
||||
assert_eq!(prior_user_input_requested(&turn), None);
|
||||
|
||||
session
|
||||
.notify_user_input_response(&request.turn_id, request_user_input_response())
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(1), request_task)
|
||||
.await
|
||||
.expect("request_user_input timed out")
|
||||
.expect("request_user_input task failed")
|
||||
.expect("request_user_input should receive response");
|
||||
|
||||
assert_eq!(prior_user_input_requested(&turn), None);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_utils_string::to_ascii_json_string;
|
||||
use serde::Serialize;
|
||||
@@ -21,6 +23,7 @@ use codex_protocol::protocol::ThreadSource;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
|
||||
const MODEL_KEY: &str = "model";
|
||||
pub(crate) const PRIOR_USER_INPUT_REQUESTED_KEY: &str = "codex_prior_user_input_requested";
|
||||
const REASONING_EFFORT_KEY: &str = "reasoning_effort";
|
||||
const TURN_STARTED_AT_UNIX_MS_KEY: &str = "turn_started_at_unix_ms";
|
||||
|
||||
@@ -186,6 +189,7 @@ pub(crate) struct TurnMetadataState {
|
||||
enriched_header: Arc<RwLock<Option<String>>>,
|
||||
turn_started_at_unix_ms: Arc<RwLock<Option<i64>>>,
|
||||
responsesapi_client_metadata: Arc<RwLock<Option<HashMap<String, String>>>>,
|
||||
prior_user_input_requested: Arc<AtomicBool>,
|
||||
enrichment_task: Arc<Mutex<Option<JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
@@ -231,6 +235,7 @@ impl TurnMetadataState {
|
||||
enriched_header: Arc::new(RwLock::new(None)),
|
||||
turn_started_at_unix_ms: Arc::new(RwLock::new(None)),
|
||||
responsesapi_client_metadata: Arc::new(RwLock::new(None)),
|
||||
prior_user_input_requested: Arc::new(AtomicBool::new(false)),
|
||||
enrichment_task: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
@@ -285,9 +290,22 @@ impl TurnMetadataState {
|
||||
metadata.remove(REASONING_EFFORT_KEY);
|
||||
}
|
||||
}
|
||||
if self.prior_user_input_requested.load(Ordering::Relaxed) {
|
||||
metadata.insert(
|
||||
PRIOR_USER_INPUT_REQUESTED_KEY.to_string(),
|
||||
Value::Bool(true),
|
||||
);
|
||||
} else {
|
||||
metadata.remove(PRIOR_USER_INPUT_REQUESTED_KEY);
|
||||
}
|
||||
Some(Value::Object(metadata))
|
||||
}
|
||||
|
||||
pub(crate) fn mark_turn_user_input_requested(&self) {
|
||||
self.prior_user_input_requested
|
||||
.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub(crate) fn set_responsesapi_client_metadata(
|
||||
&self,
|
||||
responsesapi_client_metadata: HashMap<String, String>,
|
||||
|
||||
@@ -213,6 +213,48 @@ fn turn_metadata_state_includes_model_and_reasoning_effort_only_in_request_meta(
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn turn_metadata_state_marks_prior_user_input_only_for_mcp_request_meta() {
|
||||
let temp_dir = TempDir::new().expect("temp dir");
|
||||
let cwd = temp_dir.path().abs();
|
||||
let permission_profile = PermissionProfile::read_only();
|
||||
|
||||
let state = TurnMetadataState::new(
|
||||
"session-a".to_string(),
|
||||
"thread-a".to_string(),
|
||||
/*thread_source*/ None,
|
||||
"turn-a".to_string(),
|
||||
cwd,
|
||||
&permission_profile,
|
||||
WindowsSandboxLevel::Disabled,
|
||||
/*enforce_managed_network*/ false,
|
||||
);
|
||||
|
||||
let header = state.current_header_value().expect("header");
|
||||
let header_json: Value = serde_json::from_str(&header).expect("json");
|
||||
assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
|
||||
|
||||
let meta = state
|
||||
.current_meta_value_for_mcp_request(test_mcp_turn_metadata_context())
|
||||
.expect("turn metadata should be present");
|
||||
assert!(meta.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
|
||||
|
||||
state.mark_turn_user_input_requested();
|
||||
|
||||
let header = state.current_header_value().expect("header");
|
||||
let header_json: Value = serde_json::from_str(&header).expect("json");
|
||||
assert!(header_json.get(PRIOR_USER_INPUT_REQUESTED_KEY).is_none());
|
||||
|
||||
let meta = state
|
||||
.current_meta_value_for_mcp_request(test_mcp_turn_metadata_context())
|
||||
.expect("turn metadata should be present");
|
||||
assert_eq!(
|
||||
meta.get(PRIOR_USER_INPUT_REQUESTED_KEY)
|
||||
.and_then(Value::as_bool),
|
||||
Some(true)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn turn_metadata_state_ignores_client_turn_started_at_unix_ms_before_start() {
|
||||
let temp_dir = TempDir::new().expect("temp dir");
|
||||
|
||||
Reference in New Issue
Block a user