diff --git a/codex-rs/core/tests/suite/model_runtime_selectors.rs b/codex-rs/core/tests/suite/model_runtime_selectors.rs index b6f2ea3c7d..80910ed25f 100644 --- a/codex-rs/core/tests/suite/model_runtime_selectors.rs +++ b/codex-rs/core/tests/suite/model_runtime_selectors.rs @@ -9,6 +9,7 @@ use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelPreset; use codex_protocol::openai_models::ModelVisibility; use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::MultiAgentVersion; use codex_protocol::openai_models::ToolMode; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -56,6 +57,52 @@ fn tool_names(body: &Value) -> Vec { .unwrap_or_default() } +fn namespace_child_tool_names(body: &Value, namespace: &str) -> Vec { + body.get("tools") + .and_then(Value::as_array) + .and_then(|tools| { + tools.iter().find_map(|tool| { + if tool.get("type").and_then(Value::as_str) == Some("namespace") + && tool.get("name").and_then(Value::as_str) == Some(namespace) + { + tool.get("tools").and_then(Value::as_array).map(|children| { + children + .iter() + .filter_map(|child| { + child + .get("name") + .and_then(Value::as_str) + .map(str::to_string) + }) + .collect() + }) + } else { + None + } + }) + }) + .unwrap_or_default() +} + +fn selected_tool_names(body: &Value, selected: &[&str]) -> Vec { + tool_names(body) + .into_iter() + .filter(|name| selected.contains(&name.as_str())) + .collect() +} + +fn tool_description<'a>(body: &'a Value, name: &str) -> Option<&'a str> { + body.get("tools") + .and_then(Value::as_array) + .and_then(|tools| { + tools + .iter() + .find(|tool| tool.get("name").and_then(Value::as_str) == Some(name)) + }) + .and_then(|tool| tool.get("description")) + .and_then(Value::as_str) +} + async fn wait_for_model_available(manager: &SharedModelsManager, slug: &str) -> ModelPreset { let deadline = Instant::now() + Duration::from_secs(2); loop { @@ -171,3 +218,80 @@ async fn remote_tool_mode_selector_overrides_feature_flags() -> Result<()> { Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_multi_agent_version_selector_overrides_feature_flags() -> Result<()> { + skip_if_no_network!(Ok(())); + + let mut v1_model = remote_model("test-multi-agent-v1"); + v1_model.multi_agent_version = Some(MultiAgentVersion::V1); + let v1_body = response_body_for_remote_model(v1_model, |config| { + config + .features + .enable(Feature::MultiAgentV2) + .expect("test config should allow feature update"); + }) + .await?; + assert_eq!( + namespace_child_tool_names(&v1_body, "multi_agent_v1"), + vec![ + "spawn_agent".to_string(), + "send_input".to_string(), + "resume_agent".to_string(), + "wait_agent".to_string(), + "close_agent".to_string(), + ] + ); + assert_eq!( + selected_tool_names(&v1_body, &["send_message", "followup_task", "list_agents"]), + Vec::::new() + ); + + let mut v2_model = remote_model("test-multi-agent-v2"); + v2_model.multi_agent_version = Some(MultiAgentVersion::V2); + let v2_body = response_body_for_remote_model(v2_model, |config| { + config + .features + .disable(Feature::Collab) + .expect("test config should allow feature update"); + config + .features + .disable(Feature::MultiAgentV2) + .expect("test config should allow feature update"); + config.multi_agent_v2.max_concurrent_threads_per_session = 17; + }) + .await?; + assert_eq!( + selected_tool_names( + &v2_body, + &[ + "spawn_agent", + "send_message", + "followup_task", + "wait_agent", + "close_agent", + "list_agents", + ], + ), + vec![ + "spawn_agent".to_string(), + "send_message".to_string(), + "followup_task".to_string(), + "wait_agent".to_string(), + "close_agent".to_string(), + "list_agents".to_string(), + ] + ); + assert_eq!( + namespace_child_tool_names(&v2_body, "multi_agent_v1"), + Vec::::new() + ); + assert!( + tool_description(&v2_body, "spawn_agent").is_some_and( + |description| description.contains("max_concurrent_threads_per_session = 17") + ), + "v2 spawn_agent should advertise the configured concurrency cap: {v2_body:?}" + ); + + Ok(()) +}