Default hooks for function tools

This commit is contained in:
Abhinav Vedmala
2026-05-20 14:17:29 -07:00
parent d1e3d54192
commit 64062e4bcd
8 changed files with 258 additions and 121 deletions

View File

@@ -110,4 +110,8 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
}
}
impl CoreToolRuntime for CodeModeWaitHandler {}
impl CoreToolRuntime for CodeModeWaitHandler {
fn supports_default_function_tool_hooks(&self) -> bool {
false
}
}

View File

@@ -1,20 +1,14 @@
use std::sync::Arc;
use codex_tools::ToolCall as ExtensionToolCall;
use codex_tools::ToolName;
use codex_tools::ToolSpec;
use serde_json::Value;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::flat_tool_name;
use crate::tools::hook_names::HookToolName;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::PostToolUsePayload;
use crate::tools::registry::PreToolUsePayload;
use crate::tools::registry::ToolExecutor;
use codex_tools::ToolCall as ExtensionToolCall;
use codex_tools::ToolName;
use codex_tools::ToolSpec;
pub(crate) struct ExtensionToolAdapter(Arc<dyn codex_tools::ToolExecutor<ExtensionToolCall>>);
@@ -61,29 +55,6 @@ impl CoreToolRuntime for ExtensionToolAdapter {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
self.arguments_from_payload(payload).is_some()
}
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
let arguments = self.arguments_from_payload(&invocation.payload)?;
Some(PreToolUsePayload {
tool_name: HookToolName::new(flat_tool_name(&self.tool_name()).into_owned()),
tool_input: extension_tool_hook_input(arguments),
})
}
fn post_tool_use_payload(
&self,
invocation: &ToolInvocation,
result: &dyn ToolOutput,
) -> Option<PostToolUsePayload> {
let arguments = self.arguments_from_payload(&invocation.payload)?;
Some(PostToolUsePayload {
tool_name: HookToolName::new(flat_tool_name(&self.tool_name()).into_owned()),
tool_use_id: invocation.call_id.clone(),
tool_input: extension_tool_hook_input(arguments),
tool_response: result
.post_tool_use_response(&invocation.call_id, &invocation.payload)?,
})
}
}
fn to_extension_call(invocation: &ToolInvocation) -> ExtensionToolCall {
@@ -96,14 +67,6 @@ fn to_extension_call(invocation: &ToolInvocation) -> ExtensionToolCall {
}
}
fn extension_tool_hook_input(arguments: &str) -> Value {
if arguments.trim().is_empty() {
return Value::Object(serde_json::Map::new());
}
serde_json::from_str(arguments).unwrap_or_else(|_| Value::String(arguments.to_string()))
}
#[cfg(test)]
mod tests {
use std::sync::Arc;

View File

@@ -9,10 +9,7 @@ use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::flat_tool_name;
use crate::tools::hook_names::HookToolName;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::PostToolUsePayload;
use crate::tools::registry::PreToolUsePayload;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolExposure;
use crate::tools::registry::ToolTelemetryTags;
@@ -24,8 +21,6 @@ use codex_tools::ToolName;
use codex_tools::ToolSearchSourceInfo;
use codex_tools::ToolSpec;
use codex_tools::mcp_tool_to_responses_api_tool;
use serde_json::Map;
use serde_json::Value;
pub struct McpHandler {
tool_info: ToolInfo,
@@ -169,66 +164,6 @@ impl CoreToolRuntime for McpHandler {
tags
})
}
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
let ToolPayload::Function { arguments } = &invocation.payload else {
return None;
};
Some(PreToolUsePayload {
tool_name: HookToolName::new(self.tool_name().to_string()),
tool_input: mcp_hook_tool_input(arguments),
})
}
fn with_updated_hook_input(
&self,
mut invocation: ToolInvocation,
updated_input: Value,
) -> Result<ToolInvocation, FunctionCallError> {
invocation.payload = match invocation.payload {
ToolPayload::Function { .. } => ToolPayload::Function {
arguments: serde_json::to_string(&updated_input).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to serialize rewritten MCP arguments: {err}"
))
})?,
},
payload => {
return Err(FunctionCallError::RespondToModel(format!(
"tool {} does not support hook input rewriting for payload {payload:?}",
self.tool_name()
)));
}
};
Ok(invocation)
}
fn post_tool_use_payload(
&self,
invocation: &ToolInvocation,
result: &dyn crate::tools::context::ToolOutput,
) -> Option<PostToolUsePayload> {
let ToolPayload::Function { .. } = &invocation.payload else {
return None;
};
let tool_response =
result.post_tool_use_response(&invocation.call_id, &invocation.payload)?;
Some(PostToolUsePayload {
tool_name: HookToolName::new(self.tool_name().to_string()),
tool_use_id: invocation.call_id.clone(),
tool_input: result.post_tool_use_input(&invocation.payload)?,
tool_response,
})
}
}
fn mcp_hook_tool_input(raw_arguments: &str) -> Value {
if raw_arguments.trim().is_empty() {
return Value::Object(Map::new());
}
serde_json::from_str(raw_arguments).unwrap_or_else(|_| Value::String(raw_arguments.to_string()))
}
fn build_mcp_search_text(info: &ToolInfo) -> String {
@@ -288,6 +223,9 @@ mod tests {
use super::*;
use crate::session::tests::make_session_and_context;
use crate::tools::context::ToolCallSource;
use crate::tools::hook_names::HookToolName;
use crate::tools::registry::PostToolUsePayload;
use crate::tools::registry::PreToolUsePayload;
use crate::turn_diff_tracker::TurnDiffTracker;
use pretty_assertions::assert_eq;
use serde_json::json;
@@ -443,11 +381,6 @@ mod tests {
);
}
#[test]
fn mcp_hook_tool_input_defaults_empty_args_to_object() {
assert_eq!(mcp_hook_tool_input(" "), json!({}));
}
fn tool_info(server_name: &str, callable_namespace: &str, tool_name: &str) -> ToolInfo {
ToolInfo {
server_name: server_name.to_string(),

View File

@@ -84,4 +84,8 @@ impl ToolExecutor<ToolInvocation> for RequestPermissionsHandler {
}
}
impl CoreToolRuntime for RequestPermissionsHandler {}
impl CoreToolRuntime for RequestPermissionsHandler {
fn supports_default_function_tool_hooks(&self) -> bool {
false
}
}

View File

@@ -101,6 +101,10 @@ impl CoreToolRuntime for WriteStdinHandler {
matches!(payload, ToolPayload::Function { .. })
}
fn supports_default_function_tool_hooks(&self) -> bool {
false
}
fn post_tool_use_payload(
&self,
invocation: &ToolInvocation,

View File

@@ -204,6 +204,7 @@ impl ToolCallRuntime {
result: Box::new(AbortedToolOutput {
message: Self::abort_message(call, secs),
}),
model_visible_override: None,
post_tool_use_payload: None,
}
}

View File

@@ -26,6 +26,7 @@ use crate::tools::tool_dispatch_trace::ToolDispatchTrace;
use crate::tools::tool_search_entry::ToolSearchInfo;
use crate::util::error_or_panic;
use codex_extension_api::ToolCallOutcome;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::protocol::EventMsg;
use codex_tools::ToolName;
@@ -64,14 +65,22 @@ pub(crate) trait CoreToolRuntime: ToolExecutor<ToolInvocation> {
fn post_tool_use_payload(
&self,
_invocation: &ToolInvocation,
_result: &dyn ToolOutput,
invocation: &ToolInvocation,
result: &dyn ToolOutput,
) -> Option<PostToolUsePayload> {
None
if !self.supports_default_function_tool_hooks() {
return None;
}
default_function_post_tool_use_payload(invocation, result)
}
fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
None
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
if !self.supports_default_function_tool_hooks() {
return None;
}
default_function_pre_tool_use_payload(invocation)
}
/// Rebuilds a tool invocation from hook-facing `tool_input`.
@@ -80,14 +89,28 @@ pub(crate) trait CoreToolRuntime: ToolExecutor<ToolInvocation> {
/// hook contract they expose from `pre_tool_use_payload`.
fn with_updated_hook_input(
&self,
_invocation: ToolInvocation,
_updated_input: Value,
invocation: ToolInvocation,
updated_input: Value,
) -> Result<ToolInvocation, FunctionCallError> {
if self.supports_default_function_tool_hooks() {
return rewrite_function_tool_hook_input(invocation, updated_input);
}
Err(FunctionCallError::RespondToModel(
"tool does not support hook input rewriting".to_string(),
))
}
/// Returns whether this tool uses the generic function-tool hook contract.
///
/// Most local function tools expose their JSON arguments directly to hooks.
/// Tools with compatibility-specific hook contracts can override the hook
/// payload methods instead, while function tools that should not run hooks
/// can opt out here.
fn supports_default_function_tool_hooks(&self) -> bool {
true
}
/// Creates an optional consumer for streamed tool argument diffs.
fn create_diff_consumer(&self) -> Option<Box<dyn ToolArgumentDiffConsumer>> {
None
@@ -111,6 +134,7 @@ pub(crate) struct AnyToolResult {
pub(crate) call_id: String,
pub(crate) payload: ToolPayload,
pub(crate) result: Box<dyn ToolOutput>,
pub(crate) model_visible_override: Option<FunctionToolOutput>,
pub(crate) post_tool_use_payload: Option<PostToolUsePayload>,
}
@@ -120,9 +144,13 @@ impl AnyToolResult {
call_id,
payload,
result,
model_visible_override,
..
} = self;
result.to_response_item(&call_id, &payload)
model_visible_override.map_or_else(
|| result.to_response_item(&call_id, &payload),
|output| output.to_response_item(&call_id, &payload),
)
}
pub(crate) fn code_mode_result(self) -> serde_json::Value {
@@ -234,6 +262,10 @@ impl CoreToolRuntime for ExposureOverride {
.with_updated_hook_input(invocation, updated_input)
}
fn supports_default_function_tool_hooks(&self) -> bool {
self.handler.supports_default_function_tool_hooks()
}
fn telemetry_tags<'a>(
&'a self,
invocation: &'a ToolInvocation,
@@ -539,7 +571,7 @@ impl ToolRegistry {
if let Some(replacement_text) = replacement_text {
let mut guard = response_cell.lock().await;
if let Some(result) = guard.as_mut() {
result.result = Box::new(FunctionToolOutput::from_text(
result.model_visible_override = Some(FunctionToolOutput::from_text(
replacement_text,
/*success*/ None,
));
@@ -614,10 +646,89 @@ async fn handle_any_tool(
call_id,
payload,
result: output,
model_visible_override: None,
post_tool_use_payload,
})
}
fn default_function_pre_tool_use_payload(invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
let ToolPayload::Function { arguments } = &invocation.payload else {
return None;
};
Some(PreToolUsePayload {
tool_name: function_hook_tool_name(invocation),
tool_input: function_hook_tool_input(arguments),
})
}
fn default_function_post_tool_use_payload(
invocation: &ToolInvocation,
result: &dyn ToolOutput,
) -> Option<PostToolUsePayload> {
let ToolPayload::Function { arguments } = &invocation.payload else {
return None;
};
Some(PostToolUsePayload {
tool_name: function_hook_tool_name(invocation),
tool_use_id: result.post_tool_use_id(&invocation.call_id),
tool_input: result
.post_tool_use_input(&invocation.payload)
.unwrap_or_else(|| function_hook_tool_input(arguments)),
tool_response: result
.post_tool_use_response(&invocation.call_id, &invocation.payload)
.or_else(|| model_visible_function_tool_response(invocation, result))?,
})
}
fn rewrite_function_tool_hook_input(
mut invocation: ToolInvocation,
updated_input: Value,
) -> Result<ToolInvocation, FunctionCallError> {
let ToolPayload::Function { .. } = &invocation.payload else {
return Err(FunctionCallError::RespondToModel(
"hook input rewrite received unsupported function tool payload".to_string(),
));
};
let arguments = serde_json::to_string(&updated_input).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to serialize rewritten {} arguments: {err}",
flat_tool_name(&invocation.tool_name)
))
})?;
invocation.payload = ToolPayload::Function { arguments };
Ok(invocation)
}
fn function_hook_tool_name(invocation: &ToolInvocation) -> HookToolName {
HookToolName::new(flat_tool_name(&invocation.tool_name).into_owned())
}
fn function_hook_tool_input(arguments: &str) -> Value {
if arguments.trim().is_empty() {
return Value::Object(serde_json::Map::new());
}
serde_json::from_str(arguments).unwrap_or_else(|_| Value::String(arguments.to_string()))
}
fn model_visible_function_tool_response(
invocation: &ToolInvocation,
result: &dyn ToolOutput,
) -> Option<Value> {
let ResponseInputItem::FunctionCallOutput {
output: FunctionCallOutputPayload { body, .. },
..
} = result.to_response_item(&invocation.call_id, &invocation.payload)
else {
return None;
};
serde_json::to_value(body).ok()
}
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &ToolName) -> String {
match payload {
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),

View File

@@ -153,6 +153,123 @@ fn handler_looks_up_namespaced_aliases_explicitly() {
);
}
#[tokio::test]
async fn function_tools_expose_default_hook_payloads_and_rewrites() -> anyhow::Result<()> {
let (session, turn) = crate::session::tests::make_session_and_context().await;
let tool_name = codex_tools::ToolName::namespaced("functions.", "echo");
let handler = TestHandler {
tool_name: tool_name.clone(),
};
let invocation = ToolInvocation {
payload: ToolPayload::Function {
arguments: serde_json::json!({ "message": "hello" }).to_string(),
},
..test_invocation(Arc::new(session), Arc::new(turn), "call-1", tool_name)
};
let output =
crate::tools::context::FunctionToolOutput::from_text("echoed".to_string(), Some(true));
assert_eq!(
handler.pre_tool_use_payload(&invocation),
Some(PreToolUsePayload {
tool_name: HookToolName::new("functions.echo"),
tool_input: serde_json::json!({ "message": "hello" }),
})
);
assert_eq!(
handler.post_tool_use_payload(&invocation, &output),
Some(PostToolUsePayload {
tool_name: HookToolName::new("functions.echo"),
tool_use_id: "call-1".to_string(),
tool_input: serde_json::json!({ "message": "hello" }),
tool_response: serde_json::json!("echoed"),
})
);
let invocation = handler
.with_updated_hook_input(invocation, serde_json::json!({ "message": "rewritten" }))?;
let ToolPayload::Function { arguments } = invocation.payload else {
panic!("generic rewritten function payload should remain function-shaped");
};
assert_eq!(
serde_json::from_str::<serde_json::Value>(&arguments)?,
serde_json::json!({ "message": "rewritten" })
);
Ok(())
}
#[tokio::test]
async fn function_hook_input_defaults_empty_arguments_to_object() {
let (session, turn) = crate::session::tests::make_session_and_context().await;
let tool_name = codex_tools::ToolName::plain("echo");
let handler = TestHandler {
tool_name: tool_name.clone(),
};
let invocation = ToolInvocation {
payload: ToolPayload::Function {
arguments: " ".to_string(),
},
..test_invocation(Arc::new(session), Arc::new(turn), "call-1", tool_name)
};
assert_eq!(
handler.pre_tool_use_payload(&invocation),
Some(PreToolUsePayload {
tool_name: HookToolName::new("echo"),
tool_input: serde_json::json!({}),
})
);
}
#[test]
fn model_visible_override_keeps_code_mode_result_typed() {
let result = AnyToolResult {
call_id: "call-1".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
result: Box::new(codex_tools::JsonToolOutput::new(
serde_json::json!({ "typed": true }),
)),
model_visible_override: Some(crate::tools::context::FunctionToolOutput::from_text(
"hook feedback".to_string(),
/*success*/ None,
)),
post_tool_use_payload: None,
};
assert_eq!(
result.into_response(),
ResponseInputItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: codex_protocol::models::FunctionCallOutputPayload::from_text(
"hook feedback".to_string()
),
}
);
let result = AnyToolResult {
call_id: "call-1".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
result: Box::new(codex_tools::JsonToolOutput::new(
serde_json::json!({ "typed": true }),
)),
model_visible_override: Some(crate::tools::context::FunctionToolOutput::from_text(
"hook feedback".to_string(),
/*success*/ None,
)),
post_tool_use_payload: None,
};
assert_eq!(
result.code_mode_result(),
serde_json::json!({ "typed": true })
);
}
#[tokio::test]
async fn dispatch_notifies_tool_lifecycle_contributors() -> anyhow::Result<()> {
let (mut session, turn) = crate::session::tests::make_session_and_context().await;