Compare commits

...

1 Commits

Author SHA1 Message Date
Eric Traut
e94ddb85c9 Invalidate model window after thread rollback (#21986) 2026-05-12 00:43:28 -07:00
3 changed files with 71 additions and 1 deletions

View File

@@ -563,6 +563,7 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.services.model_client.advance_window_generation();
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;

View File

@@ -395,7 +395,13 @@ impl Session {
resumed_history
.history
.iter()
.filter(|item| matches!(item, RolloutItem::Compacted(_)))
.filter(|item| {
matches!(
item,
RolloutItem::Compacted(_)
| RolloutItem::EventMsg(EventMsg::ThreadRolledBack(_))
)
})
.count(),
)
.unwrap_or(u64::MAX),

View File

@@ -102,6 +102,69 @@ async fn window_id_advances_after_compact_persists_on_resume_and_resets_on_fork(
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn window_id_advances_after_rollback_and_persists_on_resume() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let request_log = mount_sse_sequence(
&server,
vec![
sse(vec![ev_completed("resp-1")]),
sse(vec![ev_completed("resp-2")]),
sse(vec![ev_completed("resp-3")]),
sse(vec![ev_completed("resp-4")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config.model_provider.name = "Non-OpenAI Model provider".to_string();
});
let initial = builder.build(&server).await?;
let initial_thread = Arc::clone(&initial.codex);
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
submit_user_turn(&initial_thread, "before rollback").await?;
submit_user_turn(&initial_thread, "discard me").await?;
initial_thread
.submit(Op::ThreadRollback { num_turns: 1 })
.await?;
wait_for_event(&initial_thread, |event| {
matches!(event, EventMsg::ThreadRolledBack(_))
})
.await;
submit_user_turn(&initial_thread, "after rollback").await?;
shutdown_thread(&initial_thread).await?;
let resumed = builder
.resume(&server, initial.home.clone(), rollout_path.clone())
.await?;
submit_user_turn(&resumed.codex, "after resume").await?;
shutdown_thread(&resumed.codex).await?;
let requests = request_log.requests();
assert_eq!(requests.len(), 4, "expected four model requests");
let window_ids = requests.iter().map(window_id_parts).collect::<Vec<_>>();
let initial_thread_id = window_ids[0].0.clone();
assert_eq!(
window_ids,
vec![
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 1),
(initial_thread_id, 1),
]
);
Ok(())
}
async fn submit_user_turn(codex: &Arc<CodexThread>, text: &str) -> Result<()> {
codex
.submit(Op::UserInput {