Files
codex/codex-rs/core/src/compact.rs
2026-01-09 17:31:17 +00:00

485 lines
16 KiB
Rust

use std::sync::Arc;
use crate::ModelProviderInfo;
use crate::Prompt;
use crate::client_common::ResponseEvent;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::codex::get_last_assistant_message_from_turn;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::features::Feature;
use crate::protocol::CompactedItem;
use crate::protocol::ContextCompactedEvent;
use crate::protocol::EventMsg;
use crate::protocol::TurnContextItem;
use crate::protocol::TurnStartedEvent;
use crate::protocol::WarningEvent;
use crate::truncate::TruncationPolicy;
use crate::truncate::approx_token_count;
use crate::truncate::truncate_text;
use crate::util::backoff;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::user_input::UserInput;
use futures::prelude::*;
use tracing::error;
pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md");
pub const SUMMARY_PREFIX: &str = include_str!("../templates/compact/summary_prefix.md");
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
pub(crate) fn should_use_remote_compact_task(
session: &Session,
provider: &ModelProviderInfo,
) -> bool {
provider.is_openai() && session.enabled(Feature::RemoteCompaction)
}
pub(crate) async fn run_inline_auto_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
) {
let prompt = turn_context.compact_prompt().to_string();
let input = vec![UserInput::Text { text: prompt }];
run_compact_task_inner(sess, turn_context, input).await;
}
pub(crate) async fn run_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
) {
let start_event = EventMsg::TurnStarted(TurnStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
});
sess.send_event(&turn_context, start_event).await;
run_compact_task_inner(sess.clone(), turn_context, input).await;
}
async fn run_compact_task_inner(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
) {
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let mut history = sess.clone_history().await;
history.record_items(
&[initial_input_for_turn.into()],
turn_context.truncation_policy,
);
let mut truncated_count = 0usize;
let max_retries = turn_context.client.get_provider().stream_max_retries();
let mut retries = 0;
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
base_instructions: turn_context.base_instructions.clone(),
user_instructions: turn_context.user_instructions.clone(),
developer_instructions: turn_context.developer_instructions.clone(),
final_output_json_schema: turn_context.final_output_json_schema.clone(),
truncation_policy: Some(turn_context.truncation_policy.into()),
});
sess.persist_rollout_items(&[rollout_item]).await;
loop {
// Clone is required because of the loop
let turn_input = history.clone().for_prompt();
let turn_input_len = turn_input.len();
let prompt = Prompt {
input: turn_input,
..Default::default()
};
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
match attempt_result {
Ok(()) => {
if truncated_count > 0 {
sess.notify_background_event(
turn_context.as_ref(),
format!(
"Trimmed {truncated_count} older thread item(s) before compacting so the prompt fits the model context window."
),
)
.await;
}
break;
}
Err(CodexErr::Interrupted) => {
return;
}
Err(e @ CodexErr::ContextWindowExceeded) => {
if turn_input_len > 1 {
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
error!(
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
);
history.remove_first_item();
truncated_count += 1;
retries = 0;
continue;
}
sess.set_total_tokens_full(turn_context.as_ref()).await;
let event = EventMsg::Error(e.to_error_event(None));
sess.send_event(&turn_context, event).await;
return;
}
Err(e) => {
if retries < max_retries {
retries += 1;
let delay = backoff(retries);
sess.notify_stream_error(
turn_context.as_ref(),
format!("Reconnecting... {retries}/{max_retries}"),
e,
)
.await;
tokio::time::sleep(delay).await;
continue;
} else {
let event = EventMsg::Error(e.to_error_event(None));
sess.send_event(&turn_context, event).await;
return;
}
}
}
}
let history_snapshot = sess.clone_history().await;
let history_items = history_snapshot.raw_items();
let summary_suffix = get_last_assistant_message_from_turn(history_items).unwrap_or_default();
let summary_text = format!("{SUMMARY_PREFIX}\n{summary_suffix}");
let user_messages = collect_user_messages(history_items);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let mut new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
let ghost_snapshots: Vec<ResponseItem> = history_items
.iter()
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
.cloned()
.collect();
new_history.extend(ghost_snapshots);
sess.replace_history(new_history).await;
sess.recompute_token_usage(&turn_context).await;
let rollout_item = RolloutItem::Compacted(CompactedItem {
message: summary_text.clone(),
replacement_history: None,
});
sess.persist_rollout_items(&[rollout_item]).await;
let event = EventMsg::ContextCompacted(ContextCompactedEvent {});
sess.send_event(&turn_context, event).await;
let warning = EventMsg::Warning(WarningEvent {
message: "Heads up: Long threads and multiple compactions can cause the model to be less accurate. Start a new thread when possible to keep threads small and targeted.".to_string(),
});
sess.send_event(&turn_context, warning).await;
}
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
let mut pieces = Vec::new();
for item in content {
match item {
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
if !text.is_empty() {
pieces.push(text.as_str());
}
}
ContentItem::InputImage { .. } => {}
}
}
if pieces.is_empty() {
None
} else {
Some(pieces.join("\n"))
}
}
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match crate::event_mapping::parse_turn_item(item) {
Some(TurnItem::UserMessage(user)) => {
if is_summary_message(&user.message()) {
None
} else {
Some(user.message())
}
}
_ => None,
})
.collect()
}
pub(crate) fn is_summary_message(message: &str) -> bool {
message.starts_with(format!("{SUMMARY_PREFIX}\n").as_str())
}
pub(crate) fn build_compacted_history(
initial_context: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
) -> Vec<ResponseItem> {
build_compacted_history_with_limit(
initial_context,
user_messages,
summary_text,
COMPACT_USER_MESSAGE_MAX_TOKENS,
)
}
fn build_compacted_history_with_limit(
mut history: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
max_tokens: usize,
) -> Vec<ResponseItem> {
let mut selected_messages: Vec<String> = Vec::new();
if max_tokens > 0 {
let mut remaining = max_tokens;
for message in user_messages.iter().rev() {
if remaining == 0 {
break;
}
let tokens = approx_token_count(message);
if tokens <= remaining {
selected_messages.push(message.clone());
remaining = remaining.saturating_sub(tokens);
} else {
let truncated = truncate_text(message, TruncationPolicy::Tokens(remaining));
selected_messages.push(truncated);
break;
}
}
selected_messages.reverse();
}
for message in &selected_messages {
history.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: message.clone(),
}],
});
}
let summary_text = if summary_text.is_empty() {
"(no summary available)".to_string()
} else {
summary_text.to_string()
};
history.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: summary_text }],
});
history
}
async fn drain_to_completed(
sess: &Session,
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
sess.record_into_history(std::slice::from_ref(&item), turn_context)
.await;
}
Ok(ResponseEvent::RateLimits(snapshot)) => {
sess.update_rate_limits(turn_context, snapshot).await;
}
Ok(ResponseEvent::Completed { token_usage, .. }) => {
sess.update_token_usage_info(turn_context, token_usage.as_ref())
.await;
return Ok(());
}
Ok(_) => continue,
Err(e) => return Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn content_items_to_text_joins_non_empty_segments() {
let items = vec![
ContentItem::InputText {
text: "hello".to_string(),
},
ContentItem::OutputText {
text: String::new(),
},
ContentItem::OutputText {
text: "world".to_string(),
},
];
let joined = content_items_to_text(&items);
assert_eq!(Some("hello\nworld".to_string()), joined);
}
#[test]
fn content_items_to_text_ignores_image_only_content() {
let items = vec![ContentItem::InputImage {
image_url: "file://image.png".to_string(),
}];
let joined = content_items_to_text(&items);
assert_eq!(None, joined);
}
#[test]
fn collect_user_messages_extracts_user_text_only() {
let items = vec![
ResponseItem::Message {
id: Some("assistant".to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "ignored".to_string(),
}],
},
ResponseItem::Message {
id: Some("user".to_string()),
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "first".to_string(),
}],
},
ResponseItem::Other,
];
let collected = collect_user_messages(&items);
assert_eq!(vec!["first".to_string()], collected);
}
#[test]
fn collect_user_messages_filters_session_prefix_entries() {
let items = vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "# AGENTS.md instructions for project\n\n<INSTRUCTIONS>\ndo things\n</INSTRUCTIONS>"
.to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<ENVIRONMENT_CONTEXT>cwd=/tmp</ENVIRONMENT_CONTEXT>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "real user message".to_string(),
}],
},
];
let collected = collect_user_messages(&items);
assert_eq!(vec!["real user message".to_string()], collected);
}
#[test]
fn build_token_limited_compacted_history_truncates_overlong_user_messages() {
// Use a small truncation limit so the test remains fast while still validating
// that oversized user content is truncated.
let max_tokens = 16;
let big = "word ".repeat(200);
let history = super::build_compacted_history_with_limit(
Vec::new(),
std::slice::from_ref(&big),
"SUMMARY",
max_tokens,
);
assert_eq!(history.len(), 2);
let truncated_message = &history[0];
let summary_message = &history[1];
let truncated_text = match truncated_message {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("unexpected item in history: {other:?}"),
};
assert!(
truncated_text.contains("tokens truncated"),
"expected truncation marker in truncated user message"
);
assert!(
!truncated_text.contains(&big),
"truncated user message should not include the full oversized user text"
);
let summary_text = match summary_message {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("unexpected item in history: {other:?}"),
};
assert_eq!(summary_text, "SUMMARY");
}
#[test]
fn build_token_limited_compacted_history_appends_summary_message() {
let initial_context: Vec<ResponseItem> = Vec::new();
let user_messages = vec!["first user message".to_string()];
let summary_text = "summary text";
let history = build_compacted_history(initial_context, &user_messages, summary_text);
assert!(
!history.is_empty(),
"expected compacted history to include summary"
);
let last = history.last().expect("history should have a summary entry");
let summary = match last {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("expected summary message, found {other:?}"),
};
assert_eq!(summary, summary_text);
}
}