fix: update fork boundaries computation (#16322)

This commit is contained in:
jif-oai
2026-03-31 14:10:43 +02:00
committed by GitHub
parent 1fc8aa0e16
commit 4c72e62d0b
2 changed files with 49 additions and 16 deletions

View File

@@ -3,6 +3,7 @@
//! In core, "user turns" are detected by scanning `ResponseItem::Message` items and
//! interpreting them via `event_mapping::parse_turn_item(...)`.
use crate::context_manager::is_user_turn_boundary;
use crate::event_mapping;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ResponseItem;
@@ -48,36 +49,37 @@ pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<us
/// - an assistant inter-agent envelope whose parsed `trigger_turn` is `true`.
///
/// Like `user_message_positions_in_rollout`, this applies `ThreadRolledBack` markers so indexing
/// reflects the effective post-rollback history. Rollback counts real user turns only, so a
/// rollback removes the stale suffix starting at the earliest rolled-back user boundary instead of
/// simply truncating the mixed fork-boundary list.
/// reflects the effective post-rollback history. Rollback counts instruction turns, so a rollback
/// removes the stale suffix starting at the earliest rolled-back instruction-turn boundary instead
/// of simply truncating the mixed fork-boundary list.
pub(crate) fn fork_turn_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
let mut user_positions = Vec::new();
let mut rollback_turn_positions = Vec::new();
let mut fork_turn_positions = Vec::new();
for (idx, item) in items.iter().enumerate() {
match item {
RolloutItem::ResponseItem(item) if is_real_user_message_boundary(item) => {
user_positions.push(idx);
fork_turn_positions.push(idx);
}
RolloutItem::ResponseItem(item) if is_trigger_turn_boundary(item) => {
fork_turn_positions.push(idx);
RolloutItem::ResponseItem(item) => {
if is_user_turn_boundary(item) {
rollback_turn_positions.push(idx);
}
if is_real_user_message_boundary(item) || is_trigger_turn_boundary(item) {
fork_turn_positions.push(idx);
}
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
if num_turns == 0 {
continue;
}
let Some(rollback_start_idx) = user_positions
let Some(rollback_start_idx) = rollback_turn_positions
.len()
.checked_sub(num_turns)
.map(|rollback_start| user_positions[rollback_start])
.or_else(|| user_positions.first().copied())
.map(|rollback_start| rollback_turn_positions[rollback_start])
.or_else(|| rollback_turn_positions.first().copied())
else {
continue;
};
let new_user_len = user_positions.len().saturating_sub(num_turns);
user_positions.truncate(new_user_len);
let new_rollback_len = rollback_turn_positions.len().saturating_sub(num_turns);
rollback_turn_positions.truncate(new_rollback_len);
fork_turn_positions.retain(|position| *position < rollback_start_idx);
}
_ => {}

View File

@@ -260,9 +260,40 @@ fn truncates_rollout_to_last_n_fork_turns_discards_trigger_boundaries_in_rolled_
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 2);
let expected = rollout[1..].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&rollout).unwrap()
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn truncates_rollout_to_last_n_fork_turns_discards_rolled_back_assistant_instruction_turns() {
let rollout = vec![
RolloutItem::ResponseItem(user_msg("u1")),
RolloutItem::ResponseItem(assistant_msg("a1")),
RolloutItem::ResponseItem(inter_agent_msg(
"triggered task 1",
/*trigger_turn*/ true,
)),
RolloutItem::ResponseItem(assistant_msg("a2")),
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
num_turns: 1,
})),
RolloutItem::ResponseItem(inter_agent_msg(
"triggered task 2",
/*trigger_turn*/ true,
)),
RolloutItem::ResponseItem(assistant_msg("a3")),
];
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 1);
let expected = rollout[5..].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}