Preserve interrupt semantics during async abort

Keep aborting turns installed until cleanup finishes and surface interrupting network approval denials as interrupted outcomes.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Abhinav Vedmala
2026-04-13 19:01:58 -07:00
parent 71df25b787
commit 9e6f3c18fd
6 changed files with 242 additions and 27 deletions

View File

@@ -205,6 +205,7 @@ mod rollout_reconstruction_tests;
#[derive(Debug, PartialEq)]
pub enum SteerInputError {
NoActiveTurn(Vec<UserInput>),
TurnAbortInProgress(Vec<UserInput>),
ExpectedTurnMismatch { expected: String, actual: String },
ActiveTurnNotSteerable { turn_kind: NonSteerableTurnKind },
EmptyInput,
@@ -217,6 +218,10 @@ impl SteerInputError {
message: "no active turn to steer".to_string(),
codex_error_info: Some(CodexErrorInfo::BadRequest),
},
Self::TurnAbortInProgress(_) => ErrorEvent {
message: "turn abort is still in progress".to_string(),
codex_error_info: Some(CodexErrorInfo::BadRequest),
},
Self::ExpectedTurnMismatch { expected, actual } => ErrorEvent {
message: format!("expected active turn id `{expected}` but found `{actual}`"),
codex_error_info: Some(CodexErrorInfo::BadRequest),
@@ -4159,6 +4164,9 @@ impl Session {
let Some(active_turn) = active.as_mut() else {
return Err(SteerInputError::NoActiveTurn(input));
};
if active_turn.is_aborting() {
return Err(SteerInputError::TurnAbortInProgress(input));
}
let Some((active_turn_id, _)) = active_turn.tasks.first() else {
return Err(SteerInputError::NoActiveTurn(input));
@@ -5017,6 +5025,21 @@ mod handlers {
_ => unreachable!(),
};
let abort_in_progress = {
let active = sess.active_turn.lock().await;
active
.as_ref()
.is_some_and(crate::state::ActiveTurn::is_aborting)
};
if abort_in_progress {
sess.send_event_raw(Event {
id: sub_id,
msg: EventMsg::Error(SteerInputError::TurnAbortInProgress(items).to_error_event()),
})
.await;
return;
}
let Ok(current_context) = sess.new_turn_with_sub_id(sub_id.clone(), updates).await else {
// new_turn_with_sub_id already emits the error event.
return;

View File

@@ -4654,6 +4654,38 @@ impl SessionTask for NeverEndingTask {
}
}
struct SlowAbortTask {
abort_started: Arc<tokio::sync::Notify>,
finish_abort: Arc<tokio::sync::Notify>,
}
impl SessionTask for SlowAbortTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
fn span_name(&self) -> &'static str {
"session_task.slow_abort"
}
async fn run(
self: Arc<Self>,
_session: Arc<SessionTaskContext>,
_ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
_cancellation_token: CancellationToken,
) -> Option<String> {
loop {
sleep(Duration::from_secs(60)).await;
}
}
async fn abort(&self, _session: Arc<SessionTaskContext>, _ctx: Arc<TurnContext>) {
self.abort_started.notify_waiters();
self.finish_abort.notified().await;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_regular_task_emits_turn_aborted_only() {
@@ -4721,6 +4753,67 @@ async fn abort_gracefully_emits_turn_aborted_only() {
assert!(rx.try_recv().is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn interrupt_keeps_active_turn_installed_until_abort_cleanup_finishes() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
let abort_started = Arc::new(tokio::sync::Notify::new());
let finish_abort = Arc::new(tokio::sync::Notify::new());
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
SlowAbortTask {
abort_started: Arc::clone(&abort_started),
finish_abort: Arc::clone(&finish_abort),
},
)
.await;
sess.interrupt_task().await;
tokio::time::timeout(Duration::from_secs(2), abort_started.notified())
.await
.expect("timeout waiting for abort cleanup to start");
{
let active_turn = sess.active_turn.lock().await;
assert!(
active_turn
.as_ref()
.is_some_and(crate::state::ActiveTurn::is_aborting)
);
}
let err = sess
.steer_input(
vec![UserInput::Text {
text: "new prompt".to_string(),
text_elements: Vec::new(),
}],
/*expected_turn_id*/ None,
/*responsesapi_client_metadata*/ None,
)
.await
.expect_err("interrupting turn should not accept new steer input");
assert!(matches!(err, SteerInputError::TurnAbortInProgress(_)));
finish_abort.notify_waiters();
let evt = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for abort event")
.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
let active_turn = sess.active_turn.lock().await;
assert!(active_turn.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;

View File

@@ -25,6 +25,7 @@ use codex_protocol::protocol::TokenUsage;
/// Metadata about the currently running turn.
pub(crate) struct ActiveTurn {
abort_in_progress: bool,
pub(crate) tasks: IndexMap<String, RunningTask>,
pub(crate) turn_state: Arc<Mutex<TurnState>>,
}
@@ -53,6 +54,7 @@ pub(crate) enum MailboxDeliveryPhase {
impl Default for ActiveTurn {
fn default() -> Self {
Self {
abort_in_progress: false,
tasks: IndexMap::new(),
turn_state: Arc::new(Mutex::new(TurnState::default())),
}
@@ -83,11 +85,23 @@ impl ActiveTurn {
self.tasks.insert(sub_id, task);
}
pub(crate) fn is_aborting(&self) -> bool {
self.abort_in_progress
}
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
self.tasks.swap_remove(sub_id);
self.tasks.is_empty()
}
pub(crate) fn begin_abort(&mut self) -> Option<Vec<RunningTask>> {
if self.abort_in_progress {
return None;
}
self.abort_in_progress = true;
Some(self.drain_tasks())
}
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
self.tasks.drain(..).map(|(_, task)| task).collect()
}
@@ -245,11 +259,3 @@ impl TurnState {
self.granted_permissions.clone()
}
}
impl ActiveTurn {
/// Clear any pending approvals and input buffered for the current turn.
pub(crate) async fn clear_pending(&self) {
let mut ts = self.turn_state.lock().await;
ts.clear_pending();
}
}

View File

@@ -261,6 +261,7 @@ impl Session {
let mut active = self.active_turn.lock().await;
let turn = active.get_or_insert_with(ActiveTurn::default);
debug_assert!(turn.tasks.is_empty());
debug_assert!(!turn.is_aborting());
Arc::clone(&turn.turn_state)
};
{
@@ -277,6 +278,7 @@ impl Session {
let mut active = self.active_turn.lock().await;
let turn = active.get_or_insert_with(ActiveTurn::default);
debug_assert!(turn.tasks.is_empty());
debug_assert!(!turn.is_aborting());
let done_clone = Arc::clone(&done);
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let ctx = Arc::clone(&turn_context);
@@ -384,9 +386,16 @@ impl Session {
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
if let Some(mut active_turn) = self.take_active_turn().await {
let tasks = self.prepare_tasks_for_abort(&mut active_turn).await;
let turn_state = Arc::clone(&active_turn.turn_state);
let tasks = active_turn.drain_tasks();
self.prepare_tasks_for_abort(&tasks, &turn_state).await;
for task in tasks {
self.handle_task_abort(task, reason.clone()).await;
self.handle_task_abort(
task,
reason.clone(),
/*clear_active_turn_before_event*/ false,
)
.await;
}
}
if reason == TurnAbortReason::Interrupted {
@@ -395,14 +404,24 @@ impl Session {
}
pub(crate) async fn request_task_abort(self: &Arc<Self>, reason: TurnAbortReason) {
let Some(mut active_turn) = self.take_active_turn().await else {
return;
let (tasks, turn_state) = {
let mut active = self.active_turn.lock().await;
let Some(active_turn) = active.as_mut() else {
return;
};
let Some(tasks) = active_turn.begin_abort() else {
return;
};
(tasks, Arc::clone(&active_turn.turn_state))
};
let tasks = self.prepare_tasks_for_abort(&mut active_turn).await;
self.prepare_tasks_for_abort(&tasks, &turn_state).await;
let session = Arc::clone(self);
tokio::spawn(async move {
for task in tasks {
session.handle_task_abort(task, reason.clone()).await;
let num_tasks = tasks.len();
for (index, task) in tasks.into_iter().enumerate() {
session
.handle_task_abort(task, reason.clone(), index + 1 == num_tasks)
.await;
}
if reason == TurnAbortReason::Interrupted {
session.maybe_start_turn_for_pending_work().await;
@@ -564,9 +583,12 @@ impl Session {
active.take()
}
async fn prepare_tasks_for_abort(&self, active_turn: &mut ActiveTurn) -> Vec<RunningTask> {
let tasks = active_turn.drain_tasks();
for task in &tasks {
async fn prepare_tasks_for_abort(
&self,
tasks: &[RunningTask],
turn_state: &Arc<tokio::sync::Mutex<crate::state::TurnState>>,
) {
for task in tasks {
task.cancellation_token.cancel();
task.turn_context
.turn_metadata_state
@@ -574,8 +596,8 @@ impl Session {
}
// Let interrupted tasks observe cancellation before dropping pending approvals, or an
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
active_turn.clear_pending().await;
tasks
let mut state = turn_state.lock().await;
state.clear_pending();
}
pub(crate) async fn close_unified_exec_processes(&self) {
@@ -593,7 +615,12 @@ impl Session {
}
}
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
async fn handle_task_abort(
self: &Arc<Self>,
task: RunningTask,
reason: TurnAbortReason,
clear_active_turn_before_event: bool,
) {
let sub_id = task.turn_context.sub_id.clone();
if !task.cancellation_token.is_cancelled() {
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
@@ -634,6 +661,13 @@ impl Session {
}
}
if clear_active_turn_before_event {
let mut active = self.active_turn.lock().await;
if active.as_ref().is_some_and(ActiveTurn::is_aborting) {
*active = None;
}
}
let (completed_at, duration_ms) = task
.turn_context
.turn_timing_state

View File

@@ -118,6 +118,7 @@ enum PendingApprovalDecision {
enum NetworkApprovalOutcome {
DeniedByUser,
DeniedByPolicy(String),
Interrupted,
}
/// Whether an allowlist miss may be reviewed instead of hard-denied.
@@ -269,7 +270,7 @@ impl NetworkApprovalService {
let mut call_outcomes = self.call_outcomes.lock().await;
if matches!(
call_outcomes.get(registration_id),
Some(NetworkApprovalOutcome::DeniedByUser)
Some(NetworkApprovalOutcome::DeniedByUser | NetworkApprovalOutcome::Interrupted)
) {
return;
}
@@ -404,11 +405,13 @@ impl NetworkApprovalService {
}
PermissionRequestDecision::Deny { message, interrupt } => {
if let Some(owner_call) = owner_call.as_ref() {
self.record_call_outcome(
&owner_call.registration_id,
NetworkApprovalOutcome::DeniedByPolicy(message.clone()),
)
.await;
let outcome = if interrupt {
NetworkApprovalOutcome::Interrupted
} else {
NetworkApprovalOutcome::DeniedByPolicy(message.clone())
};
self.record_call_outcome(&owner_call.registration_id, outcome)
.await;
}
if interrupt {
session.interrupt_task().await;
@@ -671,6 +674,7 @@ pub(crate) async fn finish_immediate_network_approval(
Err(ToolError::Rejected("rejected by user".to_string()))
}
Some(NetworkApprovalOutcome::DeniedByPolicy(message)) => Err(ToolError::Rejected(message)),
Some(NetworkApprovalOutcome::Interrupted) => Err(ToolError::Interrupted),
None => Ok(()),
}
}

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::codex::make_session_and_context_with_rx;
use codex_network_proxy::BlockedRequestArgs;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
@@ -254,6 +255,60 @@ async fn blocked_request_policy_does_not_override_user_denial_outcome() {
);
}
#[tokio::test]
async fn blocked_request_policy_does_not_override_interrupted_outcome() {
let service = NetworkApprovalService::default();
service
.register_call(
"registration-1".to_string(),
"turn-1".to_string(),
"curl http://example.com".to_string(),
)
.await;
service
.record_call_outcome("registration-1", NetworkApprovalOutcome::Interrupted)
.await;
service
.record_blocked_request(denied_blocked_request("example.com"))
.await;
assert_eq!(
service.take_call_outcome("registration-1").await,
Some(NetworkApprovalOutcome::Interrupted)
);
}
#[tokio::test]
async fn finish_immediate_network_approval_returns_interrupted() {
let (session, _turn_context, _rx) = make_session_and_context_with_rx().await;
session
.services
.network_approval
.register_call(
"registration-1".to_string(),
"turn-1".to_string(),
"curl http://example.com".to_string(),
)
.await;
session
.services
.network_approval
.record_call_outcome("registration-1", NetworkApprovalOutcome::Interrupted)
.await;
let result = finish_immediate_network_approval(
&session,
ActiveNetworkApproval {
registration_id: Some("registration-1".to_string()),
mode: NetworkApprovalMode::Immediate,
},
)
.await;
assert!(matches!(result, Err(ToolError::Interrupted)));
}
#[tokio::test]
async fn record_blocked_request_ignores_ambiguous_unattributed_blocked_requests() {
let service = NetworkApprovalService::default();