Compare commits

...

1 Commits

Author SHA1 Message Date
Charles Cunningham
bbb5ff106f Trim synthetic context during rollback
When a rollback targets an earlier user turn, also trim the contiguous developer and contextual user messages injected immediately before that turn.

This keeps stale per-turn context out of rebuilt history and adds unit plus integration coverage for adjacent and non-adjacent cases.

Co-authored-by: Codex <noreply@openai.com>
2026-03-13 13:08:30 -07:00
3 changed files with 148 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ use crate::config_loader::ConfigLayerStackOrdering;
use crate::config_loader::NetworkConstraints;
use crate::config_loader::RequirementSource;
use crate::config_loader::Sourced;
use crate::contextual_user_message::ENVIRONMENT_CONTEXT_FRAGMENT;
use crate::exec::ExecToolCallOutput;
use crate::function_tool::FunctionCallError;
use crate::mcp_connection_manager::ToolInfo;
@@ -1079,6 +1080,53 @@ async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns()
assert_eq!(initial_context, history.raw_items());
}
#[tokio::test]
async fn thread_rollback_trims_adjacent_synthetic_context_before_removed_turn() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
attach_rollout_recorder(&sess).await;
let initial_context = sess.build_initial_context(tc.as_ref()).await;
let turn_1 = vec![
user_message("turn 1 user"),
assistant_message("turn 1 assistant"),
];
let turn_2_context = vec![
ResponseItem::from(DeveloperInstructions::new(
"<model_switch>switched to gpt-test</model_switch>".to_string(),
)),
ENVIRONMENT_CONTEXT_FRAGMENT
.into_message(ENVIRONMENT_CONTEXT_FRAGMENT.wrap("<cwd>/tmp/updated</cwd>".to_string())),
];
let turn_2 = vec![
user_message("turn 2 user"),
assistant_message("turn 2 assistant"),
];
let mut full_history = Vec::new();
full_history.extend(initial_context.clone());
full_history.extend(turn_1.clone());
full_history.extend(turn_2_context);
full_history.extend(turn_2);
sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item()))
.await;
let rollout_items: Vec<RolloutItem> = full_history
.into_iter()
.map(RolloutItem::ResponseItem)
.collect();
sess.persist_rollout_items(&rollout_items).await;
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
let rollback_event = wait_for_thread_rolled_back(&rx).await;
assert_eq!(rollback_event.num_turns, 1);
let mut expected = Vec::new();
expected.extend(initial_context);
expected.extend(turn_1);
let history = sess.clone_history().await;
assert_eq!(expected, history.raw_items());
}
#[tokio::test]
async fn thread_rollback_fails_without_persisted_rollout_path() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;

View File

@@ -212,6 +212,9 @@ impl ContextManager {
/// - if there are no user turns, this is a no-op
/// - if `num_turns` exceeds the number of user turns, all user turns are dropped while
/// preserving any items that occurred before the first user message.
/// - when rolling back to an earlier turn, contiguous developer/contextual messages inserted
/// immediately before the removed user turn are also trimmed so stale per-turn context does
/// not survive above the restored boundary.
pub(crate) fn drop_last_n_user_turns(&mut self, num_turns: u32) {
if num_turns == 0 {
return;
@@ -230,6 +233,8 @@ impl ContextManager {
} else {
user_positions[user_positions.len() - n_from_end]
};
let cut_idx =
trim_adjacent_rollback_context_before_user_boundary(&snapshot, first_user_idx, cut_idx);
self.replace(snapshot[..cut_idx].to_vec());
}
@@ -637,6 +642,36 @@ pub(crate) fn is_user_turn_boundary(item: &ResponseItem) -> bool {
role == "user" && !is_contextual_user_message_content(content)
}
fn trim_adjacent_rollback_context_before_user_boundary(
items: &[ResponseItem],
first_user_idx: usize,
cut_idx: usize,
) -> usize {
if cut_idx <= first_user_idx {
return cut_idx;
}
let mut adjusted_cut_idx = cut_idx;
while adjusted_cut_idx > first_user_idx
&& items
.get(adjusted_cut_idx.saturating_sub(1))
.is_some_and(is_adjacent_rollback_context_item)
{
adjusted_cut_idx -= 1;
}
adjusted_cut_idx
}
fn is_adjacent_rollback_context_item(item: &ResponseItem) -> bool {
match item {
ResponseItem::Message { role, .. } if role == "developer" => true,
ResponseItem::Message { role, content, .. } if role == "user" => {
is_contextual_user_message_content(content)
}
_ => false,
}
}
fn user_message_positions(items: &[ResponseItem]) -> Vec<usize> {
let mut positions = Vec::new();
for (idx, item) in items.iter().enumerate() {

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::contextual_user_message::ENVIRONMENT_CONTEXT_FRAGMENT;
use crate::truncate;
use crate::truncate::TruncationPolicy;
use base64::Engine;
@@ -6,6 +7,7 @@ use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_git::GhostCommit;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::DeveloperInstructions;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::FunctionCallOutputPayload;
@@ -70,6 +72,10 @@ fn user_input_text_msg(text: &str) -> ResponseItem {
}
}
fn developer_msg(text: &str) -> ResponseItem {
DeveloperInstructions::new(text.to_string()).into()
}
fn custom_tool_call_output(call_id: &str, output: &str) -> ResponseItem {
ResponseItem::CustomToolCallOutput {
call_id: call_id.to_string(),
@@ -794,6 +800,65 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
assert_eq!(history.for_prompt(&modalities), expected_prefix_only);
}
#[test]
fn drop_last_n_user_turns_trims_adjacent_developer_and_contextual_messages() {
let contextual_update = ENVIRONMENT_CONTEXT_FRAGMENT
.into_message(ENVIRONMENT_CONTEXT_FRAGMENT.wrap("<cwd>/tmp/updated</cwd>".to_string()));
let items = vec![
user_input_text_msg("<environment_context>session prefix</environment_context>"),
user_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
developer_msg("<model_switch>switched to gpt-test</model_switch>"),
contextual_update,
user_msg("turn 2 user"),
assistant_msg("turn 2 assistant"),
];
let modalities = default_input_modalities();
let mut history = create_history_with_items(items);
history.drop_last_n_user_turns(1);
assert_eq!(
history.for_prompt(&modalities),
vec![
user_input_text_msg("<environment_context>session prefix</environment_context>"),
user_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
]
);
}
#[test]
fn drop_last_n_user_turns_keeps_non_adjacent_synthetic_context() {
let contextual_update = ENVIRONMENT_CONTEXT_FRAGMENT
.into_message(ENVIRONMENT_CONTEXT_FRAGMENT.wrap("<cwd>/tmp/updated</cwd>".to_string()));
let items = vec![
user_input_text_msg("<environment_context>session prefix</environment_context>"),
user_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
developer_msg("<model_switch>switched to gpt-test</model_switch>"),
assistant_msg("separator"),
contextual_update,
user_msg("turn 2 user"),
assistant_msg("turn 2 assistant"),
];
let modalities = default_input_modalities();
let mut history = create_history_with_items(items);
history.drop_last_n_user_turns(1);
assert_eq!(
history.for_prompt(&modalities),
vec![
user_input_text_msg("<environment_context>session prefix</environment_context>"),
user_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
developer_msg("<model_switch>switched to gpt-test</model_switch>"),
assistant_msg("separator"),
]
);
}
#[test]
fn remove_first_item_handles_custom_tool_pair() {
let items = vec![