mirror of
https://github.com/openai/codex.git
synced 2026-02-23 01:03:48 +00:00
Compare commits
12 Commits
pr12518
...
cc/request
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efff89c933 | ||
|
|
22da6a0afc | ||
|
|
ac4fbab6f3 | ||
|
|
9b6cd7d20c | ||
|
|
1db750d93e | ||
|
|
6e9a5db6b4 | ||
|
|
f1fa175416 | ||
|
|
bbcf8ce408 | ||
|
|
a4184c7f7d | ||
|
|
841b777fdb | ||
|
|
7a5e27b135 | ||
|
|
a68e3f893e |
@@ -411,6 +411,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
);
|
||||
let empty = CoreRequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
};
|
||||
if let Err(err) = conversation
|
||||
.submit(Op::UserInputAnswer {
|
||||
@@ -1671,6 +1672,7 @@ async fn on_request_user_input_response(
|
||||
error!("request failed with client error: {err:?}");
|
||||
let empty = CoreRequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
};
|
||||
if let Err(err) = conversation
|
||||
.submit(Op::UserInputAnswer {
|
||||
@@ -1687,6 +1689,7 @@ async fn on_request_user_input_response(
|
||||
error!("request failed: {err:?}");
|
||||
let empty = CoreRequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
};
|
||||
if let Err(err) = conversation
|
||||
.submit(Op::UserInputAnswer {
|
||||
@@ -1721,6 +1724,7 @@ async fn on_request_user_input_response(
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
interrupted: false,
|
||||
};
|
||||
|
||||
if let Err(err) = conversation
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Debug;
|
||||
@@ -87,7 +88,7 @@ use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
@@ -232,8 +233,10 @@ use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tasks::TaskRunOutput;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::handlers::SEARCH_TOOL_BM25_TOOL_NAME;
|
||||
use crate::tools::js_repl::JsReplHandle;
|
||||
use crate::tools::network_approval::NetworkApprovalService;
|
||||
@@ -4477,9 +4480,9 @@ pub(crate) async fn run_turn(
|
||||
input: Vec<UserInput>,
|
||||
prewarmed_client_session: Option<ModelClientSession>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
if input.is_empty() {
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
|
||||
let model_info = turn_context.model_info.clone();
|
||||
@@ -4500,7 +4503,7 @@ pub(crate) async fn run_turn(
|
||||
.is_err()
|
||||
{
|
||||
error!("Failed to run pre-sampling compact");
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
|
||||
let previous_model = sess.previous_model().await;
|
||||
@@ -4528,7 +4531,7 @@ pub(crate) async fn run_turn(
|
||||
.await
|
||||
{
|
||||
Ok(mcp_tools) => mcp_tools,
|
||||
Err(_) => return None,
|
||||
Err(_) => return TaskRunOutput::default(),
|
||||
};
|
||||
connectors::with_app_enabled_state(
|
||||
connectors::accessible_connectors_from_mcp_tools(&mcp_tools),
|
||||
@@ -4639,6 +4642,7 @@ pub(crate) async fn run_turn(
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let mut server_model_warning_emitted_for_turn = false;
|
||||
let mut abort_reason = None;
|
||||
|
||||
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
|
||||
// one instance across retries within this turn.
|
||||
@@ -4709,8 +4713,22 @@ pub(crate) async fn run_turn(
|
||||
Ok(sampling_request_output) => {
|
||||
let SamplingRequestResult {
|
||||
needs_follow_up,
|
||||
interrupted_tool_result,
|
||||
last_agent_message: sampling_request_last_agent_message,
|
||||
} = sampling_request_output;
|
||||
if interrupted_tool_result {
|
||||
cancellation_token.cancel();
|
||||
// Keep interrupt cleanup consistent with abort_all_tasks(): a parallel
|
||||
// unified-exec tool call may still be running when request_user_input
|
||||
// returns an interrupted result.
|
||||
sess.close_unified_exec_processes().await;
|
||||
sess.finish_turn_without_completion_event(turn_context.as_ref())
|
||||
.await;
|
||||
// Defer TurnAborted emission until run_turn unwinds so the caller can
|
||||
// flush the rollout marker without blocking the in-flight tool loop.
|
||||
abort_reason = Some(TurnAbortReason::Interrupted);
|
||||
break;
|
||||
}
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
let token_limit_reached = total_usage_tokens >= auto_compact_limit;
|
||||
|
||||
@@ -4737,7 +4755,7 @@ pub(crate) async fn run_turn(
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -4799,7 +4817,7 @@ pub(crate) async fn run_turn(
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -4835,7 +4853,10 @@ pub(crate) async fn run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
last_agent_message
|
||||
TaskRunOutput {
|
||||
last_agent_message,
|
||||
abort_reason,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_pre_sampling_compact(
|
||||
@@ -5296,9 +5317,16 @@ async fn built_tools(
|
||||
#[derive(Debug)]
|
||||
struct SamplingRequestResult {
|
||||
needs_follow_up: bool,
|
||||
interrupted_tool_result: bool,
|
||||
last_agent_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IndexedToolDispatchOutput {
|
||||
seq: usize,
|
||||
output: ToolDispatchOutput,
|
||||
}
|
||||
|
||||
/// Ephemeral per-response state for streaming a single proposed plan.
|
||||
/// This is intentionally not persisted or stored in session/state since it
|
||||
/// only exists while a response is actively streaming. The final plan text
|
||||
@@ -5676,22 +5704,68 @@ async fn handle_assistant_item_done_in_plan_mode(
|
||||
}
|
||||
|
||||
async fn drain_in_flight(
|
||||
in_flight: &mut FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>>,
|
||||
in_flight: &mut FuturesUnordered<BoxFuture<'static, CodexResult<IndexedToolDispatchOutput>>>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) -> CodexResult<()> {
|
||||
) -> CodexResult<bool> {
|
||||
let mut next_seq = 0usize;
|
||||
let mut ready = BTreeMap::<usize, ToolDispatchOutput>::new();
|
||||
while let Some(res) = in_flight.next().await {
|
||||
match res {
|
||||
Ok(response_input) => {
|
||||
sess.record_conversation_items(&turn_context, &[response_input.into()])
|
||||
.await;
|
||||
Ok(indexed) => {
|
||||
let IndexedToolDispatchOutput { seq, output } = indexed;
|
||||
if output.interrupt_turn {
|
||||
// Drain any completions that are already ready but not yet yielded by
|
||||
// FuturesUnordered so earlier outputs are not lost when we return below.
|
||||
loop {
|
||||
match in_flight.next().now_or_never() {
|
||||
Some(Some(Ok(indexed))) => {
|
||||
let IndexedToolDispatchOutput { seq, output } = indexed;
|
||||
ready.insert(seq, output);
|
||||
}
|
||||
Some(Some(Err(err))) => {
|
||||
error_or_panic(format!(
|
||||
"in-flight tool future failed during interrupt drain: {err}"
|
||||
));
|
||||
}
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
// Preserve any already-completed earlier tool outputs before short-circuiting.
|
||||
// FuturesUnordered may yield the interrupting result before lower-sequence
|
||||
// completions that were buffered waiting on an even earlier slow tool.
|
||||
let _dropped_later_ready = ready.split_off(&seq);
|
||||
for (_ready_seq, ready_output) in std::mem::take(&mut ready) {
|
||||
sess.record_conversation_items(
|
||||
&turn_context,
|
||||
&[ready_output.response_input.into()],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
sess.record_conversation_items(&turn_context, &[output.response_input.into()])
|
||||
.await;
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
ready.insert(seq, output);
|
||||
while let Some(output) = ready.remove(&next_seq) {
|
||||
sess.record_conversation_items(&turn_context, &[output.response_input.into()])
|
||||
.await;
|
||||
next_seq += 1;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error_or_panic(format!("in-flight tool future failed during drain: {err}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
if !ready.is_empty() {
|
||||
for (_seq, output) in ready {
|
||||
sess.record_conversation_items(&turn_context, &[output.response_input.into()])
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -5747,8 +5821,10 @@ async fn try_run_sampling_request(
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
|
||||
FuturesOrdered::new();
|
||||
let mut in_flight: FuturesUnordered<
|
||||
BoxFuture<'static, CodexResult<IndexedToolDispatchOutput>>,
|
||||
> = FuturesUnordered::new();
|
||||
let mut next_in_flight_seq = 0usize;
|
||||
let mut needs_follow_up = false;
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
let mut active_item: Option<TurnItem> = None;
|
||||
@@ -5756,7 +5832,7 @@ async fn try_run_sampling_request(
|
||||
let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan;
|
||||
let mut plan_mode_state = plan_mode.then(|| PlanModeStreamState::new(&turn_context.sub_id));
|
||||
let receiving_span = trace_span!("receiving_stream");
|
||||
let outcome: CodexResult<SamplingRequestResult> = loop {
|
||||
let mut outcome: CodexResult<SamplingRequestResult> = loop {
|
||||
let handle_responses = trace_span!(
|
||||
parent: &receiving_span,
|
||||
"handle_responses",
|
||||
@@ -5831,7 +5907,12 @@ async fn try_run_sampling_request(
|
||||
.instrument(handle_responses)
|
||||
.await?;
|
||||
if let Some(tool_future) = output_result.tool_future {
|
||||
in_flight.push_back(tool_future);
|
||||
let seq = next_in_flight_seq;
|
||||
next_in_flight_seq += 1;
|
||||
in_flight.push(Box::pin(async move {
|
||||
let output = tool_future.await?;
|
||||
Ok(IndexedToolDispatchOutput { seq, output })
|
||||
}));
|
||||
}
|
||||
if let Some(agent_message) = output_result.last_agent_message {
|
||||
last_agent_message = Some(agent_message);
|
||||
@@ -5890,6 +5971,7 @@ async fn try_run_sampling_request(
|
||||
|
||||
break Ok(SamplingRequestResult {
|
||||
needs_follow_up,
|
||||
interrupted_tool_result: false,
|
||||
last_agent_message,
|
||||
});
|
||||
}
|
||||
@@ -5971,7 +6053,14 @@ async fn try_run_sampling_request(
|
||||
}
|
||||
};
|
||||
|
||||
drain_in_flight(&mut in_flight, sess.clone(), turn_context.clone()).await?;
|
||||
let interrupted_tool_result =
|
||||
drain_in_flight(&mut in_flight, sess.clone(), turn_context.clone()).await?;
|
||||
if let Ok(result) = outcome.as_mut()
|
||||
&& interrupted_tool_result
|
||||
{
|
||||
result.needs_follow_up = false;
|
||||
result.interrupted_tool_result = true;
|
||||
}
|
||||
|
||||
if should_emit_turn_diff {
|
||||
let unified_diff = {
|
||||
@@ -6050,6 +6139,7 @@ mod tests {
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
@@ -8299,10 +8389,10 @@ mod tests {
|
||||
_ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
if self.listen_to_cancellation_token {
|
||||
cancellation_token.cancelled().await;
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
@@ -8344,6 +8434,207 @@ mod tests {
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_in_flight_flushes_buffered_earlier_results_before_interrupt() {
|
||||
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let mut in_flight: FuturesUnordered<
|
||||
BoxFuture<'static, CodexResult<IndexedToolDispatchOutput>>,
|
||||
> = FuturesUnordered::new();
|
||||
|
||||
let (slow_tx, slow_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let _slow_tx = slow_tx;
|
||||
in_flight.push(Box::pin(async move {
|
||||
let _ = slow_rx.await;
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 0,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "slow-call".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("slow".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: false,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
in_flight.push(Box::pin(async move {
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 1,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "fast-call".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("fast".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: false,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
let _ = interrupt_tx.send(());
|
||||
});
|
||||
in_flight.push(Box::pin(async move {
|
||||
let _ = interrupt_rx.await;
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 2,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "interrupt-call".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("interrupt".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: true,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
let interrupted = drain_in_flight(&mut in_flight, Arc::clone(&sess), Arc::clone(&tc))
|
||||
.await
|
||||
.expect("drain_in_flight should succeed");
|
||||
assert!(interrupted);
|
||||
|
||||
let history = sess.clone_history().await;
|
||||
let fast_item = ResponseItem::from(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "fast-call".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("fast".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
});
|
||||
let interrupt_item = ResponseItem::from(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "interrupt-call".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("interrupt".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
});
|
||||
|
||||
let fast_idx = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.position(|item| item == &fast_item)
|
||||
.expect("buffered earlier tool result should be recorded");
|
||||
let interrupt_idx = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.position(|item| item == &interrupt_item)
|
||||
.expect("interrupting tool result should be recorded");
|
||||
assert!(
|
||||
fast_idx < interrupt_idx,
|
||||
"buffered earlier tool result should be recorded before interrupt result"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_in_flight_flushes_ready_unyielded_earlier_results_before_interrupt() {
|
||||
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let mut in_flight: FuturesUnordered<
|
||||
BoxFuture<'static, CodexResult<IndexedToolDispatchOutput>>,
|
||||
> = FuturesUnordered::new();
|
||||
|
||||
let (slow_tx, slow_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let _slow_tx = slow_tx;
|
||||
in_flight.push(Box::pin(async move {
|
||||
let _ = slow_rx.await;
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 0,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "slow-call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("slow".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: false,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
let (fast_tx, fast_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
in_flight.push(Box::pin(async move {
|
||||
let _ = fast_rx.await;
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 1,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "fast-call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("fast".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: false,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
in_flight.push(Box::pin(async move {
|
||||
let _ = fast_tx.send(());
|
||||
Ok(IndexedToolDispatchOutput {
|
||||
seq: 2,
|
||||
output: ToolDispatchOutput {
|
||||
response_input: ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "interrupt-call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("interrupt".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
interrupt_turn: true,
|
||||
},
|
||||
})
|
||||
}));
|
||||
|
||||
let interrupted = drain_in_flight(&mut in_flight, Arc::clone(&sess), Arc::clone(&tc))
|
||||
.await
|
||||
.expect("drain_in_flight should succeed");
|
||||
assert!(interrupted);
|
||||
|
||||
let history = sess.clone_history().await;
|
||||
let fast_item = ResponseItem::from(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "fast-call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("fast".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
});
|
||||
let interrupt_item = ResponseItem::from(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "interrupt-call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("interrupt".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
});
|
||||
|
||||
let fast_idx = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.position(|item| item == &fast_item)
|
||||
.expect("ready-but-unyielded earlier tool result should be recorded");
|
||||
let interrupt_idx = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.position(|item| item == &interrupt_item)
|
||||
.expect("interrupting tool result should be recorded");
|
||||
assert!(
|
||||
fast_idx < interrupt_idx,
|
||||
"ready-but-unyielded earlier tool result should be recorded before interrupt result"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_gracefully_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
@@ -422,6 +422,7 @@ where
|
||||
_ = cancel_token.cancelled() => {
|
||||
let empty = RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
};
|
||||
parent_session
|
||||
.notify_user_input_response(sub_id, empty.clone())
|
||||
@@ -430,6 +431,7 @@ where
|
||||
}
|
||||
response = fut => response.unwrap_or_else(|| RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,12 +105,14 @@ async fn should_install_mcp_dependencies(
|
||||
_ = cancellation_token.cancelled() => {
|
||||
let empty = RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
};
|
||||
sess.notify_user_input_response(sub_id, empty.clone()).await;
|
||||
empty
|
||||
}
|
||||
response = response_fut => response.unwrap_or_else(|| RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
}),
|
||||
};
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ pub(crate) async fn request_skill_dependencies(
|
||||
.await
|
||||
.unwrap_or_else(|| RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: false,
|
||||
});
|
||||
|
||||
if response.answers.is_empty() {
|
||||
|
||||
@@ -56,8 +56,11 @@ impl ActiveTurn {
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
|
||||
self.tasks.swap_remove(sub_id);
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> Option<RunningTask> {
|
||||
self.tasks.swap_remove(sub_id)
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::error::Result;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::parse_turn_item;
|
||||
use crate::proposed_plan_parser::strip_proposed_plan_blocks;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
@@ -26,7 +27,7 @@ use tracing::instrument;
|
||||
/// queuing any tool execution futures. This records items immediately so
|
||||
/// history and rollout stay in sync even if the turn is later cancelled.
|
||||
pub(crate) type InFlightFuture<'f> =
|
||||
Pin<Box<dyn Future<Output = Result<ResponseInputItem>> + Send + 'f>>;
|
||||
Pin<Box<dyn Future<Output = Result<ToolDispatchOutput>> + Send + 'f>>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct OutputItemResult {
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
use super::TaskRunOutput;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::state::TaskKind;
|
||||
use async_trait::async_trait;
|
||||
@@ -23,7 +24,7 @@ impl SessionTask for CompactTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
let session = session.clone_session();
|
||||
let _ = if crate::compact::should_use_remote_compact_task(&ctx.provider) {
|
||||
let _ = session.services.otel_manager.counter(
|
||||
@@ -40,6 +41,7 @@ impl SessionTask for CompactTask {
|
||||
);
|
||||
crate::compact::run_compact_task(session.clone(), ctx, input).await
|
||||
};
|
||||
None
|
||||
|
||||
TaskRunOutput::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::protocol::WarningEvent;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tasks::TaskRunOutput;
|
||||
use async_trait::async_trait;
|
||||
use codex_git::CreateGhostCommitOptions;
|
||||
use codex_git::GhostSnapshotReport;
|
||||
@@ -38,7 +39,7 @@ impl SessionTask for GhostSnapshotTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
tokio::task::spawn(async move {
|
||||
let token = self.token;
|
||||
let warnings_enabled = !ctx.ghost_snapshot.disable_warnings;
|
||||
@@ -152,7 +153,7 @@ impl SessionTask for GhostSnapshotTask {
|
||||
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
|
||||
}
|
||||
});
|
||||
None
|
||||
TaskRunOutput::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -72,6 +72,12 @@ impl SessionTaskContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TaskRunOutput {
|
||||
pub(crate) last_agent_message: Option<String>,
|
||||
pub(crate) abort_reason: Option<TurnAbortReason>,
|
||||
}
|
||||
|
||||
/// Async task that drives a [`Session`] turn.
|
||||
///
|
||||
/// Implementations encapsulate a specific Codex workflow (regular chat,
|
||||
@@ -100,7 +106,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String>;
|
||||
) -> TaskRunOutput;
|
||||
|
||||
/// Gives the task a chance to perform cleanup after an abort.
|
||||
///
|
||||
@@ -138,17 +144,28 @@ impl Session {
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let ctx_for_finish = Arc::clone(&ctx);
|
||||
let last_agent_message = task_for_run
|
||||
let TaskRunOutput {
|
||||
last_agent_message,
|
||||
abort_reason,
|
||||
} = task_for_run
|
||||
.run(
|
||||
Arc::clone(&session_ctx),
|
||||
ctx,
|
||||
input,
|
||||
task_cancellation_token.child_token(),
|
||||
task_cancellation_token.clone(),
|
||||
)
|
||||
.await;
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.flush_rollout().await;
|
||||
if !task_cancellation_token.is_cancelled() {
|
||||
if let Some(reason) = abort_reason {
|
||||
ctx_for_finish
|
||||
.turn_metadata_state
|
||||
.cancel_git_enrichment_task();
|
||||
// Emit TurnAborted from the spawn site so the rollout flush above
|
||||
// makes the interrupt marker durable before clients observe the event.
|
||||
sess.emit_turn_aborted(ctx_for_finish.as_ref(), reason)
|
||||
.await;
|
||||
} else if !task_cancellation_token.is_cancelled() {
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
sess.on_task_finished(Arc::clone(&ctx_for_finish), last_agent_message)
|
||||
.await;
|
||||
@@ -193,29 +210,8 @@ impl Session {
|
||||
turn_context
|
||||
.turn_metadata_state
|
||||
.cancel_git_enrichment_task();
|
||||
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut pending_input = Vec::<ResponseInputItem>::new();
|
||||
let mut should_clear_active_turn = false;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
pending_input = ts.take_pending_input();
|
||||
should_clear_active_turn = true;
|
||||
}
|
||||
if should_clear_active_turn {
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if !pending_input.is_empty() {
|
||||
let pending_response_items = pending_input
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<_>>();
|
||||
self.record_conversation_items(turn_context.as_ref(), &pending_response_items)
|
||||
.await;
|
||||
}
|
||||
self.finish_turn_without_completion_event(turn_context.as_ref())
|
||||
.await;
|
||||
let event = EventMsg::TurnComplete(TurnCompleteEvent {
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
last_agent_message,
|
||||
@@ -249,6 +245,38 @@ impl Session {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn finish_turn_without_completion_event(&self, turn_context: &TurnContext) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut pending_input = Vec::<ResponseInputItem>::new();
|
||||
let mut removed_handle: Option<Arc<AbortOnDropHandle<()>>> = None;
|
||||
let mut should_clear_active_turn = false;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& let Some(task) = at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
removed_handle = Some(task.handle);
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
pending_input = ts.take_pending_input();
|
||||
should_clear_active_turn = at.is_empty();
|
||||
}
|
||||
if should_clear_active_turn {
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if let Some(handle) = removed_handle
|
||||
&& let Ok(handle) = Arc::try_unwrap(handle)
|
||||
{
|
||||
drop(handle.detach());
|
||||
}
|
||||
if !pending_input.is_empty() {
|
||||
let pending_response_items = pending_input
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<_>>();
|
||||
self.record_conversation_items(turn_context, &pending_response_items)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
@@ -276,7 +304,15 @@ impl Session {
|
||||
session_task
|
||||
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||
.await;
|
||||
self.emit_turn_aborted(task.turn_context.as_ref(), reason)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn emit_turn_aborted(
|
||||
self: &Arc<Self>,
|
||||
turn_context: &TurnContext,
|
||||
reason: TurnAbortReason,
|
||||
) {
|
||||
if reason == TurnAbortReason::Interrupted {
|
||||
let marker = ResponseItem::Message {
|
||||
id: None,
|
||||
@@ -289,7 +325,7 @@ impl Session {
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
self.record_into_history(std::slice::from_ref(&marker), task.turn_context.as_ref())
|
||||
self.record_into_history(std::slice::from_ref(&marker), turn_context)
|
||||
.await;
|
||||
self.persist_rollout_items(&[RolloutItem::ResponseItem(marker)])
|
||||
.await;
|
||||
@@ -299,10 +335,10 @@ impl Session {
|
||||
}
|
||||
|
||||
let event = EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
turn_id: Some(task.turn_context.sub_id.clone()),
|
||||
turn_id: Some(turn_context.sub_id.clone()),
|
||||
reason,
|
||||
});
|
||||
self.send_event(task.turn_context.as_ref(), event).await;
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ use tracing::warn;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
use super::TaskRunOutput;
|
||||
|
||||
type PrewarmedSessionTask = JoinHandle<Option<ModelClientSession>>;
|
||||
|
||||
@@ -89,7 +90,7 @@ impl SessionTask for RegularTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
let sess = session.clone_session();
|
||||
let run_turn_span = trace_span!("run_turn");
|
||||
sess.set_server_reasoning_included(false).await;
|
||||
|
||||
@@ -27,6 +27,7 @@ use codex_protocol::user_input::UserInput;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
use super::TaskRunOutput;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct ReviewTask;
|
||||
@@ -49,7 +50,7 @@ impl SessionTask for ReviewTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
let _ = session
|
||||
.session
|
||||
.services
|
||||
@@ -71,7 +72,7 @@ impl SessionTask for ReviewTask {
|
||||
if !cancellation_token.is_cancelled() {
|
||||
exit_review_mode(session.clone_session(), output.clone(), ctx.clone()).await;
|
||||
}
|
||||
None
|
||||
TaskRunOutput::default()
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::protocol::UndoStartedEvent;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tasks::TaskRunOutput;
|
||||
use async_trait::async_trait;
|
||||
use codex_git::RestoreGhostCommitOptions;
|
||||
use codex_git::restore_ghost_commit_with_options;
|
||||
@@ -37,7 +38,7 @@ impl SessionTask for UndoTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
let _ = session
|
||||
.session
|
||||
.services
|
||||
@@ -61,7 +62,7 @@ impl SessionTask for UndoTask {
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
}
|
||||
|
||||
let history = sess.clone_history().await;
|
||||
@@ -86,7 +87,7 @@ impl SessionTask for UndoTask {
|
||||
completed.message = Some("No ghost snapshot available to undo.".to_string());
|
||||
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
|
||||
.await;
|
||||
return None;
|
||||
return TaskRunOutput::default();
|
||||
};
|
||||
|
||||
let commit_id = ghost_commit.id().to_string();
|
||||
@@ -122,6 +123,6 @@ impl SessionTask for UndoTask {
|
||||
|
||||
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
|
||||
.await;
|
||||
None
|
||||
TaskRunOutput::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ use crate::user_shell_command::user_shell_command_record_item;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
use super::TaskRunOutput;
|
||||
use crate::codex::Session;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -72,7 +73,7 @@ impl SessionTask for UserShellCommandTask {
|
||||
turn_context: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
) -> TaskRunOutput {
|
||||
execute_user_shell_command(
|
||||
session.clone_session(),
|
||||
turn_context,
|
||||
@@ -81,7 +82,7 @@ impl SessionTask for UserShellCommandTask {
|
||||
UserShellCommandMode::StandaloneTurn,
|
||||
)
|
||||
.await;
|
||||
None
|
||||
TaskRunOutput::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,15 +63,28 @@ pub enum ToolOutput {
|
||||
body: FunctionCallOutputBody,
|
||||
success: Option<bool>,
|
||||
},
|
||||
FunctionWithControl {
|
||||
// Canonical output body for function-style tools plus internal control
|
||||
// metadata consumed by core dispatch (not exposed on the wire).
|
||||
body: FunctionCallOutputBody,
|
||||
success: Option<bool>,
|
||||
interrupt_turn: bool,
|
||||
},
|
||||
Mcp {
|
||||
result: Result<CallToolResult, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ToolDispatchOutput {
|
||||
pub response_input: ResponseInputItem,
|
||||
pub interrupt_turn: bool,
|
||||
}
|
||||
|
||||
impl ToolOutput {
|
||||
pub fn log_preview(&self) -> String {
|
||||
match self {
|
||||
ToolOutput::Function { body, .. } => {
|
||||
ToolOutput::Function { body, .. } | ToolOutput::FunctionWithControl { body, .. } => {
|
||||
telemetry_preview(&body.to_text().unwrap_or_default())
|
||||
}
|
||||
ToolOutput::Mcp { result } => format!("{result:?}"),
|
||||
@@ -80,14 +93,23 @@ impl ToolOutput {
|
||||
|
||||
pub fn success_for_logging(&self) -> bool {
|
||||
match self {
|
||||
ToolOutput::Function { success, .. } => success.unwrap_or(true),
|
||||
ToolOutput::Function { success, .. }
|
||||
| ToolOutput::FunctionWithControl { success, .. } => success.unwrap_or(true),
|
||||
ToolOutput::Mcp { result } => result.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn interrupt_turn_hint(&self) -> bool {
|
||||
match self {
|
||||
ToolOutput::FunctionWithControl { interrupt_turn, .. } => *interrupt_turn,
|
||||
ToolOutput::Function { .. } | ToolOutput::Mcp { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
match self {
|
||||
ToolOutput::Function { body, success } => {
|
||||
ToolOutput::Function { body, success }
|
||||
| ToolOutput::FunctionWithControl { body, success, .. } => {
|
||||
// `custom_tool_call` is the Responses API item type for freeform
|
||||
// tools (`ToolSpec::Freeform`, e.g. freeform `apply_patch`).
|
||||
// Those payloads must round-trip as `custom_tool_call_output`
|
||||
@@ -205,6 +227,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_with_control_interrupt_hint_is_internal_only() {
|
||||
let payload = ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
};
|
||||
let output = ToolOutput::FunctionWithControl {
|
||||
body: FunctionCallOutputBody::Text("ok".to_string()),
|
||||
success: Some(true),
|
||||
interrupt_turn: true,
|
||||
};
|
||||
|
||||
assert!(output.interrupt_turn_hint());
|
||||
|
||||
let response = output.into_response("fn-ctrl", &payload);
|
||||
match response {
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
assert_eq!(call_id, "fn-ctrl");
|
||||
assert_eq!(output.text_content(), Some("ok"));
|
||||
assert_eq!(output.success, Some(true));
|
||||
}
|
||||
other => panic!("expected FunctionCallOutput, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_tool_calls_can_derive_text_from_content_items() {
|
||||
let payload = ToolPayload::Custom {
|
||||
|
||||
@@ -155,6 +155,7 @@ impl ToolHandler for JsReplHandler {
|
||||
};
|
||||
|
||||
let content = result.output;
|
||||
let interrupt_turn = result.interrupt_turn;
|
||||
let items = vec![FunctionCallOutputContentItem::InputText {
|
||||
text: content.clone(),
|
||||
}];
|
||||
@@ -169,10 +170,18 @@ impl ToolHandler for JsReplHandler {
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::ContentItems(items),
|
||||
success: Some(true),
|
||||
})
|
||||
if interrupt_turn {
|
||||
Ok(ToolOutput::FunctionWithControl {
|
||||
body: FunctionCallOutputBody::ContentItems(items),
|
||||
success: Some(true),
|
||||
interrupt_turn: true,
|
||||
})
|
||||
} else {
|
||||
Ok(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::ContentItems(items),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::tools::registry::ToolKind;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::config_types::TUI_VISIBLE_COLLABORATION_MODES;
|
||||
use codex_protocol::request_user_input::RequestUserInputArgs;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
|
||||
fn format_allowed_modes() -> String {
|
||||
let mode_names: Vec<&str> = TUI_VISIBLE_COLLABORATION_MODES
|
||||
@@ -109,6 +110,17 @@ impl ToolHandler for RequestUserInputHandler {
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
fn should_interrupt_turn(&self, output: &ToolOutput) -> bool {
|
||||
let ToolOutput::Function { body, .. } = output else {
|
||||
return false;
|
||||
};
|
||||
let Some(content) = body.to_text() else {
|
||||
return false;
|
||||
};
|
||||
serde_json::from_str::<RequestUserInputResponse>(&content)
|
||||
.is_ok_and(|response| response.interrupted)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -148,4 +160,26 @@ mod tests {
|
||||
"Request user input for one to three short questions and wait for the response. This tool is only available in Plan mode.".to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interrupted_response_interrupts_turn() {
|
||||
let handler = RequestUserInputHandler;
|
||||
let output = ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(r#"{"answers":{},"interrupted":true}"#.to_string()),
|
||||
success: Some(true),
|
||||
};
|
||||
|
||||
assert!(handler.should_interrupt_turn(&output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interrupted_response_does_not_interrupt_turn() {
|
||||
let handler = RequestUserInputHandler;
|
||||
let output = ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(r#"{"answers":{}}"#.to_string()),
|
||||
success: Some(true),
|
||||
};
|
||||
|
||||
assert!(!handler.should_interrupt_turn(&output));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,6 +95,7 @@ pub struct JsReplArgs {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct JsExecResult {
|
||||
pub output: String,
|
||||
pub interrupt_turn: bool,
|
||||
}
|
||||
|
||||
struct KernelState {
|
||||
@@ -116,6 +117,7 @@ struct ExecContext {
|
||||
#[derive(Default)]
|
||||
struct ExecToolCalls {
|
||||
in_flight: usize,
|
||||
interrupted: bool,
|
||||
notify: Arc<Notify>,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
@@ -359,6 +361,24 @@ impl JsReplManager {
|
||||
}
|
||||
}
|
||||
|
||||
async fn mark_exec_tool_call_interrupted(
|
||||
exec_tool_calls: &Arc<Mutex<HashMap<String, ExecToolCalls>>>,
|
||||
exec_id: &str,
|
||||
) {
|
||||
let mut calls = exec_tool_calls.lock().await;
|
||||
if let Some(state) = calls.get_mut(exec_id) {
|
||||
state.interrupted = true;
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_tool_calls_interrupted(
|
||||
exec_tool_calls: &Arc<Mutex<HashMap<String, ExecToolCalls>>>,
|
||||
exec_id: &str,
|
||||
) -> bool {
|
||||
let calls = exec_tool_calls.lock().await;
|
||||
calls.get(exec_id).is_some_and(|state| state.interrupted)
|
||||
}
|
||||
|
||||
async fn wait_for_exec_tool_calls_map(
|
||||
exec_tool_calls: &Arc<Mutex<HashMap<String, ExecToolCalls>>>,
|
||||
exec_id: &str,
|
||||
@@ -543,7 +563,13 @@ impl JsReplManager {
|
||||
};
|
||||
|
||||
match response {
|
||||
ExecResultMessage::Ok { output } => Ok(JsExecResult { output }),
|
||||
ExecResultMessage::Ok {
|
||||
output,
|
||||
interrupt_turn,
|
||||
} => Ok(JsExecResult {
|
||||
output,
|
||||
interrupt_turn,
|
||||
}),
|
||||
ExecResultMessage::Err { message } => Err(FunctionCallError::RespondToModel(message)),
|
||||
}
|
||||
}
|
||||
@@ -845,10 +871,15 @@ impl JsReplManager {
|
||||
error,
|
||||
} => {
|
||||
JsReplManager::wait_for_exec_tool_calls_map(&exec_tool_calls, &id).await;
|
||||
let interrupt_turn =
|
||||
JsReplManager::exec_tool_calls_interrupted(&exec_tool_calls, &id).await;
|
||||
let mut pending = pending_execs.lock().await;
|
||||
if let Some(tx) = pending.remove(&id) {
|
||||
let payload = if ok {
|
||||
ExecResultMessage::Ok { output }
|
||||
ExecResultMessage::Ok {
|
||||
output,
|
||||
interrupt_turn,
|
||||
}
|
||||
} else {
|
||||
ExecResultMessage::Err {
|
||||
message: error
|
||||
@@ -871,6 +902,7 @@ impl JsReplManager {
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some("js_repl exec context not found".to_string()),
|
||||
interrupt_turn: false,
|
||||
});
|
||||
if let Err(err) = JsReplManager::write_message(&stdin, &payload).await {
|
||||
let snapshot =
|
||||
@@ -904,6 +936,7 @@ impl JsReplManager {
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some("js_repl execution reset".to_string()),
|
||||
interrupt_turn: false,
|
||||
},
|
||||
result = JsReplManager::run_tool_request(ctx, req) => result,
|
||||
}
|
||||
@@ -913,8 +946,16 @@ impl JsReplManager {
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some("js_repl exec context not found".to_string()),
|
||||
interrupt_turn: false,
|
||||
},
|
||||
};
|
||||
if result.interrupt_turn {
|
||||
JsReplManager::mark_exec_tool_call_interrupted(
|
||||
&exec_tool_calls_for_task,
|
||||
&exec_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
JsReplManager::finish_exec_tool_call(&exec_tool_calls_for_task, &exec_id)
|
||||
.await;
|
||||
let payload = HostToKernel::RunToolResult(result);
|
||||
@@ -996,6 +1037,7 @@ impl JsReplManager {
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some("js_repl cannot invoke itself".to_string()),
|
||||
interrupt_turn: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1053,18 +1095,20 @@ impl JsReplManager {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => match serde_json::to_value(response) {
|
||||
Ok(output) => match serde_json::to_value(output.response_input) {
|
||||
Ok(value) => RunToolResult {
|
||||
id: req.id,
|
||||
ok: true,
|
||||
response: Some(value),
|
||||
error: None,
|
||||
interrupt_turn: output.interrupt_turn,
|
||||
},
|
||||
Err(err) => RunToolResult {
|
||||
id: req.id,
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some(format!("failed to serialize tool output: {err}")),
|
||||
interrupt_turn: false,
|
||||
},
|
||||
},
|
||||
Err(err) => RunToolResult {
|
||||
@@ -1072,6 +1116,7 @@ impl JsReplManager {
|
||||
ok: false,
|
||||
response: None,
|
||||
error: Some(err.to_string()),
|
||||
interrupt_turn: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1161,12 +1206,19 @@ struct RunToolResult {
|
||||
response: Option<JsonValue>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
interrupt_turn: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ExecResultMessage {
|
||||
Ok { output: String },
|
||||
Err { message: String },
|
||||
Ok {
|
||||
output: String,
|
||||
interrupt_turn: bool,
|
||||
},
|
||||
Err {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
@@ -51,7 +52,7 @@ impl ToolCallRuntime {
|
||||
self,
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
) -> impl std::future::Future<Output = Result<ToolDispatchOutput, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
let router = Arc::clone(&self.router);
|
||||
@@ -69,7 +70,7 @@ impl ToolCallRuntime {
|
||||
aborted = false,
|
||||
);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
let handle: AbortOnDropHandle<Result<ToolDispatchOutput, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
@@ -113,8 +114,8 @@ impl ToolCallRuntime {
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> ToolDispatchOutput {
|
||||
let response_input = match &call.payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: Self::abort_message(call, secs),
|
||||
@@ -130,6 +131,10 @@ impl ToolCallRuntime {
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
};
|
||||
ToolDispatchOutput {
|
||||
response_input,
|
||||
interrupt_turn: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::function_tool::FunctionCallError;
|
||||
use crate::memories::usage::emit_metric_for_tool_read;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::sandbox_tags::sandbox_tag;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
@@ -20,7 +21,6 @@ use codex_hooks::HookResult;
|
||||
use codex_hooks::HookToolInput;
|
||||
use codex_hooks::HookToolInputLocalShell;
|
||||
use codex_hooks::HookToolKind;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -53,6 +53,12 @@ pub trait ToolHandler: Send + Sync {
|
||||
/// Perform the actual [ToolInvocation] and returns a [ToolOutput] containing
|
||||
/// the final output to return to the model.
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
|
||||
/// Classify whether a successful tool output should interrupt the turn after
|
||||
/// persisting the tool result to history/rollout.
|
||||
fn should_interrupt_turn(&self, _output: &ToolOutput) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
@@ -79,7 +85,7 @@ impl ToolRegistry {
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
) -> Result<ToolDispatchOutput, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.otel_manager.clone();
|
||||
@@ -155,7 +161,7 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
let is_mutating = handler.is_mutating(&invocation).await;
|
||||
let output_cell = tokio::sync::Mutex::new(None);
|
||||
let output_cell = tokio::sync::Mutex::new(None::<(ToolOutput, bool)>);
|
||||
let invocation_for_tool = invocation.clone();
|
||||
|
||||
let started = Instant::now();
|
||||
@@ -180,8 +186,10 @@ impl ToolRegistry {
|
||||
Ok(output) => {
|
||||
let preview = output.log_preview();
|
||||
let success = output.success_for_logging();
|
||||
let interrupt_turn = output.interrupt_turn_hint()
|
||||
|| handler.should_interrupt_turn(&output);
|
||||
let mut guard = output_cell.lock().await;
|
||||
*guard = Some(output);
|
||||
*guard = Some((output, interrupt_turn));
|
||||
Ok((preview, success))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
@@ -213,10 +221,13 @@ impl ToolRegistry {
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
let output = guard.take().ok_or_else(|| {
|
||||
let (output, interrupt_turn) = guard.take().ok_or_else(|| {
|
||||
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||
})?;
|
||||
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||
Ok(ToolDispatchOutput {
|
||||
response_input: output.into_response(&call_id_owned, &payload_for_response),
|
||||
interrupt_turn,
|
||||
})
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolDispatchOutput;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
@@ -147,7 +148,7 @@ impl ToolRouter {
|
||||
tracker: SharedTurnDiffTracker,
|
||||
call: ToolCall,
|
||||
source: ToolCallSource,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
) -> Result<ToolDispatchOutput, FunctionCallError> {
|
||||
let ToolCall {
|
||||
tool_name,
|
||||
call_id,
|
||||
@@ -164,7 +165,7 @@ impl ToolRouter {
|
||||
"direct tool calls are disabled; use js_repl and codex.tool(...) instead"
|
||||
.to_string(),
|
||||
);
|
||||
return Ok(Self::failure_response(
|
||||
return Ok(Self::failure_output(
|
||||
failure_call_id,
|
||||
payload_outputs_custom,
|
||||
err,
|
||||
@@ -183,7 +184,7 @@ impl ToolRouter {
|
||||
match self.registry.dispatch(invocation).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)),
|
||||
Err(err) => Ok(Self::failure_response(
|
||||
Err(err) => Ok(Self::failure_output(
|
||||
failure_call_id,
|
||||
payload_outputs_custom,
|
||||
err,
|
||||
@@ -191,13 +192,13 @@ impl ToolRouter {
|
||||
}
|
||||
}
|
||||
|
||||
fn failure_response(
|
||||
fn failure_output(
|
||||
call_id: String,
|
||||
payload_outputs_custom: bool,
|
||||
err: FunctionCallError,
|
||||
) -> ResponseInputItem {
|
||||
) -> ToolDispatchOutput {
|
||||
let message = err.to_string();
|
||||
if payload_outputs_custom {
|
||||
let response_input = if payload_outputs_custom {
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: message,
|
||||
@@ -210,6 +211,10 @@ impl ToolRouter {
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
};
|
||||
ToolDispatchOutput {
|
||||
response_input,
|
||||
interrupt_turn: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -265,7 +270,7 @@ mod tests {
|
||||
.dispatch_tool_call(session, turn, tracker, call, ToolCallSource::Direct)
|
||||
.await?;
|
||||
|
||||
match response {
|
||||
match response.response_input {
|
||||
ResponseInputItem::FunctionCallOutput { output, .. } => {
|
||||
let content = output.text_content().unwrap_or_default();
|
||||
assert!(
|
||||
@@ -318,7 +323,7 @@ mod tests {
|
||||
.dispatch_tool_call(session, turn, tracker, call, ToolCallSource::JsRepl)
|
||||
.await?;
|
||||
|
||||
match response {
|
||||
match response.response_input {
|
||||
ResponseInputItem::FunctionCallOutput { output, .. } => {
|
||||
let content = output.text_content().unwrap_or_default();
|
||||
assert!(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
@@ -20,6 +21,7 @@ use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
@@ -165,7 +167,10 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
answers: vec!["yes".to_string()],
|
||||
},
|
||||
);
|
||||
let response = RequestUserInputResponse { answers };
|
||||
let response = RequestUserInputResponse {
|
||||
answers,
|
||||
interrupted: false,
|
||||
};
|
||||
codex
|
||||
.submit(Op::UserInputAnswer {
|
||||
id: request.turn_id.clone(),
|
||||
@@ -190,6 +195,264 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn request_user_input_interrupted_response_preserves_tool_output() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let builder = test_codex();
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::CollaborationModes);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
let call_id = "user-input-call-interrupt";
|
||||
let request_args = json!({
|
||||
"questions": [{
|
||||
"id": "confirm_path",
|
||||
"header": "Confirm",
|
||||
"question": "Proceed with the plan?",
|
||||
"options": [{
|
||||
"label": "Yes (Recommended)",
|
||||
"description": "Continue the current plan."
|
||||
}, {
|
||||
"label": "No",
|
||||
"description": "Stop and revisit the approach."
|
||||
}]
|
||||
}]
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "request_user_input", &request_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let follow_up_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "next turn"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
let response_mock = mount_sse_sequence(&server, vec![first_response, follow_up_response]).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "please confirm".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model.clone(),
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
collaboration_mode: Some(CollaborationMode {
|
||||
mode: ModeKind::Plan,
|
||||
settings: Settings {
|
||||
model: session_configured.model.clone(),
|
||||
reasoning_effort: None,
|
||||
developer_instructions: None,
|
||||
},
|
||||
}),
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let request = wait_for_event_match(&codex, |event| match event {
|
||||
EventMsg::RequestUserInput(request) => Some(request.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(request.call_id, call_id);
|
||||
|
||||
let mut answers = HashMap::new();
|
||||
answers.insert(
|
||||
"confirm_path".to_string(),
|
||||
RequestUserInputAnswer {
|
||||
answers: vec!["yes".to_string()],
|
||||
},
|
||||
);
|
||||
codex
|
||||
.submit(Op::UserInputAnswer {
|
||||
id: request.turn_id.clone(),
|
||||
response: RequestUserInputResponse {
|
||||
answers,
|
||||
interrupted: true,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
let terminal_event = wait_for_event_match(&codex, |event| match event {
|
||||
EventMsg::TurnAborted(_) => Some("aborted"),
|
||||
EventMsg::TurnComplete(_) => Some("complete"),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(terminal_event, "aborted", "expected interrupted turn");
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "follow up".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
collaboration_mode: Some(CollaborationMode {
|
||||
mode: ModeKind::Plan,
|
||||
settings: Settings {
|
||||
model: session_configured.model.clone(),
|
||||
reasoning_effort: None,
|
||||
developer_instructions: None,
|
||||
},
|
||||
}),
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
|
||||
|
||||
let requests = response_mock.requests();
|
||||
let request_with_output = requests
|
||||
.iter()
|
||||
.find(|req| req.function_call_output_text(call_id).is_some())
|
||||
.expect("expected request_user_input function_call_output in later request");
|
||||
let output_text = call_output(request_with_output, call_id);
|
||||
assert!(
|
||||
!output_text.contains("aborted by user"),
|
||||
"request_user_input output should not be replaced by synthetic abort text"
|
||||
);
|
||||
let output_json: Value = serde_json::from_str(&output_text)?;
|
||||
assert_eq!(
|
||||
output_json,
|
||||
json!({
|
||||
"answers": {
|
||||
"confirm_path": { "answers": ["yes"] }
|
||||
},
|
||||
"interrupted": true
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn request_user_input_interrupt_not_blocked_by_earlier_tool() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let builder = test_codex().with_model("test-gpt-5.1-codex");
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::CollaborationModes);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
let call_id = "user-input-call-fast-interrupt";
|
||||
let slow_tool_args = json!({
|
||||
"sleep_after_ms": 2_000
|
||||
})
|
||||
.to_string();
|
||||
let request_args = json!({
|
||||
"questions": [{
|
||||
"id": "confirm_path",
|
||||
"header": "Confirm",
|
||||
"question": "Proceed with the plan?",
|
||||
"options": [{
|
||||
"label": "Yes (Recommended)",
|
||||
"description": "Continue the current plan."
|
||||
}, {
|
||||
"label": "No",
|
||||
"description": "Stop and revisit the approach."
|
||||
}]
|
||||
}]
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call("slow-call-1", "test_sync_tool", &slow_tool_args),
|
||||
ev_function_call(call_id, "request_user_input", &request_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once(&server, first_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "please confirm".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
collaboration_mode: Some(CollaborationMode {
|
||||
mode: ModeKind::Plan,
|
||||
settings: Settings {
|
||||
model: session_configured.model.clone(),
|
||||
reasoning_effort: None,
|
||||
developer_instructions: None,
|
||||
},
|
||||
}),
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let request = wait_for_event_match(&codex, |event| match event {
|
||||
EventMsg::RequestUserInput(request) => Some(request.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(request.call_id, call_id);
|
||||
|
||||
codex
|
||||
.submit(Op::UserInputAnswer {
|
||||
id: request.turn_id.clone(),
|
||||
response: RequestUserInputResponse {
|
||||
answers: HashMap::new(),
|
||||
interrupted: true,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::timeout(Duration::from_millis(750), async {
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnAborted(_))).await;
|
||||
})
|
||||
.await
|
||||
.expect("interrupting request_user_input should abort promptly");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_request_user_input_rejected<F>(mode_name: &str, build_mode: F) -> anyhow::Result<()>
|
||||
where
|
||||
F: FnOnce(String) -> CollaborationMode,
|
||||
|
||||
@@ -41,6 +41,8 @@ pub struct RequestUserInputAnswer {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)]
|
||||
pub struct RequestUserInputResponse {
|
||||
pub answers: HashMap<String, RequestUserInputAnswer>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub interrupted: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)]
|
||||
|
||||
@@ -33,6 +33,7 @@ use crate::render::renderable::Renderable;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::request_user_input::RequestUserInputAnswer;
|
||||
use codex_protocol::request_user_input::RequestUserInputEvent;
|
||||
use codex_protocol::request_user_input::RequestUserInputQuestion;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use codex_protocol::user_input::TextElement;
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
@@ -88,6 +89,9 @@ impl ComposerDraft {
|
||||
struct AnswerState {
|
||||
// Scrollable cursor state for option navigation/highlight.
|
||||
options_state: ScrollState,
|
||||
// Last explicitly committed option selection. We preserve this across later
|
||||
// edits so partial interrupt submission can keep the committed selection.
|
||||
committed_option_idx: Option<usize>,
|
||||
// Per-question notes draft.
|
||||
draft: ComposerDraft,
|
||||
// Whether the answer for this question has been explicitly submitted.
|
||||
@@ -558,6 +562,7 @@ impl RequestUserInputOverlay {
|
||||
}
|
||||
AnswerState {
|
||||
options_state,
|
||||
committed_option_idx: None,
|
||||
draft: ComposerDraft::default(),
|
||||
answer_committed: false,
|
||||
notes_visible: !has_options,
|
||||
@@ -644,6 +649,9 @@ impl RequestUserInputOverlay {
|
||||
let updated = if let Some(answer) = self.current_answer_mut() {
|
||||
answer.options_state.clamp_selection(options_len);
|
||||
answer.answer_committed = committed;
|
||||
if committed {
|
||||
answer.committed_option_idx = answer.options_state.selected_idx;
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
@@ -660,6 +668,7 @@ impl RequestUserInputOverlay {
|
||||
}
|
||||
if let Some(answer) = self.current_answer_mut() {
|
||||
answer.options_state.reset();
|
||||
answer.committed_option_idx = None;
|
||||
answer.draft = ComposerDraft::default();
|
||||
answer.answer_committed = false;
|
||||
answer.notes_visible = false;
|
||||
@@ -710,46 +719,123 @@ impl RequestUserInputOverlay {
|
||||
}
|
||||
}
|
||||
|
||||
fn answer_for_question(
|
||||
&self,
|
||||
idx: usize,
|
||||
question: &RequestUserInputQuestion,
|
||||
committed_only: bool,
|
||||
) -> Option<RequestUserInputAnswer> {
|
||||
let answer_state = &self.answers[idx];
|
||||
if committed_only
|
||||
&& !answer_state.answer_committed
|
||||
&& answer_state.committed_option_idx.is_none()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let options = question.options.as_ref();
|
||||
// For option questions we may still produce no selection.
|
||||
let selected_idx = if options.is_some_and(|opts| !opts.is_empty()) {
|
||||
if answer_state.answer_committed {
|
||||
answer_state.options_state.selected_idx
|
||||
} else if committed_only {
|
||||
answer_state.committed_option_idx
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Notes are appended as extra answers. For freeform questions, only submit when
|
||||
// the user explicitly committed the draft.
|
||||
let notes = if answer_state.answer_committed {
|
||||
answer_state.draft.text_with_pending().trim().to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let selected_label = selected_idx
|
||||
.and_then(|selected_idx| Self::option_label_for_index(question, selected_idx));
|
||||
let mut answer_list = selected_label.into_iter().collect::<Vec<_>>();
|
||||
if !notes.is_empty() {
|
||||
answer_list.push(format!("user_note: {notes}"));
|
||||
}
|
||||
|
||||
Some(RequestUserInputAnswer {
|
||||
answers: answer_list,
|
||||
})
|
||||
}
|
||||
|
||||
fn submit_committed_answers_for_interrupt(&mut self) {
|
||||
self.confirm_unanswered = None;
|
||||
|
||||
let mut answers = HashMap::new();
|
||||
for (idx, question) in self.request.questions.iter().enumerate() {
|
||||
if let Some(answer) = self.answer_for_question(idx, question, true) {
|
||||
answers.insert(question.id.clone(), answer);
|
||||
}
|
||||
}
|
||||
|
||||
self.app_event_tx
|
||||
.send(AppEvent::CodexOp(Op::UserInputAnswer {
|
||||
id: self.request.turn_id.clone(),
|
||||
response: RequestUserInputResponse {
|
||||
answers: answers.clone(),
|
||||
interrupted: true,
|
||||
},
|
||||
}));
|
||||
self.app_event_tx.send(AppEvent::InsertHistoryCell(Box::new(
|
||||
history_cell::RequestUserInputResultCell {
|
||||
questions: self.request.questions.clone(),
|
||||
answers,
|
||||
interrupted: true,
|
||||
},
|
||||
)));
|
||||
}
|
||||
|
||||
fn supports_partial_interrupt_submission(&self) -> bool {
|
||||
!self.request.questions.is_empty()
|
||||
&& self
|
||||
.request
|
||||
.questions
|
||||
.iter()
|
||||
.all(Self::is_tool_style_partial_interrupt_question)
|
||||
}
|
||||
|
||||
fn is_tool_style_partial_interrupt_question(question: &RequestUserInputQuestion) -> bool {
|
||||
// Tool `request_user_input` currently emits option questions with `is_other = true`.
|
||||
// Non-tool prompts rely on a true `Op::Interrupt` to abort the turn.
|
||||
question.is_other
|
||||
&& question
|
||||
.options
|
||||
.as_ref()
|
||||
.is_some_and(|options| !options.is_empty())
|
||||
}
|
||||
|
||||
fn interrupt_current_request(&mut self) {
|
||||
if self.supports_partial_interrupt_submission() {
|
||||
self.submit_committed_answers_for_interrupt();
|
||||
} else {
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Interrupt));
|
||||
}
|
||||
self.done = true;
|
||||
}
|
||||
|
||||
/// Build the response payload and dispatch it to the app.
|
||||
fn submit_answers(&mut self) {
|
||||
self.confirm_unanswered = None;
|
||||
self.save_current_draft();
|
||||
let mut answers = HashMap::new();
|
||||
for (idx, question) in self.request.questions.iter().enumerate() {
|
||||
let answer_state = &self.answers[idx];
|
||||
let options = question.options.as_ref();
|
||||
// For option questions we may still produce no selection.
|
||||
let selected_idx =
|
||||
if options.is_some_and(|opts| !opts.is_empty()) && answer_state.answer_committed {
|
||||
answer_state.options_state.selected_idx
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Notes are appended as extra answers. For freeform questions, only submit when
|
||||
// the user explicitly committed the draft.
|
||||
let notes = if answer_state.answer_committed {
|
||||
answer_state.draft.text_with_pending().trim().to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let selected_label = selected_idx
|
||||
.and_then(|selected_idx| Self::option_label_for_index(question, selected_idx));
|
||||
let mut answer_list = selected_label.into_iter().collect::<Vec<_>>();
|
||||
if !notes.is_empty() {
|
||||
answer_list.push(format!("user_note: {notes}"));
|
||||
if let Some(answer) = self.answer_for_question(idx, question, false) {
|
||||
answers.insert(question.id.clone(), answer);
|
||||
}
|
||||
answers.insert(
|
||||
question.id.clone(),
|
||||
RequestUserInputAnswer {
|
||||
answers: answer_list,
|
||||
},
|
||||
);
|
||||
}
|
||||
self.app_event_tx
|
||||
.send(AppEvent::CodexOp(Op::UserInputAnswer {
|
||||
id: self.request.turn_id.clone(),
|
||||
response: RequestUserInputResponse {
|
||||
answers: answers.clone(),
|
||||
interrupted: false,
|
||||
},
|
||||
}));
|
||||
self.app_event_tx.send(AppEvent::InsertHistoryCell(Box::new(
|
||||
@@ -925,6 +1011,7 @@ impl RequestUserInputOverlay {
|
||||
if self.has_options() {
|
||||
if let Some(answer) = self.current_answer_mut() {
|
||||
answer.answer_committed = true;
|
||||
answer.committed_option_idx = answer.options_state.selected_idx;
|
||||
}
|
||||
} else if let Some(answer) = self.current_answer_mut() {
|
||||
answer.answer_committed = !text.trim().is_empty();
|
||||
@@ -1005,10 +1092,7 @@ impl BottomPaneView for RequestUserInputOverlay {
|
||||
self.clear_notes_and_focus_options();
|
||||
return;
|
||||
}
|
||||
// TODO: Emit interrupted request_user_input results (including committed answers)
|
||||
// once core supports persisting them reliably without follow-up turn issues.
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Interrupt));
|
||||
self.done = true;
|
||||
self.interrupt_current_request();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1221,10 +1305,7 @@ impl BottomPaneView for RequestUserInputOverlay {
|
||||
fn on_ctrl_c(&mut self) -> CancellationEvent {
|
||||
if self.confirm_unanswered_active() {
|
||||
self.close_unanswered_confirmation();
|
||||
// TODO: Emit interrupted request_user_input results (including committed answers)
|
||||
// once core supports persisting them reliably without follow-up turn issues.
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Interrupt));
|
||||
self.done = true;
|
||||
self.interrupt_current_request();
|
||||
return CancellationEvent::Handled;
|
||||
}
|
||||
if self.focus_is_notes() && !self.composer.current_text_with_pending().is_empty() {
|
||||
@@ -1232,10 +1313,7 @@ impl BottomPaneView for RequestUserInputOverlay {
|
||||
return CancellationEvent::Handled;
|
||||
}
|
||||
|
||||
// TODO: Emit interrupted request_user_input results (including committed answers)
|
||||
// once core supports persisting them reliably without follow-up turn issues.
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Interrupt));
|
||||
self.done = true;
|
||||
self.interrupt_current_request();
|
||||
CancellationEvent::Handled
|
||||
}
|
||||
|
||||
@@ -1298,15 +1376,36 @@ mod tests {
|
||||
(AppEventSender::new(tx_raw), rx)
|
||||
}
|
||||
|
||||
fn expect_interrupt_only(rx: &mut tokio::sync::mpsc::UnboundedReceiver<AppEvent>) {
|
||||
let event = rx.try_recv().expect("expected interrupt AppEvent");
|
||||
fn expect_partial_interrupt_submission(
|
||||
rx: &mut tokio::sync::mpsc::UnboundedReceiver<AppEvent>,
|
||||
expected_turn_id: &str,
|
||||
) -> RequestUserInputResponse {
|
||||
let event = rx.try_recv().expect("expected partial answer AppEvent");
|
||||
let AppEvent::CodexOp(op) = event else {
|
||||
panic!("expected CodexOp");
|
||||
};
|
||||
assert_eq!(op, Op::Interrupt);
|
||||
let Op::UserInputAnswer { id, response } = op else {
|
||||
panic!("expected UserInputAnswer");
|
||||
};
|
||||
assert_eq!(id, expected_turn_id);
|
||||
assert!(response.interrupted, "expected interrupted response");
|
||||
|
||||
let event = rx.try_recv().expect("expected history cell");
|
||||
assert!(matches!(event, AppEvent::InsertHistoryCell(_)));
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"unexpected AppEvents before interrupt completion"
|
||||
"unexpected AppEvents after interrupted submission"
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn expect_interrupt_op(rx: &mut tokio::sync::mpsc::UnboundedReceiver<AppEvent>) {
|
||||
let event = rx.try_recv().expect("expected interrupt AppEvent");
|
||||
assert!(matches!(event, AppEvent::CodexOp(Op::Interrupt)));
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"unexpected AppEvents after interrupt"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1529,7 +1628,7 @@ mod tests {
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
assert!(overlay.done, "expected overlay to be done");
|
||||
expect_interrupt_only(&mut rx);
|
||||
expect_interrupt_op(&mut rx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1910,14 +2009,17 @@ mod tests {
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
assert_eq!(overlay.done, true);
|
||||
expect_interrupt_only(&mut rx);
|
||||
expect_interrupt_op(&mut rx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_in_options_mode_interrupts() {
|
||||
fn esc_in_tool_options_mode_interrupts_with_partial_submission() {
|
||||
let (tx, mut rx) = test_sender();
|
||||
let mut overlay = RequestUserInputOverlay::new(
|
||||
request_event("turn-1", vec![question_with_options("q1", "Pick one")]),
|
||||
request_event(
|
||||
"turn-1",
|
||||
vec![question_with_options_and_other("q1", "Pick one")],
|
||||
),
|
||||
tx,
|
||||
true,
|
||||
false,
|
||||
@@ -1927,7 +2029,8 @@ mod tests {
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
assert_eq!(overlay.done, true);
|
||||
expect_interrupt_only(&mut rx);
|
||||
let response = expect_partial_interrupt_submission(&mut rx, "turn-1");
|
||||
assert!(response.answers.is_empty(), "expected no committed answers");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1988,14 +2091,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_drops_committed_answers() {
|
||||
fn esc_submits_only_committed_answers_before_interrupt() {
|
||||
let (tx, mut rx) = test_sender();
|
||||
let mut overlay = RequestUserInputOverlay::new(
|
||||
request_event(
|
||||
"turn-1",
|
||||
vec![
|
||||
question_with_options("q1", "First"),
|
||||
question_without_options("q2", "Second"),
|
||||
question_with_options_and_other("q1", "First"),
|
||||
question_with_options_and_other("q2", "Second"),
|
||||
],
|
||||
),
|
||||
tx,
|
||||
@@ -2012,7 +2115,77 @@ mod tests {
|
||||
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
expect_interrupt_only(&mut rx);
|
||||
let response = expect_partial_interrupt_submission(&mut rx, "turn-1");
|
||||
let answer = response
|
||||
.answers
|
||||
.get("q1")
|
||||
.expect("missing committed answer");
|
||||
assert_eq!(answer.answers, vec!["Option 1".to_string()]);
|
||||
assert!(
|
||||
!response.answers.contains_key("q2"),
|
||||
"uncommitted question should not be submitted"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_interrupt_preserves_committed_selection_after_notes_clear() {
|
||||
let (tx, mut rx) = test_sender();
|
||||
let mut overlay = RequestUserInputOverlay::new(
|
||||
request_event(
|
||||
"turn-1",
|
||||
vec![
|
||||
question_with_options_and_other("q1", "First"),
|
||||
question_with_options_and_other("q2", "Second"),
|
||||
],
|
||||
),
|
||||
tx,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
);
|
||||
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Char('2')));
|
||||
assert_eq!(overlay.current_index(), 1);
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"unexpected AppEvent before interruption"
|
||||
);
|
||||
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Char('h')));
|
||||
assert_eq!(overlay.current_index(), 0);
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Tab));
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
let answer = overlay.current_answer().expect("answer missing");
|
||||
assert_eq!(answer.answer_committed, false);
|
||||
assert_eq!(answer.options_state.selected_idx, Some(1));
|
||||
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
let response = expect_partial_interrupt_submission(&mut rx, "turn-1");
|
||||
let answer = response
|
||||
.answers
|
||||
.get("q1")
|
||||
.expect("missing committed answer");
|
||||
assert_eq!(answer.answers, vec!["Option 2".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_on_non_tool_freeform_request_emits_interrupt() {
|
||||
let (tx, mut rx) = test_sender();
|
||||
let mut overlay = RequestUserInputOverlay::new(
|
||||
request_event("turn-1", vec![question_without_options("q1", "Notes")]),
|
||||
tx,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
);
|
||||
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Char('x')));
|
||||
overlay.handle_key_event(KeyEvent::from(KeyCode::Esc));
|
||||
|
||||
assert!(overlay.done, "expected overlay to be done");
|
||||
expect_interrupt_op(&mut rx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user