From 1fc8aa0e169c74b571960f529baafc17d686beda Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 31 Mar 2026 13:06:08 +0200 Subject: [PATCH] feat: fork pattern v2 (#15771) Adds this: ``` properties.insert( "fork_turns".to_string(), JsonSchema::String { description: Some( "Optional MultiAgentV2 fork mode. Use `none`, `all`, or a positive integer string such as `3` to fork only the most recent turns." .to_string(), ), }, ); ``` --------- Co-authored-by: Codex --- codex-rs/core/src/agent/control.rs | 195 +++++++++++------- codex-rs/core/src/agent/control_tests.rs | 125 +++++++++++ .../core/src/thread_rollout_truncation.rs | 83 ++++++++ .../src/thread_rollout_truncation_tests.rs | 128 ++++++++++++ .../src/tools/handlers/multi_agents/spawn.rs | 2 + .../src/tools/handlers/multi_agents_tests.rs | 117 +++++++++++ .../tools/handlers/multi_agents_v2/spawn.rs | 48 ++++- codex-rs/core/src/tools/spec_tests.rs | 12 ++ codex-rs/tools/src/agent_tool.rs | 54 ++++- codex-rs/tools/src/agent_tool_tests.rs | 20 ++ 10 files changed, 702 insertions(+), 82 deletions(-) diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 1d9ff45a60..c522acd1c1 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -15,6 +15,7 @@ use crate::session_prefix::format_subagent_notification_message; use crate::shell_snapshot::ShellSnapshot; use crate::state_db; use crate::thread_manager::ThreadManagerState; +use crate::thread_rollout_truncation::truncate_rollout_to_last_n_fork_turns; use codex_features::Feature; use codex_protocol::AgentPath; use codex_protocol::ThreadId; @@ -41,9 +42,16 @@ const AGENT_NAMES: &str = include_str!("agent_names.txt"); const FORKED_SPAWN_AGENT_OUTPUT_MESSAGE: &str = "You are the newly spawned agent. The prior conversation history was forked from your parent agent. Treat the next user message as your new task, and use the forked history only as background context."; const ROOT_LAST_TASK_MESSAGE: &str = "Main thread"; +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum SpawnAgentForkMode { + FullHistory, + LastNTurns(usize), +} + #[derive(Clone, Debug, Default)] pub(crate) struct SpawnAgentOptions { pub(crate) fork_parent_spawn_call_id: Option, + pub(crate) fork_mode: Option, } #[derive(Clone, Debug)] @@ -178,83 +186,32 @@ impl AgentControl { let notification_source = session_source.clone(); // The same `AgentControl` is sent to spawn the thread. - let new_thread = match session_source { - Some(session_source) => { - if let Some(call_id) = options.fork_parent_spawn_call_id.as_ref() { - let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - .. - }) = session_source.clone() - else { - return Err(CodexErr::Fatal( - "spawn_agent fork requires a thread-spawn session source".to_string(), - )); - }; - let parent_thread = state.get_thread(parent_thread_id).await.ok(); - if let Some(parent_thread) = parent_thread.as_ref() { - // `record_conversation_items` only queues rollout writes asynchronously. - // Flush/materialize the live parent before snapshotting JSONL for a fork. - parent_thread - .codex - .session - .ensure_rollout_materialized() - .await; - parent_thread.codex.session.flush_rollout().await; - } - let rollout_path = parent_thread - .as_ref() - .and_then(|parent_thread| parent_thread.rollout_path()) - .or(find_thread_path_by_id_str( - config.codex_home.as_path(), - &parent_thread_id.to_string(), - ) - .await?) - .ok_or_else(|| { - CodexErr::Fatal(format!( - "parent thread rollout unavailable for fork: {parent_thread_id}" - )) - })?; - let mut forked_rollout_items: Vec = - RolloutRecorder::get_rollout_history(&rollout_path) - .await? - .get_rollout_items(); - let mut output = FunctionCallOutputPayload::from_text( - FORKED_SPAWN_AGENT_OUTPUT_MESSAGE.to_string(), - ); - output.success = Some(true); - forked_rollout_items.push(RolloutItem::ResponseItem( - ResponseItem::FunctionCallOutput { - call_id: call_id.clone(), - output, - }, - )); - let initial_history = InitialHistory::Forked(forked_rollout_items); - state - .fork_thread_with_source( - config, - initial_history, - self.clone(), - session_source, - /*persist_extended_history*/ false, - inherited_shell_snapshot, - inherited_exec_policy, - ) - .await? - } else { - state - .spawn_new_thread_with_source( - config, - self.clone(), - session_source, - /*persist_extended_history*/ false, - /*metrics_service_name*/ None, - inherited_shell_snapshot, - inherited_exec_policy, - ) - .await? - } + let new_thread = match (session_source, options.fork_mode.as_ref()) { + (Some(session_source), Some(_)) => { + self.spawn_forked_thread( + &state, + config, + session_source, + &options, + inherited_shell_snapshot, + inherited_exec_policy, + ) + .await? } - None => state.spawn_new_thread(config, self.clone()).await?, + (Some(session_source), None) => { + state + .spawn_new_thread_with_source( + config, + self.clone(), + session_source, + /*persist_extended_history*/ false, + /*metrics_service_name*/ None, + inherited_shell_snapshot, + inherited_exec_policy, + ) + .await? + } + (None, _) => state.spawn_new_thread(config, self.clone()).await?, }; agent_metadata.agent_id = Some(new_thread.thread_id); reservation.commit(agent_metadata.clone()); @@ -294,6 +251,92 @@ impl AgentControl { }) } + async fn spawn_forked_thread( + &self, + state: &Arc, + config: crate::config::Config, + session_source: SessionSource, + options: &SpawnAgentOptions, + inherited_shell_snapshot: Option>, + inherited_exec_policy: Option>, + ) -> CodexResult { + let Some(call_id) = options.fork_parent_spawn_call_id.as_deref() else { + return Err(CodexErr::Fatal( + "spawn_agent fork requires a parent spawn call id".to_string(), + )); + }; + let Some(fork_mode) = options.fork_mode.as_ref() else { + return Err(CodexErr::Fatal( + "spawn_agent fork requires a fork mode".to_string(), + )); + }; + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, .. + }) = &session_source + else { + return Err(CodexErr::Fatal( + "spawn_agent fork requires a thread-spawn session source".to_string(), + )); + }; + + let parent_thread_id = *parent_thread_id; + let parent_thread = state.get_thread(parent_thread_id).await.ok(); + if let Some(parent_thread) = parent_thread.as_ref() { + // `record_conversation_items` only queues rollout writes asynchronously. + // Flush/materialize the live parent before snapshotting JSONL for a fork. + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + } + + let rollout_path = parent_thread + .as_ref() + .and_then(|parent_thread| parent_thread.rollout_path()) + .or(find_thread_path_by_id_str( + config.codex_home.as_path(), + &parent_thread_id.to_string(), + ) + .await?) + .ok_or_else(|| { + CodexErr::Fatal(format!( + "parent thread rollout unavailable for fork: {parent_thread_id}" + )) + })?; + + let mut forked_rollout_items = RolloutRecorder::get_rollout_history(&rollout_path) + .await? + .get_rollout_items(); + if let SpawnAgentForkMode::LastNTurns(last_n_turns) = fork_mode { + forked_rollout_items = + truncate_rollout_to_last_n_fork_turns(&forked_rollout_items, *last_n_turns); + } + + let mut output = + FunctionCallOutputPayload::from_text(FORKED_SPAWN_AGENT_OUTPUT_MESSAGE.to_string()); + output.success = Some(true); + forked_rollout_items.push(RolloutItem::ResponseItem( + ResponseItem::FunctionCallOutput { + call_id: call_id.to_string(), + output, + }, + )); + + state + .fork_thread_with_source( + config, + InitialHistory::Forked(forked_rollout_items), + self.clone(), + session_source, + /*persist_extended_history*/ false, + inherited_shell_snapshot, + inherited_exec_policy, + ) + .await + } + /// Resume an existing agent thread from a recorded rollout file. pub(crate) async fn resume_agent_from_rollout( &self, diff --git a/codex-rs/core/src/agent/control_tests.rs b/codex-rs/core/src/agent/control_tests.rs index e440f98385..834ab2037b 100644 --- a/codex-rs/core/src/agent/control_tests.rs +++ b/codex-rs/core/src/agent/control_tests.rs @@ -602,6 +602,7 @@ async fn spawn_agent_can_fork_parent_thread_history() { })), SpawnAgentOptions { fork_parent_spawn_call_id: Some(parent_spawn_call_id), + fork_mode: Some(SpawnAgentForkMode::FullHistory), }, ) .await @@ -687,6 +688,7 @@ async fn spawn_agent_fork_injects_output_for_parent_spawn_call() { })), SpawnAgentOptions { fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + fork_mode: Some(SpawnAgentForkMode::FullHistory), }, ) .await @@ -759,6 +761,7 @@ async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { })), SpawnAgentOptions { fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + fork_mode: Some(SpawnAgentForkMode::FullHistory), }, ) .await @@ -805,6 +808,128 @@ async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { .expect("parent shutdown should submit"); } +#[tokio::test] +async fn spawn_agent_fork_last_n_turns_keeps_only_recent_turns() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + + parent_thread + .inject_user_message_without_turn("old parent context".to_string()) + .await; + let queued_communication = InterAgentCommunication::new( + AgentPath::root(), + AgentPath::try_from("/root/worker").expect("agent path"), + Vec::new(), + "queued message".to_string(), + /*trigger_turn*/ false, + ); + let queued_turn_context = parent_thread.codex.session.new_default_turn().await; + parent_thread + .codex + .session + .record_conversation_items( + queued_turn_context.as_ref(), + &[queued_communication.to_response_input_item().into()], + ) + .await; + + let triggered_communication = InterAgentCommunication::new( + AgentPath::root(), + AgentPath::try_from("/root/worker").expect("agent path"), + Vec::new(), + "triggered context".to_string(), + /*trigger_turn*/ true, + ); + let triggered_turn_context = parent_thread.codex.session.new_default_turn().await; + parent_thread + .codex + .session + .record_conversation_items( + triggered_turn_context.as_ref(), + &[triggered_communication.to_response_input_item().into()], + ) + .await; + + parent_thread + .inject_user_message_without_turn("current parent task".to_string()) + .await; + let spawn_turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-last-n".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(spawn_turn_context.as_ref(), &[parent_spawn_call]) + .await; + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + + let child_thread_id = harness + .control + .spawn_agent_with_metadata( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_path: None, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id), + fork_mode: Some(SpawnAgentForkMode::LastNTurns(2)), + }, + ) + .await + .expect("forked spawn should keep only the last two turns") + .thread_id; + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let history = child_thread.codex.session.clone_history().await; + + assert!(!history_contains_text( + history.raw_items(), + "old parent context" + )); + assert!(!history_contains_text( + history.raw_items(), + "queued message" + )); + assert!(history_contains_text( + history.raw_items(), + "triggered context" + )); + assert!(history_contains_text( + history.raw_items(), + "current parent task" + )); + + let _ = harness + .control + .shutdown_live_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); +} + #[tokio::test] async fn spawn_agent_respects_max_threads_limit() { let max_threads = 1usize; diff --git a/codex-rs/core/src/thread_rollout_truncation.rs b/codex-rs/core/src/thread_rollout_truncation.rs index 5fa1881ffa..6931c44a9c 100644 --- a/codex-rs/core/src/thread_rollout_truncation.rs +++ b/codex-rs/core/src/thread_rollout_truncation.rs @@ -7,6 +7,7 @@ use crate::event_mapping; use codex_protocol::items::TurnItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::InterAgentCommunication; use codex_protocol::protocol::RolloutItem; /// Return the indices of user message boundaries in a rollout. @@ -40,6 +41,51 @@ pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec Vec { + let mut user_positions = Vec::new(); + let mut fork_turn_positions = Vec::new(); + for (idx, item) in items.iter().enumerate() { + match item { + RolloutItem::ResponseItem(item) if is_real_user_message_boundary(item) => { + user_positions.push(idx); + fork_turn_positions.push(idx); + } + RolloutItem::ResponseItem(item) if is_trigger_turn_boundary(item) => { + fork_turn_positions.push(idx); + } + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => { + let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX); + if num_turns == 0 { + continue; + } + let Some(rollback_start_idx) = user_positions + .len() + .checked_sub(num_turns) + .map(|rollback_start| user_positions[rollback_start]) + .or_else(|| user_positions.first().copied()) + else { + continue; + }; + let new_user_len = user_positions.len().saturating_sub(num_turns); + user_positions.truncate(new_user_len); + fork_turn_positions.retain(|position| *position < rollback_start_idx); + } + _ => {} + } + } + fork_turn_positions +} + /// Return a prefix of `items` obtained by cutting strictly before the nth user message. /// /// The boundary index is 0-based from the start of `items` (so `n_from_start = 0` returns @@ -68,6 +114,43 @@ pub(crate) fn truncate_rollout_before_nth_user_message_from_start( items[..cut_idx].to_vec() } +/// Return a suffix of `items` that keeps the last `n_from_end` fork turns. +/// +/// If fewer than or equal to `n_from_end` fork turns exist, this returns the full rollout. +pub(crate) fn truncate_rollout_to_last_n_fork_turns( + items: &[RolloutItem], + n_from_end: usize, +) -> Vec { + if n_from_end == 0 { + return Vec::new(); + } + + let fork_turn_positions = fork_turn_positions_in_rollout(items); + if fork_turn_positions.len() <= n_from_end { + return items.to_vec(); + } + + let keep_idx = fork_turn_positions[fork_turn_positions.len() - n_from_end]; + items[keep_idx..].to_vec() +} + +fn is_real_user_message_boundary(item: &ResponseItem) -> bool { + matches!( + event_mapping::parse_turn_item(item), + Some(TurnItem::UserMessage(_)) + ) +} + +fn is_trigger_turn_boundary(item: &ResponseItem) -> bool { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + + role == "assistant" + && InterAgentCommunication::from_message_content(content) + .is_some_and(|communication| communication.trigger_turn) +} + #[cfg(test)] #[path = "thread_rollout_truncation_tests.rs"] mod tests; diff --git a/codex-rs/core/src/thread_rollout_truncation_tests.rs b/codex-rs/core/src/thread_rollout_truncation_tests.rs index d251b722d9..bea66d1f80 100644 --- a/codex-rs/core/src/thread_rollout_truncation_tests.rs +++ b/codex-rs/core/src/thread_rollout_truncation_tests.rs @@ -1,7 +1,9 @@ use super::*; use crate::codex::make_session_and_context; +use codex_protocol::AgentPath; use codex_protocol::models::ContentItem; use codex_protocol::models::ReasoningItemReasoningSummary; +use codex_protocol::protocol::InterAgentCommunication; use codex_protocol::protocol::ThreadRolledBackEvent; use pretty_assertions::assert_eq; @@ -29,6 +31,17 @@ fn assistant_msg(text: &str) -> ResponseItem { } } +fn inter_agent_msg(text: &str, trigger_turn: bool) -> ResponseItem { + let communication = InterAgentCommunication::new( + AgentPath::root(), + AgentPath::try_from("/root/worker").expect("agent path"), + Vec::new(), + text.to_string(), + trigger_turn, + ); + communication.to_response_input_item().into() +} + #[test] fn truncates_rollout_from_start_before_nth_user_only() { let items = [ @@ -157,3 +170,118 @@ async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() { serde_json::to_value(&expected).unwrap() ); } + +#[test] +fn truncates_rollout_to_last_n_fork_turns_counts_trigger_turn_messages() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(inter_agent_msg( + "queued message", + /*trigger_turn*/ false, + )), + RolloutItem::ResponseItem(assistant_msg("a2")), + RolloutItem::ResponseItem(inter_agent_msg( + "triggered task", + /*trigger_turn*/ true, + )), + RolloutItem::ResponseItem(assistant_msg("a3")), + RolloutItem::ResponseItem(user_msg("u2")), + RolloutItem::ResponseItem(assistant_msg("a4")), + ]; + + let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 2); + let expected = rollout[4..].to_vec(); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); +} + +#[test] +fn truncates_rollout_to_last_n_fork_turns_applies_thread_rollback_markers() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(inter_agent_msg( + "triggered task", + /*trigger_turn*/ true, + )), + RolloutItem::ResponseItem(assistant_msg("a2")), + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { + num_turns: 1, + })), + RolloutItem::ResponseItem(user_msg("u2")), + RolloutItem::ResponseItem(assistant_msg("a3")), + ]; + + let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 2); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&rollout).unwrap() + ); +} + +#[test] +fn fork_turn_positions_ignore_zero_turn_rollback_markers() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(inter_agent_msg( + "triggered task", + /*trigger_turn*/ true, + )), + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { + num_turns: 0, + })), + RolloutItem::ResponseItem(user_msg("u2")), + ]; + + assert_eq!(fork_turn_positions_in_rollout(&rollout), vec![0, 1, 3]); +} + +#[test] +fn truncates_rollout_to_last_n_fork_turns_discards_trigger_boundaries_in_rolled_back_suffix() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(user_msg("u2")), + RolloutItem::ResponseItem(inter_agent_msg( + "triggered task", + /*trigger_turn*/ true, + )), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { + num_turns: 1, + })), + RolloutItem::ResponseItem(user_msg("u3")), + RolloutItem::ResponseItem(assistant_msg("a2")), + ]; + + let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 2); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&rollout).unwrap() + ); +} + +#[test] +fn truncates_rollout_to_last_n_fork_turns_keeps_full_rollout_when_n_is_large() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(inter_agent_msg( + "triggered task", + /*trigger_turn*/ true, + )), + RolloutItem::ResponseItem(assistant_msg("a2")), + ]; + + let truncated = truncate_rollout_to_last_n_fork_turns(&rollout, /*n_from_end*/ 10); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&rollout).unwrap() + ); +} diff --git a/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs b/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs index 9bb7b6055c..308ec49d85 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs @@ -1,4 +1,5 @@ use super::*; +use crate::agent::control::SpawnAgentForkMode; use crate::agent::control::SpawnAgentOptions; use crate::agent::control::render_input_preview; use crate::agent::role::DEFAULT_ROLE_NAME; @@ -90,6 +91,7 @@ impl ToolHandler for Handler { )?), SpawnAgentOptions { fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()), + fork_mode: args.fork_context.then_some(SpawnAgentForkMode::FullHistory), }, ) .await diff --git a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs index e4e532dec1..9caf631dd8 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs @@ -480,6 +480,123 @@ async fn multi_agent_v2_spawn_returns_path_and_send_message_accepts_relative_pat })); } +#[tokio::test] +async fn multi_agent_v2_spawn_rejects_legacy_fork_context() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + let root = manager + .start_thread((*turn.config).clone()) + .await + .expect("root thread should start"); + session.services.agent_control = manager.agent_control(); + session.conversation_id = root.thread_id; + let mut config = (*turn.config).clone(); + config + .features + .enable(Feature::MultiAgentV2) + .expect("test config should allow feature update"); + turn.config = Arc::new(config); + + let err = SpawnAgentHandlerV2 + .handle(invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "inspect this repo", + "task_name": "worker", + "fork_context": true + })), + )) + .await + .expect_err("legacy fork_context should be rejected"); + + assert_eq!( + err, + FunctionCallError::RespondToModel( + "fork_context is not supported in MultiAgentV2; use fork_turns instead".to_string() + ) + ); +} + +#[tokio::test] +async fn multi_agent_v2_spawn_rejects_invalid_fork_turns_string() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + let root = manager + .start_thread((*turn.config).clone()) + .await + .expect("root thread should start"); + session.services.agent_control = manager.agent_control(); + session.conversation_id = root.thread_id; + let mut config = (*turn.config).clone(); + config + .features + .enable(Feature::MultiAgentV2) + .expect("test config should allow feature update"); + turn.config = Arc::new(config); + + let err = SpawnAgentHandlerV2 + .handle(invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "inspect this repo", + "task_name": "worker", + "fork_turns": "banana" + })), + )) + .await + .expect_err("invalid fork_turns should be rejected"); + + assert_eq!( + err, + FunctionCallError::RespondToModel( + "fork_turns must be `none`, `all`, or a positive integer string".to_string() + ) + ); +} + +#[tokio::test] +async fn multi_agent_v2_spawn_rejects_zero_fork_turns() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + let root = manager + .start_thread((*turn.config).clone()) + .await + .expect("root thread should start"); + session.services.agent_control = manager.agent_control(); + session.conversation_id = root.thread_id; + let mut config = (*turn.config).clone(); + config + .features + .enable(Feature::MultiAgentV2) + .expect("test config should allow feature update"); + turn.config = Arc::new(config); + + let err = SpawnAgentHandlerV2 + .handle(invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "inspect this repo", + "task_name": "worker", + "fork_turns": "0" + })), + )) + .await + .expect_err("zero turn count should be rejected"); + + assert_eq!( + err, + FunctionCallError::RespondToModel( + "fork_turns must be `none`, `all`, or a positive integer string".to_string() + ) + ); +} + #[tokio::test] async fn multi_agent_v2_send_message_accepts_root_target_from_child() { let (mut session, mut turn) = make_session_and_context().await; diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs index ffe128b43e..ac0beb6750 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs @@ -1,4 +1,5 @@ use super::*; +use crate::agent::control::SpawnAgentForkMode; use crate::agent::control::SpawnAgentOptions; use crate::agent::control::render_input_preview; use crate::agent::next_thread_spawn_depth; @@ -32,6 +33,7 @@ impl ToolHandler for Handler { } = invocation; let arguments = function_arguments(payload)?; let args: SpawnAgentArgs = parse_arguments(&arguments)?; + let fork_mode = args.fork_mode()?; let role_name = args .agent_type .as_deref() @@ -112,7 +114,8 @@ impl ToolHandler for Handler { }, Some(spawn_source), SpawnAgentOptions { - fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()), + fork_parent_spawn_call_id: fork_mode.as_ref().map(|_| call_id.clone()), + fork_mode, }, ) .await @@ -204,8 +207,47 @@ struct SpawnAgentArgs { agent_type: Option, model: Option, reasoning_effort: Option, - #[serde(default)] - fork_context: bool, + fork_turns: Option, + fork_context: Option, +} + +impl SpawnAgentArgs { + fn fork_mode(&self) -> Result, FunctionCallError> { + if self.fork_context.is_some() { + return Err(FunctionCallError::RespondToModel( + "fork_context is not supported in MultiAgentV2; use fork_turns instead".to_string(), + )); + } + + let Some(fork_turns) = self + .fork_turns + .as_deref() + .map(str::trim) + .filter(|fork_turns| !fork_turns.is_empty()) + else { + return Ok(None); + }; + + if fork_turns.eq_ignore_ascii_case("none") { + return Ok(None); + } + if fork_turns.eq_ignore_ascii_case("all") { + return Ok(Some(SpawnAgentForkMode::FullHistory)); + } + + let last_n_turns = fork_turns.parse::().map_err(|_| { + FunctionCallError::RespondToModel( + "fork_turns must be `none`, `all`, or a positive integer string".to_string(), + ) + })?; + if last_n_turns == 0 { + return Err(FunctionCallError::RespondToModel( + "fork_turns must be `none`, `all`, or a positive integer string".to_string(), + )); + } + + Ok(Some(SpawnAgentForkMode::LastNTurns(last_n_turns))) + } } #[derive(Debug, Serialize)] diff --git a/codex-rs/core/src/tools/spec_tests.rs b/codex-rs/core/src/tools/spec_tests.rs index c914ee5172..7283512917 100644 --- a/codex-rs/core/src/tools/spec_tests.rs +++ b/codex-rs/core/src/tools/spec_tests.rs @@ -431,6 +431,16 @@ fn test_build_specs_collab_tools_enabled() { ); assert_lacks_tool_name(&tools, "spawn_agents_on_csv"); assert_lacks_tool_name(&tools, "list_agents"); + + let spawn_agent = find_tool(&tools, "spawn_agent"); + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = &spawn_agent.spec else { + panic!("spawn_agent should be a function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("spawn_agent should use object params"); + }; + assert!(properties.contains_key("fork_context")); + assert!(!properties.contains_key("fork_turns")); } #[test] @@ -487,6 +497,8 @@ fn test_build_specs_multi_agent_v2_uses_task_names_and_hides_resume() { panic!("spawn_agent should use object params"); }; assert!(properties.contains_key("task_name")); + assert!(properties.contains_key("fork_turns")); + assert!(!properties.contains_key("fork_context")); assert_eq!(required.as_ref(), Some(&vec!["task_name".to_string()])); let output_schema = output_schema .as_ref() diff --git a/codex-rs/tools/src/agent_tool.rs b/codex-rs/tools/src/agent_tool.rs index 03b3df5c19..a278440b6b 100644 --- a/codex-rs/tools/src/agent_tool.rs +++ b/codex-rs/tools/src/agent_tool.rs @@ -23,7 +23,7 @@ pub fn create_spawn_agent_tool_v1(options: SpawnAgentToolOptions<'_>) -> ToolSpe let available_models_description = spawn_agent_models_description(options.available_models); let return_value_description = "Returns the spawned agent id plus the user-facing nickname when available."; - let properties = spawn_agent_common_properties(&options.agent_type_description); + let properties = spawn_agent_common_properties_v1(&options.agent_type_description); ToolSpec::Function(ResponsesApiTool { name: "spawn_agent".to_string(), @@ -45,7 +45,7 @@ pub fn create_spawn_agent_tool_v1(options: SpawnAgentToolOptions<'_>) -> ToolSpe pub fn create_spawn_agent_tool_v2(options: SpawnAgentToolOptions<'_>) -> ToolSpec { let available_models_description = spawn_agent_models_description(options.available_models); let return_value_description = "Returns the canonical task name for the spawned agent, plus the user-facing nickname when available."; - let mut properties = spawn_agent_common_properties(&options.agent_type_description); + let mut properties = spawn_agent_common_properties_v2(&options.agent_type_description); properties.insert( "task_name".to_string(), JsonSchema::String { @@ -544,7 +544,7 @@ fn create_collab_input_items_schema() -> JsonSchema { } } -fn spawn_agent_common_properties(agent_type_description: &str) -> BTreeMap { +fn spawn_agent_common_properties_v1(agent_type_description: &str) -> BTreeMap { BTreeMap::from([ ( "message".to_string(), @@ -592,6 +592,54 @@ fn spawn_agent_common_properties(agent_type_description: &str) -> BTreeMap BTreeMap { + BTreeMap::from([ + ( + "message".to_string(), + JsonSchema::String { + description: Some( + "Initial plain-text task for the new agent. Use either message or items." + .to_string(), + ), + }, + ), + ("items".to_string(), create_collab_input_items_schema()), + ( + "agent_type".to_string(), + JsonSchema::String { + description: Some(agent_type_description.to_string()), + }, + ), + ( + "fork_turns".to_string(), + JsonSchema::String { + description: Some( + "Optional MultiAgentV2 fork mode. Use `none`, `all`, or a positive integer string such as `3` to fork only the most recent turns." + .to_string(), + ), + }, + ), + ( + "model".to_string(), + JsonSchema::String { + description: Some( + "Optional model override for the new agent. Replaces the inherited model." + .to_string(), + ), + }, + ), + ( + "reasoning_effort".to_string(), + JsonSchema::String { + description: Some( + "Optional reasoning effort override for the new agent. Replaces the inherited reasoning effort." + .to_string(), + ), + }, + ), + ]) +} + fn spawn_agent_tool_description( available_models_description: &str, return_value_description: &str, diff --git a/codex-rs/tools/src/agent_tool_tests.rs b/codex-rs/tools/src/agent_tool_tests.rs index f483b97e08..ce312ba362 100644 --- a/codex-rs/tools/src/agent_tool_tests.rs +++ b/codex-rs/tools/src/agent_tool_tests.rs @@ -56,6 +56,8 @@ fn spawn_agent_tool_v2_requires_task_name_and_lists_visible_models() { assert!(description.contains("visible display (`visible-model`)")); assert!(!description.contains("hidden display (`hidden-model`)")); assert!(properties.contains_key("task_name")); + assert!(properties.contains_key("fork_turns")); + assert!(!properties.contains_key("fork_context")); assert_eq!( properties.get("agent_type"), Some(&JsonSchema::String { @@ -69,6 +71,24 @@ fn spawn_agent_tool_v2_requires_task_name_and_lists_visible_models() { ); } +#[test] +fn spawn_agent_tool_v1_keeps_legacy_fork_context_field() { + let tool = create_spawn_agent_tool_v1(SpawnAgentToolOptions { + available_models: &[], + agent_type_description: "role help".to_string(), + }); + + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = tool else { + panic!("spawn_agent should be a function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("spawn_agent should use object params"); + }; + + assert!(properties.contains_key("fork_context")); + assert!(!properties.contains_key("fork_turns")); +} + #[test] fn send_message_tool_requires_items_and_uses_submission_output() { let ToolSpec::Function(ResponsesApiTool {