This commit is contained in:
Ahmed Ibrahim
2025-09-10 08:34:25 -07:00
parent 70dc8db746
commit 35107a121a
5 changed files with 149 additions and 43 deletions

View File

@@ -565,6 +565,11 @@ impl Session {
/// Persist the event to rollout and send it to clients.
pub(crate) async fn send_event(&self, event: Event) {
// Persist the event into event_msgs in memory
self.state
.lock_unchecked()
.event_msgs
.record_items(std::slice::from_ref(&event.msg));
// Persist the event into rollout (recorder filters as needed)
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
self.persist_rollout_items(&rollout_items).await;
@@ -1384,8 +1389,9 @@ async fn submission_loop(
let sub_id = sub.id.clone();
let entries = {
let state = sess.state.lock_unchecked();
let rolled: Vec<RolloutItem> = (&state.response_items).into();
rolled
let rolled_response_items: Vec<RolloutItem> = (&state.response_items).into();
let rolled_event_msgs: Vec<RolloutItem> = (&state.event_msgs).into();
[rolled_response_items, rolled_event_msgs].concat()
};
let event = Event {
id: sub_id.clone(),

View File

@@ -75,9 +75,20 @@ impl EventMsgsHistory {
I::Item: std::ops::Deref<Target = EventMsg>,
{
for item in items {
self.items.push(item.clone());
if self.should_record_item(&item) {
self.items.push(item.clone());
}
}
}
fn should_record_item(&self, item: &EventMsg) -> bool {
!matches!(
item,
EventMsg::AgentMessageDelta(_)
| EventMsg::AgentReasoningDelta(_)
| EventMsg::AgentReasoningRawContentDelta(_)
| EventMsg::ExecCommandOutputDelta(_)
)
}
}
impl From<&ResponseItemsHistory> for Vec<RolloutItem> {

View File

@@ -7,6 +7,7 @@ use crate::codex_conversation::CodexConversation;
use crate::config::Config;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::event_mapping::map_response_item_to_event_messages;
use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::SessionConfiguredEvent;
@@ -150,7 +151,7 @@ impl ConversationManager {
/// caller's `config`). The new conversation will have a fresh id.
pub async fn fork_conversation(
&self,
conversation_history: Vec<ResponseItem>,
conversation_history: InitialHistory,
num_messages_to_drop: usize,
config: Config,
) -> CodexResult<NewConversation> {
@@ -171,38 +172,106 @@ impl ConversationManager {
/// Return a prefix of `items` obtained by dropping the last `n` user messages
/// and all items that follow them.
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> InitialHistory {
// Work from response items for cut logic; preserve any existing rollout items when possible.
let rollout_items: Vec<RolloutItem> = history.get_rollout_items();
let response_items: Vec<ResponseItem> = history.get_response_items();
if n == 0 {
let rolled: Vec<RolloutItem> = items.into_iter().map(RolloutItem::ResponseItem).collect();
return InitialHistory::Forked(rolled);
return history;
}
// Walk backwards counting only `user` Message items, find cut index.
let mut count = 0usize;
let mut cut_index = 0usize;
for (idx, item) in items.iter().enumerate().rev() {
let Some(cut_resp_index) = find_cut_response_index(&response_items, n) else {
return InitialHistory::New;
};
if cut_resp_index == 0 {
return InitialHistory::New;
}
let cut_events_index =
find_matching_user_event_index_in_rollout(&rollout_items, &response_items, cut_resp_index);
let rolled = build_truncated_rollout(rollout_items, cut_resp_index, cut_events_index);
InitialHistory::Forked(rolled)
}
/// Find the index (into response items) of the Nth user message from the end.
fn find_cut_response_index(response_items: &[ResponseItem], n: usize) -> Option<usize> {
if n == 0 {
return None;
}
let mut remaining = n;
for (idx, item) in response_items.iter().enumerate().rev() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
count += 1;
if count == n {
// Cut everything from this user message to the end.
cut_index = idx;
break;
remaining -= 1;
if remaining == 0 {
return Some(idx);
}
}
}
if cut_index == 0 {
// No prefix remains after dropping; start a new conversation.
InitialHistory::New
} else {
let rolled: Vec<RolloutItem> = items
.into_iter()
.take(cut_index)
.map(RolloutItem::ResponseItem)
.collect();
InitialHistory::Forked(rolled)
None
}
/// Derive the user message text (if any) associated with a response item using event mapping.
fn user_message_text_for_response(item: &ResponseItem) -> Option<String> {
let mapped = map_response_item_to_event_messages(item, false);
mapped.into_iter().find_map(|ev| match ev {
EventMsg::UserMessage(u) => Some(u.message),
_ => None,
})
}
/// Given rollout items and the response-item cut index, locate the matching user EventMsg index.
fn find_matching_user_event_index_in_rollout(
rollout_items: &[RolloutItem],
response_items: &[ResponseItem],
cut_resp_index: usize,
) -> Option<usize> {
let target_message = user_message_text_for_response(&response_items[cut_resp_index])?;
rollout_items
.iter()
.enumerate()
.find_map(|(i, it)| match it {
RolloutItem::EventMsg(EventMsg::UserMessage(u)) if u.message == target_message => {
Some(i)
}
_ => None,
})
}
/// Build a truncated rollout keeping response items strictly before `cut_resp_index` and
/// event messages strictly before `event_cut_index` (when provided). Always keeps session meta.
fn build_truncated_rollout(
rollout_items: Vec<RolloutItem>,
cut_resp_index: usize,
event_cut_index: Option<usize>,
) -> Vec<RolloutItem> {
let mut kept_response_seen = 0usize;
let mut rolled: Vec<RolloutItem> = Vec::new();
for (abs_idx, it) in rollout_items.into_iter().enumerate() {
match &it {
RolloutItem::ResponseItem(_) => {
if kept_response_seen < cut_resp_index {
rolled.push(it);
}
kept_response_seen += 1;
}
RolloutItem::EventMsg(_) => {
if let Some(evt_cut) = event_cut_index
&& abs_idx < evt_cut
{
rolled.push(it);
}
}
RolloutItem::SessionMeta(_) => {
rolled.push(it);
}
}
}
rolled
}
#[cfg(test)]
@@ -256,7 +325,13 @@ mod tests {
assistant_msg("a4"),
];
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
// Wrap as InitialHistory::Forked with response items only.
let initial: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
let got_items = truncated.get_rollout_items();
let expected_items = vec![
RolloutItem::ResponseItem(items[0].clone()),
@@ -268,7 +343,12 @@ mod tests {
serde_json::to_value(&expected_items).unwrap()
);
let truncated2 = truncate_after_dropping_last_messages(items, 2);
let initial2: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated2 = truncate_after_dropping_last_messages(InitialHistory::Forked(initial2), 2);
assert!(matches!(truncated2, InitialHistory::New));
}
}

View File

@@ -77,12 +77,13 @@ async fn fork_conversation_twice_drops_to_first_message() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
// Capture entries from the base history and compute expected prefixes after each fork.
let entries_after_three = match &base_history {
let history_after_three = match &base_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
entries.clone()
history.clone()
}
_ => panic!("expected ConversationHistory event"),
};
let entries_after_three = history_after_three.get_rollout_items();
// History layout for this test:
// [0] user instructions,
// [1] environment context,
@@ -113,7 +114,7 @@ async fn fork_conversation_twice_drops_to_first_message() {
conversation: codex_fork1,
..
} = conversation_manager
.fork_conversation(entries_after_three.clone(), 1, config_for_fork.clone())
.fork_conversation(history_after_three.clone(), 1, config_for_fork.clone())
.await
.expect("fork 1");
@@ -122,13 +123,14 @@ async fn fork_conversation_twice_drops_to_first_message() {
matches!(ev, EventMsg::ConversationHistory(_))
})
.await;
let entries_after_first_fork = match &fork1_history {
let history_after_first_fork = match &fork1_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
assert!(matches!(
fork1_history,
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref history, .. }) if *entries == expected_after_first
));
entries.clone()
let got = history.get_rollout_items();
assert_eq!(
serde_json::to_value(&got).unwrap(),
serde_json::to_value(&expected_after_first).unwrap()
);
history.clone()
}
_ => panic!("expected ConversationHistory event after first fork"),
};
@@ -138,7 +140,7 @@ async fn fork_conversation_twice_drops_to_first_message() {
conversation: codex_fork2,
..
} = conversation_manager
.fork_conversation(entries_after_first_fork.clone(), 1, config_for_fork.clone())
.fork_conversation(history_after_first_fork.clone(), 1, config_for_fork.clone())
.await
.expect("fork 2");
@@ -147,8 +149,14 @@ async fn fork_conversation_twice_drops_to_first_message() {
matches!(ev, EventMsg::ConversationHistory(_))
})
.await;
assert!(matches!(
fork2_history,
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref history, .. }) if *entries == expected_after_second
));
match &fork2_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
let got = history.get_rollout_items();
assert_eq!(
serde_json::to_value(&got).unwrap(),
serde_json::to_value(&expected_after_second).unwrap()
);
}
_ => panic!("expected ConversationHistory event after second fork"),
}
}

View File

@@ -3,6 +3,7 @@ use crate::backtrack_helpers;
use crate::pager_overlay::Overlay;
use crate::tui;
use crate::tui::TuiEvent;
use codex_core::InitialHistory;
use codex_core::protocol::ConversationHistoryResponseEvent;
use codex_protocol::mcp_protocol::ConversationId;
use color_eyre::eyre::Result;
@@ -301,7 +302,7 @@ impl App {
/// Thin wrapper around ConversationManager::fork_conversation.
async fn perform_fork(
&self,
entries: Vec<codex_protocol::models::ResponseItem>,
entries: InitialHistory,
drop_count: usize,
cfg: codex_core::config::Config,
) -> codex_core::error::Result<codex_core::NewConversation> {