Wait deterministically for turn context apply

This commit is contained in:
Eric Traut
2026-05-14 14:07:35 -07:00
parent 8af05a6aa5
commit ea06c8a6e8
2 changed files with 52 additions and 15 deletions

View File

@@ -52,8 +52,7 @@ struct ResolvedTurnContextOverrides {
overrides: CodexThreadTurnContextOverrides,
}
const TURN_CONTEXT_APPLY_TIMEOUT: Duration = Duration::from_secs(5);
const TURN_CONTEXT_APPLY_POLL_INTERVAL: Duration = Duration::from_millis(10);
const TURN_STARTED_ACK_TIMEOUT: Duration = Duration::from_secs(5);
impl TurnRequestProcessor {
#[allow(clippy::too_many_arguments)]
@@ -572,21 +571,28 @@ impl TurnRequestProcessor {
error
})?;
if let Some(after_turn_context) = after_turn_context {
if let Some(mut after_turn_context) = after_turn_context {
if before_turn_context != after_turn_context {
let started = Instant::now();
loop {
let config_snapshot = thread.config_snapshot().await;
if thread_turn_context_from_snapshot(&config_snapshot) == after_turn_context {
break;
}
if started.elapsed() >= TURN_CONTEXT_APPLY_TIMEOUT {
return Err(internal_error(
"timed out waiting for turn context overrides to apply".to_string(),
));
}
tokio::time::sleep(TURN_CONTEXT_APPLY_POLL_INTERVAL).await;
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let turn_started = {
let mut thread_state = thread_state.lock().await;
thread_state.turn_started_receiver(&turn_id)
};
if let Some(turn_started) = turn_started {
// Bound how long the RPC waits for the core turn-start acknowledgement.
tokio::time::timeout(TURN_STARTED_ACK_TIMEOUT, turn_started)
.await
.map_err(|_| {
internal_error(
"timed out waiting for turn context overrides to apply".to_string(),
)
})?
.map_err(|_| {
internal_error("turn context override waiter was cancelled".to_string())
})?;
}
after_turn_context =
thread_turn_context_from_snapshot(&thread.config_snapshot().await);
}
self.maybe_emit_turn_context_updated(
&params.thread_id,

View File

@@ -78,6 +78,7 @@ pub(crate) struct ThreadState {
pub(crate) listener_generation: u64,
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
pending_turn_started_waiters: HashMap<String, Vec<oneshot::Sender<()>>>,
listener_thread: Option<Weak<CodexThread>>,
watch_registration: WatchRegistration,
}
@@ -112,6 +113,7 @@ impl ThreadState {
let _ = cancel_tx.send(());
}
self.listener_command_tx = None;
self.pending_turn_started_waiters.clear();
self.current_turn_history.reset();
self.listener_thread = None;
self.watch_registration = WatchRegistration::default();
@@ -131,16 +133,45 @@ impl ThreadState {
self.current_turn_history.active_turn_snapshot()
}
pub(crate) fn turn_started_receiver(&mut self, turn_id: &str) -> Option<oneshot::Receiver<()>> {
if self
.active_turn_snapshot()
.is_some_and(|turn| turn.id == turn_id)
|| self.last_terminal_turn_id.as_deref() == Some(turn_id)
{
return None;
}
let (tx, rx) = oneshot::channel();
self.pending_turn_started_waiters
.entry(turn_id.to_string())
.or_default()
.push(tx);
Some(rx)
}
pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) {
if let EventMsg::TurnStarted(payload) = event {
self.turn_summary.started_at = payload.started_at;
}
self.current_turn_history.handle_event(event);
if let EventMsg::TurnStarted(payload) = event {
self.notify_turn_started(&payload.turn_id);
}
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
&& !self.current_turn_history.has_active_turn()
{
self.last_terminal_turn_id = Some(event_turn_id.to_string());
self.current_turn_history.reset();
self.notify_turn_started(event_turn_id);
}
}
fn notify_turn_started(&mut self, turn_id: &str) {
if let Some(waiters) = self.pending_turn_started_waiters.remove(turn_id) {
for waiter in waiters {
let _ = waiter.send(());
}
}
}
}