This commit is contained in:
Ahmed Ibrahim
2025-10-23 23:25:45 -07:00
parent 0f02954edb
commit 9446de0923
4 changed files with 57 additions and 83 deletions

View File

@@ -542,7 +542,7 @@ impl Session {
);
// Create the mutable state for the Session.
let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), config.model_context_window);
let services = SessionServices {
mcp_connection_manager,
@@ -876,7 +876,7 @@ impl Session {
turn_context: &TurnContext,
rollout_items: &[RolloutItem],
) -> CodexResult<Vec<ResponseItem>> {
let mut history = ConversationHistory::new();
let mut history = ConversationHistory::new(turn_context.client.get_model_context_window());
for item in rollout_items {
match item {
RolloutItem::ResponseItem(response_item) => {
@@ -1543,7 +1543,8 @@ pub(crate) async fn run_task(
// For normal turns, continue recording to the session history as before.
let is_review_mode = turn_context.is_review_mode;
let mut review_thread_history: ConversationHistory = ConversationHistory::new();
let mut review_thread_history: ConversationHistory =
ConversationHistory::new(turn_context.client.get_model_context_window());
if is_review_mode {
// Seed review threads with environment context so the model knows the working directory.
review_thread_history
@@ -2569,7 +2570,7 @@ mod tests {
original_config_do_not_use: Arc::clone(&config),
};
let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), config.model_context_window);
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
@@ -2637,7 +2638,7 @@ mod tests {
original_config_do_not_use: Arc::clone(&config),
};
let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), config.model_context_window);
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
@@ -2860,7 +2861,8 @@ mod tests {
turn_context: &TurnContext,
) -> CodexResult<(Vec<RolloutItem>, Vec<ResponseItem>)> {
let mut rollout_items = Vec::new();
let mut live_history = ConversationHistory::new();
let mut live_history =
ConversationHistory::new(turn_context.client.get_model_context_window());
let initial_context = session.build_initial_context(turn_context);
for item in &initial_context {

View File

@@ -1,14 +1,17 @@
use std::sync::OnceLock;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::protocol::TokenUsageInfo;
use codex_utils_tokenizer::Tokenizer;
use codex_utils_tokenizer::TokenizerError;
use tracing::error;
use crate::error::CodexErr;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
pub(crate) struct ConversationHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
@@ -16,13 +19,26 @@ pub(crate) struct ConversationHistory {
}
impl ConversationHistory {
pub(crate) fn new() -> Self {
pub(crate) fn new(model_context_window: Option<i64>) -> Self {
let token_info = model_context_window.map(|context_window| TokenUsageInfo {
total_token_usage: TokenUsage::default(),
last_token_usage: TokenUsage::default(),
model_context_window: Some(context_window),
});
Self {
items: Vec::new(),
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
token_info,
}
}
fn tokenizer() -> Result<&'static Tokenizer, CodexErr> {
static TOKENIZER: OnceLock<Result<Tokenizer, TokenizerError>> = OnceLock::new();
let tokenizer = TOKENIZER.get_or_init(Tokenizer::try_default);
tokenizer
.as_ref()
.map_err(|e| CodexErr::InvalidInput(format!("tokenizer error: {e}")))
}
pub(crate) fn token_info(&self) -> Option<TokenUsageInfo> {
self.token_info.clone()
}
@@ -113,8 +129,7 @@ impl ConversationHistory {
return Ok(());
};
let tokenizer = Tokenizer::try_default()
.map_err(|e| CodexErr::InvalidInput(format!("tokenizer error: {e}")))?;
let tokenizer = Self::tokenizer()?;
let mut input_tokens: i64 = 0;
for item in content {
@@ -131,8 +146,8 @@ impl ConversationHistory {
}
}
let last_turn_total = info.last_token_usage.total_tokens;
let combined_tokens = input_tokens + last_turn_total;
let prior_total = info.total_token_usage.total_tokens;
let combined_tokens = prior_total.saturating_add(input_tokens);
let threshold = context_window * 95 / 100;
if combined_tokens > threshold {
return Err(CodexErr::InvalidInput("input too large".to_string()));
@@ -394,6 +409,12 @@ impl ConversationHistory {
}
}
impl Default for ConversationHistory {
fn default() -> Self {
Self::new(None)
}
}
#[inline]
fn error_or_panic(message: String) {
if cfg!(debug_assertions) || env!("CARGO_PKG_VERSION").contains("alpha") {
@@ -440,7 +461,7 @@ mod tests {
}
fn create_history_with_items(items: Vec<ResponseItem>) -> ConversationHistory {
let mut h = ConversationHistory::new();
let mut h = ConversationHistory::new(None);
h.record_items(items.iter()).unwrap();
h
}

View File

@@ -18,10 +18,13 @@ pub(crate) struct SessionState {
impl SessionState {
/// Create a new session state mirroring previous `State::default()` semantics.
pub(crate) fn new(session_configuration: SessionConfiguration) -> Self {
pub(crate) fn new(
session_configuration: SessionConfiguration,
model_context_window: Option<i64>,
) -> Self {
Self {
session_configuration,
history: ConversationHistory::new(),
history: ConversationHistory::new(model_context_window),
latest_rate_limits: None,
}
}

View File

@@ -279,11 +279,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
ev_completed_with_tokens("r2", 330_000),
]);
let sse3 = sse(vec![
ev_assistant_message("m3", AUTO_SUMMARY_TEXT),
ev_completed_with_tokens("r3", 200),
]);
let first_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body.contains(FIRST_AUTO_MSG)
@@ -300,12 +295,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
};
mount_sse_once_match(&server, second_matcher, sse2).await;
let third_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body.contains("You have exceeded the maximum number of tokens")
};
mount_sse_once_match(&server, third_matcher, sse3).await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
..built_in_model_providers()["openai"].clone()
@@ -342,69 +331,28 @@ async fn auto_compact_runs_after_token_limit_hit() {
.await
.unwrap();
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
let EventMsg::Error(error_event) = error_event else {
unreachable!("wait_for_event returned unexpected payload");
};
assert_eq!(error_event.message, "invalid input: input too large");
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
// wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.unwrap();
assert!(
requests.len() >= 3,
"auto compact should add at least a third request, got {}",
requests.len()
assert_eq!(
requests.len(),
2,
"auto compact should reject oversize prompts before issuing another request"
);
let is_auto_compact = |req: &wiremock::Request| {
let saw_compact_prompt = requests.iter().any(|req| {
std::str::from_utf8(&req.body)
.unwrap_or("")
.contains("You have exceeded the maximum number of tokens")
};
let auto_compact_count = requests.iter().filter(|req| is_auto_compact(req)).count();
assert_eq!(
auto_compact_count, 1,
"expected exactly one auto compact request"
);
let auto_compact_index = requests
.iter()
.enumerate()
.find_map(|(idx, req)| is_auto_compact(req).then_some(idx))
.expect("auto compact request missing");
assert_eq!(
auto_compact_index, 2,
"auto compact should add a third request"
);
let body_first = requests[0].body_json::<serde_json::Value>().unwrap();
let body3 = requests[auto_compact_index]
.body_json::<serde_json::Value>()
.unwrap();
let instructions = body3
.get("instructions")
.and_then(|v| v.as_str())
.unwrap_or_default();
let baseline_instructions = body_first
.get("instructions")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
assert_eq!(
instructions, baseline_instructions,
"auto compact should keep the standard developer instructions",
);
let input3 = body3.get("input").and_then(|v| v.as_array()).unwrap();
let last3 = input3
.last()
.expect("auto compact request should append a user message");
assert_eq!(last3.get("type").and_then(|v| v.as_str()), Some("message"));
assert_eq!(last3.get("role").and_then(|v| v.as_str()), Some("user"));
let last_text = last3
.get("content")
.and_then(|v| v.as_array())
.and_then(|items| items.first())
.and_then(|item| item.get("text"))
.and_then(|text| text.as_str())
.unwrap_or_default();
assert_eq!(
last_text, SUMMARIZATION_PROMPT,
"auto compact should send the summarization prompt as a user message",
});
assert!(
!saw_compact_prompt,
"no auto compact request should be sent when the summarization prompt exceeds the limit"
);
}