diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 8f53890e8d..e11517757e 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3134,8 +3134,11 @@ mod handlers { state.set_turn_context_history(updated_turn_context_history); let mut applied = false; if state.turn_context_history.len() == user_turns - && let Some(turn_context) = state.last_turn_context() - && let Some(collaboration_mode) = turn_context.collaboration_mode.clone() + && let Some(collaboration_mode) = state + .turn_context_history + .last() + .and_then(|turn_context| turn_context.as_ref()) + .and_then(|turn_context| turn_context.collaboration_mode.clone()) { state.session_configuration.collaboration_mode = collaboration_mode; applied = true; @@ -5143,6 +5146,76 @@ mod tests { assert!(update_items.contains(&expected_item)); } + #[tokio::test] + async fn thread_rollback_ignores_stale_turn_context() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + let base_mode = sess.current_collaboration_mode().await; + let plan_mode = CollaborationMode { + mode: ModeKind::Plan, + settings: base_mode.settings.clone(), + }; + let code_mode = CollaborationMode { + mode: ModeKind::Code, + settings: base_mode.settings.clone(), + }; + + { + let mut state = sess.state.lock().await; + state.session_configuration.collaboration_mode = plan_mode.clone(); + } + let input1 = vec![UserInput::Text { + text: "turn 1".to_string(), + text_elements: Vec::new(), + }]; + let response_item1: ResponseItem = ResponseInputItem::from(input1.clone()).into(); + sess.record_user_prompt_and_emit_turn_item(tc.as_ref(), &input1, response_item1) + .await; + + let turn_2 = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "turn 2".to_string(), + }], + end_turn: None, + }]; + sess.record_into_history(&turn_2, tc.as_ref()).await; + + { + let mut state = sess.state.lock().await; + state.session_configuration.collaboration_mode = code_mode.clone(); + } + let input3 = vec![UserInput::Text { + text: "turn 3".to_string(), + text_elements: Vec::new(), + }]; + let response_item3: ResponseItem = ResponseInputItem::from(input3.clone()).into(); + sess.record_user_prompt_and_emit_turn_item(tc.as_ref(), &input3, response_item3) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let rollback_event = wait_for_thread_rolled_back(&rx).await; + pretty_assertions::assert_eq!(rollback_event.num_turns, 1); + + let (collaboration_mode, force_collaboration_instructions) = { + let mut state = sess.state.lock().await; + let force = state.force_collaboration_instructions; + state.force_collaboration_instructions = false; + ( + state.session_configuration.collaboration_mode.clone(), + force, + ) + }; + + pretty_assertions::assert_eq!(collaboration_mode, code_mode); + assert!(force_collaboration_instructions); + } + #[tokio::test] async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { let (sess, tc, rx) = make_session_and_context_with_rx().await; diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index fee67b5aff..2b47cbbe92 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -87,13 +87,6 @@ impl SessionState { } } - pub(crate) fn last_turn_context(&self) -> Option<&TurnContextItem> { - self.turn_context_history - .iter() - .rev() - .find_map(Option::as_ref) - } - pub(crate) fn set_turn_context_history( &mut self, turn_context_history: Vec>,