Use TurnContextItem for rollback modes

This commit is contained in:
Charles Cunningham
2026-01-30 16:15:32 -08:00
parent 854f1ba7c4
commit 6c9e06638c
2 changed files with 140 additions and 6 deletions

View File

@@ -1081,6 +1081,14 @@ impl Session {
self.record_into_history(&reconstructed_history, &turn_context)
.await;
}
let user_turns = Self::user_turn_count(&reconstructed_history);
let turn_context_history =
Self::reconstruct_turn_context_history_from_rollout(&rollout_items);
{
let mut state = self.state.lock().await;
state.set_turn_context_history(turn_context_history);
state.reset_turn_context_history(user_turns);
}
// Seed usage info from the recorded rollout so UIs can show token counts
// immediately on resume/fork.
@@ -1102,6 +1110,14 @@ impl Session {
self.record_into_history(&reconstructed_history, &turn_context)
.await;
}
let user_turns = Self::user_turn_count(&reconstructed_history);
let turn_context_history =
Self::reconstruct_turn_context_history_from_rollout(&rollout_items);
{
let mut state = self.state.lock().await;
state.set_turn_context_history(turn_context_history);
state.reset_turn_context_history(user_turns);
}
// Seed usage info from the recorded rollout so UIs can show token counts
// immediately on resume/fork.
@@ -1786,6 +1802,46 @@ impl Session {
history.raw_items().to_vec()
}
fn reconstruct_turn_context_history_from_rollout(
rollout_items: &[RolloutItem],
) -> Vec<Option<TurnContextItem>> {
let mut history = Vec::new();
let mut awaiting_turn_context = false;
for item in rollout_items {
match item {
RolloutItem::ResponseItem(ResponseItem::Message { role, .. }) if role == "user" => {
history.push(None);
awaiting_turn_context = true;
}
RolloutItem::TurnContext(ctx) => {
if awaiting_turn_context {
if let Some(last) = history.last_mut() {
*last = Some(ctx.clone());
} else {
history.push(Some(ctx.clone()));
}
awaiting_turn_context = false;
}
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
let drop = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
let new_len = history.len().saturating_sub(drop);
history.truncate(new_len);
awaiting_turn_context = false;
}
_ => {}
}
}
history
}
fn user_turn_count(items: &[ResponseItem]) -> usize {
items
.iter()
.filter(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user"))
.count()
}
/// Append ResponseItems to the in-memory conversation history only.
pub(crate) async fn record_into_history(
&self,
@@ -1793,6 +1849,7 @@ impl Session {
turn_context: &TurnContext,
) {
let mut state = self.state.lock().await;
state.record_user_turn_placeholders(items);
state.record_items(items.iter(), turn_context.truncation_policy);
}
@@ -1813,8 +1870,10 @@ impl Session {
}
pub(crate) async fn replace_history(&self, items: Vec<ResponseItem>) {
let user_turns = Self::user_turn_count(&items);
let mut state = self.state.lock().await;
state.replace_history(items);
state.reset_turn_context_history(user_turns);
}
pub(crate) async fn seed_initial_context_if_needed(&self, turn_context: &TurnContext) {
@@ -2078,6 +2137,23 @@ impl Session {
// those spans, and `record_response_item_and_emit_turn_item` would drop them.
self.record_conversation_items(turn_context, std::slice::from_ref(&response_item))
.await;
let collaboration_mode = self.current_collaboration_mode().await;
let turn_context_item = TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
personality: turn_context.personality,
collaboration_mode: Some(collaboration_mode),
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
user_instructions: turn_context.user_instructions.clone(),
developer_instructions: turn_context.developer_instructions.clone(),
final_output_json_schema: turn_context.final_output_json_schema.clone(),
truncation_policy: Some(turn_context.truncation_policy.into()),
};
let mut state = self.state.lock().await;
state.set_last_turn_context(turn_context_item);
let turn_item = TurnItem::UserMessage(UserMessageItem::new(input));
self.emit_turn_item_started(turn_context, &turn_item).await;
self.emit_turn_item_completed(turn_context, turn_item).await;
@@ -2985,17 +3061,26 @@ mod handlers {
let mut history = sess.clone_history().await;
history.drop_last_n_user_turns(num_turns);
if let Some(mask) = last_collaboration_mask(history.raw_items()) {
let mut state = sess.state.lock().await;
// Replace with the raw items. We don't want to replace with a normalized
// version of the history.
let user_turns = Self::user_turn_count(history.raw_items());
sess.replace_history(history.raw_items().to_vec()).await;
let mut state = sess.state.lock().await;
let mut applied = false;
if state.turn_context_history.len() == user_turns
&& let Some(turn_context) = state.last_turn_context()
&& let Some(collaboration_mode) = turn_context.collaboration_mode.clone()
{
state.session_configuration.collaboration_mode = collaboration_mode;
applied = true;
}
if !applied && let Some(mask) = last_collaboration_mask(history.raw_items()) {
state.session_configuration.collaboration_mode = state
.session_configuration
.collaboration_mode
.apply_mask(&mask);
}
// Replace with the raw items. We don't want to replace with a normalized
// version of the history.
sess.replace_history(history.raw_items().to_vec()).await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.send_event_raw_flushed(Event {

View File

@@ -1,6 +1,7 @@
//! Session-wide mutable state.
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TurnContextItem;
use std::collections::HashMap;
use std::collections::HashSet;
@@ -15,6 +16,7 @@ use crate::truncate::TruncationPolicy;
pub(crate) struct SessionState {
pub(crate) session_configuration: SessionConfiguration,
pub(crate) history: ContextManager,
pub(crate) turn_context_history: Vec<Option<TurnContextItem>>,
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
pub(crate) server_reasoning_included: bool,
pub(crate) dependency_env: HashMap<String, String>,
@@ -33,6 +35,7 @@ impl SessionState {
Self {
session_configuration,
history,
turn_context_history: Vec::new(),
latest_rate_limits: None,
server_reasoning_included: false,
dependency_env: HashMap::new(),
@@ -50,6 +53,52 @@ impl SessionState {
self.history.record_items(items, policy);
}
pub(crate) fn record_user_turn_placeholders(&mut self, items: &[ResponseItem]) {
for item in items {
if matches!(item, ResponseItem::Message { role, .. } if role == "user") {
self.turn_context_history.push(None);
}
}
}
pub(crate) fn set_last_turn_context(&mut self, turn_context: TurnContextItem) {
if let Some(last) = self.turn_context_history.last_mut()
&& last.is_none()
{
*last = Some(turn_context);
return;
}
self.turn_context_history.push(Some(turn_context));
}
pub(crate) fn reset_turn_context_history(&mut self, user_turn_count: usize) {
let existing_len = self.turn_context_history.len();
if existing_len >= user_turn_count {
let start = existing_len - user_turn_count;
self.turn_context_history = self.turn_context_history.split_off(start);
} else {
let mut new_history = Vec::with_capacity(user_turn_count);
let padding = user_turn_count - existing_len;
new_history.resize_with(padding, || None);
new_history.append(&mut self.turn_context_history);
self.turn_context_history = new_history;
}
}
pub(crate) fn last_turn_context(&self) -> Option<&TurnContextItem> {
self.turn_context_history
.iter()
.rev()
.find_map(Option::as_ref)
}
pub(crate) fn set_turn_context_history(
&mut self,
turn_context_history: Vec<Option<TurnContextItem>>,
) {
self.turn_context_history = turn_context_history;
}
pub(crate) fn clone_history(&self) -> ContextManager {
self.history.clone()
}