From 39d320d3fddbe72ff8b9b7cd5253ebb89ed5e490 Mon Sep 17 00:00:00 2001 From: Brent Traut Date: Mon, 18 May 2026 20:23:04 -0700 Subject: [PATCH] app-server: tighten deferred plan cleanup edge cases --- .../app-server/src/bespoke_event_handling.rs | 280 ++++++++++++------ codex-rs/app-server/src/thread_state.rs | 14 - 2 files changed, 185 insertions(+), 109 deletions(-) diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 7d497427f1..587d422781 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1123,6 +1123,14 @@ pub(crate) async fn apply_bespoke_event_handling( // All per-thread requests are bound to a turn, so abort them. outgoing.abort_pending_server_requests().await; respond_to_pending_interrupts(&thread_state, &outgoing).await; + let has_pending_terminal_plan_cleanup = !thread_state + .lock() + .await + .pending_terminal_plan_cleanups + .is_empty(); + let preserve_terminal_plan_progress = has_pending_terminal_plan_cleanup + && should_preserve_terminal_plan_progress(conversation.as_ref(), conversation_id) + .await; thread_watch_manager .note_turn_interrupted(&conversation_id.to_string()) @@ -1131,6 +1139,7 @@ pub(crate) async fn apply_bespoke_event_handling( conversation_id, event_turn_id, turn_aborted_event, + preserve_terminal_plan_progress, &outgoing, &thread_state, ) @@ -1139,8 +1148,11 @@ pub(crate) async fn apply_bespoke_event_handling( EventMsg::ThreadRolledBack(_rollback_event) => { let pending = { let mut state = thread_state.lock().await; - state.prune_pending_terminal_plan_cleanups_after_rollback(); - state.pending_rollbacks.take() + let pending = state.pending_rollbacks.take(); + if pending.is_none() { + state.pending_terminal_plan_cleanups.clear(); + } + pending }; if let Some(request_id) = pending { @@ -1196,6 +1208,13 @@ pub(crate) async fn apply_bespoke_event_handling( return; } }; + { + let mut state = thread_state.lock().await; + retain_pending_terminal_plan_cleanups_for_turns( + &mut state.pending_terminal_plan_cleanups, + &response.thread.turns, + ); + } outgoing.send_response(request_id, response).await; } @@ -1266,6 +1285,7 @@ async fn handle_turn_plan_update( outgoing: &ThreadScopedOutgoingMessageSender, thread_state: &Arc>, ) { + flush_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; { let mut state = thread_state.lock().await; state @@ -1284,7 +1304,6 @@ async fn handle_turn_plan_update( }); } } - flush_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; emit_turn_plan_updated(conversation_id, event_turn_id, plan_update_event, outgoing).await; } @@ -1380,6 +1399,14 @@ fn terminal_plan_cleanup_updates( .collect() } +fn retain_pending_terminal_plan_cleanups_for_turns( + pending_terminal_plan_cleanups: &mut Vec, + retained_turns: &[Turn], +) { + pending_terminal_plan_cleanups + .retain(|cleanup| retained_turns.iter().any(|turn| turn.id == cleanup.turn_id)); +} + async fn emit_turn_plan_updated( conversation_id: ThreadId, event_turn_id: &str, @@ -1630,11 +1657,14 @@ async fn handle_turn_interrupted( conversation_id: ThreadId, event_turn_id: String, turn_aborted_event: TurnAbortedEvent, + preserve_terminal_plan_progress: bool, outgoing: &ThreadScopedOutgoingMessageSender, thread_state: &Arc>, ) { let turn_summary = find_and_remove_turn_summary(thread_state).await; - flush_all_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; + if !preserve_terminal_plan_progress { + flush_all_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; + } emit_turn_completed_with_status( conversation_id, @@ -3502,6 +3532,7 @@ mod tests { conversation_id, event_turn_id.clone(), turn_aborted_event(&event_turn_id), + /*preserve_terminal_plan_progress*/ false, &outgoing, &thread_state, ) @@ -3654,6 +3685,59 @@ mod tests { Ok(()) } + #[tokio::test] + async fn mid_turn_plan_update_keeps_cleanup_until_terminal_event_without_snapshot() -> Result<()> + { + let conversation_id = ThreadId::new(); + let turn_id = "resume-mid-turn".to_string(); + let thread_state = new_thread_state(); + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new( + tx, + codex_analytics::AnalyticsEventsClient::disabled(), + )); + let outgoing = ThreadScopedOutgoingMessageSender::new( + outgoing, + vec![ConnectionId(1)], + conversation_id, + ); + + handle_turn_plan_update( + conversation_id, + &turn_id, + UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "first".to_string(), + status: StepStatus::InProgress, + }], + }, + &outgoing, + &thread_state, + ) + .await; + + let msg = recv_broadcast_message(&mut rx).await?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(n)) => { + assert_eq!(n.turn_id, turn_id); + assert_eq!(n.plan[0].status, TurnPlanStepStatus::InProgress); + } + other => bail!("unexpected message: {other:?}"), + } + assert_eq!( + thread_state + .lock() + .await + .pending_terminal_plan_cleanups + .len(), + 1, + "the live plan remains eligible for terminal cleanup" + ); + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + #[tokio::test] async fn new_live_plan_flushes_older_deferred_cleanup() -> Result<()> { let conversation_id = ThreadId::new(); @@ -3857,6 +3941,7 @@ mod tests { conversation_id, event_turn_id.clone(), turn_aborted_event(&event_turn_id), + /*preserve_terminal_plan_progress*/ false, &outgoing, &thread_state, ) @@ -3893,6 +3978,65 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_handle_turn_interrupted_preserves_in_progress_plan_with_active_goal() -> Result<()> + { + let conversation_id = ThreadId::new(); + let event_turn_id = "interrupt_active_goal".to_string(); + let thread_state = new_thread_state(); + thread_state.lock().await.pending_terminal_plan_cleanups = + vec![PendingTerminalPlanCleanup { + turn_id: event_turn_id.clone(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "first".to_string(), + status: StepStatus::InProgress, + }], + }, + }]; + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new( + tx, + codex_analytics::AnalyticsEventsClient::disabled(), + )); + let outgoing = ThreadScopedOutgoingMessageSender::new( + outgoing, + vec![ConnectionId(1)], + ThreadId::new(), + ); + + handle_turn_interrupted( + conversation_id, + event_turn_id.clone(), + turn_aborted_event(&event_turn_id), + /*preserve_terminal_plan_progress*/ true, + &outgoing, + &thread_state, + ) + .await; + + let msg = recv_broadcast_message(&mut rx).await?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, event_turn_id); + assert_eq!(n.turn.status, TurnStatus::Interrupted); + } + other => bail!("unexpected message: {other:?}"), + } + assert_eq!( + thread_state + .lock() + .await + .pending_terminal_plan_cleanups + .len(), + 1, + "active goals retain pending cleanup until the goal settles" + ); + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + #[tokio::test] async fn replaced_turn_downgrades_in_progress_plan() -> Result<()> { let codex_home = TempDir::new()?; @@ -4394,102 +4538,48 @@ mod tests { Ok(()) } - #[tokio::test] - async fn rollback_preserves_cleanup_for_turns_that_survive() -> Result<()> { - let codex_home = TempDir::new()?; - let config = load_default_config_for_test(&codex_home).await; - let thread_manager = Arc::new( - codex_core::test_support::thread_manager_with_models_provider_and_home( - CodexAuth::create_dummy_chatgpt_auth_for_testing(), - config.model_provider.clone(), - config.codex_home.to_path_buf(), - Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), - ), - ); - let codex_core::NewThread { - thread_id: conversation_id, - thread: conversation, - .. - } = thread_manager.start_thread(config).await?; - let rollback_event = codex_protocol::protocol::ThreadRolledBackEvent { num_turns: 1 }; - let thread_state = new_thread_state(); - { - let mut state = thread_state.lock().await; - for turn_id in ["older-turn", "rolled-back-turn"] { - state.track_current_turn_event( - turn_id, - &EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent { - turn_id: turn_id.to_string(), - started_at: Some(42), - model_context_window: None, - collaboration_mode_kind: Default::default(), - }), - ); - state.track_current_turn_event( - turn_id, - &EventMsg::TurnComplete(turn_complete_event(turn_id)), - ); - } - state.track_current_turn_event( - "rollback-turn", - &EventMsg::ThreadRolledBack(rollback_event.clone()), - ); - state.pending_terminal_plan_cleanups = vec![ - PendingTerminalPlanCleanup { - turn_id: "older-turn".to_string(), - plan_update: UpdatePlanArgs { - explanation: Some("still working".to_string()), - plan: vec![PlanItemArg { - step: "older".to_string(), - status: StepStatus::InProgress, - }], - }, + #[test] + fn rollback_response_preserves_cleanup_for_turns_that_survive() { + let mut pending_terminal_plan_cleanups = vec![ + PendingTerminalPlanCleanup { + turn_id: "older-turn".to_string(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "older".to_string(), + status: StepStatus::InProgress, + }], }, - PendingTerminalPlanCleanup { - turn_id: "rolled-back-turn".to_string(), - plan_update: UpdatePlanArgs { - explanation: Some("still working".to_string()), - plan: vec![PlanItemArg { - step: "newer".to_string(), - status: StepStatus::InProgress, - }], - }, - }, - ]; - } - let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); - let outgoing = Arc::new(OutgoingMessageSender::new( - tx, - codex_analytics::AnalyticsEventsClient::disabled(), - )); - let outgoing = ThreadScopedOutgoingMessageSender::new( - outgoing, - vec![ConnectionId(1)], - conversation_id, - ); - - apply_bespoke_event_handling( - Event { - id: "rollback-turn".to_string(), - msg: EventMsg::ThreadRolledBack(rollback_event), }, - conversation_id, - conversation, - thread_manager, - outgoing, - thread_state.clone(), - ThreadWatchManager::new(), - Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)), - "test-provider".to_string(), - ) - .await; + PendingTerminalPlanCleanup { + turn_id: "rolled-back-turn".to_string(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "newer".to_string(), + status: StepStatus::InProgress, + }], + }, + }, + ]; + let retained_turns = vec![Turn { + id: "older-turn".to_string(), + items: Vec::new(), + items_view: TurnItemsView::NotLoaded, + error: None, + status: TurnStatus::Completed, + started_at: None, + completed_at: None, + duration_ms: None, + }]; + + retain_pending_terminal_plan_cleanups_for_turns( + &mut pending_terminal_plan_cleanups, + &retained_turns, + ); - let pending_terminal_plan_cleanups = - &thread_state.lock().await.pending_terminal_plan_cleanups; assert_eq!(pending_terminal_plan_cleanups.len(), 1); assert_eq!(pending_terminal_plan_cleanups[0].turn_id, "older-turn"); - assert!(rx.try_recv().is_err(), "no extra messages expected"); - Ok(()) } #[tokio::test] diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index dc67c56a0f..43dd892912 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -80,7 +80,6 @@ pub(crate) struct ThreadState { pub(crate) turn_summary: TurnSummary, pub(crate) pending_terminal_plan_cleanups: Vec, pub(crate) last_terminal_turn_id: Option, - terminal_turn_ids: Vec, pub(crate) cancel_tx: Option>, pub(crate) experimental_raw_events: bool, pub(crate) listener_generation: u64, @@ -139,27 +138,14 @@ impl ThreadState { self.current_turn_history.active_turn_snapshot() } - pub(crate) fn prune_pending_terminal_plan_cleanups_after_rollback(&mut self) { - self.pending_terminal_plan_cleanups - .retain(|cleanup| self.terminal_turn_ids.contains(&cleanup.turn_id)); - } - pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) { if let EventMsg::TurnStarted(payload) = event { self.turn_summary.started_at = payload.started_at; } self.current_turn_history.handle_event(event); - if let EventMsg::ThreadRolledBack(payload) = event { - let num_turns = usize::try_from(payload.num_turns).unwrap_or(usize::MAX); - self.terminal_turn_ids - .truncate(self.terminal_turn_ids.len().saturating_sub(num_turns)); - } if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_)) && !self.current_turn_history.has_active_turn() { - if self.last_terminal_turn_id.as_deref() != Some(event_turn_id) { - self.terminal_turn_ids.push(event_turn_id.to_string()); - } self.last_terminal_turn_id = Some(event_turn_id.to_string()); self.current_turn_history.reset(); }