Compare commits

...

1 Commits

Author SHA1 Message Date
jif-oai
62ca1aa2ec feat: better rollback boundaries 2026-03-31 11:35:07 +02:00
2 changed files with 207 additions and 20 deletions

View File

@@ -7,8 +7,21 @@ use crate::event_mapping;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::InterAgentCommunication;
use codex_protocol::protocol::RolloutItem;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ForkTurnBoundaryKind {
User,
TriggerTurnEnvelope,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct ForkTurnBoundary {
idx: usize,
kind: ForkTurnBoundaryKind,
}
/// Return the indices of user message boundaries in a rollout.
///
/// A user message boundary is a `RolloutItem::ResponseItem(ResponseItem::Message { .. })`
@@ -18,26 +31,29 @@ use codex_protocol::protocol::RolloutItem;
/// last N user turns were removed from the effective thread history; we apply them here so
/// indexing uses the post-rollback history rather than the raw stream.
pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
let mut user_positions = Vec::new();
for (idx, item) in items.iter().enumerate() {
match item {
RolloutItem::ResponseItem(item @ ResponseItem::Message { .. })
if matches!(
event_mapping::parse_turn_item(item),
Some(TurnItem::UserMessage(_))
) =>
{
user_positions.push(idx);
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
let new_len = user_positions.len().saturating_sub(num_turns);
user_positions.truncate(new_len);
}
_ => {}
}
}
user_positions
fork_turn_boundaries_in_rollout(items)
.into_iter()
.filter_map(|boundary| match boundary.kind {
ForkTurnBoundaryKind::User => Some(boundary.idx),
ForkTurnBoundaryKind::TriggerTurnEnvelope => None,
})
.collect()
}
/// Return the indices of fork-turn boundaries in a rollout.
///
/// A fork-turn boundary is either:
/// - a real user message boundary, or
/// - an assistant inter-agent envelope whose parsed `trigger_turn` is `true`.
///
/// Rollbacks are applied to the effective instruction-turn stack rather than to user-only
/// boundaries, so a rollback can correctly remove trigger-turn inter-agent envelopes even when
/// there are no real user messages in the rolled-back suffix.
pub(crate) fn fork_turn_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
fork_turn_boundaries_in_rollout(items)
.into_iter()
.map(|boundary| boundary.idx)
.collect()
}
/// Return a prefix of `items` obtained by cutting strictly before the nth user message.
@@ -68,6 +84,70 @@ pub(crate) fn truncate_rollout_before_nth_user_message_from_start(
items[..cut_idx].to_vec()
}
/// Return a suffix of `items` that keeps the last `n_from_end` fork turns.
///
/// If fewer than or equal to `n_from_end` fork turns exist, this returns the full rollout.
pub(crate) fn truncate_rollout_to_last_n_fork_turns(
items: &[RolloutItem],
n_from_end: usize,
) -> Vec<RolloutItem> {
if n_from_end == 0 {
return Vec::new();
}
let fork_turn_positions = fork_turn_positions_in_rollout(items);
if fork_turn_positions.len() <= n_from_end {
return items.to_vec();
}
let keep_idx = fork_turn_positions[fork_turn_positions.len() - n_from_end];
items[keep_idx..].to_vec()
}
fn fork_turn_boundaries_in_rollout(items: &[RolloutItem]) -> Vec<ForkTurnBoundary> {
let mut boundaries = Vec::new();
for (idx, item) in items.iter().enumerate() {
match item {
RolloutItem::ResponseItem(item) if is_real_user_message_boundary(item) => {
boundaries.push(ForkTurnBoundary {
idx,
kind: ForkTurnBoundaryKind::User,
});
}
RolloutItem::ResponseItem(item) if is_trigger_turn_boundary(item) => {
boundaries.push(ForkTurnBoundary {
idx,
kind: ForkTurnBoundaryKind::TriggerTurnEnvelope,
});
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
let new_len = boundaries.len().saturating_sub(num_turns);
boundaries.truncate(new_len);
}
_ => {}
}
}
boundaries
}
fn is_real_user_message_boundary(item: &ResponseItem) -> bool {
matches!(
event_mapping::parse_turn_item(item),
Some(TurnItem::UserMessage(_))
)
}
fn is_trigger_turn_boundary(item: &ResponseItem) -> bool {
let ResponseItem::Message { role, content, .. } = item else {
return false;
};
role == "assistant"
&& InterAgentCommunication::from_message_content(content)
.is_some_and(|communication| communication.trigger_turn)
}
#[cfg(test)]
#[path = "thread_rollout_truncation_tests.rs"]
mod tests;

