Fix remote compaction estimator/payload instruction small mismatch (#10692)

## Summary
This PR fixes a deterministic mismatch in remote compaction where
pre-trim estimation and the `/v1/responses/compact` payload could use
different base instructions.

Before this change:
- pre-trim estimation used model-derived instructions
(`model_info.get_model_instructions(...)`)
- compact payload used session base instructions
(`sess.get_base_instructions()`)

After this change:
- remote pre-trim estimation and compact payload both use the same
`BaseInstructions` instance from session state.

## Changes
- Added a shared estimator entry point in `ContextManager`:
- `estimate_token_count_with_base_instructions(&self, base_instructions:
&BaseInstructions) -> Option<i64>`
- Kept `estimate_token_count(&TurnContext)` as a thin wrapper that
resolves model/personality instructions and delegates to the new helper.
- Updated remote compaction flow to fetch base instructions once and
reuse it for both:
  - trim preflight estimation
  - compact request payload construction
- Added regression coverage for parity and behavior:
  - unit test verifying explicit-base estimator behavior
- integration test proving remote compaction uses session override
instructions and trims accordingly

## Why this matters
This removes a deterministic divergence source where pre-trim could
think the request fits while the actual compact request exceeded context
because its instructions were longer/different.

## Scope
In scope:
- estimator/payload base-instructions parity in remote compaction

Out of scope:
- retry-on-`context_length_exceeded`
- compaction threshold/headroom policy changes
- broader trimming policy changes

## Codex author:
`codex fork 019c2b24-c2df-7b31-a482-fb8cf7a28559`
This commit is contained in:
Charley Cunningham
2026-02-04 23:24:06 -08:00
committed by GitHub
parent cd5f49a619
commit dc7007beaa
5 changed files with 314 additions and 10 deletions

View File

@@ -2117,10 +2117,10 @@ impl Session {
}
pub(crate) async fn recompute_token_usage(&self, turn_context: &TurnContext) {
let Some(estimated_total_tokens) = self
.clone_history()
.await
.estimate_token_count(turn_context)
let history = self.clone_history().await;
let base_instructions = self.get_base_instructions().await;
let Some(estimated_total_tokens) =
history.estimate_token_count_with_base_instructions(&base_instructions)
else {
return;
};
@@ -4782,6 +4782,7 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_app_server_protocol::AppInfo;
use codex_app_server_protocol::AuthMode;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use std::path::Path;
@@ -5061,6 +5062,46 @@ mod tests {
assert_eq!(actual, Some(info2));
}
#[tokio::test]
async fn recompute_token_usage_uses_session_base_instructions() {
let (session, turn_context) = make_session_and_context().await;
let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120);
{
let mut state = session.state.lock().await;
state.session_configuration.base_instructions = override_instructions.clone();
}
let item = user_message("hello");
session
.record_into_history(std::slice::from_ref(&item), &turn_context)
.await;
let history = session.clone_history().await;
let session_base_instructions = BaseInstructions {
text: override_instructions,
};
let expected_tokens = history
.estimate_token_count_with_base_instructions(&session_base_instructions)
.expect("estimate with session base instructions");
let model_estimated_tokens = history
.estimate_token_count(&turn_context)
.expect("estimate with model instructions");
assert_ne!(expected_tokens, model_estimated_tokens);
session.recompute_token_usage(&turn_context).await;
let actual_tokens = session
.state
.lock()
.await
.token_info()
.expect("token info")
.last_token_usage
.total_tokens;
assert_eq!(actual_tokens, expected_tokens.max(0));
}
#[tokio::test]
async fn record_initial_history_reconstructs_forked_transcript() {
let (session, turn_context) = make_session_and_context().await;

View File

@@ -12,6 +12,7 @@ use crate::protocol::RolloutItem;
use crate::protocol::TurnStartedEvent;
use codex_protocol::items::ContextCompactionItem;
use codex_protocol::items::TurnItem;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ResponseItem;
use tracing::info;
@@ -49,8 +50,12 @@ async fn run_remote_compact_task_inner_impl(
sess.emit_turn_item_started(turn_context, &compaction_item)
.await;
let mut history = sess.clone_history().await;
let deleted_items =
trim_function_call_history_to_fit_context_window(&mut history, turn_context.as_ref());
let base_instructions = sess.get_base_instructions().await;
let deleted_items = trim_function_call_history_to_fit_context_window(
&mut history,
turn_context.as_ref(),
&base_instructions,
);
if deleted_items > 0 {
info!(
turn_id = %turn_context.sub_id,
@@ -71,7 +76,7 @@ async fn run_remote_compact_task_inner_impl(
input: history.for_prompt(),
tools: vec![],
parallel_tool_calls: false,
base_instructions: sess.get_base_instructions().await,
base_instructions,
personality: turn_context.personality,
output_schema: None,
};
@@ -107,6 +112,7 @@ async fn run_remote_compact_task_inner_impl(
fn trim_function_call_history_to_fit_context_window(
history: &mut ContextManager,
turn_context: &TurnContext,
base_instructions: &BaseInstructions,
) -> usize {
let mut deleted_items = 0usize;
let Some(context_window) = turn_context.model_context_window() else {
@@ -114,7 +120,7 @@ fn trim_function_call_history_to_fit_context_window(
};
while history
.estimate_token_count(turn_context)
.estimate_token_count_with_base_instructions(base_instructions)
.is_some_and(|estimated_tokens| estimated_tokens > context_window)
{
let Some(last_item) = history.raw_items().last() else {

View File

@@ -9,6 +9,7 @@ use crate::truncate::approx_tokens_from_byte_count;
use crate::truncate::truncate_function_output_items_with_policy;
use crate::truncate::truncate_text;
use crate::user_shell_command::is_user_shell_command_text;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
@@ -88,8 +89,18 @@ impl ContextManager {
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let model_info = &turn_context.model_info;
let personality = turn_context.personality.or(turn_context.config.personality);
let base_instructions = model_info.get_model_instructions(personality);
let base_tokens = i64::try_from(approx_token_count(&base_instructions)).unwrap_or(i64::MAX);
let base_instructions = BaseInstructions {
text: model_info.get_model_instructions(personality),
};
self.estimate_token_count_with_base_instructions(&base_instructions)
}
pub(crate) fn estimate_token_count_with_base_instructions(
&self,
base_instructions: &BaseInstructions,
) -> Option<i64> {
let base_tokens =
i64::try_from(approx_token_count(&base_instructions.text)).unwrap_or(i64::MAX);
let items_tokens = self.items.iter().fold(0i64, |acc, item| {
acc.saturating_add(estimate_item_token_count(item))

View File

@@ -2,6 +2,7 @@ use super::*;
use crate::truncate;
use crate::truncate::TruncationPolicy;
use codex_git::GhostCommit;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
@@ -103,6 +104,10 @@ fn truncate_exec_output(content: &str) -> String {
truncate::truncate_text(content, TruncationPolicy::Tokens(EXEC_FORMAT_MAX_TOKENS))
}
fn approx_token_count_for_text(text: &str) -> i64 {
i64::try_from(text.len().saturating_add(3) / 4).unwrap_or(i64::MAX)
}
#[test]
fn filters_non_api_messages() {
let mut h = ContextManager::default();
@@ -250,6 +255,28 @@ fn get_history_for_prompt_drops_ghost_commits() {
assert_eq!(filtered, vec![]);
}
#[test]
fn estimate_token_count_with_base_instructions_uses_provided_text() {
let history = create_history_with_items(vec![assistant_msg("hello from history")]);
let short_base = BaseInstructions {
text: "short".to_string(),
};
let long_base = BaseInstructions {
text: "x".repeat(1_000),
};
let short_estimate = history
.estimate_token_count_with_base_instructions(&short_base)
.expect("token estimate");
let long_estimate = history
.estimate_token_count_with_base_instructions(&long_base)
.expect("token estimate");
let expected_delta = approx_token_count_for_text(&long_base.text)
- approx_token_count_for_text(&short_base.text);
assert_eq!(long_estimate - short_estimate, expected_delta);
}
#[test]
fn remove_first_item_removes_matching_output_for_function_call() {
let items = vec![

View File

@@ -25,6 +25,21 @@ use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
fn approx_token_count(text: &str) -> i64 {
i64::try_from(text.len().saturating_add(3) / 4).unwrap_or(i64::MAX)
}
fn estimate_compact_input_tokens(request: &responses::ResponsesRequest) -> i64 {
request.input().into_iter().fold(0i64, |acc, item| {
acc.saturating_add(approx_token_count(&item.to_string()))
})
}
fn estimate_compact_payload_tokens(request: &responses::ResponsesRequest) -> i64 {
estimate_compact_input_tokens(request)
.saturating_add(approx_token_count(&request.instructions_text()))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_compact_replaces_history_for_followups() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -351,6 +366,210 @@ async fn remote_compact_trims_function_call_history_to_fit_context_window() -> R
Ok(())
}
#[cfg_attr(target_os = "windows", ignore)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_compact_trim_estimate_uses_session_base_instructions() -> Result<()> {
skip_if_no_network!(Ok(()));
let first_user_message = "turn with baseline shell call";
let second_user_message = "turn with trailing shell call";
let baseline_retained_call_id = "baseline-retained-call";
let baseline_trailing_call_id = "baseline-trailing-call";
let override_retained_call_id = "override-retained-call";
let override_trailing_call_id = "override-trailing-call";
let retained_command = "printf retained-shell-output";
let trailing_command = "printf trailing-shell-output";
let baseline_harness = TestCodexHarness::with_builder(
test_codex()
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(|config| {
config.features.enable(Feature::RemoteCompaction);
config.model_context_window = Some(200_000);
}),
)
.await?;
let baseline_codex = baseline_harness.test().codex.clone();
responses::mount_sse_sequence(
baseline_harness.server(),
vec![
sse(vec![
responses::ev_shell_command_call(baseline_retained_call_id, retained_command),
responses::ev_completed("baseline-retained-call-response"),
]),
sse(vec![
responses::ev_assistant_message("baseline-retained-assistant", "retained complete"),
responses::ev_completed("baseline-retained-final-response"),
]),
sse(vec![
responses::ev_shell_command_call(baseline_trailing_call_id, trailing_command),
responses::ev_completed("baseline-trailing-call-response"),
]),
sse(vec![responses::ev_completed(
"baseline-trailing-final-response",
)]),
],
)
.await;
baseline_codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: first_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&baseline_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
baseline_codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: second_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&baseline_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let baseline_compact_mock = responses::mount_compact_json_once(
baseline_harness.server(),
serde_json::json!({ "output": [] }),
)
.await;
baseline_codex.submit(Op::Compact).await?;
wait_for_event(&baseline_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let baseline_compact_request = baseline_compact_mock.single_request();
assert!(
baseline_compact_request.has_function_call(baseline_retained_call_id),
"expected baseline compact request to retain older function call history"
);
assert!(
baseline_compact_request.has_function_call(baseline_trailing_call_id),
"expected baseline compact request to retain trailing function call history"
);
let baseline_input_tokens = estimate_compact_input_tokens(&baseline_compact_request);
let baseline_payload_tokens = estimate_compact_payload_tokens(&baseline_compact_request);
let override_base_instructions =
format!("REMOTE_BASE_INSTRUCTIONS_OVERRIDE {}", "x".repeat(120_000));
let override_context_window = baseline_payload_tokens.saturating_add(1_000);
let pretrim_override_estimate =
baseline_input_tokens.saturating_add(approx_token_count(&override_base_instructions));
assert!(
pretrim_override_estimate > override_context_window,
"expected override instructions to push pre-trim estimate past the context window"
);
let override_harness = TestCodexHarness::with_builder(
test_codex()
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config({
let override_base_instructions = override_base_instructions.clone();
move |config| {
config.features.enable(Feature::RemoteCompaction);
config.model_context_window = Some(override_context_window);
config.base_instructions = Some(override_base_instructions);
}
}),
)
.await?;
let override_codex = override_harness.test().codex.clone();
responses::mount_sse_sequence(
override_harness.server(),
vec![
sse(vec![
responses::ev_shell_command_call(override_retained_call_id, retained_command),
responses::ev_completed("override-retained-call-response"),
]),
sse(vec![
responses::ev_assistant_message("override-retained-assistant", "retained complete"),
responses::ev_completed("override-retained-final-response"),
]),
sse(vec![
responses::ev_shell_command_call(override_trailing_call_id, trailing_command),
responses::ev_completed("override-trailing-call-response"),
]),
sse(vec![responses::ev_completed(
"override-trailing-final-response",
)]),
],
)
.await;
override_codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: first_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&override_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
override_codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: second_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&override_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let override_compact_mock = responses::mount_compact_json_once(
override_harness.server(),
serde_json::json!({ "output": [] }),
)
.await;
override_codex.submit(Op::Compact).await?;
wait_for_event(&override_codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let override_compact_request = override_compact_mock.single_request();
assert_eq!(
override_compact_request.instructions_text(),
override_base_instructions
);
assert!(
override_compact_request.has_function_call(override_retained_call_id),
"expected remote compact request to preserve older function call history"
);
assert!(
!override_compact_request.has_function_call(override_trailing_call_id),
"expected remote compact request to trim trailing function call history with override instructions"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_manual_compact_emits_context_compaction_items() -> Result<()> {
skip_if_no_network!(Ok(()));