Compare commits

...

1 Commits

Author SHA1 Message Date
Roy Han
48cada6438 roll back context-overflow turns 2026-05-20 15:35:05 -07:00
4 changed files with 465 additions and 71 deletions

View File

@@ -12,6 +12,7 @@ use tracing::info_span;
use crate::session::SteerInputError;
use crate::session::session::Session;
use crate::session::session::SessionSettingsUpdate;
use crate::session::turn_context::TurnContext;
use crate::config::Config;
use crate::realtime_context::REALTIME_TURN_TOKEN_BUDGET;
@@ -499,61 +500,23 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
}
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
let live_thread = match sess.live_thread_for_persistence("rollback thread") {
Ok(live_thread) => live_thread,
Err(_) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: "thread rollback requires persisted thread history".to_string(),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
if let Err(err) = live_thread.flush().await {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: format!("failed to flush thread persistence for rollback replay: {err}"),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
let (rollback_msg, flush_error) =
match apply_thread_rollback(sess, &turn_context, num_turns).await {
Ok(outcome) => outcome,
Err(message) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message,
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
let stored_history = match live_thread.load_history(/*include_archived*/ false).await {
Ok(history) => history,
Err(err) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: format!("failed to load thread history for rollback replay: {err}"),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
let rollback_event = ThreadRolledBackEvent { num_turns };
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone());
let replay_items = stored_history
.items
.into_iter()
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
.collect::<Vec<_>>();
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;
if let Err(err) = sess.flush_rollout().await {
if let Some(err) = flush_error {
sess.send_event(
turn_context.as_ref(),
EventMsg::Warning(WarningEvent {
@@ -572,6 +535,40 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
.await;
}
pub(super) async fn apply_thread_rollback(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
num_turns: u32,
) -> Result<(EventMsg, Option<String>), String> {
let live_thread = sess
.live_thread_for_persistence("rollback thread")
.map_err(|_| "thread rollback requires persisted thread history".to_string())?;
live_thread
.flush()
.await
.map_err(|err| format!("failed to flush thread persistence for rollback replay: {err}"))?;
let stored_history = live_thread
.load_history(/*include_archived*/ false)
.await
.map_err(|err| format!("failed to load thread history for rollback replay: {err}"))?;
let rollback_event = ThreadRolledBackEvent { num_turns };
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event);
let replay_items = stored_history
.items
.into_iter()
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
.collect::<Vec<_>>();
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;
let flush_error = sess.flush_rollout().await.err().map(|err| err.to_string());
Ok((rollback_msg, flush_error))
}
pub(super) async fn persist_thread_memory_mode_update(
sess: &Arc<Session>,
mode: ThreadMemoryMode,

View File

@@ -37,6 +37,7 @@ use crate::mentions::collect_tool_mentions_from_messages;
use crate::plugins::build_plugin_injections;
use crate::session::PreviousTurnSettings;
use crate::session::TurnInput;
use crate::session::handlers::apply_thread_rollback;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::stream_events_utils::HandleOutputCtx;
@@ -87,6 +88,7 @@ use codex_protocol::protocol::AgentMessageContentDeltaEvent;
use codex_protocol::protocol::AgentReasoningSectionBreakEvent;
use codex_protocol::protocol::CodexErrorInfo;
use codex_protocol::protocol::ErrorEvent;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::PlanDeltaEvent;
use codex_protocol::protocol::ReasoningContentDeltaEvent;
@@ -138,6 +140,7 @@ pub(crate) async fn run_turn(
) -> Option<String> {
let mut client_session =
prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session());
let mut initial_user_input_recorded = false;
// TODO(ccunningham): Pre-turn compaction runs before context updates and the
// new user message are recorded. Estimate pending incoming items (context
// diffs/full reinjection + user input) and trigger compaction preemptively
@@ -192,6 +195,7 @@ pub(crate) async fn run_turn(
user_prompt_submit_outcome.additional_contexts,
)
.await;
initial_user_input_recorded = true;
}
sess.merge_connector_selection(explicitly_enabled_connectors.clone())
@@ -337,6 +341,19 @@ pub(crate) async fn run_turn(
{
Ok(reset_client_session) => reset_client_session,
Err(err) => {
if matches!(&err, CodexErr::ContextWindowExceeded) {
match rollback_recorded_initial_turn_for_context_window(
&sess,
&turn_context,
initial_user_input_recorded,
)
.await
{
ContextWindowRollbackResult::RolledBack => return None,
ContextWindowRollbackResult::NotAttempted
| ContextWindowRollbackResult::Failed => {}
}
}
if err.to_codex_protocol_error() == CodexErrorInfo::UsageLimitExceeded
&& let Err(err) = sess
.goal_runtime_apply(GoalRuntimeEvent::UsageLimitReached {
@@ -409,6 +426,32 @@ pub(crate) async fn run_turn(
// Aborted turn is reported via a different event.
break;
}
Err(CodexErr::ContextWindowExceeded) => {
match rollback_recorded_initial_turn_for_context_window(
&sess,
&turn_context,
initial_user_input_recorded,
)
.await
{
ContextWindowRollbackResult::RolledBack => {}
ContextWindowRollbackResult::NotAttempted => {
let event = EventMsg::Error(
CodexErr::ContextWindowExceeded
.to_error_event(/*message_prefix*/ None),
);
sess.send_event(&turn_context, event).await;
}
ContextWindowRollbackResult::Failed => {
let event = EventMsg::Error(
CodexErr::ContextWindowExceeded
.to_error_event(/*message_prefix*/ None),
);
sess.send_event(&turn_context, event).await;
}
}
break;
}
Err(CodexErr::InvalidImageRequest()) => {
{
let mut state = sess.state.lock().await;
@@ -450,6 +493,73 @@ pub(crate) async fn run_turn(
last_agent_message
}
enum ContextWindowRollbackResult {
NotAttempted,
RolledBack,
Failed,
}
async fn rollback_recorded_initial_turn_for_context_window(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
initial_user_input_recorded: bool,
) -> ContextWindowRollbackResult {
if !initial_user_input_recorded {
return ContextWindowRollbackResult::NotAttempted;
}
warn!(
turn_id = %turn_context.sub_id,
"context window exceeded after initial user input was recorded; rolling back current turn"
);
match apply_thread_rollback(sess, turn_context, /*num_turns*/ 1).await {
Ok((rollback_msg, flush_error)) => {
if let Some(err) = flush_error {
warn!(
"rolled back context-window turn in memory, but failed to flush rollback marker: {err}"
);
sess.send_event(
turn_context.as_ref(),
EventMsg::Warning(WarningEvent {
message: format!(
"Rolled the thread back, but failed to save the rollback marker. Codex may retry this failure after resume. Error: {err}"
),
}),
)
.await;
}
sess.deliver_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: rollback_msg,
})
.await;
sess.send_event(
turn_context.as_ref(),
EventMsg::Error(ErrorEvent {
message: "This turn exceeded the model context window and was rolled back. Split the message into smaller chunks, or save the content to a file and ask Codex to inspect it.".to_string(),
codex_error_info: Some(CodexErrorInfo::ContextWindowExceeded),
}),
)
.await;
ContextWindowRollbackResult::RolledBack
}
Err(err) => {
warn!("failed to rollback context-window turn: {err}");
sess.send_event(
turn_context.as_ref(),
EventMsg::Error(ErrorEvent {
message: format!(
"This turn exceeded the model context window, but Codex could not roll it back automatically: {err}"
),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
)
.await;
ContextWindowRollbackResult::Failed
}
}
}
#[expect(
clippy::await_holding_invalid_type,
reason = "MCP tool listing borrows the read guard across cancellation-aware await"

View File

@@ -25,7 +25,6 @@ use codex_protocol::config_types::ModelProviderAuthInfo;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::Settings;
use codex_protocol::config_types::Verbosity;
use codex_protocol::error::CodexErr;
use codex_protocol::models::ContentItem;
use codex_protocol::models::DEFAULT_IMAGE_DETAIL;
use codex_protocol::models::FunctionCallOutputContentItem;
@@ -40,6 +39,7 @@ use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::models::ResponseItem;
use codex_protocol::models::WebSearchAction;
use codex_protocol::openai_models::ReasoningEffort;
use codex_protocol::protocol::CodexErrorInfo;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::RolloutItem;
@@ -2699,13 +2699,18 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
)
.await;
let TestCodex { codex, .. } = test_codex()
.with_config(|config| {
config.model = Some("gpt-5.4".to_string());
config.model_context_window = Some(272_000);
})
.build(&server)
.await?;
let mut builder = test_codex().with_config(|config| {
config.model = Some("gpt-5.4".to_string());
config.model_context_window = Some(272_000);
});
let test = builder.build(&server).await?;
let codex = Arc::clone(&test.codex);
let home = Arc::clone(&test.home);
let rollout_path = test
.session_configured
.rollout_path
.clone()
.expect("rollout path");
codex
.submit(Op::UserInput {
@@ -2761,18 +2766,138 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
EFFECTIVE_CONTEXT_WINDOW
);
let rollback_event =
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ThreadRolledBack(_))).await;
assert!(
matches!(
rollback_event,
EventMsg::ThreadRolledBack(ref rollback) if rollback.num_turns == 1
),
"expected automatic rollback after context window error; got {rollback_event:?}"
);
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
let expected_context_window_message = CodexErr::ContextWindowExceeded.to_string();
assert!(
matches!(
error_event,
EventMsg::Error(ref err) if err.message == expected_context_window_message
EventMsg::Error(ref err)
if err.codex_error_info == Some(CodexErrorInfo::ContextWindowExceeded)
&& err.message.contains("was rolled back")
),
"expected context window error; got {error_event:?}"
"expected context window rollback error; got {error_event:?}"
);
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let follow_up_prompt = "after rollback";
let follow_up_mock = mount_sse_once_match(
&server,
body_string_contains(follow_up_prompt),
sse(vec![
ev_response_created("resp_after_rollback"),
ev_completed("resp_after_rollback"),
]),
)
.await;
let mut resume_builder = test_codex().with_config(|config| {
config.model = Some("gpt-5.4".to_string());
config.model_context_window = Some(272_000);
});
let resumed = resume_builder.resume(&server, home, rollout_path).await?;
resumed
.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: follow_up_prompt.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
thread_settings: Default::default(),
})
.await?;
wait_for_event(&resumed.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let follow_up_request = follow_up_mock.single_request().body_json().to_string();
assert!(
!follow_up_request.contains("trigger context window"),
"rolled-back user input should not be replayed after resume: {follow_up_request}"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn context_window_error_reports_rollback_failure_when_thread_is_not_persisted()
-> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
mount_sse_once_match(
&server,
body_string_contains("trigger context window"),
sse_failed(
"resp_context_window",
"context_length_exceeded",
"Your input exceeds the context window of this model. Please adjust your input and try again.",
),
)
.await;
let mut builder = test_codex().with_config(|config| {
config.ephemeral = true;
config.model = Some("gpt-5.4".to_string());
config.model_context_window = Some(272_000);
});
let test = builder.build(&server).await?;
test.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "trigger context window".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
thread_settings: Default::default(),
})
.await?;
let mut saw_repair_failure = false;
let mut saw_context_window_error = false;
let mut saw_rollback = false;
loop {
match wait_for_event(&test.codex, |_| true).await {
EventMsg::Error(err)
if err.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) =>
{
saw_repair_failure = true;
}
EventMsg::Error(err)
if err.codex_error_info == Some(CodexErrorInfo::ContextWindowExceeded) =>
{
saw_context_window_error = true;
}
EventMsg::ThreadRolledBack(_) => saw_rollback = true,
EventMsg::TurnComplete(_) => break,
_ => {}
}
}
assert!(
saw_repair_failure,
"expected repair failure error when automatic rollback cannot load persisted history"
);
assert!(
saw_context_window_error,
"expected original context-window error after rollback failure"
);
assert!(
!saw_rollback,
"ephemeral thread should not emit a rollback marker"
);
Ok(())
}

View File

@@ -3062,6 +3062,159 @@ async fn snapshot_request_shape_mid_turn_continuation_compaction() {
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mid_turn_compaction_context_window_exceeded_rolls_back_recorded_turn() {
skip_if_no_network!();
let server = start_mock_server().await;
let context_window = 100;
let limit = context_window * 90 / 100;
let over_limit_tokens = context_window * 95 / 100 + 1;
let first_turn = sse(vec![
ev_function_call(DUMMY_CALL_ID, DUMMY_FUNCTION_NAME, "{}"),
ev_completed_with_tokens("r1", over_limit_tokens),
]);
let auto_compact_failure = || {
sse_failed(
"compact-failed",
"context_length_exceeded",
CONTEXT_LIMIT_MESSAGE,
)
};
let request_log = mount_sse_sequence(
&server,
vec![
first_turn,
auto_compact_failure(),
auto_compact_failure(),
auto_compact_failure(),
auto_compact_failure(),
auto_compact_failure(),
],
)
.await;
let mut model_provider = non_openai_model_provider(&server);
model_provider.stream_max_retries = Some(0);
let mut builder = test_codex().with_config(move |config| {
config.model_provider = model_provider;
set_test_compact_prompt(config);
config.model_context_window = Some(context_window);
config.model_auto_compact_token_limit = Some(limit);
});
let initial = builder.build(&server).await.unwrap();
let home = initial.home.clone();
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
initial
.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: FUNCTION_CALL_LIMIT_MSG.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
thread_settings: Default::default(),
})
.await
.unwrap();
let mut saw_rollback = false;
let mut error_messages = Vec::new();
loop {
match wait_for_event(&initial.codex, |_| true).await {
EventMsg::Error(err) => error_messages.push(err.message),
EventMsg::ThreadRolledBack(rollback) => {
assert_eq!(rollback.num_turns, 1);
saw_rollback = true;
}
EventMsg::TurnComplete(_) => break,
_ => {}
}
}
assert!(
saw_rollback,
"expected automatic rollback after failed mid-turn compaction, got errors: {error_messages:?}"
);
assert!(
error_messages
.iter()
.any(|message| message.contains("was rolled back")),
"expected rollback recovery message, got {error_messages:?}"
);
let requests = request_log.requests();
let first_request = requests[0].body_json().to_string();
assert!(
first_request.contains(FUNCTION_CALL_LIMIT_MSG),
"initial request should include the user input that later gets rolled back"
);
let auto_compact_body = requests[1].body_json().to_string();
assert!(
body_contains_text(&auto_compact_body, SUMMARIZATION_PROMPT),
"mid-turn auto compact request should include the summarization prompt"
);
let follow_up_prompt = "AFTER_ROLLBACK";
let follow_up_mock = mount_sse_once_match(
&server,
move |req: &wiremock::Request| {
std::str::from_utf8(&req.body).is_ok_and(|body| body.contains(follow_up_prompt))
},
sse(vec![
ev_assistant_message("m_after_rollback", FINAL_REPLY),
ev_completed("r_after_rollback"),
]),
)
.await;
let resume_model_provider = non_openai_model_provider(&server);
let mut resume_builder = test_codex().with_config(move |config| {
config.model_provider = resume_model_provider;
set_test_compact_prompt(config);
config.model_context_window = Some(200_000);
config.model_auto_compact_token_limit = Some(180_000);
});
let resumed = resume_builder
.resume(&server, home, rollout_path)
.await
.unwrap();
resumed
.codex
.submit(disabled_permission_user_turn(
follow_up_prompt,
resumed.cwd.path().to_path_buf(),
resumed.session_configured.model.clone(),
))
.await
.unwrap();
wait_for_event(&resumed.codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let follow_up_requests = follow_up_mock.requests();
let follow_up_request = follow_up_requests
.iter()
.find(|request| request.body_json().to_string().contains(follow_up_prompt))
.unwrap_or_else(|| panic!("expected resumed request containing {follow_up_prompt}"))
.body_json()
.to_string();
assert!(
!follow_up_request.contains(FUNCTION_CALL_LIMIT_MSG),
"rolled-back user input should not be replayed after resume: {follow_up_request}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn auto_compact_clamps_config_limit_to_context_window() {
skip_if_no_network!();
@@ -3855,12 +4008,16 @@ async fn snapshot_request_shape_pre_turn_compaction_context_window_exceeded() {
})
.await
.expect("submit second user");
let error_message = wait_for_event_match(&codex, |event| match event {
EventMsg::Error(err) => Some(err.message.clone()),
_ => None,
})
.await;
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let mut error_message = None;
let mut saw_rollback = false;
loop {
match wait_for_event(&codex, |_| true).await {
EventMsg::Error(err) => error_message = Some(err.message),
EventMsg::ThreadRolledBack(_) => saw_rollback = true,
EventMsg::TurnComplete(_) => break,
_ => {}
}
}
let requests = request_log.requests();
assert!(
@@ -3879,6 +4036,11 @@ async fn snapshot_request_shape_pre_turn_compaction_context_window_exceeded() {
)
);
let error_message = error_message.expect("expected context-window error");
assert!(
!saw_rollback,
"pre-turn compaction failure should not roll back incoming input that was not recorded"
);
assert!(
error_message.contains("ran out of room in the model's context window"),
"expected context window exceeded message, got {error_message}"