Compare commits

...

12 Commits

Author SHA1 Message Date
Ahmed Ibrahim
e7e272b235 Merge branch 'main' into conversation_history 2025-10-16 08:07:56 -07:00
Ahmed Ibrahim
8ab406300c Revert to 4428986e (tests) by reverting subsequent commits 2025-10-16 08:07:14 -07:00
Ahmed Ibrahim
dc1852a0be feedback 2025-10-15 16:39:05 -07:00
Ahmed Ibrahim
79f4124533 feedback 2025-10-15 15:59:34 -07:00
Ahmed Ibrahim
4428986eae tests 2025-10-14 18:07:48 -07:00
Ahmed Ibrahim
6f8d42d7ee tests 2025-10-14 18:07:40 -07:00
Ahmed Ibrahim
6d58b62f18 tests 2025-10-14 18:07:26 -07:00
Ahmed Ibrahim
411a8f125b tests 2025-10-14 18:05:48 -07:00
Ahmed Ibrahim
f5310a3db9 tests 2025-10-14 18:03:34 -07:00
Ahmed Ibrahim
9488400bfb add failing tests 2025-10-14 17:41:24 -07:00
Ahmed Ibrahim
3dae8cdfc6 renaming 2025-10-14 15:42:51 -07:00
Ahmed Ibrahim
5299bd0ea6 refactor 2025-10-14 15:20:12 -07:00
4 changed files with 239 additions and 150 deletions

View File

