Wire turn item contributors into stream output (#22494)

## Summary
- run registered TurnItemContributor hooks for parsed stream output
items
- plumb the active turn extension store into stream item handling
- preserve existing memory citation parsing as fallback after
contributors run

## Tests
- cargo test -p codex-core stream_events_utils -- --nocapture
- just fmt
- just fix -p codex-core
- git diff --check
This commit is contained in:
jif-oai
2026-05-14 14:48:17 +02:00
committed by GitHub
parent 6d65686313
commit 17cd321c32
7 changed files with 550 additions and 60 deletions

View File

@@ -6944,6 +6944,7 @@ async fn handle_output_item_done_records_image_save_history_message() {
let mut ctx = HandleOutputCtx {
sess: Arc::clone(&session),
turn_context: Arc::clone(&turn_context),
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
cancellation_token: CancellationToken::new(),
};
@@ -6996,6 +6997,7 @@ async fn handle_output_item_done_skips_image_save_message_when_save_fails() {
let mut ctx = HandleOutputCtx {
sess: Arc::clone(&session),
turn_context: Arc::clone(&turn_context),
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
cancellation_token: CancellationToken::new(),
};
@@ -8883,6 +8885,7 @@ async fn tool_calls_reopen_mailbox_delivery_for_current_turn() {
let mut ctx = HandleOutputCtx {
sess: Arc::clone(&sess),
turn_context: Arc::clone(&tc),
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
tool_runtime: test_tool_runtime(Arc::clone(&sess), Arc::clone(&tc)),
cancellation_token: CancellationToken::new(),
};

View File

@@ -44,12 +44,14 @@ use crate::session::PreviousTurnSettings;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::TurnItemContributorPolicy;
use crate::stream_events_utils::finalize_non_tool_response_item;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
use crate::stream_events_utils::mark_thread_memory_mode_polluted_if_external_context;
use crate::stream_events_utils::raw_assistant_output_text_from_item;
use crate::stream_events_utils::record_completed_response_item;
use crate::stream_events_utils::record_completed_response_item_with_finalized_facts;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::parallel::ToolCallRuntime;
@@ -139,6 +141,7 @@ use tracing::warn;
pub(crate) async fn run_turn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_extension_data: Arc<codex_extension_api::ExtensionData>,
input: Vec<UserInput>,
prewarmed_client_session: Option<ModelClientSession>,
cancellation_token: CancellationToken,
@@ -454,6 +457,7 @@ pub(crate) async fn run_turn(
match run_sampling_request(
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_extension_data),
Arc::clone(&turn_diff_tracker),
&mut client_session,
turn_metadata_header.as_deref(),
@@ -1009,6 +1013,7 @@ pub(crate) fn build_prompt(
async fn run_sampling_request(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_store: Arc<codex_extension_api::ExtensionData>,
turn_diff_tracker: SharedTurnDiffTracker,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
@@ -1065,6 +1070,7 @@ async fn run_sampling_request(
tool_runtime.clone(),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_store),
client_session,
turn_metadata_header,
Arc::clone(&turn_diff_tracker),
@@ -1747,6 +1753,7 @@ async fn emit_turn_item_in_plan_mode(
async fn handle_assistant_item_done_in_plan_mode(
sess: &Session,
turn_context: &TurnContext,
turn_store: &codex_extension_api::ExtensionData,
item: &ResponseItem,
state: &mut PlanModeStreamState,
previously_active_item: Option<&TurnItem>,
@@ -1757,21 +1764,38 @@ async fn handle_assistant_item_done_in_plan_mode(
{
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;
if let Some(turn_item) =
handle_non_tool_response_item(sess, turn_context, item, /*plan_mode*/ true).await
let mut finalized_facts = None;
if let Some(finalized_turn_item) = finalize_non_tool_response_item(
sess,
turn_context,
TurnItemContributorPolicy::Run(turn_store),
item,
/*plan_mode*/ true,
)
.await
{
finalized_facts = Some(finalized_turn_item.facts.clone());
emit_turn_item_in_plan_mode(
sess,
turn_context,
turn_item,
finalized_turn_item.turn_item,
previously_active_item,
state,
)
.await;
}
let final_last_agent_message = finalized_facts
.as_ref()
.and_then(|facts| facts.last_agent_message.clone());
record_completed_response_item(sess, turn_context, item).await;
if let Some(agent_message) = last_assistant_message_from_item(item, /*plan_mode*/ true) {
record_completed_response_item_with_finalized_facts(
sess,
turn_context,
item,
finalized_facts.as_ref(),
)
.await;
if let Some(agent_message) = final_last_agent_message {
*last_agent_message = Some(agent_message);
}
return true;
@@ -1817,6 +1841,7 @@ async fn try_run_sampling_request(
tool_runtime: ToolCallRuntime,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_store: Arc<codex_extension_api::ExtensionData>,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
turn_diff_tracker: SharedTurnDiffTracker,
@@ -1865,6 +1890,9 @@ async fn try_run_sampling_request(
let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan;
let mut assistant_message_stream_parsers = AssistantMessageStreamParsers::new(plan_mode);
let mut plan_mode_state = plan_mode.then(|| PlanModeStreamState::new(&turn_context.sub_id));
let defer_streamed_turn_items_for_contributors =
!sess.services.extensions.turn_item_contributors().is_empty();
let mut active_item_is_streaming_to_client = false;
let receiving_span = trace_span!("receiving_stream");
let mut completed_response_id: Option<String> = None;
let outcome: CodexResult<SamplingRequestResult> = loop {
@@ -1917,7 +1945,13 @@ async fn try_run_sampling_request(
sess.send_event(&turn_context, event).await;
}
let previously_active_item = active_item.take();
if let Some(previous) = previously_active_item.as_ref()
let previously_streamed_item = if active_item_is_streaming_to_client {
previously_active_item
} else {
None
};
active_item_is_streaming_to_client = false;
if let Some(previous) = previously_streamed_item.as_ref()
&& matches!(previous, TurnItem::AgentMessage(_))
{
let item_id = previous.id();
@@ -1934,9 +1968,10 @@ async fn try_run_sampling_request(
&& handle_assistant_item_done_in_plan_mode(
&sess,
&turn_context,
turn_store.as_ref(),
&item,
state,
previously_active_item.as_ref(),
previously_streamed_item.as_ref(),
&mut last_agent_message,
)
.await
@@ -1947,6 +1982,7 @@ async fn try_run_sampling_request(
let mut ctx = HandleOutputCtx {
sess: sess.clone(),
turn_context: turn_context.clone(),
turn_store: Arc::clone(&turn_store),
tool_runtime: tool_runtime.clone(),
cancellation_token: cancellation_token.child_token(),
};
@@ -1971,7 +2007,7 @@ async fn try_run_sampling_request(
};
let output_result =
match handle_output_item_done(&mut ctx, item, previously_active_item)
match handle_output_item_done(&mut ctx, item, previously_streamed_item)
.instrument(handle_responses)
.await
{
@@ -2005,15 +2041,18 @@ async fn try_run_sampling_request(
if let Some(turn_item) = handle_non_tool_response_item(
sess.as_ref(),
turn_context.as_ref(),
TurnItemContributorPolicy::Skip,
&item,
plan_mode,
)
.await
{
let mut turn_item = turn_item;
let stream_item_to_client = !defer_streamed_turn_items_for_contributors;
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
let mut seeded_item_id: Option<String> = None;
if matches!(turn_item, TurnItem::AgentMessage(_))
if stream_item_to_client
&& matches!(turn_item, TurnItem::AgentMessage(_))
&& let Some(raw_text) = raw_assistant_output_text_from_item(&item)
{
let item_id = turn_item.id();
@@ -2032,31 +2071,34 @@ async fn try_run_sampling_request(
seeded_parsed = plan_mode.then_some(seeded);
seeded_item_id = Some(item_id);
}
if let Some(state) = plan_mode_state.as_mut()
&& matches!(turn_item, TurnItem::AgentMessage(_))
{
let item_id = turn_item.id();
state
.pending_agent_message_items
.insert(item_id, turn_item.clone());
} else {
sess.emit_turn_item_started(&turn_context, &turn_item).await;
}
if let (Some(state), Some(item_id), Some(parsed)) = (
plan_mode_state.as_mut(),
seeded_item_id.as_deref(),
seeded_parsed,
) {
emit_streamed_assistant_text_delta(
&sess,
&turn_context,
Some(state),
item_id,
parsed,
)
.await;
if stream_item_to_client {
if let Some(state) = plan_mode_state.as_mut()
&& matches!(turn_item, TurnItem::AgentMessage(_))
{
let item_id = turn_item.id();
state
.pending_agent_message_items
.insert(item_id, turn_item.clone());
} else {
sess.emit_turn_item_started(&turn_context, &turn_item).await;
}
if let (Some(state), Some(item_id), Some(parsed)) = (
plan_mode_state.as_mut(),
seeded_item_id.as_deref(),
seeded_parsed,
) {
emit_streamed_assistant_text_delta(
&sess,
&turn_context,
Some(state),
item_id,
parsed,
)
.await;
}
}
active_item = Some(turn_item);
active_item_is_streaming_to_client = stream_item_to_client;
}
}
ResponseEvent::ServerModel(server_model) => {
@@ -2123,6 +2165,9 @@ async fn try_run_sampling_request(
// In review child threads, suppress assistant text deltas; the
// UI will show a selection popup from the final ReviewOutput.
if let Some(active) = active_item.as_ref() {
if !active_item_is_streaming_to_client {
continue;
}
let item_id = active.id();
if matches!(active, TurnItem::AgentMessage(_)) {
let parsed = assistant_message_stream_parsers.parse_delta(&item_id, &delta);
@@ -2171,6 +2216,9 @@ async fn try_run_sampling_request(
summary_index,
} => {
if let Some(active) = active_item.as_ref() {
if !active_item_is_streaming_to_client {
continue;
}
let event = ReasoningContentDeltaEvent {
thread_id: sess.conversation_id.to_string(),
turn_id: turn_context.sub_id.clone(),
@@ -2186,6 +2234,9 @@ async fn try_run_sampling_request(
}
ResponseEvent::ReasoningSummaryPartAdded { summary_index } => {
if let Some(active) = active_item.as_ref() {
if !active_item_is_streaming_to_client {
continue;
}
let event =
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {
item_id: active.id(),
@@ -2201,6 +2252,9 @@ async fn try_run_sampling_request(
content_index,
} => {
if let Some(active) = active_item.as_ref() {
if !active_item_is_streaming_to_client {
continue;
}
let event = ReasoningRawContentDeltaEvent {
thread_id: sess.conversation_id.to_string(),
turn_id: turn_context.sub_id.clone(),
@@ -2270,3 +2324,7 @@ pub(crate) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
}
None
}
#[cfg(test)]
#[path = "turn_tests.rs"]
mod tests;

View File

@@ -0,0 +1,67 @@
use super::*;
use codex_extension_api::ExtensionData;
use codex_extension_api::TurnItemContributionFuture;
use codex_extension_api::TurnItemContributor;
use codex_protocol::items::AgentMessageContent;
use pretty_assertions::assert_eq;
use std::sync::Arc;
struct RewriteAgentMessageContributor;
impl TurnItemContributor for RewriteAgentMessageContributor {
fn contribute<'a>(
&'a self,
_thread_store: &'a ExtensionData,
_turn_store: &'a ExtensionData,
item: &'a mut TurnItem,
) -> TurnItemContributionFuture<'a> {
Box::pin(async move {
if let TurnItem::AgentMessage(agent_message) = item {
agent_message.content = vec![AgentMessageContent::Text {
text: "plan contributed assistant text".to_string(),
}];
}
Ok(())
})
}
}
fn assistant_output_text(text: &str) -> ResponseItem {
ResponseItem::Message {
id: Some("msg-1".to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
phase: None,
}
}
#[tokio::test]
async fn plan_mode_uses_contributed_turn_item_for_last_agent_message() {
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
session.services.extensions = Arc::new(builder.build());
let turn_store = ExtensionData::new();
let mut state = PlanModeStreamState::new(&turn_context.sub_id);
let mut last_agent_message = None;
let item = assistant_output_text("original assistant text");
let handled = handle_assistant_item_done_in_plan_mode(
&session,
&turn_context,
&turn_store,
&item,
&mut state,
/*previously_active_item*/ None,
&mut last_agent_message,
)
.await;
assert!(handled);
assert_eq!(
last_agent_message.as_deref(),
Some("plan contributed assistant text")
);
}

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_extension_api::ExtensionData;
use codex_protocol::config_types::ModeKind;
use codex_protocol::items::TurnItem;
use codex_utils_stream_parser::strip_citations;
@@ -20,6 +21,7 @@ use codex_memories_read::citations::parse_memory_citation;
use codex_memories_read::citations::thread_ids_from_memory_citation;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_protocol::memory_citation::MemoryCitation;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::MessagePhase;
@@ -31,6 +33,7 @@ use codex_utils_stream_parser::strip_proposed_plan_blocks;
use futures::Future;
use tracing::debug;
use tracing::instrument;
use tracing::warn;
const GENERATED_IMAGE_ARTIFACTS_DIR: &str = "generated_images";
@@ -127,22 +130,50 @@ pub(crate) async fn record_completed_response_item(
sess: &Session,
turn_context: &TurnContext,
item: &ResponseItem,
) {
record_completed_response_item_with_finalized_facts(
sess,
turn_context,
item,
/*finalized_facts*/ None,
)
.await;
}
pub(crate) async fn record_completed_response_item_with_finalized_facts(
sess: &Session,
turn_context: &TurnContext,
item: &ResponseItem,
finalized_facts: Option<&FinalizedTurnItemFacts>,
) {
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
.await;
if completed_item_defers_mailbox_delivery_to_next_turn(
item,
turn_context.collaboration_mode.mode == ModeKind::Plan,
) {
let defers_mailbox_delivery = finalized_facts.map_or_else(
|| {
completed_item_defers_mailbox_delivery_to_next_turn(
item,
turn_context.collaboration_mode.mode == ModeKind::Plan,
)
},
|facts| facts.defers_mailbox_delivery_to_next_turn,
);
if defers_mailbox_delivery {
sess.defer_mailbox_delivery_to_next_turn(&turn_context.sub_id)
.await;
}
mark_thread_memory_mode_polluted_if_external_context(sess, turn_context, item).await;
let has_memory_citation = record_stage1_output_usage_and_detect_memory_citation(
sess.services.state_db.as_ref(),
item,
)
.await;
let has_memory_citation = if let Some(memory_citation) =
finalized_facts.and_then(|facts| facts.memory_citation.as_ref())
{
record_stage1_output_usage_for_memory_citation(
sess.services.state_db.as_ref(),
memory_citation,
)
.await
} else {
record_stage1_output_usage_and_detect_memory_citation(sess.services.state_db.as_ref(), item)
.await
};
if has_memory_citation {
sess.record_memory_citation_for_turn(&turn_context.sub_id)
.await;
@@ -188,7 +219,14 @@ async fn record_stage1_output_usage_and_detect_memory_citation(
let Some(memory_citation) = parse_memory_citation(citations) else {
return false;
};
let thread_ids = thread_ids_from_memory_citation(&memory_citation);
record_stage1_output_usage_for_memory_citation(state_db_ctx, &memory_citation).await
}
async fn record_stage1_output_usage_for_memory_citation(
state_db_ctx: Option<&state_db::StateDbHandle>,
memory_citation: &MemoryCitation,
) -> bool {
let thread_ids = thread_ids_from_memory_citation(memory_citation);
if thread_ids.is_empty() {
return true;
}
@@ -215,10 +253,91 @@ pub(crate) struct OutputItemResult {
pub(crate) struct HandleOutputCtx {
pub sess: Arc<Session>,
pub turn_context: Arc<TurnContext>,
pub turn_store: Arc<ExtensionData>,
pub tool_runtime: ToolCallRuntime,
pub cancellation_token: CancellationToken,
}
async fn apply_turn_item_contributors(
sess: &Session,
turn_store: &ExtensionData,
item: &mut TurnItem,
) {
let contributors = sess.services.extensions.turn_item_contributors().to_vec();
for contributor in contributors {
if let Err(err) = contributor
.contribute(&sess.services.thread_extension_data, turn_store, item)
.await
{
warn!("turn item contributor failed: {err}");
}
}
}
pub(crate) enum TurnItemContributorPolicy<'a> {
Skip,
Run(&'a ExtensionData),
}
pub(crate) struct FinalizedTurnItem {
pub(crate) turn_item: TurnItem,
pub(crate) facts: FinalizedTurnItemFacts,
}
#[derive(Clone, Default)]
pub(crate) struct FinalizedTurnItemFacts {
pub(crate) memory_citation: Option<MemoryCitation>,
pub(crate) last_agent_message: Option<String>,
pub(crate) defers_mailbox_delivery_to_next_turn: bool,
}
pub(crate) async fn finalize_non_tool_response_item(
sess: &Session,
turn_context: &TurnContext,
contributor_policy: TurnItemContributorPolicy<'_>,
item: &ResponseItem,
plan_mode: bool,
) -> Option<FinalizedTurnItem> {
let turn_item =
handle_non_tool_response_item(sess, turn_context, contributor_policy, item, plan_mode)
.await?;
let (memory_citation, last_agent_message, defers_mailbox_delivery_to_next_turn) =
match &turn_item {
TurnItem::AgentMessage(agent_message) => {
let combined = agent_message
.content
.iter()
.map(|entry| match entry {
codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(),
})
.collect::<String>();
let last_agent_message = if combined.trim().is_empty() {
None
} else {
Some(combined)
};
let defers_mailbox_delivery_to_next_turn =
!matches!(agent_message.phase, Some(MessagePhase::Commentary))
&& last_agent_message.is_some();
(
agent_message.memory_citation.clone(),
last_agent_message,
defers_mailbox_delivery_to_next_turn,
)
}
TurnItem::ImageGeneration(_) => (None, None, true),
_ => (None, None, false),
};
Some(FinalizedTurnItem {
turn_item,
facts: FinalizedTurnItemFacts {
memory_citation,
last_agent_message,
defers_mailbox_delivery_to_next_turn,
},
})
}
#[instrument(level = "trace", skip_all)]
pub(crate) async fn handle_output_item_done(
ctx: &mut HandleOutputCtx,
@@ -258,16 +377,20 @@ pub(crate) async fn handle_output_item_done(
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
let turn_item = handle_non_tool_response_item(
let finalized_turn_item = finalize_non_tool_response_item(
ctx.sess.as_ref(),
ctx.turn_context.as_ref(),
TurnItemContributorPolicy::Run(ctx.turn_store.as_ref()),
&item,
plan_mode,
)
.await;
if let Some(turn_item) = turn_item {
let finalized_facts = finalized_turn_item
.as_ref()
.map(|finalized| finalized.facts.clone());
if let Some(finalized_turn_item) = finalized_turn_item {
if previously_active_item.is_none() {
let mut started_item = turn_item.clone();
let mut started_item = finalized_turn_item.turn_item.clone();
if let TurnItem::ImageGeneration(item) = &mut started_item {
item.status = "in_progress".to_string();
item.revised_prompt = None;
@@ -280,14 +403,18 @@ pub(crate) async fn handle_output_item_done(
}
ctx.sess
.emit_turn_item_completed(&ctx.turn_context, turn_item)
.emit_turn_item_completed(&ctx.turn_context, finalized_turn_item.turn_item)
.await;
}
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
let last_agent_message = last_assistant_message_from_item(&item, plan_mode);
record_completed_response_item_with_finalized_facts(
ctx.sess.as_ref(),
ctx.turn_context.as_ref(),
&item,
finalized_facts.as_ref(),
)
.await;
output.last_agent_message = last_agent_message;
output.last_agent_message = finalized_facts.and_then(|facts| facts.last_agent_message);
}
// The tool request should be answered directly (or was denied); push that response into the transcript.
Err(FunctionCallError::RespondToModel(message)) => {
@@ -323,6 +450,7 @@ pub(crate) async fn handle_output_item_done(
pub(crate) async fn handle_non_tool_response_item(
sess: &Session,
turn_context: &TurnContext,
contributor_policy: TurnItemContributorPolicy<'_>,
item: &ResponseItem,
plan_mode: bool,
) -> Option<TurnItem> {
@@ -334,6 +462,9 @@ pub(crate) async fn handle_non_tool_response_item(
| ResponseItem::WebSearchCall { .. }
| ResponseItem::ImageGenerationCall { .. } => {
let mut turn_item = parse_turn_item(item)?;
if let TurnItemContributorPolicy::Run(turn_store) = contributor_policy {
apply_turn_item_contributors(sess, turn_store, &mut turn_item).await;
}
if let TurnItem::AgentMessage(agent_message) = &mut turn_item {
let combined = agent_message
.content
@@ -346,7 +477,9 @@ pub(crate) async fn handle_non_tool_response_item(
strip_hidden_assistant_markup_and_parse_memory_citation(&combined, plan_mode);
agent_message.content =
vec![codex_protocol::items::AgentMessageContent::Text { text: stripped }];
agent_message.memory_citation = memory_citation;
if agent_message.memory_citation.is_none() {
agent_message.memory_citation = memory_citation;
}
}
if let TurnItem::ImageGeneration(image_item) = &mut turn_item {
let session_id = sess.conversation_id.to_string();

View File

@@ -1,12 +1,24 @@
use super::HandleOutputCtx;
use super::TurnItemContributorPolicy;
use super::completed_item_defers_mailbox_delivery_to_next_turn;
use super::finalize_non_tool_response_item;
use super::handle_non_tool_response_item;
use super::handle_output_item_done;
use super::image_generation_artifact_path;
use super::last_assistant_message_from_item;
use super::response_item_may_include_external_context;
use super::save_image_generation_result;
use crate::session::tests::make_session_and_context;
use crate::tools::ToolRouter;
use crate::tools::parallel::ToolCallRuntime;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_extension_api::ExtensionData;
use codex_extension_api::TurnItemContributionFuture;
use codex_extension_api::TurnItemContributor;
use codex_protocol::error::CodexErr;
use codex_protocol::items::AgentMessageContent;
use codex_protocol::items::TurnItem;
use codex_protocol::memory_citation::MemoryCitation;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::LocalShellAction;
@@ -16,6 +28,8 @@ use codex_protocol::models::MessagePhase;
use codex_protocol::models::ResponseItem;
use codex_utils_absolute_path::test_support::PathExt;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
fn assistant_output_text(text: &str) -> ResponseItem {
assistant_output_text_with_phase(text, /*phase*/ None)
@@ -117,10 +131,15 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message()
"hello<oai-mem-citation><citation_entries>\nMEMORY.md:1-2|note=[x]\n</citation_entries>\n<rollout_ids>\n019cc2ea-1dff-7902-8d40-c8f6e5d83cc4\n</rollout_ids></oai-mem-citation> world",
);
let turn_item =
handle_non_tool_response_item(&session, &turn_context, &item, /*plan_mode*/ false)
.await
.expect("assistant message should parse");
let turn_item = handle_non_tool_response_item(
&session,
&turn_context,
TurnItemContributorPolicy::Skip,
&item,
/*plan_mode*/ false,
)
.await
.expect("assistant message should parse");
let TurnItem::AgentMessage(agent_message) = turn_item else {
panic!("expected agent message");
@@ -144,6 +163,199 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message()
);
}
struct TestTurnItemContributor;
#[derive(Debug)]
struct TurnItemContributorRan;
impl TurnItemContributor for TestTurnItemContributor {
fn contribute<'a>(
&'a self,
_thread_store: &'a ExtensionData,
turn_store: &'a ExtensionData,
item: &'a mut TurnItem,
) -> TurnItemContributionFuture<'a> {
Box::pin(async move {
turn_store.insert(TurnItemContributorRan);
if let TurnItem::AgentMessage(agent_message) = item {
agent_message.memory_citation = Some(MemoryCitation {
entries: Vec::new(),
rollout_ids: Vec::new(),
});
}
Ok(())
})
}
}
struct RewriteAgentMessageContributor;
impl TurnItemContributor for RewriteAgentMessageContributor {
fn contribute<'a>(
&'a self,
_thread_store: &'a ExtensionData,
_turn_store: &'a ExtensionData,
item: &'a mut TurnItem,
) -> TurnItemContributionFuture<'a> {
Box::pin(async move {
if let TurnItem::AgentMessage(agent_message) = item {
agent_message.content = vec![AgentMessageContent::Text {
text: "contributed assistant text".to_string(),
}];
}
Ok(())
})
}
}
#[tokio::test]
async fn handle_non_tool_response_item_runs_turn_item_contributors_only_when_requested() {
let (mut session, turn_context) = make_session_and_context().await;
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
builder.turn_item_contributor(Arc::new(TestTurnItemContributor));
session.services.extensions = Arc::new(builder.build());
let turn_store = ExtensionData::new();
let item = assistant_output_text(
"hello<oai-mem-citation>ignored by memory parser</oai-mem-citation> world",
);
let provisional_turn_item = handle_non_tool_response_item(
&session,
&turn_context,
TurnItemContributorPolicy::Skip,
&item,
/*plan_mode*/ false,
)
.await
.expect("assistant message should parse");
assert!(turn_store.get::<TurnItemContributorRan>().is_none());
let TurnItem::AgentMessage(provisional_agent_message) = provisional_turn_item else {
panic!("expected agent message");
};
assert_eq!(provisional_agent_message.memory_citation, None);
let turn_item = handle_non_tool_response_item(
&session,
&turn_context,
TurnItemContributorPolicy::Run(&turn_store),
&item,
/*plan_mode*/ false,
)
.await
.expect("assistant message should parse");
assert!(turn_store.get::<TurnItemContributorRan>().is_some());
let TurnItem::AgentMessage(agent_message) = turn_item else {
panic!("expected agent message");
};
assert!(agent_message.memory_citation.is_some());
let text = agent_message
.content
.iter()
.map(|entry| match entry {
codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(),
})
.collect::<String>();
assert_eq!(text, "hello world");
}
#[tokio::test]
async fn handle_output_item_done_returns_contributed_last_agent_message() {
let (mut session, turn_context) = make_session_and_context().await;
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
session.services.extensions = Arc::new(builder.build());
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
crate::tools::router::ToolRouterParams {
mcp_tools: None,
deferred_mcp_tools: None,
discoverable_tools: None,
extension_tool_executors: Vec::new(),
dynamic_tools: turn_context.dynamic_tools.as_slice(),
},
));
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let tool_runtime = ToolCallRuntime::new(
router,
Arc::clone(&session),
Arc::clone(&turn_context),
tracker,
);
let item = assistant_output_text("original assistant text");
let mut ctx = HandleOutputCtx {
sess: session,
turn_context,
turn_store: Arc::new(ExtensionData::new()),
tool_runtime,
cancellation_token: CancellationToken::new(),
};
let output = handle_output_item_done(&mut ctx, item, /*previously_active_item*/ None)
.await
.expect("assistant message should complete");
assert_eq!(
output.last_agent_message.as_deref(),
Some("contributed assistant text")
);
}
#[tokio::test]
async fn finalized_turn_item_defers_mailbox_for_contributed_visible_text() {
let (mut session, turn_context) = make_session_and_context().await;
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
session.services.extensions = Arc::new(builder.build());
let turn_store = ExtensionData::new();
let item = assistant_output_text("<oai-mem-citation>hidden only</oai-mem-citation>");
let finalized = finalize_non_tool_response_item(
&session,
&turn_context,
TurnItemContributorPolicy::Run(&turn_store),
&item,
/*plan_mode*/ false,
)
.await
.expect("assistant message should parse");
assert_eq!(
finalized.facts.last_agent_message.as_deref(),
Some("contributed assistant text")
);
assert!(finalized.facts.defers_mailbox_delivery_to_next_turn);
}
#[tokio::test]
async fn finalized_turn_item_keeps_mailbox_open_for_commentary_text() {
let (mut session, turn_context) = make_session_and_context().await;
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
session.services.extensions = Arc::new(builder.build());
let turn_store = ExtensionData::new();
let item = assistant_output_text_with_phase("still working", Some(MessagePhase::Commentary));
let finalized = finalize_non_tool_response_item(
&session,
&turn_context,
TurnItemContributorPolicy::Run(&turn_store),
&item,
/*plan_mode*/ false,
)
.await
.expect("assistant message should parse");
assert_eq!(
finalized.facts.last_agent_message.as_deref(),
Some("contributed assistant text")
);
assert!(!finalized.facts.defers_mailbox_delivery_to_next_turn);
}
#[test]
fn last_assistant_message_from_item_strips_citations_and_plan_blocks() {
let item = assistant_output_text(

View File

@@ -8,6 +8,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use codex_extension_api::ExtensionData;
use futures::future::BoxFuture;
use tokio::select;
use tokio::sync::Notify;
@@ -153,17 +154,25 @@ fn bool_tag(value: bool) -> &'static str {
#[derive(Clone)]
pub(crate) struct SessionTaskContext {
session: Arc<Session>,
turn_extension_data: Arc<ExtensionData>,
}
impl SessionTaskContext {
pub(crate) fn new(session: Arc<Session>) -> Self {
Self { session }
pub(crate) fn new(session: Arc<Session>, turn_extension_data: Arc<ExtensionData>) -> Self {
Self {
session,
turn_extension_data,
}
}
pub(crate) fn clone_session(&self) -> Arc<Session> {
Arc::clone(&self.session)
}
pub(crate) fn turn_extension_data(&self) -> Arc<ExtensionData> {
Arc::clone(&self.turn_extension_data)
}
pub(crate) fn auth_manager(&self) -> Arc<AuthManager> {
Arc::clone(&self.session.services.auth_manager)
}
@@ -362,7 +371,10 @@ impl Session {
let turn = active.get_or_insert_with(ActiveTurn::default);
debug_assert!(turn.tasks.is_empty());
let done_clone = Arc::clone(&done);
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let session_ctx = Arc::new(SessionTaskContext::new(
Arc::clone(self),
Arc::clone(&turn_extension_data),
));
let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task);
let task_cancellation_token = cancellation_token.child_token();
@@ -829,7 +841,10 @@ impl Session {
task.handle.abort();
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let session_ctx = Arc::new(SessionTaskContext::new(
Arc::clone(self),
Arc::clone(&task.turn_extension_data),
));
session_task
.abort(session_ctx, Arc::clone(&task.turn_context))
.await;

View File

@@ -45,6 +45,7 @@ impl SessionTask for RegularTask {
cancellation_token: CancellationToken,
) -> Option<String> {
let sess = session.clone_session();
let turn_extension_data = session.turn_extension_data();
let run_turn_span = trace_span!("run_turn");
// Regular turns emit `TurnStarted` inline so first-turn lifecycle does
// not wait on startup prewarm resolution.
@@ -72,6 +73,7 @@ impl SessionTask for RegularTask {
let last_agent_message = run_turn(
Arc::clone(&sess),
Arc::clone(&ctx),
Arc::clone(&turn_extension_data),
next_input,
prewarmed_client_session.take(),
cancellation_token.child_token(),