app-server: preserve deferred cleanup across rollback

This commit is contained in:
Brent Traut
2026-05-18 20:11:35 -07:00
parent 5e007ccff4
commit 7eeeb53ae5
2 changed files with 218 additions and 2 deletions

View File

@@ -1139,7 +1139,7 @@ pub(crate) async fn apply_bespoke_event_handling(
EventMsg::ThreadRolledBack(_rollback_event) => {
let pending = {
let mut state = thread_state.lock().await;
state.pending_terminal_plan_cleanups.clear();
state.prune_pending_terminal_plan_cleanups_after_rollback();
state.pending_rollbacks.take()
};
@@ -1334,6 +1334,20 @@ pub(crate) async fn flush_pending_terminal_plan_cleanup(
}
}
async fn flush_all_pending_terminal_plan_cleanup(
conversation_id: ThreadId,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let pending_terminal_plan_cleanups =
std::mem::take(&mut thread_state.lock().await.pending_terminal_plan_cleanups);
for (turn_id, latest_plan_update) in
terminal_plan_cleanup_updates(pending_terminal_plan_cleanups)
{
emit_turn_plan_updated(conversation_id, &turn_id, latest_plan_update, outgoing).await;
}
}
fn terminal_plan_cleanup_updates(
pending_terminal_plan_cleanups: Vec<PendingTerminalPlanCleanup>,
) -> Vec<(String, UpdatePlanArgs)> {
@@ -1614,7 +1628,7 @@ async fn handle_turn_interrupted(
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(thread_state).await;
flush_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await;
flush_all_pending_terminal_plan_cleanup(conversation_id, thread_state, outgoing).await;
emit_turn_completed_with_status(
conversation_id,
@@ -3950,6 +3964,93 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn tracked_replaced_turn_downgrades_in_progress_plan() -> Result<()> {
let codex_home = TempDir::new()?;
let config = load_default_config_for_test(&codex_home).await;
let thread_manager = Arc::new(
codex_core::test_support::thread_manager_with_models_provider_and_home(
CodexAuth::create_dummy_chatgpt_auth_for_testing(),
config.model_provider.clone(),
config.codex_home.to_path_buf(),
Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
),
);
let codex_core::NewThread {
thread_id: conversation_id,
thread: conversation,
..
} = thread_manager.start_thread(config).await?;
let turn_id = "tracked-replaced-plan-turn".to_string();
let aborted = TurnAbortedEvent {
turn_id: Some(turn_id.clone()),
reason: codex_protocol::protocol::TurnAbortReason::Replaced,
completed_at: Some(TEST_TURN_COMPLETED_AT),
duration_ms: Some(TEST_TURN_DURATION_MS),
};
let thread_state = new_thread_state();
{
let mut state = thread_state.lock().await;
state.track_current_turn_event(
&turn_id,
&EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent {
turn_id: turn_id.clone(),
started_at: Some(42),
model_context_window: None,
collaboration_mode_kind: Default::default(),
}),
);
state.track_current_turn_event(&turn_id, &EventMsg::TurnAborted(aborted.clone()));
state.pending_terminal_plan_cleanups = vec![PendingTerminalPlanCleanup {
turn_id: turn_id.clone(),
plan_update: UpdatePlanArgs {
explanation: Some("still working".to_string()),
plan: vec![PlanItemArg {
step: "first".to_string(),
status: StepStatus::InProgress,
}],
},
}];
}
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(
tx,
codex_analytics::AnalyticsEventsClient::disabled(),
));
let outgoing = ThreadScopedOutgoingMessageSender::new(
outgoing,
vec![ConnectionId(1)],
conversation_id,
);
apply_bespoke_event_handling(
Event {
id: turn_id,
msg: EventMsg::TurnAborted(aborted),
},
conversation_id,
conversation,
thread_manager,
outgoing,
thread_state,
ThreadWatchManager::new(),
Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)),
"test-provider".to_string(),
)
.await;
assert!(matches!(
recv_broadcast_message(&mut rx).await?,
OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(_))
));
assert!(matches!(
recv_broadcast_message(&mut rx).await?,
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(_))
));
assert!(rx.try_recv().is_err(), "no extra messages expected");
Ok(())
}
#[tokio::test]
async fn test_handle_turn_complete_preserves_in_progress_plan_with_active_goal() -> Result<()> {
let conversation_id = ThreadId::new();
@@ -4287,6 +4388,104 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn rollback_preserves_cleanup_for_turns_that_survive() -> Result<()> {
let codex_home = TempDir::new()?;
let config = load_default_config_for_test(&codex_home).await;
let thread_manager = Arc::new(
codex_core::test_support::thread_manager_with_models_provider_and_home(
CodexAuth::create_dummy_chatgpt_auth_for_testing(),
config.model_provider.clone(),
config.codex_home.to_path_buf(),
Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
),
);
let codex_core::NewThread {
thread_id: conversation_id,
thread: conversation,
..
} = thread_manager.start_thread(config).await?;
let rollback_event = codex_protocol::protocol::ThreadRolledBackEvent { num_turns: 1 };
let thread_state = new_thread_state();
{
let mut state = thread_state.lock().await;
for turn_id in ["older-turn", "rolled-back-turn"] {
state.track_current_turn_event(
turn_id,
&EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent {
turn_id: turn_id.to_string(),
started_at: Some(42),
model_context_window: None,
collaboration_mode_kind: Default::default(),
}),
);
state.track_current_turn_event(
turn_id,
&EventMsg::TurnComplete(turn_complete_event(turn_id)),
);
}
state.track_current_turn_event(
"rollback-turn",
&EventMsg::ThreadRolledBack(rollback_event.clone()),
);
state.pending_terminal_plan_cleanups = vec![
PendingTerminalPlanCleanup {
turn_id: "older-turn".to_string(),
plan_update: UpdatePlanArgs {
explanation: Some("still working".to_string()),
plan: vec![PlanItemArg {
step: "older".to_string(),
status: StepStatus::InProgress,
}],
},
},
PendingTerminalPlanCleanup {
turn_id: "rolled-back-turn".to_string(),
plan_update: UpdatePlanArgs {
explanation: Some("still working".to_string()),
plan: vec![PlanItemArg {
step: "newer".to_string(),
status: StepStatus::InProgress,
}],
},
},
];
}
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(
tx,
codex_analytics::AnalyticsEventsClient::disabled(),
));
let outgoing = ThreadScopedOutgoingMessageSender::new(
outgoing,
vec![ConnectionId(1)],
conversation_id,
);
apply_bespoke_event_handling(
Event {
id: "rollback-turn".to_string(),
msg: EventMsg::ThreadRolledBack(rollback_event),
},
conversation_id,
conversation,
thread_manager,
outgoing,
thread_state.clone(),
ThreadWatchManager::new(),
Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)),
"test-provider".to_string(),
)
.await;
let pending_terminal_plan_cleanups =
&thread_state.lock().await.pending_terminal_plan_cleanups;
assert_eq!(pending_terminal_plan_cleanups.len(), 1);
assert_eq!(pending_terminal_plan_cleanups[0].turn_id, "older-turn");
assert!(rx.try_recv().is_err(), "no extra messages expected");
Ok(())
}
#[tokio::test]
async fn test_handle_token_count_event_emits_usage_and_rate_limits() -> Result<()> {
let conversation_id = ThreadId::new();

View File

@@ -80,6 +80,7 @@ pub(crate) struct ThreadState {
pub(crate) turn_summary: TurnSummary,
pub(crate) pending_terminal_plan_cleanups: Vec<PendingTerminalPlanCleanup>,
pub(crate) last_terminal_turn_id: Option<String>,
terminal_turn_ids: Vec<String>,
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
pub(crate) experimental_raw_events: bool,
pub(crate) listener_generation: u64,
@@ -138,14 +139,30 @@ impl ThreadState {
self.current_turn_history.active_turn_snapshot()
}
pub(crate) fn prune_pending_terminal_plan_cleanups_after_rollback(&mut self) {
self.pending_terminal_plan_cleanups.retain(|cleanup| {
self.terminal_turn_ids
.iter()
.any(|turn_id| turn_id == &cleanup.turn_id)
});
}
pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) {
if let EventMsg::TurnStarted(payload) = event {
self.turn_summary.started_at = payload.started_at;
}
self.current_turn_history.handle_event(event);
if let EventMsg::ThreadRolledBack(payload) = event {
let num_turns = usize::try_from(payload.num_turns).unwrap_or(usize::MAX);
self.terminal_turn_ids
.truncate(self.terminal_turn_ids.len().saturating_sub(num_turns));
}
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
&& !self.current_turn_history.has_active_turn()
{
if self.last_terminal_turn_id.as_deref() != Some(event_turn_id) {
self.terminal_turn_ids.push(event_turn_id.to_string());
}
self.last_terminal_turn_id = Some(event_turn_id.to_string());
self.current_turn_history.reset();
}