From e2fae5de8d88260874fe5fb1bf500f176a5356f4 Mon Sep 17 00:00:00 2001 From: Charles Cunningham Date: Fri, 6 Mar 2026 00:53:14 -0800 Subject: [PATCH] Clarify turn context ownership and retry refresh CAS Rename RunningTask.turn_context to initial_turn_context to make task snapshot semantics explicit, and retry session-context refresh when compare-and-swap installation races with concurrent mid-turn updates. Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 63 +++++++++++++-------- codex-rs/core/src/state/turn.rs | 6 +- codex-rs/core/src/tasks/mod.rs | 20 ++++--- codex-rs/core/src/tools/network_approval.rs | 11 ++-- 4 files changed, 57 insertions(+), 43 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f820e9765e..2b27b07b66 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2403,25 +2403,38 @@ impl Session { } async fn refresh_current_active_turn_context_from_session_configuration(&self) { - let Some(current_turn_context) = self.current_active_turn_context().await else { - return; - }; - let session_configuration = { - let state = self.state.lock().await; - state.session_configuration.clone() - }; - let realtime_active = self.conversation.running_state().await.is_some(); - let next_turn_context = self - .build_updated_turn_context(current_turn_context.as_ref(), &session_configuration) - .await; - let next_turn_context = if next_turn_context.realtime_active == realtime_active { - next_turn_context - } else { - Arc::new(next_turn_context.with_realtime_active(realtime_active)) - }; - let _ = self - .set_current_active_turn_context(Some(¤t_turn_context), next_turn_context) - .await; + const MAX_CONTEXT_REFRESH_ATTEMPTS: usize = 3; + for attempt in 0..MAX_CONTEXT_REFRESH_ATTEMPTS { + let Some(current_turn_context) = self.current_active_turn_context().await else { + return; + }; + let session_configuration = { + let state = self.state.lock().await; + state.session_configuration.clone() + }; + let realtime_active = self.conversation.running_state().await.is_some(); + let next_turn_context = self + .build_updated_turn_context(current_turn_context.as_ref(), &session_configuration) + .await; + let next_turn_context = if next_turn_context.realtime_active == realtime_active { + next_turn_context + } else { + Arc::new(next_turn_context.with_realtime_active(realtime_active)) + }; + if self + .set_current_active_turn_context(Some(¤t_turn_context), next_turn_context) + .await + { + return; + } + + if attempt + 1 == MAX_CONTEXT_REFRESH_ATTEMPTS { + warn!( + "failed to refresh active turn context from session configuration after {} attempts", + MAX_CONTEXT_REFRESH_ATTEMPTS + ); + } + } } pub(crate) async fn refresh_current_active_turn_context_from_realtime_state(&self) { @@ -2761,7 +2774,7 @@ impl Session { turn.tasks.get(sub_id).map(|task| { turn.current_turn_context .clone() - .unwrap_or_else(|| Arc::clone(&task.turn_context)) + .unwrap_or_else(|| Arc::clone(&task.initial_turn_context)) }) }) } @@ -2772,7 +2785,7 @@ impl Session { turn.current_turn_context.clone().or_else(|| { turn.tasks .first() - .map(|(_, task)| Arc::clone(&task.turn_context)) + .map(|(_, task)| Arc::clone(&task.initial_turn_context)) }) } @@ -2792,7 +2805,7 @@ impl Session { let Some(current_turn_context) = turn.current_turn_context.clone().or_else(|| { turn.tasks .first() - .map(|(_, task)| Arc::clone(&task.turn_context)) + .map(|(_, task)| Arc::clone(&task.initial_turn_context)) }) else { return false; }; @@ -2828,7 +2841,7 @@ impl Session { Some(( turn.current_turn_context .clone() - .unwrap_or_else(|| Arc::clone(&task.turn_context)), + .unwrap_or_else(|| Arc::clone(&task.initial_turn_context)), task.cancellation_token.child_token(), )) } @@ -9640,7 +9653,7 @@ mod tests { handle: Arc::new(tokio_util::task::AbortOnDropHandle::new(tokio::spawn( async {}, ))), - turn_context: Arc::clone(&replacement_turn_context), + initial_turn_context: Arc::clone(&replacement_turn_context), _timer: None, }, )]), @@ -9767,7 +9780,7 @@ mod tests { handle: Arc::new(tokio_util::task::AbortOnDropHandle::new(tokio::spawn( async {}, ))), - turn_context: Arc::clone(&tc), + initial_turn_context: Arc::clone(&tc), _timer: None, }, )]), diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index bb48f98f38..3ca73cce90 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -48,15 +48,15 @@ pub(crate) struct RunningTask { pub(crate) task: Arc, pub(crate) cancellation_token: CancellationToken, pub(crate) handle: Arc>, - pub(crate) turn_context: Arc, + pub(crate) initial_turn_context: Arc, // Timer recorded when the task drops to capture the full turn duration. pub(crate) _timer: Option, } impl ActiveTurn { pub(crate) fn add_task(&mut self, task: RunningTask) { - self.current_turn_context = Some(Arc::clone(&task.turn_context)); - let sub_id = task.turn_context.sub_id.clone(); + self.current_turn_context = Some(Arc::clone(&task.initial_turn_context)); + let sub_id = task.initial_turn_context.sub_id.clone(); self.tasks.insert(sub_id, task); } diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index f32957c5ad..50e4505800 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -185,7 +185,7 @@ impl Session { kind: task_kind, task, cancellation_token, - turn_context: Arc::clone(&turn_context), + initial_turn_context: Arc::clone(&turn_context), _timer: timer, }; self.register_new_active_task(running_task).await; @@ -361,14 +361,14 @@ impl Session { } async fn handle_task_abort(self: &Arc, task: RunningTask, reason: TurnAbortReason) { - let sub_id = task.turn_context.sub_id.clone(); + let sub_id = task.initial_turn_context.sub_id.clone(); if task.cancellation_token.is_cancelled() { return; } trace!(task_kind = ?task.kind, sub_id, "aborting running task"); task.cancellation_token.cancel(); - task.turn_context + task.initial_turn_context .turn_metadata_state .cancel_git_enrichment_task(); let session_task = task.task; @@ -385,7 +385,7 @@ impl Session { let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); session_task - .abort(session_ctx, Arc::clone(&task.turn_context)) + .abort(session_ctx, Arc::clone(&task.initial_turn_context)) .await; if reason == TurnAbortReason::Interrupted { @@ -400,8 +400,11 @@ impl Session { end_turn: None, phase: None, }; - self.record_into_history(std::slice::from_ref(&marker), task.turn_context.as_ref()) - .await; + self.record_into_history( + std::slice::from_ref(&marker), + task.initial_turn_context.as_ref(), + ) + .await; self.persist_rollout_items(&[RolloutItem::ResponseItem(marker)]) .await; // Ensure the marker is durably visible before emitting TurnAborted: some clients @@ -410,10 +413,11 @@ impl Session { } let event = EventMsg::TurnAborted(TurnAbortedEvent { - turn_id: Some(task.turn_context.sub_id.clone()), + turn_id: Some(task.initial_turn_context.sub_id.clone()), reason, }); - self.send_event(task.turn_context.as_ref(), event).await; + self.send_event(task.initial_turn_context.as_ref(), event) + .await; } } diff --git a/codex-rs/core/src/tools/network_approval.rs b/codex-rs/core/src/tools/network_approval.rs index 4e4f06b739..020de7f9d5 100644 --- a/codex-rs/core/src/tools/network_approval.rs +++ b/codex-rs/core/src/tools/network_approval.rs @@ -248,10 +248,6 @@ impl NetworkApprovalService { .await; } - async fn active_turn_context(session: &Session) -> Option> { - session.current_active_turn_context().await - } - fn format_network_target(protocol: &str, host: &str, port: u16) -> String { format!("{protocol}://{host}:{port}") } @@ -299,7 +295,7 @@ impl NetworkApprovalService { format!("Network access to \"{target}\" was blocked by policy."); let prompt_reason = format!("{} is not in the allowed_domains", request.host); - let Some(turn_context) = Self::active_turn_context(session).await else { + let Some(turn_context) = session.current_active_turn_context().await else { pending.set_decision(PendingApprovalDecision::Deny).await; let mut pending_approvals = self.pending_host_approvals.lock().await; pending_approvals.remove(&key); @@ -753,14 +749,15 @@ mod tests { task: Arc::new(NoopTask), cancellation_token: CancellationToken::new(), handle: Arc::new(AbortOnDropHandle::new(handle)), - turn_context: Arc::clone(&original_turn_context), + initial_turn_context: Arc::clone(&original_turn_context), _timer: None, }, )]), ..Default::default() }); - let active_turn_context = NetworkApprovalService::active_turn_context(&session) + let active_turn_context = session + .current_active_turn_context() .await .expect("active turn context");