View File

@@ -1,7 +1,9 @@
use super::*;
use crate::codex::make_session_and_context;
use codex_protocol::AgentPath;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::protocol::InterAgentCommunication;
use codex_protocol::protocol::ThreadRolledBackEvent;
use pretty_assertions::assert_eq;
@@ -29,6 +31,17 @@ fn assistant_msg(text: &str) -> ResponseItem {
}
}
fn inter_agent_msg(text: &str, trigger_turn: bool) -> ResponseItem {
let communication = InterAgentCommunication::new(
AgentPath::root(),
AgentPath::try_from("/root/worker").expect("agent path"),
Vec::new(),
text.to_string(),
trigger_turn,
);
communication.to_response_input_item().into()
}
#[test]
fn truncates_rollout_from_start_before_nth_user_only() {
let items = [
@@ -157,3 +170,97 @@ async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() {
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn truncates_rollout_to_last_n_fork_turns_counts_trigger_turn_messages() {
let rollout = vec![
RolloutItem::ResponseItem(user_msg("u1")),
RolloutItem::ResponseItem(assistant_msg("a1")),
RolloutItem::ResponseItem(inter_agent_msg(
"queued message",
/*trigger_turn*/ false,
)),
RolloutItem::ResponseItem(assistant_msg("a2")),
RolloutItem::ResponseItem(inter_agent_msg(
"triggered task",
/*trigger_turn*/ true,
)),
RolloutItem::ResponseItem(assistant_msg("a3")),
RolloutItem::ResponseItem(user_msg("u2")),
RolloutItem::ResponseItem(assistant_msg("a4")),
];
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, 2);
let expected = rollout[4..].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn truncates_rollout_to_last_n_fork_turns_applies_thread_rollback_markers() {
let rollout = vec![
RolloutItem::ResponseItem(user_msg("u1")),
RolloutItem::ResponseItem(assistant_msg("a1")),
RolloutItem::ResponseItem(inter_agent_msg(
"triggered task",
/*trigger_turn*/ true,
)),
RolloutItem::ResponseItem(assistant_msg("a2")),
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
num_turns: 1,
})),
RolloutItem::ResponseItem(user_msg("u2")),
RolloutItem::ResponseItem(assistant_msg("a3")),
];
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, 2);
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&rollout).unwrap()
);
}
#[test]
fn truncates_rollout_to_last_n_fork_turns_discards_rolled_back_trigger_turn_only_suffix() {
let rollout = vec![
RolloutItem::ResponseItem(inter_agent_msg("task A", /*trigger_turn*/ true)),
RolloutItem::ResponseItem(assistant_msg("working on task A")),
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
num_turns: 1,
})),
RolloutItem::ResponseItem(inter_agent_msg("task B", /*trigger_turn*/ true)),
RolloutItem::ResponseItem(assistant_msg("working on task B")),
];
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, 2);
let expected = rollout[3..].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn truncates_rollout_to_last_n_fork_turns_keeps_full_rollout_when_n_is_large() {
let rollout = vec![
RolloutItem::ResponseItem(user_msg("u1")),
RolloutItem::ResponseItem(assistant_msg("a1")),
RolloutItem::ResponseItem(inter_agent_msg(
"triggered task",
/*trigger_turn*/ true,
)),
RolloutItem::ResponseItem(assistant_msg("a2")),
];
let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, 10);
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&rollout).unwrap()
);
}