diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 36e4e866d9..a53a1e0fa1 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -6944,6 +6944,7 @@ async fn handle_output_item_done_records_image_save_history_message() { let mut ctx = HandleOutputCtx { sess: Arc::clone(&session), turn_context: Arc::clone(&turn_context), + turn_store: Arc::new(codex_extension_api::ExtensionData::new()), tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)), cancellation_token: CancellationToken::new(), }; @@ -6996,6 +6997,7 @@ async fn handle_output_item_done_skips_image_save_message_when_save_fails() { let mut ctx = HandleOutputCtx { sess: Arc::clone(&session), turn_context: Arc::clone(&turn_context), + turn_store: Arc::new(codex_extension_api::ExtensionData::new()), tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)), cancellation_token: CancellationToken::new(), }; @@ -8883,6 +8885,7 @@ async fn tool_calls_reopen_mailbox_delivery_for_current_turn() { let mut ctx = HandleOutputCtx { sess: Arc::clone(&sess), turn_context: Arc::clone(&tc), + turn_store: Arc::new(codex_extension_api::ExtensionData::new()), tool_runtime: test_tool_runtime(Arc::clone(&sess), Arc::clone(&tc)), cancellation_token: CancellationToken::new(), }; diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index d910e03b21..86710e4871 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -44,12 +44,14 @@ use crate::session::PreviousTurnSettings; use crate::session::session::Session; use crate::session::turn_context::TurnContext; use crate::stream_events_utils::HandleOutputCtx; +use crate::stream_events_utils::TurnItemContributorPolicy; +use crate::stream_events_utils::finalize_non_tool_response_item; use crate::stream_events_utils::handle_non_tool_response_item; use crate::stream_events_utils::handle_output_item_done; use crate::stream_events_utils::last_assistant_message_from_item; use crate::stream_events_utils::mark_thread_memory_mode_polluted_if_external_context; use crate::stream_events_utils::raw_assistant_output_text_from_item; -use crate::stream_events_utils::record_completed_response_item; +use crate::stream_events_utils::record_completed_response_item_with_finalized_facts; use crate::tools::ToolRouter; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::parallel::ToolCallRuntime; @@ -139,6 +141,7 @@ use tracing::warn; pub(crate) async fn run_turn( sess: Arc, turn_context: Arc, + turn_extension_data: Arc, input: Vec, prewarmed_client_session: Option, cancellation_token: CancellationToken, @@ -454,6 +457,7 @@ pub(crate) async fn run_turn( match run_sampling_request( Arc::clone(&sess), Arc::clone(&turn_context), + Arc::clone(&turn_extension_data), Arc::clone(&turn_diff_tracker), &mut client_session, turn_metadata_header.as_deref(), @@ -1009,6 +1013,7 @@ pub(crate) fn build_prompt( async fn run_sampling_request( sess: Arc, turn_context: Arc, + turn_store: Arc, turn_diff_tracker: SharedTurnDiffTracker, client_session: &mut ModelClientSession, turn_metadata_header: Option<&str>, @@ -1065,6 +1070,7 @@ async fn run_sampling_request( tool_runtime.clone(), Arc::clone(&sess), Arc::clone(&turn_context), + Arc::clone(&turn_store), client_session, turn_metadata_header, Arc::clone(&turn_diff_tracker), @@ -1747,6 +1753,7 @@ async fn emit_turn_item_in_plan_mode( async fn handle_assistant_item_done_in_plan_mode( sess: &Session, turn_context: &TurnContext, + turn_store: &codex_extension_api::ExtensionData, item: &ResponseItem, state: &mut PlanModeStreamState, previously_active_item: Option<&TurnItem>, @@ -1757,21 +1764,38 @@ async fn handle_assistant_item_done_in_plan_mode( { maybe_complete_plan_item_from_message(sess, turn_context, state, item).await; - if let Some(turn_item) = - handle_non_tool_response_item(sess, turn_context, item, /*plan_mode*/ true).await + let mut finalized_facts = None; + if let Some(finalized_turn_item) = finalize_non_tool_response_item( + sess, + turn_context, + TurnItemContributorPolicy::Run(turn_store), + item, + /*plan_mode*/ true, + ) + .await { + finalized_facts = Some(finalized_turn_item.facts.clone()); emit_turn_item_in_plan_mode( sess, turn_context, - turn_item, + finalized_turn_item.turn_item, previously_active_item, state, ) .await; } + let final_last_agent_message = finalized_facts + .as_ref() + .and_then(|facts| facts.last_agent_message.clone()); - record_completed_response_item(sess, turn_context, item).await; - if let Some(agent_message) = last_assistant_message_from_item(item, /*plan_mode*/ true) { + record_completed_response_item_with_finalized_facts( + sess, + turn_context, + item, + finalized_facts.as_ref(), + ) + .await; + if let Some(agent_message) = final_last_agent_message { *last_agent_message = Some(agent_message); } return true; @@ -1817,6 +1841,7 @@ async fn try_run_sampling_request( tool_runtime: ToolCallRuntime, sess: Arc, turn_context: Arc, + turn_store: Arc, client_session: &mut ModelClientSession, turn_metadata_header: Option<&str>, turn_diff_tracker: SharedTurnDiffTracker, @@ -1865,6 +1890,9 @@ async fn try_run_sampling_request( let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan; let mut assistant_message_stream_parsers = AssistantMessageStreamParsers::new(plan_mode); let mut plan_mode_state = plan_mode.then(|| PlanModeStreamState::new(&turn_context.sub_id)); + let defer_streamed_turn_items_for_contributors = + !sess.services.extensions.turn_item_contributors().is_empty(); + let mut active_item_is_streaming_to_client = false; let receiving_span = trace_span!("receiving_stream"); let mut completed_response_id: Option = None; let outcome: CodexResult = loop { @@ -1917,7 +1945,13 @@ async fn try_run_sampling_request( sess.send_event(&turn_context, event).await; } let previously_active_item = active_item.take(); - if let Some(previous) = previously_active_item.as_ref() + let previously_streamed_item = if active_item_is_streaming_to_client { + previously_active_item + } else { + None + }; + active_item_is_streaming_to_client = false; + if let Some(previous) = previously_streamed_item.as_ref() && matches!(previous, TurnItem::AgentMessage(_)) { let item_id = previous.id(); @@ -1934,9 +1968,10 @@ async fn try_run_sampling_request( && handle_assistant_item_done_in_plan_mode( &sess, &turn_context, + turn_store.as_ref(), &item, state, - previously_active_item.as_ref(), + previously_streamed_item.as_ref(), &mut last_agent_message, ) .await @@ -1947,6 +1982,7 @@ async fn try_run_sampling_request( let mut ctx = HandleOutputCtx { sess: sess.clone(), turn_context: turn_context.clone(), + turn_store: Arc::clone(&turn_store), tool_runtime: tool_runtime.clone(), cancellation_token: cancellation_token.child_token(), }; @@ -1971,7 +2007,7 @@ async fn try_run_sampling_request( }; let output_result = - match handle_output_item_done(&mut ctx, item, previously_active_item) + match handle_output_item_done(&mut ctx, item, previously_streamed_item) .instrument(handle_responses) .await { @@ -2005,15 +2041,18 @@ async fn try_run_sampling_request( if let Some(turn_item) = handle_non_tool_response_item( sess.as_ref(), turn_context.as_ref(), + TurnItemContributorPolicy::Skip, &item, plan_mode, ) .await { let mut turn_item = turn_item; + let stream_item_to_client = !defer_streamed_turn_items_for_contributors; let mut seeded_parsed: Option = None; let mut seeded_item_id: Option = None; - if matches!(turn_item, TurnItem::AgentMessage(_)) + if stream_item_to_client + && matches!(turn_item, TurnItem::AgentMessage(_)) && let Some(raw_text) = raw_assistant_output_text_from_item(&item) { let item_id = turn_item.id(); @@ -2032,31 +2071,34 @@ async fn try_run_sampling_request( seeded_parsed = plan_mode.then_some(seeded); seeded_item_id = Some(item_id); } - if let Some(state) = plan_mode_state.as_mut() - && matches!(turn_item, TurnItem::AgentMessage(_)) - { - let item_id = turn_item.id(); - state - .pending_agent_message_items - .insert(item_id, turn_item.clone()); - } else { - sess.emit_turn_item_started(&turn_context, &turn_item).await; - } - if let (Some(state), Some(item_id), Some(parsed)) = ( - plan_mode_state.as_mut(), - seeded_item_id.as_deref(), - seeded_parsed, - ) { - emit_streamed_assistant_text_delta( - &sess, - &turn_context, - Some(state), - item_id, - parsed, - ) - .await; + if stream_item_to_client { + if let Some(state) = plan_mode_state.as_mut() + && matches!(turn_item, TurnItem::AgentMessage(_)) + { + let item_id = turn_item.id(); + state + .pending_agent_message_items + .insert(item_id, turn_item.clone()); + } else { + sess.emit_turn_item_started(&turn_context, &turn_item).await; + } + if let (Some(state), Some(item_id), Some(parsed)) = ( + plan_mode_state.as_mut(), + seeded_item_id.as_deref(), + seeded_parsed, + ) { + emit_streamed_assistant_text_delta( + &sess, + &turn_context, + Some(state), + item_id, + parsed, + ) + .await; + } } active_item = Some(turn_item); + active_item_is_streaming_to_client = stream_item_to_client; } } ResponseEvent::ServerModel(server_model) => { @@ -2123,6 +2165,9 @@ async fn try_run_sampling_request( // In review child threads, suppress assistant text deltas; the // UI will show a selection popup from the final ReviewOutput. if let Some(active) = active_item.as_ref() { + if !active_item_is_streaming_to_client { + continue; + } let item_id = active.id(); if matches!(active, TurnItem::AgentMessage(_)) { let parsed = assistant_message_stream_parsers.parse_delta(&item_id, &delta); @@ -2171,6 +2216,9 @@ async fn try_run_sampling_request( summary_index, } => { if let Some(active) = active_item.as_ref() { + if !active_item_is_streaming_to_client { + continue; + } let event = ReasoningContentDeltaEvent { thread_id: sess.conversation_id.to_string(), turn_id: turn_context.sub_id.clone(), @@ -2186,6 +2234,9 @@ async fn try_run_sampling_request( } ResponseEvent::ReasoningSummaryPartAdded { summary_index } => { if let Some(active) = active_item.as_ref() { + if !active_item_is_streaming_to_client { + continue; + } let event = EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent { item_id: active.id(), @@ -2201,6 +2252,9 @@ async fn try_run_sampling_request( content_index, } => { if let Some(active) = active_item.as_ref() { + if !active_item_is_streaming_to_client { + continue; + } let event = ReasoningRawContentDeltaEvent { thread_id: sess.conversation_id.to_string(), turn_id: turn_context.sub_id.clone(), @@ -2270,3 +2324,7 @@ pub(crate) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) - } None } + +#[cfg(test)] +#[path = "turn_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/session/turn_tests.rs b/codex-rs/core/src/session/turn_tests.rs new file mode 100644 index 0000000000..4f19474d38 --- /dev/null +++ b/codex-rs/core/src/session/turn_tests.rs @@ -0,0 +1,67 @@ +use super::*; +use codex_extension_api::ExtensionData; +use codex_extension_api::TurnItemContributionFuture; +use codex_extension_api::TurnItemContributor; +use codex_protocol::items::AgentMessageContent; +use pretty_assertions::assert_eq; +use std::sync::Arc; + +struct RewriteAgentMessageContributor; + +impl TurnItemContributor for RewriteAgentMessageContributor { + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> TurnItemContributionFuture<'a> { + Box::pin(async move { + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.content = vec![AgentMessageContent::Text { + text: "plan contributed assistant text".to_string(), + }]; + } + Ok(()) + }) + } +} + +fn assistant_output_text(text: &str) -> ResponseItem { + ResponseItem::Message { + id: Some("msg-1".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + phase: None, + } +} + +#[tokio::test] +async fn plan_mode_uses_contributed_turn_item_for_last_agent_message() { + let (mut session, turn_context) = crate::session::tests::make_session_and_context().await; + let mut builder = codex_extension_api::ExtensionRegistryBuilder::new(); + builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor)); + session.services.extensions = Arc::new(builder.build()); + let turn_store = ExtensionData::new(); + let mut state = PlanModeStreamState::new(&turn_context.sub_id); + let mut last_agent_message = None; + let item = assistant_output_text("original assistant text"); + + let handled = handle_assistant_item_done_in_plan_mode( + &session, + &turn_context, + &turn_store, + &item, + &mut state, + /*previously_active_item*/ None, + &mut last_agent_message, + ) + .await; + + assert!(handled); + assert_eq!( + last_agent_message.as_deref(), + Some("plan contributed assistant text") + ); +} diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 25d7dada95..bc44fe0c3b 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use codex_extension_api::ExtensionData; use codex_protocol::config_types::ModeKind; use codex_protocol::items::TurnItem; use codex_utils_stream_parser::strip_citations; @@ -20,6 +21,7 @@ use codex_memories_read::citations::parse_memory_citation; use codex_memories_read::citations::thread_ids_from_memory_citation; use codex_protocol::error::CodexErr; use codex_protocol::error::Result; +use codex_protocol::memory_citation::MemoryCitation; use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::MessagePhase; @@ -31,6 +33,7 @@ use codex_utils_stream_parser::strip_proposed_plan_blocks; use futures::Future; use tracing::debug; use tracing::instrument; +use tracing::warn; const GENERATED_IMAGE_ARTIFACTS_DIR: &str = "generated_images"; @@ -127,22 +130,50 @@ pub(crate) async fn record_completed_response_item( sess: &Session, turn_context: &TurnContext, item: &ResponseItem, +) { + record_completed_response_item_with_finalized_facts( + sess, + turn_context, + item, + /*finalized_facts*/ None, + ) + .await; +} + +pub(crate) async fn record_completed_response_item_with_finalized_facts( + sess: &Session, + turn_context: &TurnContext, + item: &ResponseItem, + finalized_facts: Option<&FinalizedTurnItemFacts>, ) { sess.record_conversation_items(turn_context, std::slice::from_ref(item)) .await; - if completed_item_defers_mailbox_delivery_to_next_turn( - item, - turn_context.collaboration_mode.mode == ModeKind::Plan, - ) { + let defers_mailbox_delivery = finalized_facts.map_or_else( + || { + completed_item_defers_mailbox_delivery_to_next_turn( + item, + turn_context.collaboration_mode.mode == ModeKind::Plan, + ) + }, + |facts| facts.defers_mailbox_delivery_to_next_turn, + ); + if defers_mailbox_delivery { sess.defer_mailbox_delivery_to_next_turn(&turn_context.sub_id) .await; } mark_thread_memory_mode_polluted_if_external_context(sess, turn_context, item).await; - let has_memory_citation = record_stage1_output_usage_and_detect_memory_citation( - sess.services.state_db.as_ref(), - item, - ) - .await; + let has_memory_citation = if let Some(memory_citation) = + finalized_facts.and_then(|facts| facts.memory_citation.as_ref()) + { + record_stage1_output_usage_for_memory_citation( + sess.services.state_db.as_ref(), + memory_citation, + ) + .await + } else { + record_stage1_output_usage_and_detect_memory_citation(sess.services.state_db.as_ref(), item) + .await + }; if has_memory_citation { sess.record_memory_citation_for_turn(&turn_context.sub_id) .await; @@ -188,7 +219,14 @@ async fn record_stage1_output_usage_and_detect_memory_citation( let Some(memory_citation) = parse_memory_citation(citations) else { return false; }; - let thread_ids = thread_ids_from_memory_citation(&memory_citation); + record_stage1_output_usage_for_memory_citation(state_db_ctx, &memory_citation).await +} + +async fn record_stage1_output_usage_for_memory_citation( + state_db_ctx: Option<&state_db::StateDbHandle>, + memory_citation: &MemoryCitation, +) -> bool { + let thread_ids = thread_ids_from_memory_citation(memory_citation); if thread_ids.is_empty() { return true; } @@ -215,10 +253,91 @@ pub(crate) struct OutputItemResult { pub(crate) struct HandleOutputCtx { pub sess: Arc, pub turn_context: Arc, + pub turn_store: Arc, pub tool_runtime: ToolCallRuntime, pub cancellation_token: CancellationToken, } +async fn apply_turn_item_contributors( + sess: &Session, + turn_store: &ExtensionData, + item: &mut TurnItem, +) { + let contributors = sess.services.extensions.turn_item_contributors().to_vec(); + for contributor in contributors { + if let Err(err) = contributor + .contribute(&sess.services.thread_extension_data, turn_store, item) + .await + { + warn!("turn item contributor failed: {err}"); + } + } +} + +pub(crate) enum TurnItemContributorPolicy<'a> { + Skip, + Run(&'a ExtensionData), +} + +pub(crate) struct FinalizedTurnItem { + pub(crate) turn_item: TurnItem, + pub(crate) facts: FinalizedTurnItemFacts, +} + +#[derive(Clone, Default)] +pub(crate) struct FinalizedTurnItemFacts { + pub(crate) memory_citation: Option, + pub(crate) last_agent_message: Option, + pub(crate) defers_mailbox_delivery_to_next_turn: bool, +} + +pub(crate) async fn finalize_non_tool_response_item( + sess: &Session, + turn_context: &TurnContext, + contributor_policy: TurnItemContributorPolicy<'_>, + item: &ResponseItem, + plan_mode: bool, +) -> Option { + let turn_item = + handle_non_tool_response_item(sess, turn_context, contributor_policy, item, plan_mode) + .await?; + let (memory_citation, last_agent_message, defers_mailbox_delivery_to_next_turn) = + match &turn_item { + TurnItem::AgentMessage(agent_message) => { + let combined = agent_message + .content + .iter() + .map(|entry| match entry { + codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(), + }) + .collect::(); + let last_agent_message = if combined.trim().is_empty() { + None + } else { + Some(combined) + }; + let defers_mailbox_delivery_to_next_turn = + !matches!(agent_message.phase, Some(MessagePhase::Commentary)) + && last_agent_message.is_some(); + ( + agent_message.memory_citation.clone(), + last_agent_message, + defers_mailbox_delivery_to_next_turn, + ) + } + TurnItem::ImageGeneration(_) => (None, None, true), + _ => (None, None, false), + }; + Some(FinalizedTurnItem { + turn_item, + facts: FinalizedTurnItemFacts { + memory_citation, + last_agent_message, + defers_mailbox_delivery_to_next_turn, + }, + }) +} + #[instrument(level = "trace", skip_all)] pub(crate) async fn handle_output_item_done( ctx: &mut HandleOutputCtx, @@ -258,16 +377,20 @@ pub(crate) async fn handle_output_item_done( } // No tool call: convert messages/reasoning into turn items and mark them as complete. Ok(None) => { - let turn_item = handle_non_tool_response_item( + let finalized_turn_item = finalize_non_tool_response_item( ctx.sess.as_ref(), ctx.turn_context.as_ref(), + TurnItemContributorPolicy::Run(ctx.turn_store.as_ref()), &item, plan_mode, ) .await; - if let Some(turn_item) = turn_item { + let finalized_facts = finalized_turn_item + .as_ref() + .map(|finalized| finalized.facts.clone()); + if let Some(finalized_turn_item) = finalized_turn_item { if previously_active_item.is_none() { - let mut started_item = turn_item.clone(); + let mut started_item = finalized_turn_item.turn_item.clone(); if let TurnItem::ImageGeneration(item) = &mut started_item { item.status = "in_progress".to_string(); item.revised_prompt = None; @@ -280,14 +403,18 @@ pub(crate) async fn handle_output_item_done( } ctx.sess - .emit_turn_item_completed(&ctx.turn_context, turn_item) + .emit_turn_item_completed(&ctx.turn_context, finalized_turn_item.turn_item) .await; } - record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item) - .await; - let last_agent_message = last_assistant_message_from_item(&item, plan_mode); + record_completed_response_item_with_finalized_facts( + ctx.sess.as_ref(), + ctx.turn_context.as_ref(), + &item, + finalized_facts.as_ref(), + ) + .await; - output.last_agent_message = last_agent_message; + output.last_agent_message = finalized_facts.and_then(|facts| facts.last_agent_message); } // The tool request should be answered directly (or was denied); push that response into the transcript. Err(FunctionCallError::RespondToModel(message)) => { @@ -323,6 +450,7 @@ pub(crate) async fn handle_output_item_done( pub(crate) async fn handle_non_tool_response_item( sess: &Session, turn_context: &TurnContext, + contributor_policy: TurnItemContributorPolicy<'_>, item: &ResponseItem, plan_mode: bool, ) -> Option { @@ -334,6 +462,9 @@ pub(crate) async fn handle_non_tool_response_item( | ResponseItem::WebSearchCall { .. } | ResponseItem::ImageGenerationCall { .. } => { let mut turn_item = parse_turn_item(item)?; + if let TurnItemContributorPolicy::Run(turn_store) = contributor_policy { + apply_turn_item_contributors(sess, turn_store, &mut turn_item).await; + } if let TurnItem::AgentMessage(agent_message) = &mut turn_item { let combined = agent_message .content @@ -346,7 +477,9 @@ pub(crate) async fn handle_non_tool_response_item( strip_hidden_assistant_markup_and_parse_memory_citation(&combined, plan_mode); agent_message.content = vec![codex_protocol::items::AgentMessageContent::Text { text: stripped }]; - agent_message.memory_citation = memory_citation; + if agent_message.memory_citation.is_none() { + agent_message.memory_citation = memory_citation; + } } if let TurnItem::ImageGeneration(image_item) = &mut turn_item { let session_id = sess.conversation_id.to_string(); diff --git a/codex-rs/core/src/stream_events_utils_tests.rs b/codex-rs/core/src/stream_events_utils_tests.rs index 2012e05aa3..46c7ffed19 100644 --- a/codex-rs/core/src/stream_events_utils_tests.rs +++ b/codex-rs/core/src/stream_events_utils_tests.rs @@ -1,12 +1,24 @@ +use super::HandleOutputCtx; +use super::TurnItemContributorPolicy; use super::completed_item_defers_mailbox_delivery_to_next_turn; +use super::finalize_non_tool_response_item; use super::handle_non_tool_response_item; +use super::handle_output_item_done; use super::image_generation_artifact_path; use super::last_assistant_message_from_item; use super::response_item_may_include_external_context; use super::save_image_generation_result; use crate::session::tests::make_session_and_context; +use crate::tools::ToolRouter; +use crate::tools::parallel::ToolCallRuntime; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_extension_api::ExtensionData; +use codex_extension_api::TurnItemContributionFuture; +use codex_extension_api::TurnItemContributor; use codex_protocol::error::CodexErr; +use codex_protocol::items::AgentMessageContent; use codex_protocol::items::TurnItem; +use codex_protocol::memory_citation::MemoryCitation; use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::LocalShellAction; @@ -16,6 +28,8 @@ use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; use codex_utils_absolute_path::test_support::PathExt; use pretty_assertions::assert_eq; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; fn assistant_output_text(text: &str) -> ResponseItem { assistant_output_text_with_phase(text, /*phase*/ None) @@ -117,10 +131,15 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message() "hello\nMEMORY.md:1-2|note=[x]\n\n\n019cc2ea-1dff-7902-8d40-c8f6e5d83cc4\n world", ); - let turn_item = - handle_non_tool_response_item(&session, &turn_context, &item, /*plan_mode*/ false) - .await - .expect("assistant message should parse"); + let turn_item = handle_non_tool_response_item( + &session, + &turn_context, + TurnItemContributorPolicy::Skip, + &item, + /*plan_mode*/ false, + ) + .await + .expect("assistant message should parse"); let TurnItem::AgentMessage(agent_message) = turn_item else { panic!("expected agent message"); @@ -144,6 +163,199 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message() ); } +struct TestTurnItemContributor; + +#[derive(Debug)] +struct TurnItemContributorRan; + +impl TurnItemContributor for TestTurnItemContributor { + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> TurnItemContributionFuture<'a> { + Box::pin(async move { + turn_store.insert(TurnItemContributorRan); + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.memory_citation = Some(MemoryCitation { + entries: Vec::new(), + rollout_ids: Vec::new(), + }); + } + Ok(()) + }) + } +} + +struct RewriteAgentMessageContributor; + +impl TurnItemContributor for RewriteAgentMessageContributor { + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> TurnItemContributionFuture<'a> { + Box::pin(async move { + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.content = vec![AgentMessageContent::Text { + text: "contributed assistant text".to_string(), + }]; + } + Ok(()) + }) + } +} + +#[tokio::test] +async fn handle_non_tool_response_item_runs_turn_item_contributors_only_when_requested() { + let (mut session, turn_context) = make_session_and_context().await; + let mut builder = codex_extension_api::ExtensionRegistryBuilder::new(); + builder.turn_item_contributor(Arc::new(TestTurnItemContributor)); + session.services.extensions = Arc::new(builder.build()); + let turn_store = ExtensionData::new(); + let item = assistant_output_text( + "helloignored by memory parser world", + ); + + let provisional_turn_item = handle_non_tool_response_item( + &session, + &turn_context, + TurnItemContributorPolicy::Skip, + &item, + /*plan_mode*/ false, + ) + .await + .expect("assistant message should parse"); + + assert!(turn_store.get::().is_none()); + let TurnItem::AgentMessage(provisional_agent_message) = provisional_turn_item else { + panic!("expected agent message"); + }; + assert_eq!(provisional_agent_message.memory_citation, None); + + let turn_item = handle_non_tool_response_item( + &session, + &turn_context, + TurnItemContributorPolicy::Run(&turn_store), + &item, + /*plan_mode*/ false, + ) + .await + .expect("assistant message should parse"); + + assert!(turn_store.get::().is_some()); + let TurnItem::AgentMessage(agent_message) = turn_item else { + panic!("expected agent message"); + }; + assert!(agent_message.memory_citation.is_some()); + let text = agent_message + .content + .iter() + .map(|entry| match entry { + codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(), + }) + .collect::(); + assert_eq!(text, "hello world"); +} + +#[tokio::test] +async fn handle_output_item_done_returns_contributed_last_agent_message() { + let (mut session, turn_context) = make_session_and_context().await; + let mut builder = codex_extension_api::ExtensionRegistryBuilder::new(); + builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor)); + session.services.extensions = Arc::new(builder.build()); + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + let router = Arc::new(ToolRouter::from_config( + &turn_context.tools_config, + crate::tools::router::ToolRouterParams { + mcp_tools: None, + deferred_mcp_tools: None, + discoverable_tools: None, + extension_tool_executors: Vec::new(), + dynamic_tools: turn_context.dynamic_tools.as_slice(), + }, + )); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let tool_runtime = ToolCallRuntime::new( + router, + Arc::clone(&session), + Arc::clone(&turn_context), + tracker, + ); + let item = assistant_output_text("original assistant text"); + let mut ctx = HandleOutputCtx { + sess: session, + turn_context, + turn_store: Arc::new(ExtensionData::new()), + tool_runtime, + cancellation_token: CancellationToken::new(), + }; + + let output = handle_output_item_done(&mut ctx, item, /*previously_active_item*/ None) + .await + .expect("assistant message should complete"); + + assert_eq!( + output.last_agent_message.as_deref(), + Some("contributed assistant text") + ); +} + +#[tokio::test] +async fn finalized_turn_item_defers_mailbox_for_contributed_visible_text() { + let (mut session, turn_context) = make_session_and_context().await; + let mut builder = codex_extension_api::ExtensionRegistryBuilder::new(); + builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor)); + session.services.extensions = Arc::new(builder.build()); + let turn_store = ExtensionData::new(); + let item = assistant_output_text("hidden only"); + + let finalized = finalize_non_tool_response_item( + &session, + &turn_context, + TurnItemContributorPolicy::Run(&turn_store), + &item, + /*plan_mode*/ false, + ) + .await + .expect("assistant message should parse"); + + assert_eq!( + finalized.facts.last_agent_message.as_deref(), + Some("contributed assistant text") + ); + assert!(finalized.facts.defers_mailbox_delivery_to_next_turn); +} + +#[tokio::test] +async fn finalized_turn_item_keeps_mailbox_open_for_commentary_text() { + let (mut session, turn_context) = make_session_and_context().await; + let mut builder = codex_extension_api::ExtensionRegistryBuilder::new(); + builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor)); + session.services.extensions = Arc::new(builder.build()); + let turn_store = ExtensionData::new(); + let item = assistant_output_text_with_phase("still working", Some(MessagePhase::Commentary)); + + let finalized = finalize_non_tool_response_item( + &session, + &turn_context, + TurnItemContributorPolicy::Run(&turn_store), + &item, + /*plan_mode*/ false, + ) + .await + .expect("assistant message should parse"); + + assert_eq!( + finalized.facts.last_agent_message.as_deref(), + Some("contributed assistant text") + ); + assert!(!finalized.facts.defers_mailbox_delivery_to_next_turn); +} + #[test] fn last_assistant_message_from_item_strips_citations_and_plan_blocks() { let item = assistant_output_text( diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 6da6d9ad9b..5e6b931500 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::Instant; +use codex_extension_api::ExtensionData; use futures::future::BoxFuture; use tokio::select; use tokio::sync::Notify; @@ -153,17 +154,25 @@ fn bool_tag(value: bool) -> &'static str { #[derive(Clone)] pub(crate) struct SessionTaskContext { session: Arc, + turn_extension_data: Arc, } impl SessionTaskContext { - pub(crate) fn new(session: Arc) -> Self { - Self { session } + pub(crate) fn new(session: Arc, turn_extension_data: Arc) -> Self { + Self { + session, + turn_extension_data, + } } pub(crate) fn clone_session(&self) -> Arc { Arc::clone(&self.session) } + pub(crate) fn turn_extension_data(&self) -> Arc { + Arc::clone(&self.turn_extension_data) + } + pub(crate) fn auth_manager(&self) -> Arc { Arc::clone(&self.session.services.auth_manager) } @@ -362,7 +371,10 @@ impl Session { let turn = active.get_or_insert_with(ActiveTurn::default); debug_assert!(turn.tasks.is_empty()); let done_clone = Arc::clone(&done); - let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); + let session_ctx = Arc::new(SessionTaskContext::new( + Arc::clone(self), + Arc::clone(&turn_extension_data), + )); let ctx = Arc::clone(&turn_context); let task_for_run = Arc::clone(&task); let task_cancellation_token = cancellation_token.child_token(); @@ -829,7 +841,10 @@ impl Session { task.handle.abort(); - let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); + let session_ctx = Arc::new(SessionTaskContext::new( + Arc::clone(self), + Arc::clone(&task.turn_extension_data), + )); session_task .abort(session_ctx, Arc::clone(&task.turn_context)) .await; diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 08c4933488..756a691f11 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -45,6 +45,7 @@ impl SessionTask for RegularTask { cancellation_token: CancellationToken, ) -> Option { let sess = session.clone_session(); + let turn_extension_data = session.turn_extension_data(); let run_turn_span = trace_span!("run_turn"); // Regular turns emit `TurnStarted` inline so first-turn lifecycle does // not wait on startup prewarm resolution. @@ -72,6 +73,7 @@ impl SessionTask for RegularTask { let last_agent_message = run_turn( Arc::clone(&sess), Arc::clone(&ctx), + Arc::clone(&turn_extension_data), next_input, prewarmed_client_session.take(), cancellation_token.child_token(),