core: estimate post-usage tokens from all added items

This commit is contained in:
Charles Cunningham
2026-02-06 23:14:03 -08:00
parent 055c82c98b
commit 027f4318dd
2 changed files with 104 additions and 70 deletions

View File

@@ -26,6 +26,8 @@ pub(crate) struct ContextManager {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
token_info: Option<TokenUsageInfo>,
/// Number of history items present when `last_token_usage` was last updated.
last_usage_items_len: usize,
}
#[derive(Debug, Clone, Copy, Default)]
@@ -41,6 +43,7 @@ impl ContextManager {
Self {
items: Vec::new(),
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
last_usage_items_len: 0,
}
}
@@ -49,7 +52,9 @@ impl ContextManager {
}
pub(crate) fn set_token_info(&mut self, info: Option<TokenUsageInfo>) {
let has_info = info.is_some();
self.token_info = info;
self.last_usage_items_len = if has_info { self.items.len() } else { 0 };
}
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
@@ -59,6 +64,7 @@ impl ContextManager {
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
}
}
self.last_usage_items_len = self.items.len();
}
/// `items` is ordered from oldest to newest.
@@ -216,6 +222,7 @@ impl ContextManager {
&Some(usage.clone()),
model_context_window,
);
self.last_usage_items_len = self.items.len();
}
fn get_non_last_reasoning_items_tokens(&self) -> i64 {
@@ -245,30 +252,29 @@ impl ContextManager {
})
}
fn get_trailing_codex_generated_items_tokens(&self) -> i64 {
let mut total = 0i64;
for item in self.items.iter().rev() {
if !is_codex_generated_item(item) {
break;
}
total = total.saturating_add(estimate_item_token_count(item));
}
total
fn items_added_since_last_usage(&self) -> &[ResponseItem] {
let start = self.last_usage_items_len.min(self.items.len());
&self.items[start..]
}
fn get_trailing_codex_generated_items_bytes(&self) -> usize {
let mut total = 0usize;
for item in self.items.iter().rev() {
if !is_codex_generated_item(item) {
break;
}
total = total.saturating_add(
serde_json::to_vec(item)
.map(|bytes| bytes.len())
.unwrap_or_default(),
);
}
total
fn get_items_added_since_last_usage_tokens(&self) -> i64 {
self.items_added_since_last_usage()
.iter()
.fold(0i64, |acc, item| {
acc.saturating_add(estimate_item_token_count(item))
})
}
fn get_items_added_since_last_usage_bytes(&self) -> usize {
self.items_added_since_last_usage()
.iter()
.fold(0usize, |acc, item| {
acc.saturating_add(
serde_json::to_vec(item)
.map(|bytes| bytes.len())
.unwrap_or_default(),
)
})
}
/// When true, the server already accounted for past reasoning tokens and
@@ -279,13 +285,13 @@ impl ContextManager {
.as_ref()
.map(|info| info.last_token_usage.total_tokens)
.unwrap_or(0);
let trailing_codex_generated_tokens = self.get_trailing_codex_generated_items_tokens();
let items_added_since_last_usage_tokens = self.get_items_added_since_last_usage_tokens();
if server_reasoning_included {
last_tokens.saturating_add(trailing_codex_generated_tokens)
last_tokens.saturating_add(items_added_since_last_usage_tokens)
} else {
last_tokens
.saturating_add(self.get_non_last_reasoning_items_tokens())
.saturating_add(trailing_codex_generated_tokens)
.saturating_add(items_added_since_last_usage_tokens)
}
}
@@ -306,9 +312,9 @@ impl ContextManager {
.unwrap_or(usize::MAX)
},
estimated_tokens_of_items_added_since_last_successful_api_response: self
.get_trailing_codex_generated_items_tokens(),
.get_items_added_since_last_usage_tokens(),
estimated_bytes_of_items_added_since_last_successful_api_response: self
.get_trailing_codex_generated_items_bytes(),
.get_items_added_since_last_usage_bytes(),
}
}

View File