@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::fmt::Debug;
use std::path::PathBuf;
use std::sync::Arc;
@@ -560,7 +559,8 @@ impl Session {
InitialHistory::New => {
// Build and record initial items (user instructions + environment context)
let items = self.build_initial_context(turn_context);
self.record_conversation_items(&items).await;
self.record_conversation_items(&items, TaskKind::Regular)
.await;
}
InitialHistory::Resumed(_) | InitialHistory::Forked(_) => {
let rollout_items = conversation_history.get_rollout_items();
@@ -570,7 +570,8 @@ impl Session {
let reconstructed_history =
self.reconstruct_history_from_rollout(turn_context, &rollout_items);
if !reconstructed_history.is_empty() {
self.record_into_history(&reconstructed_history).await;
self.record_into_history(&reconstructed_history, TaskKind::Regular)
.await;
}
// If persisting, persist all rollout items as-is (recorder filters)
@@ -697,9 +698,21 @@ impl Session {
/// Records input items: always append to conversation history and
/// persist these response items to rollout.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
self.record_into_history(items).await;
self.persist_rollout_response_items(items).await;
async fn record_conversation_items(&self, items: &[ResponseItem], task_kind: TaskKind) {
match task_kind {
TaskKind::Regular | TaskKind::Compact => {
self.record_into_history(items, task_kind).await;
self.persist_rollout_response_items(items).await;
}
TaskKind::Review => {
self.record_into_history(items, task_kind).await;
}
}
}
async fn clear_review_thread(&self) {
let mut state = self.state.lock().await;
state.clear_review_thread();
}
fn reconstruct_history_from_rollout(
@@ -711,7 +724,7 @@ impl Session {
for item in rollout_items {
match item {
RolloutItem::ResponseItem(response_item) => {
history.record_items(std::iter::once(response_item));
history.record_items(std::iter::once(response_item.clone()), TaskKind::Regular);
}
RolloutItem::Compacted(compacted) => {
let snapshot = history.contents();
@@ -730,9 +743,9 @@ impl Session {
}
/// Append ResponseItems to the in-memory conversation history only.
async fn record_into_history(&self, items: &[ResponseItem]) {
async fn record_into_history(&self, items: &[ResponseItem], task_kind: TaskKind) {
let mut state = self.state.lock().await;
state.record_items(items.iter());
state.record_items(items.iter().cloned(), task_kind);
}
async fn replace_history(&self, items: Vec<ResponseItem>) {
@@ -829,13 +842,9 @@ impl Session {
}
}
/// Record a user input item to conversation history and also persist a
/// corresponding UserMessage EventMsg to rollout.
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
/// persist a corresponding UserMessage EventMsg to rollout.
async fn persist_user_msg_to_rollout(&self, response_input: &ResponseInputItem) {
let response_item: ResponseItem = response_input.clone().into();
// Add to conversation history and persist response item to rollout
self.record_conversation_items(std::slice::from_ref(&response_item))
.await;
// Derive user message events and persist only UserMessage to rollout
let msgs =
@@ -852,6 +861,38 @@ impl Session {
}
}
async fn initialize_task_history(
&self,
response_input: &ResponseInputItem,
task_kind: TaskKind,
initial_context: Vec<ResponseItem>,
) {
match task_kind {
TaskKind::Regular | TaskKind::Compact => {
let response_item: ResponseItem = response_input.clone().into();
self.record_conversation_items(std::slice::from_ref(&response_item), task_kind)
.await;
self.persist_user_msg_to_rollout(response_input).await;
}
TaskKind::Review => {
let mut state = self.state.lock().await;
state.initialize_review_history(response_input, initial_context);
}
}
}
async fn prepare_prompt_input(
&self,
pending_input: Vec<ResponseItem>,
task_kind: TaskKind,
) -> Vec<ResponseItem> {
if !pending_input.is_empty() && matches!(task_kind, TaskKind::Regular | TaskKind::Compact) {
self.persist_rollout_response_items(&pending_input).await;
}
let mut state = self.state.lock().await;
state.prepare_prompt_input(task_kind, pending_input)
}
async fn on_exec_command_begin(
&self,
turn_diff_tracker: SharedTurnDiffTracker,
@@ -1243,13 +1284,16 @@ async fn submission_loop(
// Optionally persist changes to model / effort
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
cwd,
approval_policy,
sandbox_policy,
// Shell is not configurable from turn to turn
None,
))])
sess.record_conversation_items(
&[ResponseItem::from(EnvironmentContext::new(
cwd,
approval_policy,
sandbox_policy,
// Shell is not configurable from turn to turn
None,
))],
TaskKind::Regular,
)
.await;
}
}
@@ -1336,8 +1380,11 @@ async fn submission_loop(
let new_env_context = EnvironmentContext::from(&fresh_turn_context);
if !new_env_context.equals_except_shell(&previous_env_context) {
let env_response_item = ResponseItem::from(new_env_context);
sess.record_conversation_items(std::slice::from_ref(&env_response_item))
.await;
sess.record_conversation_items(
std::slice::from_ref(&env_response_item),
TaskKind::Regular,
)
.await;
for msg in map_response_item_to_event_messages(
&env_response_item,
sess.show_raw_agent_reasoning(),
@@ -1664,19 +1711,9 @@ pub(crate) async fn run_task(
sess.send_event(event).await;
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
// For review threads, keep an isolated in-memory history so the
// model sees a fresh conversation without the parent session's history.
// For normal turns, continue recording to the session history as before.
let is_review_mode = turn_context.is_review_mode;
let mut review_thread_history: Vec<ResponseItem> = Vec::new();
if is_review_mode {
// Seed review threads with environment context so the model knows the working directory.
review_thread_history.extend(sess.build_initial_context(turn_context.as_ref()));
review_thread_history.push(initial_input_for_turn.into());
} else {
sess.record_input_and_rollout_usermsg(&initial_input_for_turn)
.await;
}
let initial_context = sess.build_initial_context(turn_context.as_ref());
sess.initialize_task_history(&initial_input_for_turn, task_kind, initial_context)
.await;
let mut last_agent_message: Option<String> = None;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
@@ -1695,25 +1732,7 @@ pub(crate) async fn run_task(
.map(ResponseItem::from)
.collect::<Vec<ResponseItem>>();
// Construct the input that we will send to the model.
//
// - For review threads, use the isolated in-memory history so the
// model sees a fresh conversation (no parent history/user_instructions).
//
// - For normal turns, use the session's full history. When using the
// chat completions API (or ZDR clients), the model needs the full
// conversation history on each turn. The rollout file, however, should
// only record the new items that originated in this turn so that it
// represents an append-only log without duplicates.
let turn_input: Vec<ResponseItem> = if is_review_mode {
if !pending_input.is_empty() {
review_thread_history.extend(pending_input);
}
review_thread_history.clone()
} else {
sess.record_conversation_items(&pending_input).await;
sess.turn_input_with_history(pending_input).await
};
let turn_input = sess.prepare_prompt_input(pending_input, task_kind).await;
let turn_input_messages: Vec<String> = turn_input
.iter()
@@ -1848,13 +1867,11 @@ pub(crate) async fn run_task(
// Only attempt to take the lock if there is something to record.
if !items_to_record_in_conversation_history.is_empty() {
if is_review_mode {
review_thread_history
.extend(items_to_record_in_conversation_history.clone());
} else {
sess.record_conversation_items(&items_to_record_in_conversation_history)
.await;
}
sess.record_conversation_items(
&items_to_record_in_conversation_history,
task_kind,
)
.await;
}
if token_limit_reached {
@@ -2067,61 +2084,6 @@ async fn try_run_turn(
prompt: &Prompt,
task_kind: TaskKind,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id),
_ => None,
})
.collect::<Vec<_>>();
// call_ids that were pending but are not part of this response.
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id),
_ => None,
})
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
} else {
Some(call_id.clone())
}
})
.map(|call_id| ResponseItem::CustomToolCallOutput {
call_id,
output: "aborted".to_string(),
})
.collect::<Vec<_>>()
};
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
Cow::Borrowed(prompt)
} else {
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
let input = [missing_calls, prompt.input.clone()].concat();
Cow::Owned(Prompt {
input,
..prompt.clone()
})
};
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
@@ -2134,7 +2096,7 @@ async fn try_run_turn(
let mut stream = turn_context
.client
.clone()
.stream_with_task_kind(prompt.as_ref(), task_kind)
.stream_with_task_kind(prompt, task_kind)
.await?;
let tool_runtime = ToolCallRuntime::new(
@@ -2456,12 +2418,16 @@ pub(crate) async fn exit_review_mode(
}
session
.record_conversation_items(&[ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: user_message }],
}])
.record_conversation_items(
&[ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: user_message }],
}],
TaskKind::Regular,
)
.await;
session.clear_review_thread().await;
}
use crate::executor::errors::ExecError;
@@ -3032,7 +2998,7 @@ mod tests {
for item in &initial_context {
rollout_items.push(RolloutItem::ResponseItem(item.clone()));
}
live_history.record_items(initial_context.iter());
live_history.record_items(initial_context.iter().cloned(), TaskKind::Regular);
let user1 = ResponseItem::Message {
id: None,
@@ -3041,8 +3007,8 @@ mod tests {
text: "first user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user1));
rollout_items.push(RolloutItem::ResponseItem(user1.clone()));
live_history.record_items(std::iter::once(user1.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(user1));
let assistant1 = ResponseItem::Message {
id: None,
@@ -3051,8 +3017,8 @@ mod tests {
text: "assistant reply one".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant1));
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
live_history.record_items(std::iter::once(assistant1.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(assistant1));
let summary1 = "summary one";
let snapshot1 = live_history.contents();
@@ -3074,8 +3040,8 @@ mod tests {
text: "second user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user2));
rollout_items.push(RolloutItem::ResponseItem(user2.clone()));
live_history.record_items(std::iter::once(user2.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(user2));
let assistant2 = ResponseItem::Message {
id: None,
@@ -3084,8 +3050,8 @@ mod tests {
text: "assistant reply two".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant2));
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
live_history.record_items(std::iter::once(assistant2.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(assistant2));
let summary2 = "summary two";
let snapshot2 = live_history.contents();
@@ -3107,8 +3073,8 @@ mod tests {
text: "third user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user3));
rollout_items.push(RolloutItem::ResponseItem(user3.clone()));
live_history.record_items(std::iter::once(user3.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(user3));
let assistant3 = ResponseItem::Message {
id: None,
@@ -3117,8 +3083,8 @@ mod tests {
text: "assistant reply three".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant3));
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));
live_history.record_items(std::iter::once(assistant3.clone()), TaskKind::Regular);
rollout_items.push(RolloutItem::ResponseItem(assistant3));
(rollout_items, live_history.contents())
}

View File

@@ -274,7 +274,8 @@ async fn drain_to_completed(
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
sess.record_into_history(std::slice::from_ref(&item)).await;
sess.record_into_history(std::slice::from_ref(&item), TaskKind::Compact)
.await;
}
Ok(ResponseEvent::RateLimits(snapshot)) => {
sess.update_rate_limits(sub_id, snapshot).await;

View File

@@ -1,15 +1,22 @@
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use crate::state::TaskKind;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
pub(crate) struct ConversationHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
review_thread_history: Vec<ResponseItem>,
}
impl ConversationHistory {
pub(crate) fn new() -> Self {
Self { items: Vec::new() }
Self {
items: Vec::new(),
review_thread_history: Vec::new(),
}
}
/// Returns a clone of the contents in the transcript.
@@ -17,24 +24,113 @@ impl ConversationHistory {
self.items.clone()
}
pub(crate) fn review_thread_contents(&self) -> Vec<ResponseItem> {
self.review_thread_history.clone()
}
pub(crate) fn clear_review_thread(&mut self) {
self.review_thread_history.clear();
}
/// `items` is ordered from oldest to newest.
pub(crate) fn record_items<I>(&mut self, items: I)
pub(crate) fn record_items<I>(&mut self, items: I, task_kind: TaskKind)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
I: IntoIterator<Item = ResponseItem>,
{
for item in items {
if !is_api_message(&item) {
continue;
}
self.items.push(item.clone());
match task_kind {
TaskKind::Regular | TaskKind::Compact => {
self.items.push(item);
}
TaskKind::Review => {
self.review_thread_history.push(item);
}
}
}
}
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
self.items = items;
}
pub(crate) fn initialize_review_history(
&mut self,
response_input: &ResponseInputItem,
initial_context: Vec<ResponseItem>,
) {
self.clear_review_thread();
self.record_items(initial_context, TaskKind::Review);
self.record_items(
std::iter::once(ResponseItem::from(response_input.clone())),
TaskKind::Review,
);
}
pub(crate) fn add_pending_input(
&mut self,
pending_input: Vec<ResponseItem>,
task_kind: TaskKind,
) {
self.record_items(pending_input, task_kind);
}
pub(crate) fn handle_missing_tool_call_output(&mut self, task_kind: TaskKind) {
// call_ids that are part of this response.
let content = match task_kind {
TaskKind::Regular => self.contents(),
TaskKind::Review => self.review_thread_contents(),
TaskKind::Compact => self.contents(),
};
let completed_call_ids = content
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id),
_ => None,
})
.collect::<Vec<_>>();
// call_ids that were pending but are not part of this response.
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
content
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id),
_ => None,
})
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
} else {
Some(call_id.clone())
}
})
.map(|call_id| ResponseItem::CustomToolCallOutput {
call_id,
output: "aborted".to_string(),
})
.collect::<Vec<_>>()
};
self.record_items(missing_calls, task_kind);
}
pub(crate) fn prompt(&self, task_kind: TaskKind) -> Vec<ResponseItem> {
match task_kind {
TaskKind::Regular | TaskKind::Compact => self.contents(),
TaskKind::Review => self.review_thread_contents(),
}
}
}
/// Anything that is not a system message or "reasoning" message is considered
@@ -89,12 +185,12 @@ mod tests {
text: "ignored".to_string(),
}],
};
h.record_items([&system, &ResponseItem::Other]);
h.record_items([system, ResponseItem::Other], TaskKind::Regular);
// User and assistant should be retained.
let u = user_msg("hi");
let a = assistant_msg("hello");
h.record_items([&u, &a]);
h.record_items([u, a], TaskKind::Regular);
let items = h.contents();
assert_eq!(
@@ -114,7 +210,7 @@ mod tests {
text: "hello".to_string()
}]
}
]
],
);
}
}

