mirror of
https://github.com/openai/codex.git
synced 2026-05-29 23:40:29 +00:00
Wire turn item contributors into stream output (#22494)
## Summary - run registered TurnItemContributor hooks for parsed stream output items - plumb the active turn extension store into stream item handling - preserve existing memory citation parsing as fallback after contributors run ## Tests - cargo test -p codex-core stream_events_utils -- --nocapture - just fmt - just fix -p codex-core - git diff --check
This commit is contained in:
@@ -6944,6 +6944,7 @@ async fn handle_output_item_done_records_image_save_history_message() {
|
||||
let mut ctx = HandleOutputCtx {
|
||||
sess: Arc::clone(&session),
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
|
||||
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
};
|
||||
@@ -6996,6 +6997,7 @@ async fn handle_output_item_done_skips_image_save_message_when_save_fails() {
|
||||
let mut ctx = HandleOutputCtx {
|
||||
sess: Arc::clone(&session),
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
|
||||
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
};
|
||||
@@ -8883,6 +8885,7 @@ async fn tool_calls_reopen_mailbox_delivery_for_current_turn() {
|
||||
let mut ctx = HandleOutputCtx {
|
||||
sess: Arc::clone(&sess),
|
||||
turn_context: Arc::clone(&tc),
|
||||
turn_store: Arc::new(codex_extension_api::ExtensionData::new()),
|
||||
tool_runtime: test_tool_runtime(Arc::clone(&sess), Arc::clone(&tc)),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
};
|
||||
|
||||
@@ -44,12 +44,14 @@ use crate::session::PreviousTurnSettings;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::stream_events_utils::HandleOutputCtx;
|
||||
use crate::stream_events_utils::TurnItemContributorPolicy;
|
||||
use crate::stream_events_utils::finalize_non_tool_response_item;
|
||||
use crate::stream_events_utils::handle_non_tool_response_item;
|
||||
use crate::stream_events_utils::handle_output_item_done;
|
||||
use crate::stream_events_utils::last_assistant_message_from_item;
|
||||
use crate::stream_events_utils::mark_thread_memory_mode_polluted_if_external_context;
|
||||
use crate::stream_events_utils::raw_assistant_output_text_from_item;
|
||||
use crate::stream_events_utils::record_completed_response_item;
|
||||
use crate::stream_events_utils::record_completed_response_item_with_finalized_facts;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
@@ -139,6 +141,7 @@ use tracing::warn;
|
||||
pub(crate) async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_extension_data: Arc<codex_extension_api::ExtensionData>,
|
||||
input: Vec<UserInput>,
|
||||
prewarmed_client_session: Option<ModelClientSession>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -454,6 +457,7 @@ pub(crate) async fn run_turn(
|
||||
match run_sampling_request(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_extension_data),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&mut client_session,
|
||||
turn_metadata_header.as_deref(),
|
||||
@@ -1009,6 +1013,7 @@ pub(crate) fn build_prompt(
|
||||
async fn run_sampling_request(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_store: Arc<codex_extension_api::ExtensionData>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_metadata_header: Option<&str>,
|
||||
@@ -1065,6 +1070,7 @@ async fn run_sampling_request(
|
||||
tool_runtime.clone(),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_store),
|
||||
client_session,
|
||||
turn_metadata_header,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
@@ -1747,6 +1753,7 @@ async fn emit_turn_item_in_plan_mode(
|
||||
async fn handle_assistant_item_done_in_plan_mode(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_store: &codex_extension_api::ExtensionData,
|
||||
item: &ResponseItem,
|
||||
state: &mut PlanModeStreamState,
|
||||
previously_active_item: Option<&TurnItem>,
|
||||
@@ -1757,21 +1764,38 @@ async fn handle_assistant_item_done_in_plan_mode(
|
||||
{
|
||||
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;
|
||||
|
||||
if let Some(turn_item) =
|
||||
handle_non_tool_response_item(sess, turn_context, item, /*plan_mode*/ true).await
|
||||
let mut finalized_facts = None;
|
||||
if let Some(finalized_turn_item) = finalize_non_tool_response_item(
|
||||
sess,
|
||||
turn_context,
|
||||
TurnItemContributorPolicy::Run(turn_store),
|
||||
item,
|
||||
/*plan_mode*/ true,
|
||||
)
|
||||
.await
|
||||
{
|
||||
finalized_facts = Some(finalized_turn_item.facts.clone());
|
||||
emit_turn_item_in_plan_mode(
|
||||
sess,
|
||||
turn_context,
|
||||
turn_item,
|
||||
finalized_turn_item.turn_item,
|
||||
previously_active_item,
|
||||
state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let final_last_agent_message = finalized_facts
|
||||
.as_ref()
|
||||
.and_then(|facts| facts.last_agent_message.clone());
|
||||
|
||||
record_completed_response_item(sess, turn_context, item).await;
|
||||
if let Some(agent_message) = last_assistant_message_from_item(item, /*plan_mode*/ true) {
|
||||
record_completed_response_item_with_finalized_facts(
|
||||
sess,
|
||||
turn_context,
|
||||
item,
|
||||
finalized_facts.as_ref(),
|
||||
)
|
||||
.await;
|
||||
if let Some(agent_message) = final_last_agent_message {
|
||||
*last_agent_message = Some(agent_message);
|
||||
}
|
||||
return true;
|
||||
@@ -1817,6 +1841,7 @@ async fn try_run_sampling_request(
|
||||
tool_runtime: ToolCallRuntime,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_store: Arc<codex_extension_api::ExtensionData>,
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_metadata_header: Option<&str>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
@@ -1865,6 +1890,9 @@ async fn try_run_sampling_request(
|
||||
let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan;
|
||||
let mut assistant_message_stream_parsers = AssistantMessageStreamParsers::new(plan_mode);
|
||||
let mut plan_mode_state = plan_mode.then(|| PlanModeStreamState::new(&turn_context.sub_id));
|
||||
let defer_streamed_turn_items_for_contributors =
|
||||
!sess.services.extensions.turn_item_contributors().is_empty();
|
||||
let mut active_item_is_streaming_to_client = false;
|
||||
let receiving_span = trace_span!("receiving_stream");
|
||||
let mut completed_response_id: Option<String> = None;
|
||||
let outcome: CodexResult<SamplingRequestResult> = loop {
|
||||
@@ -1917,7 +1945,13 @@ async fn try_run_sampling_request(
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
let previously_active_item = active_item.take();
|
||||
if let Some(previous) = previously_active_item.as_ref()
|
||||
let previously_streamed_item = if active_item_is_streaming_to_client {
|
||||
previously_active_item
|
||||
} else {
|
||||
None
|
||||
};
|
||||
active_item_is_streaming_to_client = false;
|
||||
if let Some(previous) = previously_streamed_item.as_ref()
|
||||
&& matches!(previous, TurnItem::AgentMessage(_))
|
||||
{
|
||||
let item_id = previous.id();
|
||||
@@ -1934,9 +1968,10 @@ async fn try_run_sampling_request(
|
||||
&& handle_assistant_item_done_in_plan_mode(
|
||||
&sess,
|
||||
&turn_context,
|
||||
turn_store.as_ref(),
|
||||
&item,
|
||||
state,
|
||||
previously_active_item.as_ref(),
|
||||
previously_streamed_item.as_ref(),
|
||||
&mut last_agent_message,
|
||||
)
|
||||
.await
|
||||
@@ -1947,6 +1982,7 @@ async fn try_run_sampling_request(
|
||||
let mut ctx = HandleOutputCtx {
|
||||
sess: sess.clone(),
|
||||
turn_context: turn_context.clone(),
|
||||
turn_store: Arc::clone(&turn_store),
|
||||
tool_runtime: tool_runtime.clone(),
|
||||
cancellation_token: cancellation_token.child_token(),
|
||||
};
|
||||
@@ -1971,7 +2007,7 @@ async fn try_run_sampling_request(
|
||||
};
|
||||
|
||||
let output_result =
|
||||
match handle_output_item_done(&mut ctx, item, previously_active_item)
|
||||
match handle_output_item_done(&mut ctx, item, previously_streamed_item)
|
||||
.instrument(handle_responses)
|
||||
.await
|
||||
{
|
||||
@@ -2005,15 +2041,18 @@ async fn try_run_sampling_request(
|
||||
if let Some(turn_item) = handle_non_tool_response_item(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
TurnItemContributorPolicy::Skip,
|
||||
&item,
|
||||
plan_mode,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut turn_item = turn_item;
|
||||
let stream_item_to_client = !defer_streamed_turn_items_for_contributors;
|
||||
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
|
||||
let mut seeded_item_id: Option<String> = None;
|
||||
if matches!(turn_item, TurnItem::AgentMessage(_))
|
||||
if stream_item_to_client
|
||||
&& matches!(turn_item, TurnItem::AgentMessage(_))
|
||||
&& let Some(raw_text) = raw_assistant_output_text_from_item(&item)
|
||||
{
|
||||
let item_id = turn_item.id();
|
||||
@@ -2032,31 +2071,34 @@ async fn try_run_sampling_request(
|
||||
seeded_parsed = plan_mode.then_some(seeded);
|
||||
seeded_item_id = Some(item_id);
|
||||
}
|
||||
if let Some(state) = plan_mode_state.as_mut()
|
||||
&& matches!(turn_item, TurnItem::AgentMessage(_))
|
||||
{
|
||||
let item_id = turn_item.id();
|
||||
state
|
||||
.pending_agent_message_items
|
||||
.insert(item_id, turn_item.clone());
|
||||
} else {
|
||||
sess.emit_turn_item_started(&turn_context, &turn_item).await;
|
||||
}
|
||||
if let (Some(state), Some(item_id), Some(parsed)) = (
|
||||
plan_mode_state.as_mut(),
|
||||
seeded_item_id.as_deref(),
|
||||
seeded_parsed,
|
||||
) {
|
||||
emit_streamed_assistant_text_delta(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Some(state),
|
||||
item_id,
|
||||
parsed,
|
||||
)
|
||||
.await;
|
||||
if stream_item_to_client {
|
||||
if let Some(state) = plan_mode_state.as_mut()
|
||||
&& matches!(turn_item, TurnItem::AgentMessage(_))
|
||||
{
|
||||
let item_id = turn_item.id();
|
||||
state
|
||||
.pending_agent_message_items
|
||||
.insert(item_id, turn_item.clone());
|
||||
} else {
|
||||
sess.emit_turn_item_started(&turn_context, &turn_item).await;
|
||||
}
|
||||
if let (Some(state), Some(item_id), Some(parsed)) = (
|
||||
plan_mode_state.as_mut(),
|
||||
seeded_item_id.as_deref(),
|
||||
seeded_parsed,
|
||||
) {
|
||||
emit_streamed_assistant_text_delta(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Some(state),
|
||||
item_id,
|
||||
parsed,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
active_item = Some(turn_item);
|
||||
active_item_is_streaming_to_client = stream_item_to_client;
|
||||
}
|
||||
}
|
||||
ResponseEvent::ServerModel(server_model) => {
|
||||
@@ -2123,6 +2165,9 @@ async fn try_run_sampling_request(
|
||||
// In review child threads, suppress assistant text deltas; the
|
||||
// UI will show a selection popup from the final ReviewOutput.
|
||||
if let Some(active) = active_item.as_ref() {
|
||||
if !active_item_is_streaming_to_client {
|
||||
continue;
|
||||
}
|
||||
let item_id = active.id();
|
||||
if matches!(active, TurnItem::AgentMessage(_)) {
|
||||
let parsed = assistant_message_stream_parsers.parse_delta(&item_id, &delta);
|
||||
@@ -2171,6 +2216,9 @@ async fn try_run_sampling_request(
|
||||
summary_index,
|
||||
} => {
|
||||
if let Some(active) = active_item.as_ref() {
|
||||
if !active_item_is_streaming_to_client {
|
||||
continue;
|
||||
}
|
||||
let event = ReasoningContentDeltaEvent {
|
||||
thread_id: sess.conversation_id.to_string(),
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
@@ -2186,6 +2234,9 @@ async fn try_run_sampling_request(
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryPartAdded { summary_index } => {
|
||||
if let Some(active) = active_item.as_ref() {
|
||||
if !active_item_is_streaming_to_client {
|
||||
continue;
|
||||
}
|
||||
let event =
|
||||
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {
|
||||
item_id: active.id(),
|
||||
@@ -2201,6 +2252,9 @@ async fn try_run_sampling_request(
|
||||
content_index,
|
||||
} => {
|
||||
if let Some(active) = active_item.as_ref() {
|
||||
if !active_item_is_streaming_to_client {
|
||||
continue;
|
||||
}
|
||||
let event = ReasoningRawContentDeltaEvent {
|
||||
thread_id: sess.conversation_id.to_string(),
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
@@ -2270,3 +2324,7 @@ pub(crate) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "turn_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
67
codex-rs/core/src/session/turn_tests.rs
Normal file
67
codex-rs/core/src/session/turn_tests.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use super::*;
|
||||
use codex_extension_api::ExtensionData;
|
||||
use codex_extension_api::TurnItemContributionFuture;
|
||||
use codex_extension_api::TurnItemContributor;
|
||||
use codex_protocol::items::AgentMessageContent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct RewriteAgentMessageContributor;
|
||||
|
||||
impl TurnItemContributor for RewriteAgentMessageContributor {
|
||||
fn contribute<'a>(
|
||||
&'a self,
|
||||
_thread_store: &'a ExtensionData,
|
||||
_turn_store: &'a ExtensionData,
|
||||
item: &'a mut TurnItem,
|
||||
) -> TurnItemContributionFuture<'a> {
|
||||
Box::pin(async move {
|
||||
if let TurnItem::AgentMessage(agent_message) = item {
|
||||
agent_message.content = vec![AgentMessageContent::Text {
|
||||
text: "plan contributed assistant text".to_string(),
|
||||
}];
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_output_text(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: Some("msg-1".to_string()),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn plan_mode_uses_contributed_turn_item_for_last_agent_message() {
|
||||
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
|
||||
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
let turn_store = ExtensionData::new();
|
||||
let mut state = PlanModeStreamState::new(&turn_context.sub_id);
|
||||
let mut last_agent_message = None;
|
||||
let item = assistant_output_text("original assistant text");
|
||||
|
||||
let handled = handle_assistant_item_done_in_plan_mode(
|
||||
&session,
|
||||
&turn_context,
|
||||
&turn_store,
|
||||
&item,
|
||||
&mut state,
|
||||
/*previously_active_item*/ None,
|
||||
&mut last_agent_message,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(handled);
|
||||
assert_eq!(
|
||||
last_agent_message.as_deref(),
|
||||
Some("plan contributed assistant text")
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use codex_extension_api::ExtensionData;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_utils_stream_parser::strip_citations;
|
||||
@@ -20,6 +21,7 @@ use codex_memories_read::citations::parse_memory_citation;
|
||||
use codex_memories_read::citations::thread_ids_from_memory_citation;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::memory_citation::MemoryCitation;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::MessagePhase;
|
||||
@@ -31,6 +33,7 @@ use codex_utils_stream_parser::strip_proposed_plan_blocks;
|
||||
use futures::Future;
|
||||
use tracing::debug;
|
||||
use tracing::instrument;
|
||||
use tracing::warn;
|
||||
|
||||
const GENERATED_IMAGE_ARTIFACTS_DIR: &str = "generated_images";
|
||||
|
||||
@@ -127,22 +130,50 @@ pub(crate) async fn record_completed_response_item(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
item: &ResponseItem,
|
||||
) {
|
||||
record_completed_response_item_with_finalized_facts(
|
||||
sess,
|
||||
turn_context,
|
||||
item,
|
||||
/*finalized_facts*/ None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn record_completed_response_item_with_finalized_facts(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
item: &ResponseItem,
|
||||
finalized_facts: Option<&FinalizedTurnItemFacts>,
|
||||
) {
|
||||
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
|
||||
.await;
|
||||
if completed_item_defers_mailbox_delivery_to_next_turn(
|
||||
item,
|
||||
turn_context.collaboration_mode.mode == ModeKind::Plan,
|
||||
) {
|
||||
let defers_mailbox_delivery = finalized_facts.map_or_else(
|
||||
|| {
|
||||
completed_item_defers_mailbox_delivery_to_next_turn(
|
||||
item,
|
||||
turn_context.collaboration_mode.mode == ModeKind::Plan,
|
||||
)
|
||||
},
|
||||
|facts| facts.defers_mailbox_delivery_to_next_turn,
|
||||
);
|
||||
if defers_mailbox_delivery {
|
||||
sess.defer_mailbox_delivery_to_next_turn(&turn_context.sub_id)
|
||||
.await;
|
||||
}
|
||||
mark_thread_memory_mode_polluted_if_external_context(sess, turn_context, item).await;
|
||||
let has_memory_citation = record_stage1_output_usage_and_detect_memory_citation(
|
||||
sess.services.state_db.as_ref(),
|
||||
item,
|
||||
)
|
||||
.await;
|
||||
let has_memory_citation = if let Some(memory_citation) =
|
||||
finalized_facts.and_then(|facts| facts.memory_citation.as_ref())
|
||||
{
|
||||
record_stage1_output_usage_for_memory_citation(
|
||||
sess.services.state_db.as_ref(),
|
||||
memory_citation,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
record_stage1_output_usage_and_detect_memory_citation(sess.services.state_db.as_ref(), item)
|
||||
.await
|
||||
};
|
||||
if has_memory_citation {
|
||||
sess.record_memory_citation_for_turn(&turn_context.sub_id)
|
||||
.await;
|
||||
@@ -188,7 +219,14 @@ async fn record_stage1_output_usage_and_detect_memory_citation(
|
||||
let Some(memory_citation) = parse_memory_citation(citations) else {
|
||||
return false;
|
||||
};
|
||||
let thread_ids = thread_ids_from_memory_citation(&memory_citation);
|
||||
record_stage1_output_usage_for_memory_citation(state_db_ctx, &memory_citation).await
|
||||
}
|
||||
|
||||
async fn record_stage1_output_usage_for_memory_citation(
|
||||
state_db_ctx: Option<&state_db::StateDbHandle>,
|
||||
memory_citation: &MemoryCitation,
|
||||
) -> bool {
|
||||
let thread_ids = thread_ids_from_memory_citation(memory_citation);
|
||||
if thread_ids.is_empty() {
|
||||
return true;
|
||||
}
|
||||
@@ -215,10 +253,91 @@ pub(crate) struct OutputItemResult {
|
||||
pub(crate) struct HandleOutputCtx {
|
||||
pub sess: Arc<Session>,
|
||||
pub turn_context: Arc<TurnContext>,
|
||||
pub turn_store: Arc<ExtensionData>,
|
||||
pub tool_runtime: ToolCallRuntime,
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
async fn apply_turn_item_contributors(
|
||||
sess: &Session,
|
||||
turn_store: &ExtensionData,
|
||||
item: &mut TurnItem,
|
||||
) {
|
||||
let contributors = sess.services.extensions.turn_item_contributors().to_vec();
|
||||
for contributor in contributors {
|
||||
if let Err(err) = contributor
|
||||
.contribute(&sess.services.thread_extension_data, turn_store, item)
|
||||
.await
|
||||
{
|
||||
warn!("turn item contributor failed: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum TurnItemContributorPolicy<'a> {
|
||||
Skip,
|
||||
Run(&'a ExtensionData),
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizedTurnItem {
|
||||
pub(crate) turn_item: TurnItem,
|
||||
pub(crate) facts: FinalizedTurnItemFacts,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct FinalizedTurnItemFacts {
|
||||
pub(crate) memory_citation: Option<MemoryCitation>,
|
||||
pub(crate) last_agent_message: Option<String>,
|
||||
pub(crate) defers_mailbox_delivery_to_next_turn: bool,
|
||||
}
|
||||
|
||||
pub(crate) async fn finalize_non_tool_response_item(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
contributor_policy: TurnItemContributorPolicy<'_>,
|
||||
item: &ResponseItem,
|
||||
plan_mode: bool,
|
||||
) -> Option<FinalizedTurnItem> {
|
||||
let turn_item =
|
||||
handle_non_tool_response_item(sess, turn_context, contributor_policy, item, plan_mode)
|
||||
.await?;
|
||||
let (memory_citation, last_agent_message, defers_mailbox_delivery_to_next_turn) =
|
||||
match &turn_item {
|
||||
TurnItem::AgentMessage(agent_message) => {
|
||||
let combined = agent_message
|
||||
.content
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(),
|
||||
})
|
||||
.collect::<String>();
|
||||
let last_agent_message = if combined.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(combined)
|
||||
};
|
||||
let defers_mailbox_delivery_to_next_turn =
|
||||
!matches!(agent_message.phase, Some(MessagePhase::Commentary))
|
||||
&& last_agent_message.is_some();
|
||||
(
|
||||
agent_message.memory_citation.clone(),
|
||||
last_agent_message,
|
||||
defers_mailbox_delivery_to_next_turn,
|
||||
)
|
||||
}
|
||||
TurnItem::ImageGeneration(_) => (None, None, true),
|
||||
_ => (None, None, false),
|
||||
};
|
||||
Some(FinalizedTurnItem {
|
||||
turn_item,
|
||||
facts: FinalizedTurnItemFacts {
|
||||
memory_citation,
|
||||
last_agent_message,
|
||||
defers_mailbox_delivery_to_next_turn,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub(crate) async fn handle_output_item_done(
|
||||
ctx: &mut HandleOutputCtx,
|
||||
@@ -258,16 +377,20 @@ pub(crate) async fn handle_output_item_done(
|
||||
}
|
||||
// No tool call: convert messages/reasoning into turn items and mark them as complete.
|
||||
Ok(None) => {
|
||||
let turn_item = handle_non_tool_response_item(
|
||||
let finalized_turn_item = finalize_non_tool_response_item(
|
||||
ctx.sess.as_ref(),
|
||||
ctx.turn_context.as_ref(),
|
||||
TurnItemContributorPolicy::Run(ctx.turn_store.as_ref()),
|
||||
&item,
|
||||
plan_mode,
|
||||
)
|
||||
.await;
|
||||
if let Some(turn_item) = turn_item {
|
||||
let finalized_facts = finalized_turn_item
|
||||
.as_ref()
|
||||
.map(|finalized| finalized.facts.clone());
|
||||
if let Some(finalized_turn_item) = finalized_turn_item {
|
||||
if previously_active_item.is_none() {
|
||||
let mut started_item = turn_item.clone();
|
||||
let mut started_item = finalized_turn_item.turn_item.clone();
|
||||
if let TurnItem::ImageGeneration(item) = &mut started_item {
|
||||
item.status = "in_progress".to_string();
|
||||
item.revised_prompt = None;
|
||||
@@ -280,14 +403,18 @@ pub(crate) async fn handle_output_item_done(
|
||||
}
|
||||
|
||||
ctx.sess
|
||||
.emit_turn_item_completed(&ctx.turn_context, turn_item)
|
||||
.emit_turn_item_completed(&ctx.turn_context, finalized_turn_item.turn_item)
|
||||
.await;
|
||||
}
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
let last_agent_message = last_assistant_message_from_item(&item, plan_mode);
|
||||
record_completed_response_item_with_finalized_facts(
|
||||
ctx.sess.as_ref(),
|
||||
ctx.turn_context.as_ref(),
|
||||
&item,
|
||||
finalized_facts.as_ref(),
|
||||
)
|
||||
.await;
|
||||
|
||||
output.last_agent_message = last_agent_message;
|
||||
output.last_agent_message = finalized_facts.and_then(|facts| facts.last_agent_message);
|
||||
}
|
||||
// The tool request should be answered directly (or was denied); push that response into the transcript.
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
@@ -323,6 +450,7 @@ pub(crate) async fn handle_output_item_done(
|
||||
pub(crate) async fn handle_non_tool_response_item(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
contributor_policy: TurnItemContributorPolicy<'_>,
|
||||
item: &ResponseItem,
|
||||
plan_mode: bool,
|
||||
) -> Option<TurnItem> {
|
||||
@@ -334,6 +462,9 @@ pub(crate) async fn handle_non_tool_response_item(
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::ImageGenerationCall { .. } => {
|
||||
let mut turn_item = parse_turn_item(item)?;
|
||||
if let TurnItemContributorPolicy::Run(turn_store) = contributor_policy {
|
||||
apply_turn_item_contributors(sess, turn_store, &mut turn_item).await;
|
||||
}
|
||||
if let TurnItem::AgentMessage(agent_message) = &mut turn_item {
|
||||
let combined = agent_message
|
||||
.content
|
||||
@@ -346,7 +477,9 @@ pub(crate) async fn handle_non_tool_response_item(
|
||||
strip_hidden_assistant_markup_and_parse_memory_citation(&combined, plan_mode);
|
||||
agent_message.content =
|
||||
vec![codex_protocol::items::AgentMessageContent::Text { text: stripped }];
|
||||
agent_message.memory_citation = memory_citation;
|
||||
if agent_message.memory_citation.is_none() {
|
||||
agent_message.memory_citation = memory_citation;
|
||||
}
|
||||
}
|
||||
if let TurnItem::ImageGeneration(image_item) = &mut turn_item {
|
||||
let session_id = sess.conversation_id.to_string();
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
use super::HandleOutputCtx;
|
||||
use super::TurnItemContributorPolicy;
|
||||
use super::completed_item_defers_mailbox_delivery_to_next_turn;
|
||||
use super::finalize_non_tool_response_item;
|
||||
use super::handle_non_tool_response_item;
|
||||
use super::handle_output_item_done;
|
||||
use super::image_generation_artifact_path;
|
||||
use super::last_assistant_message_from_item;
|
||||
use super::response_item_may_include_external_context;
|
||||
use super::save_image_generation_result;
|
||||
use crate::session::tests::make_session_and_context;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_extension_api::ExtensionData;
|
||||
use codex_extension_api::TurnItemContributionFuture;
|
||||
use codex_extension_api::TurnItemContributor;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::items::AgentMessageContent;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::memory_citation::MemoryCitation;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
@@ -16,6 +28,8 @@ use codex_protocol::models::MessagePhase;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_utils_absolute_path::test_support::PathExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
fn assistant_output_text(text: &str) -> ResponseItem {
|
||||
assistant_output_text_with_phase(text, /*phase*/ None)
|
||||
@@ -117,10 +131,15 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message()
|
||||
"hello<oai-mem-citation><citation_entries>\nMEMORY.md:1-2|note=[x]\n</citation_entries>\n<rollout_ids>\n019cc2ea-1dff-7902-8d40-c8f6e5d83cc4\n</rollout_ids></oai-mem-citation> world",
|
||||
);
|
||||
|
||||
let turn_item =
|
||||
handle_non_tool_response_item(&session, &turn_context, &item, /*plan_mode*/ false)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
let turn_item = handle_non_tool_response_item(
|
||||
&session,
|
||||
&turn_context,
|
||||
TurnItemContributorPolicy::Skip,
|
||||
&item,
|
||||
/*plan_mode*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
|
||||
let TurnItem::AgentMessage(agent_message) = turn_item else {
|
||||
panic!("expected agent message");
|
||||
@@ -144,6 +163,199 @@ async fn handle_non_tool_response_item_strips_citations_from_assistant_message()
|
||||
);
|
||||
}
|
||||
|
||||
struct TestTurnItemContributor;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TurnItemContributorRan;
|
||||
|
||||
impl TurnItemContributor for TestTurnItemContributor {
|
||||
fn contribute<'a>(
|
||||
&'a self,
|
||||
_thread_store: &'a ExtensionData,
|
||||
turn_store: &'a ExtensionData,
|
||||
item: &'a mut TurnItem,
|
||||
) -> TurnItemContributionFuture<'a> {
|
||||
Box::pin(async move {
|
||||
turn_store.insert(TurnItemContributorRan);
|
||||
if let TurnItem::AgentMessage(agent_message) = item {
|
||||
agent_message.memory_citation = Some(MemoryCitation {
|
||||
entries: Vec::new(),
|
||||
rollout_ids: Vec::new(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct RewriteAgentMessageContributor;
|
||||
|
||||
impl TurnItemContributor for RewriteAgentMessageContributor {
|
||||
fn contribute<'a>(
|
||||
&'a self,
|
||||
_thread_store: &'a ExtensionData,
|
||||
_turn_store: &'a ExtensionData,
|
||||
item: &'a mut TurnItem,
|
||||
) -> TurnItemContributionFuture<'a> {
|
||||
Box::pin(async move {
|
||||
if let TurnItem::AgentMessage(agent_message) = item {
|
||||
agent_message.content = vec![AgentMessageContent::Text {
|
||||
text: "contributed assistant text".to_string(),
|
||||
}];
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_non_tool_response_item_runs_turn_item_contributors_only_when_requested() {
|
||||
let (mut session, turn_context) = make_session_and_context().await;
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
|
||||
builder.turn_item_contributor(Arc::new(TestTurnItemContributor));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
let turn_store = ExtensionData::new();
|
||||
let item = assistant_output_text(
|
||||
"hello<oai-mem-citation>ignored by memory parser</oai-mem-citation> world",
|
||||
);
|
||||
|
||||
let provisional_turn_item = handle_non_tool_response_item(
|
||||
&session,
|
||||
&turn_context,
|
||||
TurnItemContributorPolicy::Skip,
|
||||
&item,
|
||||
/*plan_mode*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
|
||||
assert!(turn_store.get::<TurnItemContributorRan>().is_none());
|
||||
let TurnItem::AgentMessage(provisional_agent_message) = provisional_turn_item else {
|
||||
panic!("expected agent message");
|
||||
};
|
||||
assert_eq!(provisional_agent_message.memory_citation, None);
|
||||
|
||||
let turn_item = handle_non_tool_response_item(
|
||||
&session,
|
||||
&turn_context,
|
||||
TurnItemContributorPolicy::Run(&turn_store),
|
||||
&item,
|
||||
/*plan_mode*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
|
||||
assert!(turn_store.get::<TurnItemContributorRan>().is_some());
|
||||
let TurnItem::AgentMessage(agent_message) = turn_item else {
|
||||
panic!("expected agent message");
|
||||
};
|
||||
assert!(agent_message.memory_citation.is_some());
|
||||
let text = agent_message
|
||||
.content
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(),
|
||||
})
|
||||
.collect::<String>();
|
||||
assert_eq!(text, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_output_item_done_returns_contributed_last_agent_message() {
|
||||
let (mut session, turn_context) = make_session_and_context().await;
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
|
||||
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
crate::tools::router::ToolRouterParams {
|
||||
mcp_tools: None,
|
||||
deferred_mcp_tools: None,
|
||||
discoverable_tools: None,
|
||||
extension_tool_executors: Vec::new(),
|
||||
dynamic_tools: turn_context.dynamic_tools.as_slice(),
|
||||
},
|
||||
));
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
router,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
tracker,
|
||||
);
|
||||
let item = assistant_output_text("original assistant text");
|
||||
let mut ctx = HandleOutputCtx {
|
||||
sess: session,
|
||||
turn_context,
|
||||
turn_store: Arc::new(ExtensionData::new()),
|
||||
tool_runtime,
|
||||
cancellation_token: CancellationToken::new(),
|
||||
};
|
||||
|
||||
let output = handle_output_item_done(&mut ctx, item, /*previously_active_item*/ None)
|
||||
.await
|
||||
.expect("assistant message should complete");
|
||||
|
||||
assert_eq!(
|
||||
output.last_agent_message.as_deref(),
|
||||
Some("contributed assistant text")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalized_turn_item_defers_mailbox_for_contributed_visible_text() {
|
||||
let (mut session, turn_context) = make_session_and_context().await;
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
|
||||
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
let turn_store = ExtensionData::new();
|
||||
let item = assistant_output_text("<oai-mem-citation>hidden only</oai-mem-citation>");
|
||||
|
||||
let finalized = finalize_non_tool_response_item(
|
||||
&session,
|
||||
&turn_context,
|
||||
TurnItemContributorPolicy::Run(&turn_store),
|
||||
&item,
|
||||
/*plan_mode*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
|
||||
assert_eq!(
|
||||
finalized.facts.last_agent_message.as_deref(),
|
||||
Some("contributed assistant text")
|
||||
);
|
||||
assert!(finalized.facts.defers_mailbox_delivery_to_next_turn);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalized_turn_item_keeps_mailbox_open_for_commentary_text() {
|
||||
let (mut session, turn_context) = make_session_and_context().await;
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::new();
|
||||
builder.turn_item_contributor(Arc::new(RewriteAgentMessageContributor));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
let turn_store = ExtensionData::new();
|
||||
let item = assistant_output_text_with_phase("still working", Some(MessagePhase::Commentary));
|
||||
|
||||
let finalized = finalize_non_tool_response_item(
|
||||
&session,
|
||||
&turn_context,
|
||||
TurnItemContributorPolicy::Run(&turn_store),
|
||||
&item,
|
||||
/*plan_mode*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
|
||||
assert_eq!(
|
||||
finalized.facts.last_agent_message.as_deref(),
|
||||
Some("contributed assistant text")
|
||||
);
|
||||
assert!(!finalized.facts.defers_mailbox_delivery_to_next_turn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_assistant_message_from_item_strips_citations_and_plan_blocks() {
|
||||
let item = assistant_output_text(
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use codex_extension_api::ExtensionData;
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::select;
|
||||
use tokio::sync::Notify;
|
||||
@@ -153,17 +154,25 @@ fn bool_tag(value: bool) -> &'static str {
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionTaskContext {
|
||||
session: Arc<Session>,
|
||||
turn_extension_data: Arc<ExtensionData>,
|
||||
}
|
||||
|
||||
impl SessionTaskContext {
|
||||
pub(crate) fn new(session: Arc<Session>) -> Self {
|
||||
Self { session }
|
||||
pub(crate) fn new(session: Arc<Session>, turn_extension_data: Arc<ExtensionData>) -> Self {
|
||||
Self {
|
||||
session,
|
||||
turn_extension_data,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clone_session(&self) -> Arc<Session> {
|
||||
Arc::clone(&self.session)
|
||||
}
|
||||
|
||||
pub(crate) fn turn_extension_data(&self) -> Arc<ExtensionData> {
|
||||
Arc::clone(&self.turn_extension_data)
|
||||
}
|
||||
|
||||
pub(crate) fn auth_manager(&self) -> Arc<AuthManager> {
|
||||
Arc::clone(&self.session.services.auth_manager)
|
||||
}
|
||||
@@ -362,7 +371,10 @@ impl Session {
|
||||
let turn = active.get_or_insert_with(ActiveTurn::default);
|
||||
debug_assert!(turn.tasks.is_empty());
|
||||
let done_clone = Arc::clone(&done);
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(
|
||||
Arc::clone(self),
|
||||
Arc::clone(&turn_extension_data),
|
||||
));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let task_cancellation_token = cancellation_token.child_token();
|
||||
@@ -829,7 +841,10 @@ impl Session {
|
||||
|
||||
task.handle.abort();
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(
|
||||
Arc::clone(self),
|
||||
Arc::clone(&task.turn_extension_data),
|
||||
));
|
||||
session_task
|
||||
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||
.await;
|
||||
|
||||
@@ -45,6 +45,7 @@ impl SessionTask for RegularTask {
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
let turn_extension_data = session.turn_extension_data();
|
||||
let run_turn_span = trace_span!("run_turn");
|
||||
// Regular turns emit `TurnStarted` inline so first-turn lifecycle does
|
||||
// not wait on startup prewarm resolution.
|
||||
@@ -72,6 +73,7 @@ impl SessionTask for RegularTask {
|
||||
let last_agent_message = run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&ctx),
|
||||
Arc::clone(&turn_extension_data),
|
||||
next_input,
|
||||
prewarmed_client_session.take(),
|
||||
cancellation_token.child_token(),
|
||||
|
||||
Reference in New Issue
Block a user