@@ -62,13 +62,6 @@ fn user_input_text_msg(text: &str) -> ResponseItem {
}
}
fn function_call_output(call_id: &str, content: &str) -> ResponseItem {
ResponseItem::FunctionCallOutput {
call_id: call_id.to_string(),
output: FunctionCallOutputPayload::from_text(content.to_string()),
}
}
fn custom_tool_call_output(call_id: &str, output: &str) -> ResponseItem {
ResponseItem::CustomToolCallOutput {
call_id: call_id.to_string(),
@@ -189,48 +182,51 @@ fn non_last_reasoning_tokens_ignore_entries_after_last_user() {
}
#[test]
fn trailing_codex_generated_tokens_stop_at_first_non_generated_item() {
let earlier_output = function_call_output("call-earlier", "earlier output");
let trailing_function_output = function_call_output("call-tail-1", "tail function output");
let trailing_custom_output = custom_tool_call_output("call-tail-2", "tail custom output");
let history = create_history_with_items(vec![
earlier_output,
user_msg("boundary item"),
trailing_function_output.clone(),
trailing_custom_output.clone(),
]);
let expected_tokens = estimate_item_token_count(&trailing_function_output)
.saturating_add(estimate_item_token_count(&trailing_custom_output));
fn usage_breakdown_counts_all_items_added_since_last_usage() {
let mut history = create_history_with_items(vec![assistant_msg("already counted by API")]);
history.update_token_info(
&TokenUsage {
total_tokens: 100,
..Default::default()
},
None,
);
let added_user = user_msg("new user message");
let added_tool_output = custom_tool_call_output("call-tail", "new tool output");
history.record_items(
[&added_user, &added_tool_output],
TruncationPolicy::Tokens(10_000),
);
let expected_tokens = estimate_item_token_count(&added_user)
.saturating_add(estimate_item_token_count(&added_tool_output));
let expected_bytes = serde_json::to_vec(&added_user)
.map(|bytes| bytes.len())
.unwrap_or_default()
.saturating_add(
serde_json::to_vec(&added_tool_output)
.map(|bytes| bytes.len())
.unwrap_or_default(),
);
assert_eq!(
history.get_trailing_codex_generated_items_tokens(),
history
.get_total_token_usage_breakdown()
.estimated_tokens_of_items_added_since_last_successful_api_response,
expected_tokens
);
assert_eq!(
history
.get_total_token_usage_breakdown()
.estimated_bytes_of_items_added_since_last_successful_api_response,
expected_bytes
);
}
#[test]
fn trailing_codex_generated_tokens_exclude_function_call_tail() {
let history = create_history_with_items(vec![ResponseItem::FunctionCall {
id: None,
name: "not-generated".to_string(),
arguments: "{}".to_string(),
call_id: "call-tail".to_string(),
}]);
assert_eq!(history.get_trailing_codex_generated_items_tokens(), 0);
}
#[test]
fn total_token_usage_includes_only_trailing_codex_generated_items() {
let non_trailing_output = function_call_output("call-before-message", "not trailing");
let trailing_assistant = assistant_msg("assistant boundary");
let trailing_output = custom_tool_call_output("tool-tail", "trailing output");
let mut history = create_history_with_items(vec![
non_trailing_output,
user_msg("boundary"),
trailing_assistant,
trailing_output.clone(),
]);
fn usage_breakdown_counts_no_added_items_when_nothing_changed_since_last_usage() {
let mut history = create_history_with_items(vec![assistant_msg("already counted by API")]);
history.update_token_info(
&TokenUsage {
total_tokens: 100,
@@ -239,9 +235,41 @@ fn total_token_usage_includes_only_trailing_codex_generated_items() {
None,
);
assert_eq!(
history
.get_total_token_usage_breakdown()
.estimated_tokens_of_items_added_since_last_successful_api_response,
0
);
assert_eq!(
history
.get_total_token_usage_breakdown()
.estimated_bytes_of_items_added_since_last_successful_api_response,
0
);
}
#[test]
fn total_token_usage_includes_all_items_added_since_last_usage() {
let mut history = create_history_with_items(vec![assistant_msg("already counted by API")]);
history.update_token_info(
&TokenUsage {
total_tokens: 100,
..Default::default()
},
None,
);
let added_user = user_msg("new user message");
let added_tool_output = custom_tool_call_output("tool-tail", "new tool output");
history.record_items(
[&added_user, &added_tool_output],
TruncationPolicy::Tokens(10_000),
);
assert_eq!(
history.get_total_token_usage(true),
100 + estimate_item_token_count(&trailing_output)
100 + estimate_item_token_count(&added_user)
+ estimate_item_token_count(&added_tool_output)
);
}