diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 73a4209e59..04c3649723 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1139,7 +1139,7 @@ pub(crate) async fn apply_bespoke_event_handling( EventMsg::ThreadRolledBack(_rollback_event) => { let pending = { let mut state = thread_state.lock().await; - state.pending_terminal_plan_cleanups.clear(); + state.prune_pending_terminal_plan_cleanups_after_rollback(); state.pending_rollbacks.take() }; @@ -1334,6 +1334,20 @@ pub(crate) async fn flush_pending_terminal_plan_cleanup( } } +async fn flush_all_pending_terminal_plan_cleanup( + conversation_id: ThreadId, + thread_state: &Arc>, + outgoing: &ThreadScopedOutgoingMessageSender, +) { + let pending_terminal_plan_cleanups = + std::mem::take(&mut thread_state.lock().await.pending_terminal_plan_cleanups); + for (turn_id, latest_plan_update) in + terminal_plan_cleanup_updates(pending_terminal_plan_cleanups) + { + emit_turn_plan_updated(conversation_id, &turn_id, latest_plan_update, outgoing).await; + } +} + fn terminal_plan_cleanup_updates( pending_terminal_plan_cleanups: Vec, ) -> Vec<(String, UpdatePlanArgs)> { @@ -1614,7 +1628,7 @@ async fn handle_turn_interrupted( thread_state: &Arc>, ) { let turn_summary = find_and_remove_turn_summary(thread_state).await; - flush_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; + flush_all_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await; emit_turn_completed_with_status( conversation_id, @@ -3950,6 +3964,93 @@ mod tests { Ok(()) } + #[tokio::test] + async fn tracked_replaced_turn_downgrades_in_progress_plan() -> Result<()> { + let codex_home = TempDir::new()?; + let config = load_default_config_for_test(&codex_home).await; + let thread_manager = Arc::new( + codex_core::test_support::thread_manager_with_models_provider_and_home( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + config.model_provider.clone(), + config.codex_home.to_path_buf(), + Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + ), + ); + let codex_core::NewThread { + thread_id: conversation_id, + thread: conversation, + .. + } = thread_manager.start_thread(config).await?; + let turn_id = "tracked-replaced-plan-turn".to_string(); + let aborted = TurnAbortedEvent { + turn_id: Some(turn_id.clone()), + reason: codex_protocol::protocol::TurnAbortReason::Replaced, + completed_at: Some(TEST_TURN_COMPLETED_AT), + duration_ms: Some(TEST_TURN_DURATION_MS), + }; + let thread_state = new_thread_state(); + { + let mut state = thread_state.lock().await; + state.track_current_turn_event( + &turn_id, + &EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent { + turn_id: turn_id.clone(), + started_at: Some(42), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), + ); + state.track_current_turn_event(&turn_id, &EventMsg::TurnAborted(aborted.clone())); + state.pending_terminal_plan_cleanups = vec![PendingTerminalPlanCleanup { + turn_id: turn_id.clone(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "first".to_string(), + status: StepStatus::InProgress, + }], + }, + }]; + } + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new( + tx, + codex_analytics::AnalyticsEventsClient::disabled(), + )); + let outgoing = ThreadScopedOutgoingMessageSender::new( + outgoing, + vec![ConnectionId(1)], + conversation_id, + ); + + apply_bespoke_event_handling( + Event { + id: turn_id, + msg: EventMsg::TurnAborted(aborted), + }, + conversation_id, + conversation, + thread_manager, + outgoing, + thread_state, + ThreadWatchManager::new(), + Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)), + "test-provider".to_string(), + ) + .await; + + assert!(matches!( + recv_broadcast_message(&mut rx).await?, + OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(_)) + )); + assert!(matches!( + recv_broadcast_message(&mut rx).await?, + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(_)) + )); + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + #[tokio::test] async fn test_handle_turn_complete_preserves_in_progress_plan_with_active_goal() -> Result<()> { let conversation_id = ThreadId::new(); @@ -4287,6 +4388,104 @@ mod tests { Ok(()) } + #[tokio::test] + async fn rollback_preserves_cleanup_for_turns_that_survive() -> Result<()> { + let codex_home = TempDir::new()?; + let config = load_default_config_for_test(&codex_home).await; + let thread_manager = Arc::new( + codex_core::test_support::thread_manager_with_models_provider_and_home( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + config.model_provider.clone(), + config.codex_home.to_path_buf(), + Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + ), + ); + let codex_core::NewThread { + thread_id: conversation_id, + thread: conversation, + .. + } = thread_manager.start_thread(config).await?; + let rollback_event = codex_protocol::protocol::ThreadRolledBackEvent { num_turns: 1 }; + let thread_state = new_thread_state(); + { + let mut state = thread_state.lock().await; + for turn_id in ["older-turn", "rolled-back-turn"] { + state.track_current_turn_event( + turn_id, + &EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent { + turn_id: turn_id.to_string(), + started_at: Some(42), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), + ); + state.track_current_turn_event( + turn_id, + &EventMsg::TurnComplete(turn_complete_event(turn_id)), + ); + } + state.track_current_turn_event( + "rollback-turn", + &EventMsg::ThreadRolledBack(rollback_event.clone()), + ); + state.pending_terminal_plan_cleanups = vec![ + PendingTerminalPlanCleanup { + turn_id: "older-turn".to_string(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "older".to_string(), + status: StepStatus::InProgress, + }], + }, + }, + PendingTerminalPlanCleanup { + turn_id: "rolled-back-turn".to_string(), + plan_update: UpdatePlanArgs { + explanation: Some("still working".to_string()), + plan: vec![PlanItemArg { + step: "newer".to_string(), + status: StepStatus::InProgress, + }], + }, + }, + ]; + } + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new( + tx, + codex_analytics::AnalyticsEventsClient::disabled(), + )); + let outgoing = ThreadScopedOutgoingMessageSender::new( + outgoing, + vec![ConnectionId(1)], + conversation_id, + ); + + apply_bespoke_event_handling( + Event { + id: "rollback-turn".to_string(), + msg: EventMsg::ThreadRolledBack(rollback_event), + }, + conversation_id, + conversation, + thread_manager, + outgoing, + thread_state.clone(), + ThreadWatchManager::new(), + Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)), + "test-provider".to_string(), + ) + .await; + + let pending_terminal_plan_cleanups = + &thread_state.lock().await.pending_terminal_plan_cleanups; + assert_eq!(pending_terminal_plan_cleanups.len(), 1); + assert_eq!(pending_terminal_plan_cleanups[0].turn_id, "older-turn"); + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + #[tokio::test] async fn test_handle_token_count_event_emits_usage_and_rate_limits() -> Result<()> { let conversation_id = ThreadId::new(); diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index 43dd892912..4163b6a12a 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -80,6 +80,7 @@ pub(crate) struct ThreadState { pub(crate) turn_summary: TurnSummary, pub(crate) pending_terminal_plan_cleanups: Vec, pub(crate) last_terminal_turn_id: Option, + terminal_turn_ids: Vec, pub(crate) cancel_tx: Option>, pub(crate) experimental_raw_events: bool, pub(crate) listener_generation: u64, @@ -138,14 +139,30 @@ impl ThreadState { self.current_turn_history.active_turn_snapshot() } + pub(crate) fn prune_pending_terminal_plan_cleanups_after_rollback(&mut self) { + self.pending_terminal_plan_cleanups.retain(|cleanup| { + self.terminal_turn_ids + .iter() + .any(|turn_id| turn_id == &cleanup.turn_id) + }); + } + pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) { if let EventMsg::TurnStarted(payload) = event { self.turn_summary.started_at = payload.started_at; } self.current_turn_history.handle_event(event); + if let EventMsg::ThreadRolledBack(payload) = event { + let num_turns = usize::try_from(payload.num_turns).unwrap_or(usize::MAX); + self.terminal_turn_ids + .truncate(self.terminal_turn_ids.len().saturating_sub(num_turns)); + } if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_)) && !self.current_turn_history.has_active_turn() { + if self.last_terminal_turn_id.as_deref() != Some(event_turn_id) { + self.terminal_turn_ids.push(event_turn_id.to_string()); + } self.last_terminal_turn_id = Some(event_turn_id.to_string()); self.current_turn_history.reset(); }