diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 1056f8afb0..b067467f2e 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -543,24 +543,17 @@ impl ToolRegistry { pub struct ToolRegistryBuilder { handlers: HashMap>, specs: Vec, - code_mode_enabled: bool, } impl ToolRegistryBuilder { - pub fn new(code_mode_enabled: bool) -> Self { + pub fn new() -> Self { Self { handlers: HashMap::new(), specs: Vec::new(), - code_mode_enabled, } } pub(crate) fn push_spec(&mut self, spec: ToolSpec) { - let spec = if self.code_mode_enabled { - codex_tools::augment_tool_spec_for_code_mode(spec) - } else { - spec - }; self.specs.push(spec); } @@ -597,7 +590,11 @@ impl ToolRegistryBuilder { self.handlers.insert(name, handler); } - pub fn register_tool_bundle(&mut self, bundle: ExtensionToolBundle) { + pub fn register_tool_bundle( + &mut self, + bundle: ExtensionToolBundle, + transform_spec: impl FnOnce(ToolSpec) -> ToolSpec, + ) { let tool_name = ToolName::plain(bundle.tool_name()); if self.handlers.contains_key(&tool_name) { warn!("Skipping extension tool `{tool_name}`: handler already registered"); @@ -613,7 +610,7 @@ impl ToolRegistryBuilder { return; } }; - self.push_spec(spec.clone()); + self.push_spec(transform_spec(spec.clone())); let handler: Arc = Arc::new(BundledToolHandler::new(bundle, spec)); self.handlers.insert(tool_name, handler); diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index d6acb80fb4..dc744321d1 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -63,16 +63,13 @@ fn handler_looks_up_namespaced_aliases_explicitly() { } #[test] -fn register_handler_adds_handler_and_augments_specs_for_code_mode() { - let mut builder = ToolRegistryBuilder::new(/*code_mode_enabled*/ true); +fn register_handler_adds_handler_and_spec() { + let mut builder = ToolRegistryBuilder::new(); builder.register_handler(Arc::new(GetGoalHandler)); let (specs, registry) = builder.build(); assert_eq!(specs.len(), 1); - assert_eq!( - specs[0], - codex_tools::augment_tool_spec_for_code_mode(create_get_goal_tool()) - ); + assert_eq!(specs[0], create_get_goal_tool()); assert!(registry.has_handler(&codex_tools::ToolName::plain(GET_GOAL_TOOL_NAME))); } diff --git a/codex-rs/core/src/tools/spec_plan.rs b/codex-rs/core/src/tools/spec_plan.rs index 0d50c8bbe6..abcb2d3aba 100644 --- a/codex-rs/core/src/tools/spec_plan.rs +++ b/codex-rs/core/src/tools/spec_plan.rs @@ -27,6 +27,7 @@ use crate::tools::handlers::ViewImageHandler; use crate::tools::handlers::WriteStdinHandler; use crate::tools::handlers::agent_jobs::ReportAgentJobResultHandler; use crate::tools::handlers::agent_jobs::SpawnAgentsOnCsvHandler; +use crate::tools::handlers::extension_tools::extension_tool_spec; use crate::tools::handlers::multi_agents::CloseAgentHandler; use crate::tools::handlers::multi_agents::ResumeAgentHandler; use crate::tools::handlers::multi_agents::SendInputHandler; @@ -67,7 +68,7 @@ pub fn build_tool_registry_builder( config: &ToolsConfig, params: ToolRegistryBuildParams<'_>, ) -> ToolRegistryBuilder { - let mut builder = ToolRegistryBuilder::new(config.code_mode_enabled); + let mut builder = ToolRegistryBuilder::new(); let all_deferred_tools = params .deferred_mcp_tools .into_iter() @@ -98,10 +99,16 @@ pub fn build_tool_registry_builder( ) }) .collect::>(); - let code_mode_nested_tool_specs = handlers + let mut code_mode_nested_tool_specs = handlers .iter() .filter_map(|handler| handler.spec()) .collect::>(); + code_mode_nested_tool_specs.extend( + params + .extension_tool_bundles + .iter() + .filter_map(|bundle| extension_tool_spec(bundle.spec()).ok()), + ); let mut enabled_tools = collect_code_mode_exec_prompt_tool_definitions(code_mode_nested_tool_specs.iter()); enabled_tools @@ -143,6 +150,11 @@ pub fn build_tool_registry_builder( if !config.namespace_tools && matches!(spec, ToolSpec::Namespace(_)) { continue; } + let spec = if config.code_mode_enabled { + codex_tools::augment_tool_spec_for_code_mode(spec) + } else { + spec + }; builder.push_spec(spec); } @@ -178,7 +190,13 @@ pub fn build_tool_registry_builder( } for bundle in params.extension_tool_bundles.iter().cloned() { - builder.register_tool_bundle(bundle); + builder.register_tool_bundle(bundle, |spec| { + if config.code_mode_enabled { + codex_tools::augment_tool_spec_for_code_mode(spec) + } else { + spec + } + }); } builder