Compare commits

...

12 Commits

Author SHA1 Message Date
Ahmed Ibrahim
5cccaa5a0e tests 2025-09-10 12:10:51 -07:00
Ahmed Ibrahim
6564b498ad tests 2025-09-10 11:27:10 -07:00
Ahmed Ibrahim
92d0bb41ed progress 2025-09-10 10:27:01 -07:00
Ahmed Ibrahim
daf45dcf17 refactor 2025-09-10 10:22:07 -07:00
Ahmed Ibrahim
de4524e280 refactor 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
bc5499d3f5 refactor 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
f147df539e refactor 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
d2730f2b93 forking 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
7b3d8c83c9 forking 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
2f50987567 forking 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
35107a121a forking 2025-09-10 10:22:06 -07:00
Ahmed Ibrahim
70dc8db746 progress 2025-09-10 10:22:06 -07:00
6 changed files with 305 additions and 68 deletions

View File

@@ -9,6 +9,7 @@ use std::sync::atomic::AtomicU64;
use std::time::Duration;
use crate::AuthManager;
use crate::conversation_history::EventMsgsHistory;
use crate::event_mapping::map_response_item_to_event_messages;
use async_channel::Receiver;
use async_channel::Sender;
@@ -44,7 +45,7 @@ use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
use crate::config_types::ShellEnvironmentPolicy;
use crate::conversation_history::ConversationHistory;
use crate::conversation_history::ResponseItemsHistory;
use crate::environment_context::EnvironmentContext;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
@@ -263,7 +264,8 @@ struct State {
current_task: Option<AgentTask>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
history: ConversationHistory,
response_items: ResponseItemsHistory,
event_msgs: EventMsgsHistory,
token_info: Option<TokenUsageInfo>,
}
@@ -417,7 +419,7 @@ impl Session {
let rollout_path = rollout_recorder.rollout_path.clone();
// Create the mutable state for the Session.
let state = State {
history: ConversationHistory::new(),
response_items: ResponseItemsHistory::new(),
..Default::default()
};
@@ -541,7 +543,7 @@ impl Session {
InitialHistory::New => {
// Build and record initial items (user instructions + environment context)
let items = self.build_initial_context(turn_context);
self.record_conversation_items(&items).await;
self.record_response_items(&items).await;
}
InitialHistory::Resumed(_) | InitialHistory::Forked(_) => {
let rollout_items = conversation_history.get_rollout_items();
@@ -550,7 +552,13 @@ impl Session {
// Always add response items to conversation history
let response_items = conversation_history.get_response_items();
if !response_items.is_empty() {
self.record_into_history(&response_items);
self.record_into_history_response_items(&response_items);
}
// Always add event msgs to conversation history
let event_msgs = conversation_history.get_event_msgs();
if let Some(event_msgs) = event_msgs {
self.record_into_history_event_msgs(&event_msgs);
}
// If persisting, persist all rollout items as-is (recorder filters)
@@ -563,9 +571,9 @@ impl Session {
/// Persist the event to rollout and send it to clients.
pub(crate) async fn send_event(&self, event: Event) {
// Persist the event into rollout (recorder filters as needed)
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
self.persist_rollout_items(&rollout_items).await;
// Persist the event into event_msgs in memory
self.record_conversation_event_msgs(std::slice::from_ref(&event.msg))
.await;
if let Err(e) = self.tx_event.send(event).await {
error!("failed to send tool call event: {e}");
}
@@ -655,18 +663,31 @@ impl Session {
state.approved_commands.insert(cmd);
}
async fn record_conversation_event_msgs(&self, items: &[EventMsg]) {
self.record_into_history_event_msgs(items);
self.persist_rollout_event_msgs(items).await;
}
/// Records input items: always append to conversation history and
/// persist these response items to rollout.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
self.record_into_history(items);
async fn record_response_items(&self, items: &[ResponseItem]) {
self.record_into_history_response_items(items);
self.persist_rollout_response_items(items).await;
}
/// Append ResponseItems to the in-memory conversation history only.
fn record_into_history(&self, items: &[ResponseItem]) {
fn record_into_history_response_items(&self, items: &[ResponseItem]) {
self.state
.lock_unchecked()
.history
.response_items
.record_items(items.iter());
}
/// Append EventMsgs to the in-memory conversation history only.
fn record_into_history_event_msgs(&self, items: &[EventMsg]) {
self.state
.lock_unchecked()
.event_msgs
.record_items(items.iter());
}
@@ -679,6 +700,12 @@ impl Session {
self.persist_rollout_items(&rollout_items).await;
}
async fn persist_rollout_event_msgs(&self, items: &[EventMsg]) {
let rollout_items: Vec<RolloutItem> =
items.iter().cloned().map(RolloutItem::EventMsg).collect();
self.persist_rollout_items(&rollout_items).await;
}
fn build_initial_context(&self, turn_context: &TurnContext) -> Vec<ResponseItem> {
let mut items = Vec::<ResponseItem>::with_capacity(2);
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
@@ -710,13 +737,14 @@ impl Session {
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
let response_item: ResponseItem = response_input.clone().into();
// Add to conversation history and persist response item to rollout
self.record_conversation_items(std::slice::from_ref(&response_item))
self.record_response_items(std::slice::from_ref(&response_item))
.await;
// Derive user message events and persist only UserMessage to rollout
let msgs =
map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning);
let user_msgs: Vec<RolloutItem> = msgs
.clone()
.into_iter()
.filter_map(|m| match m {
EventMsg::UserMessage(ev) => Some(RolloutItem::EventMsg(EventMsg::UserMessage(ev))),
@@ -726,6 +754,7 @@ impl Session {
if !user_msgs.is_empty() {
self.persist_rollout_items(&user_msgs).await;
}
self.state.lock_unchecked().event_msgs.record_items(&msgs);
}
async fn on_exec_command_begin(
@@ -908,7 +937,7 @@ impl Session {
/// Build the full turn input by concatenating the current conversation
/// history with additional items for this turn.
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
[self.state.lock_unchecked().history.contents(), extra].concat()
[self.state.lock_unchecked().response_items.contents(), extra].concat()
}
/// Returns the input if there was no task running to inject into
@@ -1163,7 +1192,7 @@ async fn submission_loop(
// Install the new persistent context for subsequent tasks/turns.
turn_context = Arc::new(new_turn_context);
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
sess.record_response_items(&[ResponseItem::from(EnvironmentContext::new(
cwd,
approval_policy,
sandbox_policy,
@@ -1380,12 +1409,17 @@ async fn submission_loop(
}
Op::GetHistory => {
let sub_id = sub.id.clone();
let entries = {
let state = sess.state.lock_unchecked();
let rolled_response_items: Vec<RolloutItem> = (&state.response_items).into();
let rolled_event_msgs: Vec<RolloutItem> = (&state.event_msgs).into();
[rolled_response_items, rolled_event_msgs].concat()
};
let event = Event {
id: sub_id.clone(),
msg: EventMsg::ConversationHistory(ConversationHistoryResponseEvent {
conversation_id: sess.conversation_id,
entries: sess.state.lock_unchecked().history.contents(),
history: InitialHistory::Forked(entries),
}),
};
sess.send_event(event).await;
@@ -1446,7 +1480,7 @@ async fn run_task(
.into_iter()
.map(ResponseItem::from)
.collect::<Vec<ResponseItem>>();
sess.record_conversation_items(&pending_input).await;
sess.record_response_items(&pending_input).await;
// Construct the input that we will send to the model. When using the
// Chat completions API (or ZDR clients), the model needs the full
@@ -1573,7 +1607,7 @@ async fn run_task(
// Only attempt to take the lock if there is something to record.
if !items_to_record_in_conversation_history.is_empty() {
sess.record_conversation_items(&items_to_record_in_conversation_history)
sess.record_response_items(&items_to_record_in_conversation_history)
.await;
}
@@ -1928,7 +1962,7 @@ async fn run_compact_task(
{
let mut state = sess.state.lock_unchecked();
state.history.keep_last_messages(1);
state.response_items.keep_last_messages(1);
}
let event = Event {
@@ -2882,7 +2916,9 @@ async fn drain_to_completed(
Ok(ResponseEvent::OutputItemDone(item)) => {
// Record only to in-memory conversation history; avoid state snapshot.
let mut state = sess.state.lock_unchecked();
state.history.record_items(std::slice::from_ref(&item));
state
.response_items
.record_items(std::slice::from_ref(&item));
}
Ok(ResponseEvent::Completed {
response_id: _,

View File

@@ -1,13 +1,16 @@
use crate::rollout::policy::should_persist_event_msg;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
pub(crate) struct ConversationHistory {
pub(crate) struct ResponseItemsHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
}
impl ConversationHistory {
impl ResponseItemsHistory {
pub(crate) fn new() -> Self {
Self { items: Vec::new() }
}
@@ -61,6 +64,51 @@ impl ConversationHistory {
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct EventMsgsHistory {
items: Vec<EventMsg>,
}
impl EventMsgsHistory {
pub(crate) fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = EventMsg>,
{
for item in items {
if self.should_record_item(&item) {
self.items.push(item.clone());
}
}
}
fn should_record_item(&self, item: &EventMsg) -> bool {
should_persist_event_msg(item)
}
}
impl From<&ResponseItemsHistory> for Vec<RolloutItem> {
fn from(history: &ResponseItemsHistory) -> Self {
history
.items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect()
}
}
impl From<&EventMsgsHistory> for Vec<RolloutItem> {
fn from(history: &EventMsgsHistory) -> Self {
history
.items
.iter()
.cloned()
.map(RolloutItem::EventMsg)
.collect()
}
}
/// Anything that is not a system message or "reasoning" message is considered
/// an API message.
fn is_api_message(message: &ResponseItem) -> bool {
@@ -103,7 +151,7 @@ mod tests {
#[test]
fn filters_non_api_messages() {
let mut h = ConversationHistory::default();
let mut h = ResponseItemsHistory::default();
// System message is not an API message; Other is ignored.
let system = ResponseItem::Message {
id: None,

View File

@@ -150,7 +150,7 @@ impl ConversationManager {
/// caller's `config`). The new conversation will have a fresh id.
pub async fn fork_conversation(
&self,
conversation_history: Vec<ResponseItem>,
conversation_history: InitialHistory,
num_messages_to_drop: usize,
config: Config,
) -> CodexResult<NewConversation> {
@@ -171,43 +171,117 @@ impl ConversationManager {
/// Return a prefix of `items` obtained by dropping the last `n` user messages
/// and all items that follow them.
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> InitialHistory {
if n == 0 {
let rolled: Vec<RolloutItem> = items.into_iter().map(RolloutItem::ResponseItem).collect();
return InitialHistory::Forked(rolled);
return history;
}
// Walk backwards counting only `user` Message items, find cut index.
let mut count = 0usize;
let mut cut_index = 0usize;
for (idx, item) in items.iter().enumerate().rev() {
// Compute event prefix by dropping the last `n` user events (counted from the end).
let event_msgs_prefix: Vec<EventMsg> =
build_event_prefix_excluding_last_n_user_turns(&history, n);
// Keep only response items strictly before the cut (drop last `n` user messages).
let response_prefix: Vec<ResponseItem> =
build_response_prefix_excluding_last_n_user_turns(&history, n);
let rolled = build_truncated_rollout(&event_msgs_prefix, &response_prefix);
if rolled.is_empty() {
InitialHistory::New
} else {
InitialHistory::Forked(rolled)
}
}
/// Build the event messages prefix from `history` by dropping the last `n` user
/// turns (counted from the end) and taking everything before that cut.
fn build_event_prefix_excluding_last_n_user_turns(
history: &InitialHistory,
n: usize,
) -> Vec<EventMsg> {
match history.get_event_msgs() {
Some(all_events) => {
take_prefix_before_index(&all_events, find_cut_event_index(&all_events, n))
}
None => Vec::new(),
}
}
/// Build the response items prefix from `history` by dropping the last `n` user
/// turns (counted from the end) and taking everything before that cut.
fn build_response_prefix_excluding_last_n_user_turns(
history: &InitialHistory,
n: usize,
) -> Vec<ResponseItem> {
let all_items: Vec<ResponseItem> = history.get_response_items();
take_prefix_before_index(&all_items, find_cut_response_index(&all_items, n))
}
/// Return a cloned prefix of `items` up to (but not including) `idx`.
/// If `idx` is `None`, returns an empty vector.
fn take_prefix_before_index<T: Clone>(items: &[T], idx: Option<usize>) -> Vec<T> {
match idx {
Some(i) => items[..i].to_vec(),
None => Vec::new(),
}
}
/// Find the index (into response items) of the Nth user message from the end.
fn find_cut_response_index(response_items: &[ResponseItem], n: usize) -> Option<usize> {
if n == 0 {
return None;
}
let mut remaining = n;
for (idx, item) in response_items.iter().enumerate().rev() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
count += 1;
if count == n {
// Cut everything from this user message to the end.
cut_index = idx;
break;
remaining -= 1;
if remaining == 0 {
return Some(idx);
}
}
}
if cut_index == 0 {
// No prefix remains after dropping; start a new conversation.
InitialHistory::New
} else {
let rolled: Vec<RolloutItem> = items
.into_iter()
.take(cut_index)
.map(RolloutItem::ResponseItem)
.collect();
InitialHistory::Forked(rolled)
None
}
/// Find the index (into event messages) of the Nth user event from the end.
fn find_cut_event_index(event_msgs: &[EventMsg], n: usize) -> Option<usize> {
if n == 0 {
return None;
}
let mut remaining = n;
for (idx, ev) in event_msgs.iter().enumerate().rev() {
if matches!(ev, EventMsg::UserMessage(_)) {
remaining -= 1;
if remaining == 0 {
return Some(idx);
}
}
}
None
}
/// Build a truncated rollout by concatenating the (already-sliced) event messages and response items.
fn build_truncated_rollout(
event_msgs: &[EventMsg],
response_items: &[ResponseItem],
) -> Vec<RolloutItem> {
let mut rolled: Vec<RolloutItem> = Vec::with_capacity(event_msgs.len() + response_items.len());
rolled.extend(event_msgs.iter().cloned().map(RolloutItem::EventMsg));
rolled.extend(
response_items
.iter()
.cloned()
.map(RolloutItem::ResponseItem),
);
rolled
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event_mapping::map_response_item_to_event_messages;
use crate::protocol::EventMsg;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::models::ResponseItem;
@@ -221,6 +295,15 @@ mod tests {
}],
}
}
fn user_input(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: text.to_string(),
}],
}
}
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
@@ -256,7 +339,13 @@ mod tests {
assistant_msg("a4"),
];
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
// Wrap as InitialHistory::Forked with response items only.
let initial: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
let got_items = truncated.get_rollout_items();
let expected_items = vec![
RolloutItem::ResponseItem(items[0].clone()),
@@ -268,7 +357,62 @@ mod tests {
serde_json::to_value(&expected_items).unwrap()
);
let truncated2 = truncate_after_dropping_last_messages(items, 2);
let initial2: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated2 = truncate_after_dropping_last_messages(InitialHistory::Forked(initial2), 2);
assert!(matches!(truncated2, InitialHistory::New));
}
#[test]
fn event_prefix_counts_from_end_with_duplicate_user_prompts() {
// Two identical user prompts with assistant replies between them.
let responses = vec![
user_input("same"),
assistant_msg("a1"),
user_input("same"),
assistant_msg("a2"),
];
// Derive event messages in order from responses (user → UserMessage, assistant → AgentMessage).
let mut events: Vec<EventMsg> = Vec::new();
for r in &responses {
events.extend(map_response_item_to_event_messages(r, false));
}
// Build initial history containing both events and responses.
let mut initial: Vec<RolloutItem> = Vec::new();
initial.extend(events.iter().cloned().map(RolloutItem::EventMsg));
initial.extend(responses.iter().cloned().map(RolloutItem::ResponseItem));
// Drop the last user turn.
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
// Expect the event prefix to include the first user + first assistant only,
// and the response prefix to include the first user + first assistant only.
let got_items = truncated.get_rollout_items();
// Compute expected events and responses after cut.
let expected_event_prefix: Vec<RolloutItem> = events[..2]
.iter()
.cloned()
.map(RolloutItem::EventMsg)
.collect();
let expected_response_prefix: Vec<RolloutItem> = responses[..2]
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let mut expected: Vec<RolloutItem> = Vec::new();
expected.extend(expected_event_prefix);
expected.extend(expected_response_prefix);
assert_eq!(
serde_json::to_value(&got_items).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
}

View File

@@ -77,12 +77,13 @@ async fn fork_conversation_twice_drops_to_first_message() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
// Capture entries from the base history and compute expected prefixes after each fork.
let entries_after_three = match &base_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
entries.clone()
let history_after_three = match &base_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
history.clone()
}
_ => panic!("expected ConversationHistory event"),
};
let entries_after_three = history_after_three.get_rollout_items();
// History layout for this test:
// [0] user instructions,
// [1] environment context,
@@ -113,7 +114,7 @@ async fn fork_conversation_twice_drops_to_first_message() {
conversation: codex_fork1,
..
} = conversation_manager
.fork_conversation(entries_after_three.clone(), 1, config_for_fork.clone())
.fork_conversation(history_after_three.clone(), 1, config_for_fork.clone())
.await
.expect("fork 1");
@@ -122,13 +123,14 @@ async fn fork_conversation_twice_drops_to_first_message() {
matches!(ev, EventMsg::ConversationHistory(_))
})
.await;
let entries_after_first_fork = match &fork1_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
assert!(matches!(
fork1_history,
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_first
));
entries.clone()
let history_after_first_fork = match &fork1_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
let got = history.get_rollout_items();
assert_eq!(
serde_json::to_value(&got).unwrap(),
serde_json::to_value(&expected_after_first).unwrap()
);
history.clone()
}
_ => panic!("expected ConversationHistory event after first fork"),
};
@@ -138,7 +140,7 @@ async fn fork_conversation_twice_drops_to_first_message() {
conversation: codex_fork2,
..
} = conversation_manager
.fork_conversation(entries_after_first_fork.clone(), 1, config_for_fork.clone())
.fork_conversation(history_after_first_fork.clone(), 1, config_for_fork.clone())
.await
.expect("fork 2");
@@ -147,8 +149,14 @@ async fn fork_conversation_twice_drops_to_first_message() {
matches!(ev, EventMsg::ConversationHistory(_))
})
.await;
assert!(matches!(
fork2_history,
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_second
));
match &fork2_history {
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
let got = history.get_rollout_items();
assert_eq!(
serde_json::to_value(&got).unwrap(),
serde_json::to_value(&expected_after_second).unwrap()
);
}
_ => panic!("expected ConversationHistory event after second fork"),
}
}

View File

@@ -803,7 +803,7 @@ pub struct WebSearchEndEvent {
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
pub struct ConversationHistoryResponseEvent {
pub conversation_id: ConversationId,
pub entries: Vec<ResponseItem>,
pub history: InitialHistory,
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]

View File

@@ -3,6 +3,7 @@ use crate::backtrack_helpers;
use crate::pager_overlay::Overlay;
use crate::tui;
use crate::tui::TuiEvent;
use codex_core::InitialHistory;
use codex_core::protocol::ConversationHistoryResponseEvent;
use codex_protocol::mcp_protocol::ConversationId;
use color_eyre::eyre::Result;
@@ -288,7 +289,7 @@ impl App {
let cfg = self.chat_widget.config_ref().clone();
// Perform the fork via a thin wrapper for clarity/testability.
let result = self
.perform_fork(ev.entries.clone(), drop_count, cfg.clone())
.perform_fork(ev.history.clone(), drop_count, cfg.clone())
.await;
match result {
Ok(new_conv) => {
@@ -301,7 +302,7 @@ impl App {
/// Thin wrapper around ConversationManager::fork_conversation.
async fn perform_fork(
&self,
entries: Vec<codex_protocol::models::ResponseItem>,
entries: InitialHistory,
drop_count: usize,
cfg: codex_core::config::Config,
) -> codex_core::error::Result<codex_core::NewConversation> {