From 64062e4bcda800c0243a9b0fd4e47092c9ee10a0 Mon Sep 17 00:00:00 2001 From: Abhinav Vedmala Date: Wed, 20 May 2026 14:17:29 -0700 Subject: [PATCH] Default hooks for function tools --- .../core/src/tools/code_mode/wait_handler.rs | 6 +- .../src/tools/handlers/extension_tools.rs | 43 +----- codex-rs/core/src/tools/handlers/mcp.rs | 73 +--------- .../src/tools/handlers/request_permissions.rs | 6 +- .../handlers/unified_exec/write_stdin.rs | 4 + codex-rs/core/src/tools/parallel.rs | 1 + codex-rs/core/src/tools/registry.rs | 129 ++++++++++++++++-- codex-rs/core/src/tools/registry_tests.rs | 117 ++++++++++++++++ 8 files changed, 258 insertions(+), 121 deletions(-) diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index 725535339f..efd4ebe4e9 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -110,4 +110,8 @@ impl ToolExecutor for CodeModeWaitHandler { } } -impl CoreToolRuntime for CodeModeWaitHandler {} +impl CoreToolRuntime for CodeModeWaitHandler { + fn supports_default_function_tool_hooks(&self) -> bool { + false + } +} diff --git a/codex-rs/core/src/tools/handlers/extension_tools.rs b/codex-rs/core/src/tools/handlers/extension_tools.rs index dd2ac42ea4..ce0f41b7da 100644 --- a/codex-rs/core/src/tools/handlers/extension_tools.rs +++ b/codex-rs/core/src/tools/handlers/extension_tools.rs @@ -1,20 +1,14 @@ use std::sync::Arc; -use codex_tools::ToolCall as ExtensionToolCall; -use codex_tools::ToolName; -use codex_tools::ToolSpec; -use serde_json::Value; - use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; -use crate::tools::flat_tool_name; -use crate::tools::hook_names::HookToolName; use crate::tools::registry::CoreToolRuntime; -use crate::tools::registry::PostToolUsePayload; -use crate::tools::registry::PreToolUsePayload; use crate::tools::registry::ToolExecutor; +use codex_tools::ToolCall as ExtensionToolCall; +use codex_tools::ToolName; +use codex_tools::ToolSpec; pub(crate) struct ExtensionToolAdapter(Arc>); @@ -61,29 +55,6 @@ impl CoreToolRuntime for ExtensionToolAdapter { fn matches_kind(&self, payload: &ToolPayload) -> bool { self.arguments_from_payload(payload).is_some() } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - let arguments = self.arguments_from_payload(&invocation.payload)?; - Some(PreToolUsePayload { - tool_name: HookToolName::new(flat_tool_name(&self.tool_name()).into_owned()), - tool_input: extension_tool_hook_input(arguments), - }) - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &dyn ToolOutput, - ) -> Option { - let arguments = self.arguments_from_payload(&invocation.payload)?; - Some(PostToolUsePayload { - tool_name: HookToolName::new(flat_tool_name(&self.tool_name()).into_owned()), - tool_use_id: invocation.call_id.clone(), - tool_input: extension_tool_hook_input(arguments), - tool_response: result - .post_tool_use_response(&invocation.call_id, &invocation.payload)?, - }) - } } fn to_extension_call(invocation: &ToolInvocation) -> ExtensionToolCall { @@ -96,14 +67,6 @@ fn to_extension_call(invocation: &ToolInvocation) -> ExtensionToolCall { } } -fn extension_tool_hook_input(arguments: &str) -> Value { - if arguments.trim().is_empty() { - return Value::Object(serde_json::Map::new()); - } - - serde_json::from_str(arguments).unwrap_or_else(|_| Value::String(arguments.to_string())) -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs index 6f4f79505b..f32dc727e6 100644 --- a/codex-rs/core/src/tools/handlers/mcp.rs +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -9,10 +9,7 @@ use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::context::boxed_tool_output; use crate::tools::flat_tool_name; -use crate::tools::hook_names::HookToolName; use crate::tools::registry::CoreToolRuntime; -use crate::tools::registry::PostToolUsePayload; -use crate::tools::registry::PreToolUsePayload; use crate::tools::registry::ToolExecutor; use crate::tools::registry::ToolExposure; use crate::tools::registry::ToolTelemetryTags; @@ -24,8 +21,6 @@ use codex_tools::ToolName; use codex_tools::ToolSearchSourceInfo; use codex_tools::ToolSpec; use codex_tools::mcp_tool_to_responses_api_tool; -use serde_json::Map; -use serde_json::Value; pub struct McpHandler { tool_info: ToolInfo, @@ -169,66 +164,6 @@ impl CoreToolRuntime for McpHandler { tags }) } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - let ToolPayload::Function { arguments } = &invocation.payload else { - return None; - }; - - Some(PreToolUsePayload { - tool_name: HookToolName::new(self.tool_name().to_string()), - tool_input: mcp_hook_tool_input(arguments), - }) - } - - fn with_updated_hook_input( - &self, - mut invocation: ToolInvocation, - updated_input: Value, - ) -> Result { - invocation.payload = match invocation.payload { - ToolPayload::Function { .. } => ToolPayload::Function { - arguments: serde_json::to_string(&updated_input).map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to serialize rewritten MCP arguments: {err}" - )) - })?, - }, - payload => { - return Err(FunctionCallError::RespondToModel(format!( - "tool {} does not support hook input rewriting for payload {payload:?}", - self.tool_name() - ))); - } - }; - Ok(invocation) - } - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &dyn crate::tools::context::ToolOutput, - ) -> Option { - let ToolPayload::Function { .. } = &invocation.payload else { - return None; - }; - - let tool_response = - result.post_tool_use_response(&invocation.call_id, &invocation.payload)?; - Some(PostToolUsePayload { - tool_name: HookToolName::new(self.tool_name().to_string()), - tool_use_id: invocation.call_id.clone(), - tool_input: result.post_tool_use_input(&invocation.payload)?, - tool_response, - }) - } -} - -fn mcp_hook_tool_input(raw_arguments: &str) -> Value { - if raw_arguments.trim().is_empty() { - return Value::Object(Map::new()); - } - - serde_json::from_str(raw_arguments).unwrap_or_else(|_| Value::String(raw_arguments.to_string())) } fn build_mcp_search_text(info: &ToolInfo) -> String { @@ -288,6 +223,9 @@ mod tests { use super::*; use crate::session::tests::make_session_and_context; use crate::tools::context::ToolCallSource; + use crate::tools::hook_names::HookToolName; + use crate::tools::registry::PostToolUsePayload; + use crate::tools::registry::PreToolUsePayload; use crate::turn_diff_tracker::TurnDiffTracker; use pretty_assertions::assert_eq; use serde_json::json; @@ -443,11 +381,6 @@ mod tests { ); } - #[test] - fn mcp_hook_tool_input_defaults_empty_args_to_object() { - assert_eq!(mcp_hook_tool_input(" "), json!({})); - } - fn tool_info(server_name: &str, callable_namespace: &str, tool_name: &str) -> ToolInfo { ToolInfo { server_name: server_name.to_string(), diff --git a/codex-rs/core/src/tools/handlers/request_permissions.rs b/codex-rs/core/src/tools/handlers/request_permissions.rs index 007c8d2208..afb094acfd 100644 --- a/codex-rs/core/src/tools/handlers/request_permissions.rs +++ b/codex-rs/core/src/tools/handlers/request_permissions.rs @@ -84,4 +84,8 @@ impl ToolExecutor for RequestPermissionsHandler { } } -impl CoreToolRuntime for RequestPermissionsHandler {} +impl CoreToolRuntime for RequestPermissionsHandler { + fn supports_default_function_tool_hooks(&self) -> bool { + false + } +} diff --git a/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs b/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs index 77565eaac2..82651e6c1d 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs @@ -101,6 +101,10 @@ impl CoreToolRuntime for WriteStdinHandler { matches!(payload, ToolPayload::Function { .. }) } + fn supports_default_function_tool_hooks(&self) -> bool { + false + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 15954869e2..9070695c47 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -204,6 +204,7 @@ impl ToolCallRuntime { result: Box::new(AbortedToolOutput { message: Self::abort_message(call, secs), }), + model_visible_override: None, post_tool_use_payload: None, } } diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 96837497f8..e0631ecff2 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -26,6 +26,7 @@ use crate::tools::tool_dispatch_trace::ToolDispatchTrace; use crate::tools::tool_search_entry::ToolSearchInfo; use crate::util::error_or_panic; use codex_extension_api::ToolCallOutcome; +use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::protocol::EventMsg; use codex_tools::ToolName; @@ -64,14 +65,22 @@ pub(crate) trait CoreToolRuntime: ToolExecutor { fn post_tool_use_payload( &self, - _invocation: &ToolInvocation, - _result: &dyn ToolOutput, + invocation: &ToolInvocation, + result: &dyn ToolOutput, ) -> Option { - None + if !self.supports_default_function_tool_hooks() { + return None; + } + + default_function_post_tool_use_payload(invocation, result) } - fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option { - None + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + if !self.supports_default_function_tool_hooks() { + return None; + } + + default_function_pre_tool_use_payload(invocation) } /// Rebuilds a tool invocation from hook-facing `tool_input`. @@ -80,14 +89,28 @@ pub(crate) trait CoreToolRuntime: ToolExecutor { /// hook contract they expose from `pre_tool_use_payload`. fn with_updated_hook_input( &self, - _invocation: ToolInvocation, - _updated_input: Value, + invocation: ToolInvocation, + updated_input: Value, ) -> Result { + if self.supports_default_function_tool_hooks() { + return rewrite_function_tool_hook_input(invocation, updated_input); + } + Err(FunctionCallError::RespondToModel( "tool does not support hook input rewriting".to_string(), )) } + /// Returns whether this tool uses the generic function-tool hook contract. + /// + /// Most local function tools expose their JSON arguments directly to hooks. + /// Tools with compatibility-specific hook contracts can override the hook + /// payload methods instead, while function tools that should not run hooks + /// can opt out here. + fn supports_default_function_tool_hooks(&self) -> bool { + true + } + /// Creates an optional consumer for streamed tool argument diffs. fn create_diff_consumer(&self) -> Option> { None @@ -111,6 +134,7 @@ pub(crate) struct AnyToolResult { pub(crate) call_id: String, pub(crate) payload: ToolPayload, pub(crate) result: Box, + pub(crate) model_visible_override: Option, pub(crate) post_tool_use_payload: Option, } @@ -120,9 +144,13 @@ impl AnyToolResult { call_id, payload, result, + model_visible_override, .. } = self; - result.to_response_item(&call_id, &payload) + model_visible_override.map_or_else( + || result.to_response_item(&call_id, &payload), + |output| output.to_response_item(&call_id, &payload), + ) } pub(crate) fn code_mode_result(self) -> serde_json::Value { @@ -234,6 +262,10 @@ impl CoreToolRuntime for ExposureOverride { .with_updated_hook_input(invocation, updated_input) } + fn supports_default_function_tool_hooks(&self) -> bool { + self.handler.supports_default_function_tool_hooks() + } + fn telemetry_tags<'a>( &'a self, invocation: &'a ToolInvocation, @@ -539,7 +571,7 @@ impl ToolRegistry { if let Some(replacement_text) = replacement_text { let mut guard = response_cell.lock().await; if let Some(result) = guard.as_mut() { - result.result = Box::new(FunctionToolOutput::from_text( + result.model_visible_override = Some(FunctionToolOutput::from_text( replacement_text, /*success*/ None, )); @@ -614,10 +646,89 @@ async fn handle_any_tool( call_id, payload, result: output, + model_visible_override: None, post_tool_use_payload, }) } +fn default_function_pre_tool_use_payload(invocation: &ToolInvocation) -> Option { + let ToolPayload::Function { arguments } = &invocation.payload else { + return None; + }; + + Some(PreToolUsePayload { + tool_name: function_hook_tool_name(invocation), + tool_input: function_hook_tool_input(arguments), + }) +} + +fn default_function_post_tool_use_payload( + invocation: &ToolInvocation, + result: &dyn ToolOutput, +) -> Option { + let ToolPayload::Function { arguments } = &invocation.payload else { + return None; + }; + + Some(PostToolUsePayload { + tool_name: function_hook_tool_name(invocation), + tool_use_id: result.post_tool_use_id(&invocation.call_id), + tool_input: result + .post_tool_use_input(&invocation.payload) + .unwrap_or_else(|| function_hook_tool_input(arguments)), + tool_response: result + .post_tool_use_response(&invocation.call_id, &invocation.payload) + .or_else(|| model_visible_function_tool_response(invocation, result))?, + }) +} + +fn rewrite_function_tool_hook_input( + mut invocation: ToolInvocation, + updated_input: Value, +) -> Result { + let ToolPayload::Function { .. } = &invocation.payload else { + return Err(FunctionCallError::RespondToModel( + "hook input rewrite received unsupported function tool payload".to_string(), + )); + }; + + let arguments = serde_json::to_string(&updated_input).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to serialize rewritten {} arguments: {err}", + flat_tool_name(&invocation.tool_name) + )) + })?; + invocation.payload = ToolPayload::Function { arguments }; + Ok(invocation) +} + +fn function_hook_tool_name(invocation: &ToolInvocation) -> HookToolName { + HookToolName::new(flat_tool_name(&invocation.tool_name).into_owned()) +} + +fn function_hook_tool_input(arguments: &str) -> Value { + if arguments.trim().is_empty() { + return Value::Object(serde_json::Map::new()); + } + + serde_json::from_str(arguments).unwrap_or_else(|_| Value::String(arguments.to_string())) +} + +fn model_visible_function_tool_response( + invocation: &ToolInvocation, + result: &dyn ToolOutput, +) -> Option { + let ResponseInputItem::FunctionCallOutput { + output: FunctionCallOutputPayload { body, .. }, + .. + } = result.to_response_item(&invocation.call_id, &invocation.payload) + else { + return None; + }; + + serde_json::to_value(body).ok() +} + fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &ToolName) -> String { match payload { ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"), diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index e3ecfc8f98..8cfe37efff 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -153,6 +153,123 @@ fn handler_looks_up_namespaced_aliases_explicitly() { ); } +#[tokio::test] +async fn function_tools_expose_default_hook_payloads_and_rewrites() -> anyhow::Result<()> { + let (session, turn) = crate::session::tests::make_session_and_context().await; + let tool_name = codex_tools::ToolName::namespaced("functions.", "echo"); + let handler = TestHandler { + tool_name: tool_name.clone(), + }; + let invocation = ToolInvocation { + payload: ToolPayload::Function { + arguments: serde_json::json!({ "message": "hello" }).to_string(), + }, + ..test_invocation(Arc::new(session), Arc::new(turn), "call-1", tool_name) + }; + let output = + crate::tools::context::FunctionToolOutput::from_text("echoed".to_string(), Some(true)); + + assert_eq!( + handler.pre_tool_use_payload(&invocation), + Some(PreToolUsePayload { + tool_name: HookToolName::new("functions.echo"), + tool_input: serde_json::json!({ "message": "hello" }), + }) + ); + assert_eq!( + handler.post_tool_use_payload(&invocation, &output), + Some(PostToolUsePayload { + tool_name: HookToolName::new("functions.echo"), + tool_use_id: "call-1".to_string(), + tool_input: serde_json::json!({ "message": "hello" }), + tool_response: serde_json::json!("echoed"), + }) + ); + + let invocation = handler + .with_updated_hook_input(invocation, serde_json::json!({ "message": "rewritten" }))?; + let ToolPayload::Function { arguments } = invocation.payload else { + panic!("generic rewritten function payload should remain function-shaped"); + }; + assert_eq!( + serde_json::from_str::(&arguments)?, + serde_json::json!({ "message": "rewritten" }) + ); + + Ok(()) +} + +#[tokio::test] +async fn function_hook_input_defaults_empty_arguments_to_object() { + let (session, turn) = crate::session::tests::make_session_and_context().await; + let tool_name = codex_tools::ToolName::plain("echo"); + let handler = TestHandler { + tool_name: tool_name.clone(), + }; + let invocation = ToolInvocation { + payload: ToolPayload::Function { + arguments: " ".to_string(), + }, + ..test_invocation(Arc::new(session), Arc::new(turn), "call-1", tool_name) + }; + + assert_eq!( + handler.pre_tool_use_payload(&invocation), + Some(PreToolUsePayload { + tool_name: HookToolName::new("echo"), + tool_input: serde_json::json!({}), + }) + ); +} + +#[test] +fn model_visible_override_keeps_code_mode_result_typed() { + let result = AnyToolResult { + call_id: "call-1".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + result: Box::new(codex_tools::JsonToolOutput::new( + serde_json::json!({ "typed": true }), + )), + model_visible_override: Some(crate::tools::context::FunctionToolOutput::from_text( + "hook feedback".to_string(), + /*success*/ None, + )), + post_tool_use_payload: None, + }; + + assert_eq!( + result.into_response(), + ResponseInputItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: codex_protocol::models::FunctionCallOutputPayload::from_text( + "hook feedback".to_string() + ), + } + ); + + let result = AnyToolResult { + call_id: "call-1".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + result: Box::new(codex_tools::JsonToolOutput::new( + serde_json::json!({ "typed": true }), + )), + model_visible_override: Some(crate::tools::context::FunctionToolOutput::from_text( + "hook feedback".to_string(), + /*success*/ None, + )), + post_tool_use_payload: None, + }; + + assert_eq!( + result.code_mode_result(), + serde_json::json!({ "typed": true }) + ); +} + #[tokio::test] async fn dispatch_notifies_tool_lifecycle_contributors() -> anyhow::Result<()> { let (mut session, turn) = crate::session::tests::make_session_and_context().await;