diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index 646528cf97..11d71ad9dd 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -2511,7 +2511,22 @@ impl Session { turn_context: &TurnContext, items: &[ResponseItem], ) { - self.record_into_history(items, turn_context).await; + self.record_conversation_items_with_history_policy( + turn_context, + items, + turn_context.truncation_policy, + ) + .await; + } + + pub(crate) async fn record_conversation_items_with_history_policy( + &self, + turn_context: &TurnContext, + items: &[ResponseItem], + history_truncation_policy: TruncationPolicy, + ) { + self.record_into_history_with_policy(items, history_truncation_policy) + .await; self.persist_rollout_response_items(items).await; self.send_raw_response_items(turn_context, items).await; } @@ -2521,9 +2536,18 @@ impl Session { &self, items: &[ResponseItem], turn_context: &TurnContext, + ) { + self.record_into_history_with_policy(items, turn_context.truncation_policy) + .await; + } + + pub(crate) async fn record_into_history_with_policy( + &self, + items: &[ResponseItem], + history_truncation_policy: TruncationPolicy, ) { let mut state = self.state.lock().await; - state.record_items(items.iter(), turn_context.truncation_policy); + state.record_items(items.iter(), history_truncation_policy); } async fn maybe_warn_on_server_model_mismatch( diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 4ac511f4fa..3a8c31bf28 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -51,6 +51,7 @@ use crate::stream_events_utils::record_completed_response_item_with_finalized_fa use crate::tools::ToolRouter; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::parallel::ToolCallRuntime; +use crate::tools::registry::RecordedToolResponse; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::router::ToolRouterParams; use crate::tools::router::extension_tool_executors; @@ -81,7 +82,6 @@ use codex_protocol::items::build_hook_prompt_message; use codex_protocol::models::BaseInstructions; use codex_protocol::models::ContentItem; use codex_protocol::models::MessagePhase; -use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::AgentMessageContentDeltaEvent; use codex_protocol::protocol::AgentReasoningSectionBreakEvent; @@ -1676,16 +1676,22 @@ async fn handle_assistant_item_done_in_plan_mode( } async fn drain_in_flight( - in_flight: &mut FuturesOrdered>>, + in_flight: &mut FuturesOrdered>>, sess: Arc, turn_context: Arc, ) -> CodexResult<()> { while let Some(res) = in_flight.next().await { match res { - Ok(response_input) => { - let response_item = response_input.into(); - sess.record_conversation_items(&turn_context, std::slice::from_ref(&response_item)) - .await; + Ok(recorded_tool_response) => { + let response_item = recorded_tool_response.response_item.into(); + sess.record_conversation_items_with_history_policy( + &turn_context, + std::slice::from_ref(&response_item), + recorded_tool_response + .history_truncation_policy + .unwrap_or(turn_context.truncation_policy), + ) + .await; mark_thread_memory_mode_polluted_if_external_context( sess.as_ref(), turn_context.as_ref(), @@ -1747,7 +1753,7 @@ async fn try_run_sampling_request( .instrument(trace_span!("stream_request")) .or_cancel(&cancellation_token) .await??; - let mut in_flight: FuturesOrdered>> = + let mut in_flight: FuturesOrdered>> = FuturesOrdered::new(); let mut needs_follow_up = false; let mut last_agent_message: Option = None; diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 0e32ccb9fb..24c60b81d0 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -242,7 +242,7 @@ async fn record_stage1_output_usage_for_memory_citation( /// queuing any tool execution futures. This records items immediately so /// history and rollout stay in sync even if the turn is later cancelled. pub(crate) type InFlightFuture<'f> = - Pin> + Send + 'f>>; + Pin> + Send + 'f>>; #[derive(Default)] pub(crate) struct OutputItemResult { diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs index e101e1a800..76ead2fb2f 100644 --- a/codex-rs/core/src/tools/code_mode/execute_handler.rs +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -10,6 +10,7 @@ use codex_tools::ToolSpec; use super::ExecContext; use super::PUBLIC_TOOL_NAME; +use super::code_mode_output_truncation_policy; use super::handle_runtime_response; use super::is_exec_tool_name; @@ -127,4 +128,15 @@ impl CoreToolRuntime for CodeModeExecuteHandler { fn matches_kind(&self, payload: &ToolPayload) -> bool { matches!(payload, ToolPayload::Custom { .. }) } + + fn history_truncation_policy( + &self, + invocation: &ToolInvocation, + ) -> Option { + let ToolPayload::Custom { input } = &invocation.payload else { + return None; + }; + let args = codex_code_mode::parse_exec_source(input).ok()?; + Some(code_mode_output_truncation_policy(args.max_output_tokens)) + } } diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs index ff9f8c8893..97f9b003a4 100644 --- a/codex-rs/core/src/tools/code_mode/mod.rs +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -243,8 +243,7 @@ fn truncate_code_mode_result( items: Vec, max_output_tokens: Option, ) -> Vec { - let max_output_tokens = resolve_max_tokens(max_output_tokens); - let policy = TruncationPolicy::Tokens(max_output_tokens); + let policy = code_mode_output_truncation_policy(max_output_tokens); if items .iter() .all(|item| matches!(item, FunctionCallOutputContentItem::InputText { .. })) @@ -257,6 +256,12 @@ fn truncate_code_mode_result( truncate_function_output_items_with_policy(&items, policy) } +pub(super) fn code_mode_output_truncation_policy( + max_output_tokens: Option, +) -> TruncationPolicy { + TruncationPolicy::Tokens(resolve_max_tokens(max_output_tokens)) +} + async fn call_nested_tool( _exec: ExecContext, tool_runtime: ToolCallRuntime, diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index 725535339f..e05fd1d4c4 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -12,6 +12,7 @@ use codex_tools::ToolSpec; use super::DEFAULT_WAIT_YIELD_TIME_MS; use super::ExecContext; use super::WAIT_TOOL_NAME; +use super::code_mode_output_truncation_policy; use super::handle_runtime_response; use super::wait_spec::create_wait_tool; @@ -110,4 +111,15 @@ impl ToolExecutor for CodeModeWaitHandler { } } -impl CoreToolRuntime for CodeModeWaitHandler {} +impl CoreToolRuntime for CodeModeWaitHandler { + fn history_truncation_policy( + &self, + invocation: &ToolInvocation, + ) -> Option { + let ToolPayload::Function { arguments } = &invocation.payload else { + return None; + }; + let args: ExecWaitArgs = parse_arguments(arguments).ok()?; + Some(code_mode_output_truncation_policy(args.max_tokens)) + } +} diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 15954869e2..55bd9adba7 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -20,6 +20,7 @@ use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; use crate::tools::lifecycle::notify_tool_aborted; use crate::tools::registry::AnyToolResult; +use crate::tools::registry::RecordedToolResponse; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::router::ToolCall; use crate::tools::router::ToolCallSource; @@ -64,13 +65,13 @@ impl ToolCallRuntime { self, call: ToolCall, cancellation_token: CancellationToken, - ) -> impl std::future::Future> { + ) -> impl std::future::Future> { let error_call = call.clone(); let future = self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token); async move { match future.await { - Ok(response) => Ok(response.into_response()), + Ok(response) => Ok(response.into_recorded_response()), Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), Err(other) => Ok(Self::failure_response(error_call, other)), } @@ -170,9 +171,9 @@ impl ToolCallRuntime { FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) } - fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem { + fn failure_response(call: ToolCall, err: FunctionCallError) -> RecordedToolResponse { let message = err.to_string(); - match call.payload { + let response_item = match call.payload { ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput { call_id: call.call_id, status: "completed".to_string(), @@ -194,6 +195,10 @@ impl ToolCallRuntime { success: Some(false), }, }, + }; + RecordedToolResponse { + response_item, + history_truncation_policy: None, } } @@ -205,6 +210,7 @@ impl ToolCallRuntime { message: Self::abort_message(call, secs), }), post_tool_use_payload: None, + history_truncation_policy: None, } } @@ -353,7 +359,7 @@ mod tests { success: Some(true), }, }; - assert_eq!(expected_response, response); + assert_eq!(expected_response, response.response_item); let actual = records .lock() diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 96837497f8..cbb5bab167 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -30,6 +30,7 @@ use codex_protocol::models::ResponseInputItem; use codex_protocol::protocol::EventMsg; use codex_tools::ToolName; use codex_tools::ToolSpec; +use codex_utils_output_truncation::TruncationPolicy; use futures::future::BoxFuture; use serde_json::Value; use tracing::warn; @@ -70,6 +71,10 @@ pub(crate) trait CoreToolRuntime: ToolExecutor { None } + fn history_truncation_policy(&self, _invocation: &ToolInvocation) -> Option { + None + } + fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option { None } @@ -112,9 +117,16 @@ pub(crate) struct AnyToolResult { pub(crate) payload: ToolPayload, pub(crate) result: Box, pub(crate) post_tool_use_payload: Option, + pub(crate) history_truncation_policy: Option, +} + +pub(crate) struct RecordedToolResponse { + pub(crate) response_item: ResponseInputItem, + pub(crate) history_truncation_policy: Option, } impl AnyToolResult { + #[cfg(test)] pub(crate) fn into_response(self) -> ResponseInputItem { let Self { call_id, @@ -125,6 +137,20 @@ impl AnyToolResult { result.to_response_item(&call_id, &payload) } + pub(crate) fn into_recorded_response(self) -> RecordedToolResponse { + let Self { + call_id, + payload, + result, + history_truncation_policy, + .. + } = self; + RecordedToolResponse { + response_item: result.to_response_item(&call_id, &payload), + history_truncation_policy, + } + } + pub(crate) fn code_mode_result(self) -> serde_json::Value { let Self { payload, result, .. @@ -225,6 +251,10 @@ impl CoreToolRuntime for ExposureOverride { self.handler.post_tool_use_payload(invocation, result) } + fn history_truncation_policy(&self, invocation: &ToolInvocation) -> Option { + self.handler.history_truncation_policy(invocation) + } + fn with_updated_hook_input( &self, invocation: ToolInvocation, @@ -610,11 +640,13 @@ async fn handle_any_tool( let output = tool.handle(invocation.clone()).await?; let post_tool_use_payload = CoreToolRuntime::post_tool_use_payload(tool, &invocation, output.as_ref()); + let history_truncation_policy = CoreToolRuntime::history_truncation_policy(tool, &invocation); Ok(AnyToolResult { call_id, payload, result: output, post_tool_use_payload, + history_truncation_policy, }) }