mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
fix: update fork boundaries computation (#16322)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
//! In core, "user turns" are detected by scanning `ResponseItem::Message` items and
|
||||
//! interpreting them via `event_mapping::parse_turn_item(...)`.
|
||||
|
||||
use crate::context_manager::is_user_turn_boundary;
|
||||
use crate::event_mapping;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -48,36 +49,37 @@ pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<us
|
||||
/// - an assistant inter-agent envelope whose parsed `trigger_turn` is `true`.
|
||||
///
|
||||
/// Like `user_message_positions_in_rollout`, this applies `ThreadRolledBack` markers so indexing
|
||||
/// reflects the effective post-rollback history. Rollback counts real user turns only, so a
|
||||
/// rollback removes the stale suffix starting at the earliest rolled-back user boundary instead of
|
||||
/// simply truncating the mixed fork-boundary list.
|
||||
/// reflects the effective post-rollback history. Rollback counts instruction turns, so a rollback
|
||||
/// removes the stale suffix starting at the earliest rolled-back instruction-turn boundary instead
|
||||
/// of simply truncating the mixed fork-boundary list.
|
||||
pub(crate) fn fork_turn_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
|
||||
let mut user_positions = Vec::new();
|
||||
let mut rollback_turn_positions = Vec::new();
|
||||
let mut fork_turn_positions = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) if is_real_user_message_boundary(item) => {
|
||||
user_positions.push(idx);
|
||||
fork_turn_positions.push(idx);
|
||||
}
|
||||
RolloutItem::ResponseItem(item) if is_trigger_turn_boundary(item) => {
|
||||
fork_turn_positions.push(idx);
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
if is_user_turn_boundary(item) {
|
||||
rollback_turn_positions.push(idx);
|
||||
}
|
||||
if is_real_user_message_boundary(item) || is_trigger_turn_boundary(item) {
|
||||
fork_turn_positions.push(idx);
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
|
||||
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
|
||||
if num_turns == 0 {
|
||||
continue;
|
||||
}
|
||||
let Some(rollback_start_idx) = user_positions
|
||||
let Some(rollback_start_idx) = rollback_turn_positions
|
||||
.len()
|
||||
.checked_sub(num_turns)
|
||||
.map(|rollback_start| user_positions[rollback_start])
|
||||
.or_else(|| user_positions.first().copied())
|
||||
.map(|rollback_start| rollback_turn_positions[rollback_start])
|
||||
.or_else(|| rollback_turn_positions.first().copied())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
let new_user_len = user_positions.len().saturating_sub(num_turns);
|
||||
user_positions.truncate(new_user_len);
|
||||
let new_rollback_len = rollback_turn_positions.len().saturating_sub(num_turns);
|
||||
rollback_turn_positions.truncate(new_rollback_len);
|
||||
fork_turn_positions.retain(|position| *position < rollback_start_idx);
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -260,9 +260,40 @@ fn truncates_rollout_to_last_n_fork_turns_discards_trigger_boundaries_in_rolled_
|
||||
|
||||
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 2);
|
||||
|
||||
let expected = rollout[1..].to_vec();
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&rollout).unwrap()
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_rollout_to_last_n_fork_turns_discards_rolled_back_assistant_instruction_turns() {
|
||||
let rollout = vec![
|
||||
RolloutItem::ResponseItem(user_msg("u1")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a1")),
|
||||
RolloutItem::ResponseItem(inter_agent_msg(
|
||||
"triggered task 1",
|
||||
/*trigger_turn*/ true,
|
||||
)),
|
||||
RolloutItem::ResponseItem(assistant_msg("a2")),
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
|
||||
num_turns: 1,
|
||||
})),
|
||||
RolloutItem::ResponseItem(inter_agent_msg(
|
||||
"triggered task 2",
|
||||
/*trigger_turn*/ true,
|
||||
)),
|
||||
RolloutItem::ResponseItem(assistant_msg("a3")),
|
||||
];
|
||||
|
||||
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 1);
|
||||
let expected = rollout[5..].to_vec();
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user