View File

@@ -1,11 +1,13 @@
//! Session-wide mutable state.
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use crate::conversation_history::ConversationHistory;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
use crate::state::TaskKind;
/// Persistent, session-scoped state previously stored directly on `Session`.
#[derive(Default)]
@@ -25,12 +27,11 @@ impl SessionState {
}
// History helpers
pub(crate) fn record_items<I>(&mut self, items: I)
pub(crate) fn record_items<I>(&mut self, items: I, task_kind: TaskKind)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
I: IntoIterator<Item = ResponseItem>,
{
self.history.record_items(items)
self.history.record_items(items, task_kind)
}
pub(crate) fn history_snapshot(&self) -> Vec<ResponseItem> {
@@ -41,6 +42,31 @@ impl SessionState {
self.history.replace(items);
}
pub(crate) fn clear_review_thread(&mut self) {
self.history.clear_review_thread();
}
pub(crate) fn initialize_review_history(
&mut self,
response_input: &ResponseInputItem,
initial_context: Vec<ResponseItem>,
) {
self.history
.initialize_review_history(response_input, initial_context);
}
pub(crate) fn prepare_prompt_input(
&mut self,
task_kind: TaskKind,
pending_input: Vec<ResponseItem>,
) -> Vec<ResponseItem> {
if !pending_input.is_empty() {
self.history.add_pending_input(pending_input, task_kind);
}
self.history.handle_missing_tool_call_output(task_kind);
self.history.prompt(task_kind)
}
// Token/rate limit helpers
pub(crate) fn update_token_info_from_usage(
&mut self,