Compare commits

...

5 Commits

6 changed files with 367 additions and 10 deletions

View File

@@ -391,7 +391,6 @@ impl ModelClient {
self.state
.window_generation
.store(window_generation, Ordering::Relaxed);
self.store_cached_websocket_session(WebsocketSession::default());
}
pub(crate) fn advance_window_generation(&self) {

View File

@@ -569,6 +569,11 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.services.model_client.set_window_generation(
super::rollout_reconstruction::effective_window_generation_from_rollout(
replay_items.as_slice(),
),
);
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;

View File

@@ -35,6 +35,81 @@ struct ActiveReplaySegment<'a> {
base_replacement_history: Option<&'a [ResponseItem]>,
}
#[derive(Debug, Default)]
struct WindowGenerationReplaySegment {
counts_as_user_turn: bool,
compaction_count: u64,
}
fn finalize_window_generation_segment(
active_segment: WindowGenerationReplaySegment,
window_generation: &mut u64,
pending_rollback_turns: &mut usize,
) {
if *pending_rollback_turns > 0 {
if active_segment.counts_as_user_turn {
*pending_rollback_turns -= 1;
}
return;
}
*window_generation = window_generation.saturating_add(active_segment.compaction_count);
}
/// Replays rollout segments newest-to-oldest so compactions in rolled-back suffixes do not
/// contribute to the public context-window lineage generation.
pub(super) fn effective_window_generation_from_rollout(rollout_items: &[RolloutItem]) -> u64 {
let mut window_generation = 0u64;
let mut pending_rollback_turns = 0usize;
let mut active_segment: Option<WindowGenerationReplaySegment> = None;
for item in rollout_items.iter().rev() {
match item {
RolloutItem::Compacted(_) => {
let active_segment =
active_segment.get_or_insert_with(WindowGenerationReplaySegment::default);
active_segment.compaction_count = active_segment.compaction_count.saturating_add(1);
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
pending_rollback_turns = pending_rollback_turns
.saturating_add(usize::try_from(rollback.num_turns).unwrap_or(usize::MAX));
}
RolloutItem::EventMsg(EventMsg::TurnStarted(_)) => {
if let Some(active_segment) = active_segment.take() {
finalize_window_generation_segment(
active_segment,
&mut window_generation,
&mut pending_rollback_turns,
);
}
}
RolloutItem::EventMsg(EventMsg::UserMessage(_)) => {
active_segment
.get_or_insert_with(WindowGenerationReplaySegment::default)
.counts_as_user_turn = true;
}
RolloutItem::ResponseItem(response_item) => {
let active_segment =
active_segment.get_or_insert_with(WindowGenerationReplaySegment::default);
active_segment.counts_as_user_turn |= is_user_turn_boundary(response_item);
}
RolloutItem::EventMsg(_)
| RolloutItem::TurnContext(_)
| RolloutItem::SessionMeta(_) => {}
}
}
if let Some(active_segment) = active_segment {
finalize_window_generation_segment(
active_segment,
&mut window_generation,
&mut pending_rollback_turns,
);
}
window_generation
}
fn turn_ids_are_compatible(active_turn_id: Option<&str>, item_turn_id: Option<&str>) -> bool {
active_turn_id
.is_none_or(|turn_id| item_turn_id.is_none_or(|item_turn_id| item_turn_id == turn_id))

View File

@@ -524,15 +524,15 @@ impl Session {
InitialHistory::Resumed(resumed_history) => resumed_history.conversation_id,
};
let window_generation = match &initial_history {
InitialHistory::Resumed(resumed_history) => u64::try_from(
resumed_history
.history
.iter()
.filter(|item| matches!(item, RolloutItem::Compacted(_)))
.count(),
)
.unwrap_or(u64::MAX),
InitialHistory::New | InitialHistory::Cleared | InitialHistory::Forked(_) => 0,
InitialHistory::Resumed(resumed_history) => {
super::rollout_reconstruction::effective_window_generation_from_rollout(
&resumed_history.history,
)
}
InitialHistory::Forked(history) => {
super::rollout_reconstruction::effective_window_generation_from_rollout(history)
}
InitialHistory::New | InitialHistory::Cleared => 0,
};
// Kick off independent async setup tasks in parallel to reduce startup latency.
//

View File

@@ -1,6 +1,8 @@
use anyhow::Result;
use codex_features::Feature;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
@@ -10,6 +12,7 @@ use core_test_support::responses::start_websocket_server;
use core_test_support::responses::start_websocket_server_with_headers;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
use serde_json::Value;
use std::time::Duration;
@@ -255,6 +258,62 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_rollback_reuses_connection_without_previous_response_id() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "kept"),
ev_completed("resp-1"),
],
vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-2", "discarded"),
ev_completed("resp-2"),
],
vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-3", "after rollback"),
ev_completed("resp-3"),
],
]])
.await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::ResponsesWebsocketsV2)
.expect("test config should allow feature update");
});
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("before rollback").await?;
test.submit_turn("discard me").await?;
test.codex
.submit(Op::ThreadRollback { num_turns: 1 })
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::ThreadRolledBack(_))
})
.await;
test.submit_turn("after rollback").await?;
assert_eq!(server.handshakes().len(), 1);
let connections = server.connections();
assert_eq!(connections.len(), 1);
assert_eq!(connections[0].len(), 4);
let after_rollback = connections[0][3].body_json();
assert_eq!(after_rollback["type"].as_str(), Some("response.create"));
assert_eq!(after_rollback.get("previous_response_id"), None);
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_first_turn_uses_updated_fast_tier_after_startup_prewarm() -> Result<()> {
skip_if_no_network!(Ok(()));

View File

@@ -102,6 +102,225 @@ async fn window_id_advances_after_compact_persists_on_resume_and_resets_on_fork(
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn window_id_stays_stable_after_rollback_and_resume() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let request_log = mount_sse_sequence(
&server,
vec![
sse(vec![ev_completed("resp-1")]),
sse(vec![ev_completed("resp-2")]),
sse(vec![ev_completed("resp-3")]),
sse(vec![ev_completed("resp-4")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config.model_provider.name = "Non-OpenAI Model provider".to_string();
});
let initial = builder.build(&server).await?;
let initial_thread = Arc::clone(&initial.codex);
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
submit_user_turn(&initial_thread, "before rollback").await?;
submit_user_turn(&initial_thread, "discard me").await?;
initial_thread
.submit(Op::ThreadRollback { num_turns: 1 })
.await?;
wait_for_event(&initial_thread, |event| {
matches!(event, EventMsg::ThreadRolledBack(_))
})
.await;
submit_user_turn(&initial_thread, "after rollback").await?;
shutdown_thread(&initial_thread).await?;
let resumed = builder
.resume(&server, initial.home.clone(), rollout_path)
.await?;
submit_user_turn(&resumed.codex, "after resume").await?;
shutdown_thread(&resumed.codex).await?;
let requests = request_log.requests();
assert_eq!(requests.len(), 4, "expected four model requests");
let window_ids = requests.iter().map(window_id_parts).collect::<Vec<_>>();
let initial_thread_id = window_ids[0].0.clone();
assert_eq!(
window_ids,
vec![
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 0),
(initial_thread_id, 0),
]
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn window_id_rolls_back_across_compaction_and_persists_on_resume() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let request_log = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_assistant_message("msg-1", "first reply"),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-2", "summary"),
ev_completed("resp-2"),
]),
sse(vec![ev_completed("resp-3")]),
sse(vec![ev_completed("resp-4")]),
sse(vec![ev_completed("resp-5")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config.model_provider.name = "Non-OpenAI Model provider".to_string();
config.compact_prompt = Some(SUMMARIZATION_PROMPT.to_string());
});
let initial = builder.build(&server).await?;
let initial_thread = Arc::clone(&initial.codex);
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
submit_user_turn(&initial_thread, "before compact").await?;
submit_compact_turn(&initial_thread).await?;
submit_user_turn(&initial_thread, "discard me").await?;
initial_thread
.submit(Op::ThreadRollback { num_turns: 2 })
.await?;
wait_for_event(&initial_thread, |event| {
matches!(event, EventMsg::ThreadRolledBack(_))
})
.await;
submit_user_turn(&initial_thread, "after rollback").await?;
shutdown_thread(&initial_thread).await?;
let resumed = builder
.resume(&server, initial.home.clone(), rollout_path)
.await?;
submit_user_turn(&resumed.codex, "after resume").await?;
shutdown_thread(&resumed.codex).await?;
let requests = request_log.requests();
assert_eq!(requests.len(), 5, "expected five model requests");
let window_ids = requests.iter().map(window_id_parts).collect::<Vec<_>>();
let initial_thread_id = window_ids[0].0.clone();
assert_eq!(
window_ids,
vec![
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 0),
(initial_thread_id.clone(), 1),
(initial_thread_id.clone(), 0),
(initial_thread_id, 0),
]
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn forked_compacted_history_inherits_effective_generation_on_resume() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let request_log = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_assistant_message("msg-1", "first reply"),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-2", "summary"),
ev_completed("resp-2"),
]),
sse(vec![ev_completed("resp-3")]),
sse(vec![ev_completed("resp-4")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config.model_provider.name = "Non-OpenAI Model provider".to_string();
config.compact_prompt = Some(SUMMARIZATION_PROMPT.to_string());
});
let initial = builder.build(&server).await?;
let initial_thread = Arc::clone(&initial.codex);
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
submit_user_turn(&initial_thread, "before compact").await?;
submit_compact_turn(&initial_thread).await?;
shutdown_thread(&initial_thread).await?;
let forked = initial
.thread_manager
.fork_thread(
/*snapshot*/ usize::MAX,
initial.config.clone(),
rollout_path,
/*thread_source*/ None,
/*persist_extended_history*/ false,
/*parent_trace*/ None,
)
.await?;
let fork_rollout_path = forked
.session_configured
.rollout_path
.clone()
.expect("fork rollout path");
submit_user_turn(&forked.thread, "after fork").await?;
shutdown_thread(&forked.thread).await?;
let resumed = builder
.resume(&server, initial.home.clone(), fork_rollout_path)
.await?;
submit_user_turn(&resumed.codex, "after fork resume").await?;
shutdown_thread(&resumed.codex).await?;
let requests = request_log.requests();
assert_eq!(requests.len(), 4, "expected four model requests");
let window_ids = requests.iter().map(window_id_parts).collect::<Vec<_>>();
let initial_thread_id = window_ids[0].0.clone();
let forked_thread_id = window_ids[2].0.clone();
assert_ne!(forked_thread_id, initial_thread_id);
assert_eq!(
window_ids,
vec![
(initial_thread_id.clone(), 0),
(initial_thread_id, 0),
(forked_thread_id.clone(), 1),
(forked_thread_id, 1),
]
);
Ok(())
}
async fn submit_user_turn(codex: &Arc<CodexThread>, text: &str) -> Result<()> {
codex
.submit(Op::UserInput {