Clarify turn context ownership and retry refresh CAS

Rename RunningTask.turn_context to initial_turn_context to make task snapshot semantics explicit, and retry session-context refresh when compare-and-swap installation races with concurrent mid-turn updates.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Charles Cunningham
2026-03-06 00:53:14 -08:00
parent c360485a13
commit e2fae5de8d
4 changed files with 57 additions and 43 deletions

View File

@@ -2403,25 +2403,38 @@ impl Session {
}
async fn refresh_current_active_turn_context_from_session_configuration(&self) {
let Some(current_turn_context) = self.current_active_turn_context().await else {
return;
};
let session_configuration = {
let state = self.state.lock().await;
state.session_configuration.clone()
};
let realtime_active = self.conversation.running_state().await.is_some();
let next_turn_context = self
.build_updated_turn_context(current_turn_context.as_ref(), &session_configuration)
.await;
let next_turn_context = if next_turn_context.realtime_active == realtime_active {
next_turn_context
} else {
Arc::new(next_turn_context.with_realtime_active(realtime_active))
};
let _ = self
.set_current_active_turn_context(Some(&current_turn_context), next_turn_context)
.await;
const MAX_CONTEXT_REFRESH_ATTEMPTS: usize = 3;
for attempt in 0..MAX_CONTEXT_REFRESH_ATTEMPTS {
let Some(current_turn_context) = self.current_active_turn_context().await else {
return;
};
let session_configuration = {
let state = self.state.lock().await;
state.session_configuration.clone()
};
let realtime_active = self.conversation.running_state().await.is_some();
let next_turn_context = self
.build_updated_turn_context(current_turn_context.as_ref(), &session_configuration)
.await;
let next_turn_context = if next_turn_context.realtime_active == realtime_active {
next_turn_context
} else {
Arc::new(next_turn_context.with_realtime_active(realtime_active))
};
if self
.set_current_active_turn_context(Some(&current_turn_context), next_turn_context)
.await
{
return;
}
if attempt + 1 == MAX_CONTEXT_REFRESH_ATTEMPTS {
warn!(
"failed to refresh active turn context from session configuration after {} attempts",
MAX_CONTEXT_REFRESH_ATTEMPTS
);
}
}
}
pub(crate) async fn refresh_current_active_turn_context_from_realtime_state(&self) {
@@ -2761,7 +2774,7 @@ impl Session {
turn.tasks.get(sub_id).map(|task| {
turn.current_turn_context
.clone()
.unwrap_or_else(|| Arc::clone(&task.turn_context))
.unwrap_or_else(|| Arc::clone(&task.initial_turn_context))
})
})
}
@@ -2772,7 +2785,7 @@ impl Session {
turn.current_turn_context.clone().or_else(|| {
turn.tasks
.first()
.map(|(_, task)| Arc::clone(&task.turn_context))
.map(|(_, task)| Arc::clone(&task.initial_turn_context))
})
}
@@ -2792,7 +2805,7 @@ impl Session {
let Some(current_turn_context) = turn.current_turn_context.clone().or_else(|| {
turn.tasks
.first()
.map(|(_, task)| Arc::clone(&task.turn_context))
.map(|(_, task)| Arc::clone(&task.initial_turn_context))
}) else {
return false;
};
@@ -2828,7 +2841,7 @@ impl Session {
Some((
turn.current_turn_context
.clone()
.unwrap_or_else(|| Arc::clone(&task.turn_context)),
.unwrap_or_else(|| Arc::clone(&task.initial_turn_context)),
task.cancellation_token.child_token(),
))
}
@@ -9640,7 +9653,7 @@ mod tests {
handle: Arc::new(tokio_util::task::AbortOnDropHandle::new(tokio::spawn(
async {},
))),
turn_context: Arc::clone(&replacement_turn_context),
initial_turn_context: Arc::clone(&replacement_turn_context),
_timer: None,
},
)]),
@@ -9767,7 +9780,7 @@ mod tests {
handle: Arc::new(tokio_util::task::AbortOnDropHandle::new(tokio::spawn(
async {},
))),
turn_context: Arc::clone(&tc),
initial_turn_context: Arc::clone(&tc),
_timer: None,
},
)]),

