mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
Preserve interrupt semantics during async abort
Keep aborting turns installed until cleanup finishes and surface interrupting network approval denials as interrupted outcomes. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -205,6 +205,7 @@ mod rollout_reconstruction_tests;
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum SteerInputError {
|
||||
NoActiveTurn(Vec<UserInput>),
|
||||
TurnAbortInProgress(Vec<UserInput>),
|
||||
ExpectedTurnMismatch { expected: String, actual: String },
|
||||
ActiveTurnNotSteerable { turn_kind: NonSteerableTurnKind },
|
||||
EmptyInput,
|
||||
@@ -217,6 +218,10 @@ impl SteerInputError {
|
||||
message: "no active turn to steer".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
},
|
||||
Self::TurnAbortInProgress(_) => ErrorEvent {
|
||||
message: "turn abort is still in progress".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
},
|
||||
Self::ExpectedTurnMismatch { expected, actual } => ErrorEvent {
|
||||
message: format!("expected active turn id `{expected}` but found `{actual}`"),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
@@ -4159,6 +4164,9 @@ impl Session {
|
||||
let Some(active_turn) = active.as_mut() else {
|
||||
return Err(SteerInputError::NoActiveTurn(input));
|
||||
};
|
||||
if active_turn.is_aborting() {
|
||||
return Err(SteerInputError::TurnAbortInProgress(input));
|
||||
}
|
||||
|
||||
let Some((active_turn_id, _)) = active_turn.tasks.first() else {
|
||||
return Err(SteerInputError::NoActiveTurn(input));
|
||||
@@ -5017,6 +5025,21 @@ mod handlers {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let abort_in_progress = {
|
||||
let active = sess.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.is_some_and(crate::state::ActiveTurn::is_aborting)
|
||||
};
|
||||
if abort_in_progress {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(SteerInputError::TurnAbortInProgress(items).to_error_event()),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(current_context) = sess.new_turn_with_sub_id(sub_id.clone(), updates).await else {
|
||||
// new_turn_with_sub_id already emits the error event.
|
||||
return;
|
||||
|
||||
@@ -4654,6 +4654,38 @@ impl SessionTask for NeverEndingTask {
|
||||
}
|
||||
}
|
||||
|
||||
struct SlowAbortTask {
|
||||
abort_started: Arc<tokio::sync::Notify>,
|
||||
finish_abort: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl SessionTask for SlowAbortTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
fn span_name(&self) -> &'static str {
|
||||
"session_task.slow_abort"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
_session: Arc<SessionTaskContext>,
|
||||
_ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, _session: Arc<SessionTaskContext>, _ctx: Arc<TurnContext>) {
|
||||
self.abort_started.notify_waiters();
|
||||
self.finish_abort.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[test_log::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
@@ -4721,6 +4753,67 @@ async fn abort_gracefully_emits_turn_aborted_only() {
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn interrupt_keeps_active_turn_installed_until_abort_cleanup_finishes() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
let abort_started = Arc::new(tokio::sync::Notify::new());
|
||||
let finish_abort = Arc::new(tokio::sync::Notify::new());
|
||||
let input = vec![UserInput::Text {
|
||||
text: "hello".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
input,
|
||||
SlowAbortTask {
|
||||
abort_started: Arc::clone(&abort_started),
|
||||
finish_abort: Arc::clone(&finish_abort),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.interrupt_task().await;
|
||||
tokio::time::timeout(Duration::from_secs(2), abort_started.notified())
|
||||
.await
|
||||
.expect("timeout waiting for abort cleanup to start");
|
||||
|
||||
{
|
||||
let active_turn = sess.active_turn.lock().await;
|
||||
assert!(
|
||||
active_turn
|
||||
.as_ref()
|
||||
.is_some_and(crate::state::ActiveTurn::is_aborting)
|
||||
);
|
||||
}
|
||||
|
||||
let err = sess
|
||||
.steer_input(
|
||||
vec![UserInput::Text {
|
||||
text: "new prompt".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
/*expected_turn_id*/ None,
|
||||
/*responsesapi_client_metadata*/ None,
|
||||
)
|
||||
.await
|
||||
.expect_err("interrupting turn should not accept new steer input");
|
||||
assert!(matches!(err, SteerInputError::TurnAbortInProgress(_)));
|
||||
|
||||
finish_abort.notify_waiters();
|
||||
|
||||
let evt = tokio::time::timeout(Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for abort event")
|
||||
.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
|
||||
let active_turn = sess.active_turn.lock().await;
|
||||
assert!(active_turn.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
@@ -25,6 +25,7 @@ use codex_protocol::protocol::TokenUsage;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
pub(crate) struct ActiveTurn {
|
||||
abort_in_progress: bool,
|
||||
pub(crate) tasks: IndexMap<String, RunningTask>,
|
||||
pub(crate) turn_state: Arc<Mutex<TurnState>>,
|
||||
}
|
||||
@@ -53,6 +54,7 @@ pub(crate) enum MailboxDeliveryPhase {
|
||||
impl Default for ActiveTurn {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
abort_in_progress: false,
|
||||
tasks: IndexMap::new(),
|
||||
turn_state: Arc::new(Mutex::new(TurnState::default())),
|
||||
}
|
||||
@@ -83,11 +85,23 @@ impl ActiveTurn {
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
pub(crate) fn is_aborting(&self) -> bool {
|
||||
self.abort_in_progress
|
||||
}
|
||||
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
|
||||
self.tasks.swap_remove(sub_id);
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn begin_abort(&mut self) -> Option<Vec<RunningTask>> {
|
||||
if self.abort_in_progress {
|
||||
return None;
|
||||
}
|
||||
self.abort_in_progress = true;
|
||||
Some(self.drain_tasks())
|
||||
}
|
||||
|
||||
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
|
||||
self.tasks.drain(..).map(|(_, task)| task).collect()
|
||||
}
|
||||
@@ -245,11 +259,3 @@ impl TurnState {
|
||||
self.granted_permissions.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
/// Clear any pending approvals and input buffered for the current turn.
|
||||
pub(crate) async fn clear_pending(&self) {
|
||||
let mut ts = self.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,6 +261,7 @@ impl Session {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let turn = active.get_or_insert_with(ActiveTurn::default);
|
||||
debug_assert!(turn.tasks.is_empty());
|
||||
debug_assert!(!turn.is_aborting());
|
||||
Arc::clone(&turn.turn_state)
|
||||
};
|
||||
{
|
||||
@@ -277,6 +278,7 @@ impl Session {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let turn = active.get_or_insert_with(ActiveTurn::default);
|
||||
debug_assert!(turn.tasks.is_empty());
|
||||
debug_assert!(!turn.is_aborting());
|
||||
let done_clone = Arc::clone(&done);
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
@@ -384,9 +386,16 @@ impl Session {
|
||||
|
||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
if let Some(mut active_turn) = self.take_active_turn().await {
|
||||
let tasks = self.prepare_tasks_for_abort(&mut active_turn).await;
|
||||
let turn_state = Arc::clone(&active_turn.turn_state);
|
||||
let tasks = active_turn.drain_tasks();
|
||||
self.prepare_tasks_for_abort(&tasks, &turn_state).await;
|
||||
for task in tasks {
|
||||
self.handle_task_abort(task, reason.clone()).await;
|
||||
self.handle_task_abort(
|
||||
task,
|
||||
reason.clone(),
|
||||
/*clear_active_turn_before_event*/ false,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
if reason == TurnAbortReason::Interrupted {
|
||||
@@ -395,14 +404,24 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn request_task_abort(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
let Some(mut active_turn) = self.take_active_turn().await else {
|
||||
return;
|
||||
let (tasks, turn_state) = {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let Some(active_turn) = active.as_mut() else {
|
||||
return;
|
||||
};
|
||||
let Some(tasks) = active_turn.begin_abort() else {
|
||||
return;
|
||||
};
|
||||
(tasks, Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
let tasks = self.prepare_tasks_for_abort(&mut active_turn).await;
|
||||
self.prepare_tasks_for_abort(&tasks, &turn_state).await;
|
||||
let session = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
for task in tasks {
|
||||
session.handle_task_abort(task, reason.clone()).await;
|
||||
let num_tasks = tasks.len();
|
||||
for (index, task) in tasks.into_iter().enumerate() {
|
||||
session
|
||||
.handle_task_abort(task, reason.clone(), index + 1 == num_tasks)
|
||||
.await;
|
||||
}
|
||||
if reason == TurnAbortReason::Interrupted {
|
||||
session.maybe_start_turn_for_pending_work().await;
|
||||
@@ -564,9 +583,12 @@ impl Session {
|
||||
active.take()
|
||||
}
|
||||
|
||||
async fn prepare_tasks_for_abort(&self, active_turn: &mut ActiveTurn) -> Vec<RunningTask> {
|
||||
let tasks = active_turn.drain_tasks();
|
||||
for task in &tasks {
|
||||
async fn prepare_tasks_for_abort(
|
||||
&self,
|
||||
tasks: &[RunningTask],
|
||||
turn_state: &Arc<tokio::sync::Mutex<crate::state::TurnState>>,
|
||||
) {
|
||||
for task in tasks {
|
||||
task.cancellation_token.cancel();
|
||||
task.turn_context
|
||||
.turn_metadata_state
|
||||
@@ -574,8 +596,8 @@ impl Session {
|
||||
}
|
||||
// Let interrupted tasks observe cancellation before dropping pending approvals, or an
|
||||
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
|
||||
active_turn.clear_pending().await;
|
||||
tasks
|
||||
let mut state = turn_state.lock().await;
|
||||
state.clear_pending();
|
||||
}
|
||||
|
||||
pub(crate) async fn close_unified_exec_processes(&self) {
|
||||
@@ -593,7 +615,12 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
async fn handle_task_abort(
|
||||
self: &Arc<Self>,
|
||||
task: RunningTask,
|
||||
reason: TurnAbortReason,
|
||||
clear_active_turn_before_event: bool,
|
||||
) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if !task.cancellation_token.is_cancelled() {
|
||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||
@@ -634,6 +661,13 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
if clear_active_turn_before_event {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if active.as_ref().is_some_and(ActiveTurn::is_aborting) {
|
||||
*active = None;
|
||||
}
|
||||
}
|
||||
|
||||
let (completed_at, duration_ms) = task
|
||||
.turn_context
|
||||
.turn_timing_state
|
||||
|
||||
@@ -118,6 +118,7 @@ enum PendingApprovalDecision {
|
||||
enum NetworkApprovalOutcome {
|
||||
DeniedByUser,
|
||||
DeniedByPolicy(String),
|
||||
Interrupted,
|
||||
}
|
||||
|
||||
/// Whether an allowlist miss may be reviewed instead of hard-denied.
|
||||
@@ -269,7 +270,7 @@ impl NetworkApprovalService {
|
||||
let mut call_outcomes = self.call_outcomes.lock().await;
|
||||
if matches!(
|
||||
call_outcomes.get(registration_id),
|
||||
Some(NetworkApprovalOutcome::DeniedByUser)
|
||||
Some(NetworkApprovalOutcome::DeniedByUser | NetworkApprovalOutcome::Interrupted)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
@@ -404,11 +405,13 @@ impl NetworkApprovalService {
|
||||
}
|
||||
PermissionRequestDecision::Deny { message, interrupt } => {
|
||||
if let Some(owner_call) = owner_call.as_ref() {
|
||||
self.record_call_outcome(
|
||||
&owner_call.registration_id,
|
||||
NetworkApprovalOutcome::DeniedByPolicy(message.clone()),
|
||||
)
|
||||
.await;
|
||||
let outcome = if interrupt {
|
||||
NetworkApprovalOutcome::Interrupted
|
||||
} else {
|
||||
NetworkApprovalOutcome::DeniedByPolicy(message.clone())
|
||||
};
|
||||
self.record_call_outcome(&owner_call.registration_id, outcome)
|
||||
.await;
|
||||
}
|
||||
if interrupt {
|
||||
session.interrupt_task().await;
|
||||
@@ -671,6 +674,7 @@ pub(crate) async fn finish_immediate_network_approval(
|
||||
Err(ToolError::Rejected("rejected by user".to_string()))
|
||||
}
|
||||
Some(NetworkApprovalOutcome::DeniedByPolicy(message)) => Err(ToolError::Rejected(message)),
|
||||
Some(NetworkApprovalOutcome::Interrupted) => Err(ToolError::Interrupted),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context_with_rx;
|
||||
use codex_network_proxy::BlockedRequestArgs;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
@@ -254,6 +255,60 @@ async fn blocked_request_policy_does_not_override_user_denial_outcome() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blocked_request_policy_does_not_override_interrupted_outcome() {
|
||||
let service = NetworkApprovalService::default();
|
||||
service
|
||||
.register_call(
|
||||
"registration-1".to_string(),
|
||||
"turn-1".to_string(),
|
||||
"curl http://example.com".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
service
|
||||
.record_call_outcome("registration-1", NetworkApprovalOutcome::Interrupted)
|
||||
.await;
|
||||
service
|
||||
.record_blocked_request(denied_blocked_request("example.com"))
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
service.take_call_outcome("registration-1").await,
|
||||
Some(NetworkApprovalOutcome::Interrupted)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finish_immediate_network_approval_returns_interrupted() {
|
||||
let (session, _turn_context, _rx) = make_session_and_context_with_rx().await;
|
||||
session
|
||||
.services
|
||||
.network_approval
|
||||
.register_call(
|
||||
"registration-1".to_string(),
|
||||
"turn-1".to_string(),
|
||||
"curl http://example.com".to_string(),
|
||||
)
|
||||
.await;
|
||||
session
|
||||
.services
|
||||
.network_approval
|
||||
.record_call_outcome("registration-1", NetworkApprovalOutcome::Interrupted)
|
||||
.await;
|
||||
|
||||
let result = finish_immediate_network_approval(
|
||||
&session,
|
||||
ActiveNetworkApproval {
|
||||
registration_id: Some("registration-1".to_string()),
|
||||
mode: NetworkApprovalMode::Immediate,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ToolError::Interrupted)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_blocked_request_ignores_ambiguous_unattributed_blocked_requests() {
|
||||
let service = NetworkApprovalService::default();
|
||||
|
||||
Reference in New Issue
Block a user