View File

@@ -48,15 +48,15 @@ pub(crate) struct RunningTask {
pub(crate) task: Arc<dyn SessionTask>,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
pub(crate) turn_context: Arc<TurnContext>,
pub(crate) initial_turn_context: Arc<TurnContext>,
// Timer recorded when the task drops to capture the full turn duration.
pub(crate) _timer: Option<codex_otel::Timer>,
}
impl ActiveTurn {
pub(crate) fn add_task(&mut self, task: RunningTask) {
self.current_turn_context = Some(Arc::clone(&task.turn_context));
let sub_id = task.turn_context.sub_id.clone();
self.current_turn_context = Some(Arc::clone(&task.initial_turn_context));
let sub_id = task.initial_turn_context.sub_id.clone();
self.tasks.insert(sub_id, task);
}

View File

@@ -185,7 +185,7 @@ impl Session {
kind: task_kind,
task,
cancellation_token,
turn_context: Arc::clone(&turn_context),
initial_turn_context: Arc::clone(&turn_context),
_timer: timer,
};
self.register_new_active_task(running_task).await;
@@ -361,14 +361,14 @@ impl Session {
}
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
let sub_id = task.turn_context.sub_id.clone();
let sub_id = task.initial_turn_context.sub_id.clone();
if task.cancellation_token.is_cancelled() {
return;
}
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
task.cancellation_token.cancel();
task.turn_context
task.initial_turn_context
.turn_metadata_state
.cancel_git_enrichment_task();
let session_task = task.task;
@@ -385,7 +385,7 @@ impl Session {
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task
.abort(session_ctx, Arc::clone(&task.turn_context))
.abort(session_ctx, Arc::clone(&task.initial_turn_context))
.await;
if reason == TurnAbortReason::Interrupted {
@@ -400,8 +400,11 @@ impl Session {
end_turn: None,
phase: None,
};
self.record_into_history(std::slice::from_ref(&marker), task.turn_context.as_ref())
.await;
self.record_into_history(
std::slice::from_ref(&marker),
task.initial_turn_context.as_ref(),
)
.await;
self.persist_rollout_items(&[RolloutItem::ResponseItem(marker)])
.await;
// Ensure the marker is durably visible before emitting TurnAborted: some clients
@@ -410,10 +413,11 @@ impl Session {
}
let event = EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(task.turn_context.sub_id.clone()),
turn_id: Some(task.initial_turn_context.sub_id.clone()),
reason,
});
self.send_event(task.turn_context.as_ref(), event).await;
self.send_event(task.initial_turn_context.as_ref(), event)
.await;
}
}

View File

@@ -248,10 +248,6 @@ impl NetworkApprovalService {
.await;
}
async fn active_turn_context(session: &Session) -> Option<Arc<crate::codex::TurnContext>> {
session.current_active_turn_context().await
}
fn format_network_target(protocol: &str, host: &str, port: u16) -> String {
format!("{protocol}://{host}:{port}")
}
@@ -299,7 +295,7 @@ impl NetworkApprovalService {
format!("Network access to \"{target}\" was blocked by policy.");
let prompt_reason = format!("{} is not in the allowed_domains", request.host);
let Some(turn_context) = Self::active_turn_context(session).await else {
let Some(turn_context) = session.current_active_turn_context().await else {
pending.set_decision(PendingApprovalDecision::Deny).await;
let mut pending_approvals = self.pending_host_approvals.lock().await;
pending_approvals.remove(&key);
@@ -753,14 +749,15 @@ mod tests {
task: Arc::new(NoopTask),
cancellation_token: CancellationToken::new(),
handle: Arc::new(AbortOnDropHandle::new(handle)),
turn_context: Arc::clone(&original_turn_context),
initial_turn_context: Arc::clone(&original_turn_context),
_timer: None,
},
)]),
..Default::default()
});
let active_turn_context = NetworkApprovalService::active_turn_context(&session)
let active_turn_context = session
.current_active_turn_context()
.await
.expect("active turn context");