mirror of
https://github.com/openai/codex.git
synced 2026-05-17 09:43:19 +00:00
Support PreToolUse updatedInput rewrites (#20527)
## Why
`PreToolUse` already exposes `updatedInput` in its hook output schema,
but Codex currently rejects it instead of applying the rewrite. That
leaves hook authors unable to make the documented pre-execution
adjustment to a tool call before it runs.
## What
- Accept `updatedInput` from `PreToolUse` hooks when paired with
`permissionDecision: "allow"`.
- Apply the rewritten input before dispatch so the tool executes the
updated payload, not the original one.
- Preserve the stable hook-facing compatibility shapes that
participating tool handlers expose:
- Bash-like tools (`shell`, `container.exec`, `local_shell`,
`shell_command`, `exec_command`) use `{ "command": ... }`.
- `apply_patch` exposes its patch body through the same command-shaped
hook contract.
- MCP tools expose their JSON argument object directly.
- Keep each participating tool handler responsible for translating
hook-facing `updatedInput` back into its concrete invocation shape.
## Verification
Direct Bash-like rewrite coverage:
- `pre_tool_use_rewrites_shell_before_execution`
- `pre_tool_use_rewrites_container_exec_before_execution`
- `pre_tool_use_rewrites_local_shell_before_execution`
- `pre_tool_use_rewrites_shell_command_before_execution`
- `pre_tool_use_rewrites_exec_command_before_execution`
These cases assert that each supported Bash-like surface runs only the
rewritten command while the hook still observes the original `{
"command": ... }` input.
`pre_tool_use_rewrites_apply_patch_before_execution`
- Model emits one patch.
- Hook swaps in a different patch.
- Asserts only the rewritten file is created, and the hook saw the
original patch.
`pre_tool_use_rewrites_code_mode_nested_exec_command_before_execution`
- Model runs one nested shell command from code mode.
- Hook rewrites it.
- Asserts only the rewritten command runs, and the hook saw the original
nested input.
`pre_tool_use_rewrites_mcp_tool_before_execution`
- Model calls the RMCP echo tool.
- Hook rewrites the MCP arguments.
- Asserts the MCP server receives and returns the rewritten message, not
the original one.
This commit is contained in:
@@ -44,6 +44,11 @@ pub(crate) struct HookRuntimeOutcome {
|
||||
pub additional_contexts: Vec<String>,
|
||||
}
|
||||
|
||||
pub(crate) enum PreToolUseHookResult {
|
||||
Continue { updated_input: Option<Value> },
|
||||
Blocked(String),
|
||||
}
|
||||
|
||||
pub(crate) enum PendingInputHookDisposition {
|
||||
Accepted(Box<PendingInputRecord>),
|
||||
Blocked { additional_contexts: Vec<String> },
|
||||
@@ -141,7 +146,7 @@ pub(crate) async fn run_pre_tool_use_hooks(
|
||||
tool_use_id: String,
|
||||
tool_name: &HookToolName,
|
||||
tool_input: &Value,
|
||||
) -> Option<String> {
|
||||
) -> PreToolUseHookResult {
|
||||
let request = PreToolUseRequest {
|
||||
session_id: sess.conversation_id,
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
@@ -163,25 +168,32 @@ pub(crate) async fn run_pre_tool_use_hooks(
|
||||
should_block,
|
||||
block_reason,
|
||||
additional_contexts,
|
||||
updated_input,
|
||||
} = hooks.run_pre_tool_use(request).await;
|
||||
emit_hook_completed_events(sess, turn_context, hook_events).await;
|
||||
record_additional_contexts(sess, turn_context, additional_contexts).await;
|
||||
|
||||
if should_block {
|
||||
block_reason.map(|reason| {
|
||||
if (tool_name.name() == "Bash" || tool_name.name() == "apply_patch")
|
||||
&& let Some(command) = tool_input.get("command").and_then(Value::as_str)
|
||||
{
|
||||
format!("Command blocked by PreToolUse hook: {reason}. Command: {command}")
|
||||
} else {
|
||||
format!(
|
||||
"Tool call blocked by PreToolUse hook: {reason}. Tool: {}",
|
||||
tool_name.name()
|
||||
)
|
||||
}
|
||||
})
|
||||
if !should_block {
|
||||
return PreToolUseHookResult::Continue { updated_input };
|
||||
}
|
||||
|
||||
let Some(reason) = block_reason else {
|
||||
return PreToolUseHookResult::Continue {
|
||||
updated_input: None,
|
||||
};
|
||||
};
|
||||
|
||||
if (tool_name.name() == "Bash" || tool_name.name() == "apply_patch")
|
||||
&& let Some(command) = tool_input.get("command").and_then(Value::as_str)
|
||||
{
|
||||
PreToolUseHookResult::Blocked(format!(
|
||||
"Command blocked by PreToolUse hook: {reason}. Command: {command}"
|
||||
))
|
||||
} else {
|
||||
None
|
||||
PreToolUseHookResult::Blocked(format!(
|
||||
"Tool call blocked by PreToolUse hook: {reason}. Tool: {}",
|
||||
tool_name.name()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::tools::events::ToolEventCtx;
|
||||
use crate::tools::handlers::apply_granted_turn_permissions;
|
||||
use crate::tools::handlers::apply_patch_spec::create_apply_patch_freeform_tool;
|
||||
use crate::tools::handlers::resolve_tool_environment;
|
||||
use crate::tools::handlers::updated_hook_command;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::orchestrator::ToolOrchestrator;
|
||||
use crate::tools::registry::PostToolUsePayload;
|
||||
@@ -325,6 +326,21 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
let patch = updated_hook_command(&updated_input)?;
|
||||
invocation.payload = match invocation.payload {
|
||||
ToolPayload::Custom { .. } => ToolPayload::Custom {
|
||||
input: patch.to_string(),
|
||||
},
|
||||
payload => payload,
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -15,6 +15,7 @@ use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolTelemetryTags;
|
||||
use codex_mcp::ToolInfo;
|
||||
use codex_tools::ToolName;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
|
||||
pub struct McpHandler {
|
||||
@@ -57,6 +58,28 @@ impl ToolHandler for McpHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
invocation.payload = match invocation.payload {
|
||||
ToolPayload::Function { .. } => ToolPayload::Function {
|
||||
arguments: serde_json::to_string(&updated_input).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize rewritten MCP arguments: {err}"
|
||||
))
|
||||
})?,
|
||||
},
|
||||
payload => {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"tool {} does not support hook input rewriting for payload {payload:?}",
|
||||
self.tool_name()
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
@@ -118,7 +141,7 @@ impl ToolHandler for McpHandler {
|
||||
|
||||
fn mcp_hook_tool_input(raw_arguments: &str) -> Value {
|
||||
if raw_arguments.trim().is_empty() {
|
||||
return Value::Object(serde_json::Map::new());
|
||||
return Value::Object(Map::new());
|
||||
}
|
||||
|
||||
serde_json::from_str(raw_arguments).unwrap_or_else(|_| Value::String(raw_arguments.to_string()))
|
||||
@@ -148,7 +171,6 @@ mod tests {
|
||||
};
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
let handler = McpHandler::new(tool_info("memory", "mcp__memory__", "create_entities"));
|
||||
|
||||
assert_eq!(
|
||||
handler.pre_tool_use_payload(&ToolInvocation {
|
||||
session: session.into(),
|
||||
@@ -172,6 +194,62 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_pre_tool_use_payload_keeps_builtin_like_tool_names_namespaced() {
|
||||
let payload = ToolPayload::Function {
|
||||
arguments: json!({ "message": "hello" }).to_string(),
|
||||
};
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
let handler = McpHandler::new(tool_info("foo", "mcp__foo__", "exec_command"));
|
||||
|
||||
assert_eq!(
|
||||
handler.pre_tool_use_payload(&ToolInvocation {
|
||||
session: session.into(),
|
||||
turn: turn.into(),
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
tracker: Arc::new(Mutex::new(TurnDiffTracker::new())),
|
||||
call_id: "call-mcp-pre-builtin-like".to_string(),
|
||||
tool_name: codex_tools::ToolName::namespaced("mcp__foo__", "exec_command"),
|
||||
source: ToolCallSource::Direct,
|
||||
payload,
|
||||
}),
|
||||
Some(PreToolUsePayload {
|
||||
tool_name: HookToolName::new("mcp__foo__exec_command"),
|
||||
tool_input: json!({ "message": "hello" }),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_updated_input_rewrites_builtin_like_tool_names_as_mcp() {
|
||||
let payload = ToolPayload::Function {
|
||||
arguments: json!({ "message": "hello" }).to_string(),
|
||||
};
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
let handler = McpHandler::new(tool_info("foo", "mcp__foo__", "exec_command"));
|
||||
|
||||
let invocation = handler
|
||||
.with_updated_hook_input(
|
||||
ToolInvocation {
|
||||
session: session.into(),
|
||||
turn: turn.into(),
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
tracker: Arc::new(Mutex::new(TurnDiffTracker::new())),
|
||||
call_id: "call-mcp-rewrite-builtin-like".to_string(),
|
||||
tool_name: codex_tools::ToolName::namespaced("mcp__foo__", "exec_command"),
|
||||
source: ToolCallSource::Direct,
|
||||
payload,
|
||||
},
|
||||
json!({ "message": "rewritten" }),
|
||||
)
|
||||
.expect("MCP rewrite should succeed");
|
||||
|
||||
let ToolPayload::Function { arguments } = invocation.payload else {
|
||||
panic!("builtin-like MCP tool should stay function-shaped");
|
||||
};
|
||||
assert_eq!(arguments, json!({ "message": "rewritten" }).to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_post_tool_use_payload_uses_model_tool_name_args_and_result() {
|
||||
let payload = ToolPayload::Function {
|
||||
|
||||
@@ -37,6 +37,7 @@ use codex_sandboxing::policy_transforms::normalize_additional_permissions;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use codex_utils_absolute_path::AbsolutePathBufGuard;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -76,7 +77,7 @@ pub(crate) use unified_exec::ExecCommandHandlerOptions;
|
||||
pub use unified_exec::WriteStdinHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
|
||||
fn parse_arguments<T>(arguments: &str) -> Result<T, FunctionCallError>
|
||||
pub(crate) fn parse_arguments<T>(arguments: &str) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
@@ -85,6 +86,47 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn updated_hook_command(updated_input: &Value) -> Result<&str, FunctionCallError> {
|
||||
updated_input
|
||||
.get("command")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"hook returned updatedInput without string field `command`".to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn rewrite_function_arguments(
|
||||
arguments: &str,
|
||||
tool_name: &str,
|
||||
rewrite: impl FnOnce(&mut Map<String, Value>),
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let mut arguments: Value = parse_arguments(arguments)?;
|
||||
let Value::Object(arguments) = &mut arguments else {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"{tool_name} arguments must be an object"
|
||||
)));
|
||||
};
|
||||
rewrite(arguments);
|
||||
serde_json::to_string(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize rewritten {tool_name} arguments: {err}"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn rewrite_function_string_argument(
|
||||
arguments: &str,
|
||||
tool_name: &str,
|
||||
field_name: &str,
|
||||
value: &str,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
rewrite_function_arguments(arguments, tool_name, |arguments| {
|
||||
arguments.insert(field_name.to_string(), Value::String(value.to_string()));
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_arguments_with_base_path<T>(
|
||||
arguments: &str,
|
||||
base_path: &AbsolutePathBuf,
|
||||
|
||||
@@ -19,6 +19,8 @@ use crate::tools::handlers::apply_patch::intercept_apply_patch;
|
||||
use crate::tools::handlers::implicit_granted_permissions;
|
||||
use crate::tools::handlers::normalize_and_validate_additional_permissions;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::handlers::rewrite_function_arguments;
|
||||
use crate::tools::handlers::updated_hook_command;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::orchestrator::ToolOrchestrator;
|
||||
use crate::tools::registry::PostToolUsePayload;
|
||||
@@ -93,6 +95,32 @@ fn shell_function_pre_tool_use_payload(invocation: &ToolInvocation) -> Option<Pr
|
||||
})
|
||||
}
|
||||
|
||||
fn rewrite_shell_function_updated_hook_input(
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: JsonValue,
|
||||
tool_name: &str,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
let ToolPayload::Function { arguments } = invocation.payload else {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"hook input rewrite received unsupported {tool_name} payload"
|
||||
)));
|
||||
};
|
||||
let command = shlex::split(updated_hook_command(&updated_input)?).ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"hook returned shell input with an invalid command string".to_string(),
|
||||
)
|
||||
})?;
|
||||
invocation.payload = ToolPayload::Function {
|
||||
arguments: rewrite_function_arguments(&arguments, tool_name, |arguments| {
|
||||
arguments.insert(
|
||||
"command".to_string(),
|
||||
JsonValue::Array(command.into_iter().map(JsonValue::String).collect()),
|
||||
);
|
||||
})?,
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
|
||||
fn shell_function_post_tool_use_payload(
|
||||
invocation: &ToolInvocation,
|
||||
result: &FunctionToolOutput,
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::runtimes::shell::ShellRuntimeBackend;
|
||||
|
||||
use super::RunExecLikeArgs;
|
||||
use super::rewrite_shell_function_updated_hook_input;
|
||||
use super::run_exec_like;
|
||||
use super::shell_function_post_tool_use_payload;
|
||||
use super::shell_function_pre_tool_use_payload;
|
||||
@@ -46,6 +47,14 @@ impl ToolHandler for ContainerExecHandler {
|
||||
shell_function_pre_tool_use_payload(invocation)
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
rewrite_shell_function_updated_hook_input(invocation, updated_input, "container.exec")
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::tools::context::FunctionToolOutput;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::updated_hook_command;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::registry::PostToolUsePayload;
|
||||
use crate::tools::registry::PreToolUsePayload;
|
||||
@@ -64,6 +65,26 @@ impl ToolHandler for LocalShellHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
let command = updated_hook_command(&updated_input)?;
|
||||
invocation.payload = match invocation.payload {
|
||||
ToolPayload::LocalShell { mut params } => {
|
||||
params.command = shlex::split(command).ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"hook returned shell input with an invalid command string".to_string(),
|
||||
)
|
||||
})?;
|
||||
ToolPayload::LocalShell { params }
|
||||
}
|
||||
payload => payload,
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::parse_arguments_with_base_path;
|
||||
use crate::tools::handlers::resolve_workdir_base_path;
|
||||
use crate::tools::handlers::rewrite_function_string_argument;
|
||||
use crate::tools::handlers::updated_hook_command;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::registry::PostToolUsePayload;
|
||||
use crate::tools::registry::PreToolUsePayload;
|
||||
@@ -175,6 +177,27 @@ impl ToolHandler for ShellCommandHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
let ToolPayload::Function { arguments } = invocation.payload else {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"hook input rewrite received unsupported shell_command payload".to_string(),
|
||||
));
|
||||
};
|
||||
invocation.payload = ToolPayload::Function {
|
||||
arguments: rewrite_function_string_argument(
|
||||
&arguments,
|
||||
"shell_command",
|
||||
"command",
|
||||
updated_hook_command(&updated_input)?,
|
||||
)?,
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -22,6 +22,7 @@ use codex_tools::ToolSpec;
|
||||
use super::super::shell_spec::ShellToolOptions;
|
||||
use super::super::shell_spec::create_shell_tool;
|
||||
use super::RunExecLikeArgs;
|
||||
use super::rewrite_shell_function_updated_hook_input;
|
||||
use super::run_exec_like;
|
||||
use super::shell_function_post_tool_use_payload;
|
||||
use super::shell_function_pre_tool_use_payload;
|
||||
@@ -95,6 +96,14 @@ impl ToolHandler for ShellHandler {
|
||||
shell_function_pre_tool_use_payload(invocation)
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
rewrite_shell_function_updated_hook_input(invocation, updated_input, "shell")
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -12,6 +12,8 @@ use crate::tools::handlers::normalize_and_validate_additional_permissions;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::handlers::parse_arguments_with_base_path;
|
||||
use crate::tools::handlers::resolve_tool_environment;
|
||||
use crate::tools::handlers::rewrite_function_string_argument;
|
||||
use crate::tools::handlers::updated_hook_command;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::registry::PostToolUsePayload;
|
||||
use crate::tools::registry::PreToolUsePayload;
|
||||
@@ -128,6 +130,27 @@ impl ToolHandler for ExecCommandHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
updated_input: serde_json::Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
let ToolPayload::Function { arguments } = invocation.payload else {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"hook input rewrite received unsupported exec_command payload".to_string(),
|
||||
));
|
||||
};
|
||||
invocation.payload = ToolPayload::Function {
|
||||
arguments: rewrite_function_string_argument(
|
||||
&arguments,
|
||||
"exec_command",
|
||||
"cmd",
|
||||
updated_hook_command(&updated_input)?,
|
||||
)?,
|
||||
};
|
||||
Ok(invocation)
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
invocation: &ToolInvocation,
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::time::Duration;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::goals::GoalRuntimeEvent;
|
||||
use crate::hook_runtime::PreToolUseHookResult;
|
||||
use crate::hook_runtime::record_additional_contexts;
|
||||
use crate::hook_runtime::run_post_tool_use_hooks;
|
||||
use crate::hook_runtime::run_pre_tool_use_hooks;
|
||||
@@ -73,10 +74,6 @@ pub trait ToolHandler: Send + Sync {
|
||||
async { false }
|
||||
}
|
||||
|
||||
fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
||||
None
|
||||
}
|
||||
|
||||
fn post_tool_use_payload(
|
||||
&self,
|
||||
_invocation: &ToolInvocation,
|
||||
@@ -85,6 +82,24 @@ pub trait ToolHandler: Send + Sync {
|
||||
None
|
||||
}
|
||||
|
||||
fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Rebuilds a tool invocation from hook-facing `tool_input`.
|
||||
///
|
||||
/// Tools that opt into input-rewriting hooks should invert the same stable
|
||||
/// hook contract they expose from `pre_tool_use_payload`.
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
_invocation: ToolInvocation,
|
||||
_updated_input: Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
Err(FunctionCallError::RespondToModel(
|
||||
"tool does not support hook input rewriting".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Creates an optional consumer for streamed tool argument diffs.
|
||||
fn create_diff_consumer(&self) -> Option<Box<dyn ToolArgumentDiffConsumer>> {
|
||||
None
|
||||
@@ -175,6 +190,12 @@ trait AnyToolHandler: Send + Sync {
|
||||
|
||||
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload>;
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
updated_input: Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError>;
|
||||
|
||||
fn telemetry_tags<'a>(
|
||||
&'a self,
|
||||
invocation: &'a ToolInvocation,
|
||||
@@ -207,6 +228,14 @@ where
|
||||
ToolHandler::pre_tool_use_payload(self, invocation)
|
||||
}
|
||||
|
||||
fn with_updated_hook_input(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
updated_input: Value,
|
||||
) -> Result<ToolInvocation, FunctionCallError> {
|
||||
ToolHandler::with_updated_hook_input(self, invocation, updated_input)
|
||||
}
|
||||
|
||||
fn telemetry_tags<'a>(
|
||||
&'a self,
|
||||
invocation: &'a ToolInvocation,
|
||||
@@ -286,14 +315,12 @@ impl ToolRegistry {
|
||||
)]
|
||||
pub(crate) async fn dispatch_any(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
mut invocation: ToolInvocation,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let tool_name_flat = flat_tool_name(&tool_name);
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.session_telemetry.clone();
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
let log_payload = payload_for_response.log_payload();
|
||||
let base_tool_result_tags = [
|
||||
(
|
||||
"sandbox",
|
||||
@@ -325,6 +352,7 @@ impl ToolRegistry {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
let message = unsupported_tool_call_message(&invocation.payload, &tool_name);
|
||||
let log_payload = invocation.payload.log_payload();
|
||||
otel.tool_result_with_tags(
|
||||
tool_name_flat.as_ref(),
|
||||
&call_id_owned,
|
||||
@@ -353,9 +381,9 @@ impl ToolRegistry {
|
||||
tool_result_tags.push((*key, value.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||
let log_payload = invocation.payload.log_payload();
|
||||
otel.tool_result_with_tags(
|
||||
tool_name_flat.as_ref(),
|
||||
&call_id_owned,
|
||||
@@ -371,8 +399,8 @@ impl ToolRegistry {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation)
|
||||
&& let Some(message) = run_pre_tool_use_hooks(
|
||||
if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation) {
|
||||
match run_pre_tool_use_hooks(
|
||||
&invocation.session,
|
||||
&invocation.turn,
|
||||
invocation.call_id.clone(),
|
||||
@@ -380,15 +408,27 @@ impl ToolRegistry {
|
||||
&pre_tool_use_payload.tool_input,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let err = FunctionCallError::RespondToModel(message);
|
||||
dispatch_trace.record_failed(&err);
|
||||
return Err(err);
|
||||
{
|
||||
PreToolUseHookResult::Blocked(message) => {
|
||||
let err = FunctionCallError::RespondToModel(message);
|
||||
dispatch_trace.record_failed(&err);
|
||||
return Err(err);
|
||||
}
|
||||
PreToolUseHookResult::Continue {
|
||||
updated_input: Some(updated_input),
|
||||
} => {
|
||||
invocation = handler.with_updated_hook_input(invocation, updated_input)?;
|
||||
}
|
||||
PreToolUseHookResult::Continue {
|
||||
updated_input: None,
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
|
||||
let is_mutating = handler.is_mutating(&invocation).await;
|
||||
let response_cell = tokio::sync::Mutex::new(None);
|
||||
let invocation_for_tool = invocation.clone();
|
||||
let log_payload = invocation.payload.log_payload();
|
||||
|
||||
let result = otel
|
||||
.log_tool_result_with_tags(
|
||||
|
||||
@@ -26,6 +26,7 @@ use core_test_support::managed_network_requirements_loader;
|
||||
use core_test_support::responses::ev_apply_patch_custom_tool_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_message_item_added;
|
||||
use core_test_support::responses::ev_output_text_delta;
|
||||
@@ -307,6 +308,54 @@ elif mode == "exit_2":
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_updating_pre_tool_use_hook(
|
||||
home: &Path,
|
||||
matcher: &str,
|
||||
updated_input: &Value,
|
||||
) -> Result<()> {
|
||||
let script_path = home.join("pre_tool_use_hook.py");
|
||||
let log_path = home.join("pre_tool_use_hook_log.jsonl");
|
||||
let updated_input_json =
|
||||
serde_json::to_string(updated_input).context("serialize updated pre tool input")?;
|
||||
let script = format!(
|
||||
r#"import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
payload = json.load(sys.stdin)
|
||||
|
||||
with Path(r"{log_path}").open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(payload) + "\n")
|
||||
|
||||
print(json.dumps({{
|
||||
"hookSpecificOutput": {{
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
"updatedInput": {updated_input_json}
|
||||
}}
|
||||
}}))
|
||||
"#,
|
||||
log_path = log_path.display(),
|
||||
updated_input_json = updated_input_json,
|
||||
);
|
||||
let hooks = serde_json::json!({
|
||||
"hooks": {
|
||||
"PreToolUse": [{
|
||||
"matcher": matcher,
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": format!("python3 {}", script_path.display()),
|
||||
"statusMessage": "rewriting pre tool input",
|
||||
}]
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
fs::write(&script_path, script).context("write updating pre tool use hook script")?;
|
||||
fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_pre_tool_use_hook_toml(
|
||||
home: &Path,
|
||||
script_name: &str,
|
||||
@@ -2081,6 +2130,274 @@ async fn blocked_pre_tool_use_records_additional_context_for_shell_command() ->
|
||||
!marker.exists(),
|
||||
"blocked command should not create marker file"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum BashRewriteSurface {
|
||||
ContainerExec,
|
||||
ExecCommand,
|
||||
LocalShell,
|
||||
Shell,
|
||||
ShellCommand,
|
||||
}
|
||||
|
||||
impl BashRewriteSurface {
|
||||
fn slug(self) -> &'static str {
|
||||
match self {
|
||||
BashRewriteSurface::ContainerExec => "container-exec",
|
||||
BashRewriteSurface::ExecCommand => "exec-command",
|
||||
BashRewriteSurface::LocalShell => "local-shell",
|
||||
BashRewriteSurface::Shell => "shell",
|
||||
BashRewriteSurface::ShellCommand => "shell-command",
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_call(self, call_id: &str, command: &[String], command_text: &str) -> Result<Value> {
|
||||
match self {
|
||||
BashRewriteSurface::ContainerExec => Ok(ev_function_call(
|
||||
call_id,
|
||||
"container.exec",
|
||||
&serde_json::to_string(&serde_json::json!({ "command": command }))?,
|
||||
)),
|
||||
BashRewriteSurface::ExecCommand => Ok(ev_function_call(
|
||||
call_id,
|
||||
"exec_command",
|
||||
&serde_json::to_string(&serde_json::json!({ "cmd": command_text }))?,
|
||||
)),
|
||||
BashRewriteSurface::LocalShell => {
|
||||
Ok(core_test_support::responses::ev_local_shell_call(
|
||||
call_id,
|
||||
"completed",
|
||||
command.iter().map(String::as_str).collect(),
|
||||
))
|
||||
}
|
||||
BashRewriteSurface::Shell => Ok(ev_function_call(
|
||||
call_id,
|
||||
"shell",
|
||||
&serde_json::to_string(&serde_json::json!({ "command": command }))?,
|
||||
)),
|
||||
BashRewriteSurface::ShellCommand => Ok(ev_function_call(
|
||||
call_id,
|
||||
"shell_command",
|
||||
&serde_json::to_string(&serde_json::json!({ "command": command_text }))?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn original_command(self, marker: &Path) -> (Vec<String>, String) {
|
||||
let command_text = format!("printf original > {}", marker.display());
|
||||
match self {
|
||||
BashRewriteSurface::ContainerExec
|
||||
| BashRewriteSurface::LocalShell
|
||||
| BashRewriteSurface::Shell => {
|
||||
let command = vec!["/bin/sh".to_string(), "-c".to_string(), command_text];
|
||||
let command_text = codex_shell_command::parse_command::shlex_join(&command);
|
||||
(command, command_text)
|
||||
}
|
||||
BashRewriteSurface::ExecCommand | BashRewriteSurface::ShellCommand => {
|
||||
(Vec::new(), command_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rewritten_command(self, marker: &Path) -> String {
|
||||
let command_text = format!("printf rewritten > {}", marker.display());
|
||||
match self {
|
||||
BashRewriteSurface::ContainerExec
|
||||
| BashRewriteSurface::LocalShell
|
||||
| BashRewriteSurface::Shell => codex_shell_command::parse_command::shlex_join(&[
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
command_text,
|
||||
]),
|
||||
BashRewriteSurface::ExecCommand | BashRewriteSurface::ShellCommand => command_text,
|
||||
}
|
||||
}
|
||||
|
||||
fn configure(self, config: &mut Config) {
|
||||
trust_discovered_hooks(config);
|
||||
if matches!(self, BashRewriteSurface::ExecCommand) {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
if let Err(error) = config.features.enable(Feature::UnifiedExec) {
|
||||
panic!("test config should allow feature update: {error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn assert_pre_tool_use_rewrites_bash_surface(surface: BashRewriteSurface) -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let slug = surface.slug();
|
||||
let call_id = format!("pretooluse-{slug}-rewrite");
|
||||
let original_marker = std::env::temp_dir().join(format!("pretooluse-{slug}-original-marker"));
|
||||
let rewritten_marker = std::env::temp_dir().join(format!("pretooluse-{slug}-rewritten-marker"));
|
||||
let (tool_command, original_command) = surface.original_command(&original_marker);
|
||||
let rewritten_command = surface.rewritten_command(&rewritten_marker);
|
||||
let responses = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
surface.tool_call(&call_id, &tool_command, &original_command)?,
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "hook rewrote it"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let updated_input = serde_json::json!({ "command": rewritten_command });
|
||||
let mut builder = test_codex()
|
||||
.with_pre_build_hook(move |home| {
|
||||
if let Err(error) = write_updating_pre_tool_use_hook(home, "^Bash$", &updated_input) {
|
||||
panic!("failed to write updating pre tool use hook fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(move |config| surface.configure(config));
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
if original_marker.exists() {
|
||||
fs::remove_file(&original_marker).context("remove stale original pre tool marker")?;
|
||||
}
|
||||
if rewritten_marker.exists() {
|
||||
fs::remove_file(&rewritten_marker).context("remove stale rewritten pre tool marker")?;
|
||||
}
|
||||
|
||||
test.submit_turn_with_permission_profile(
|
||||
&format!("run the rewritten {slug} command"),
|
||||
PermissionProfile::Disabled,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = responses.requests();
|
||||
assert_eq!(requests.len(), 2);
|
||||
requests[1].function_call_output(&call_id);
|
||||
assert!(
|
||||
!original_marker.exists(),
|
||||
"original {slug} command should not execute after rewrite"
|
||||
);
|
||||
assert_eq!(
|
||||
fs::read_to_string(&rewritten_marker).context("read rewritten pre tool marker")?,
|
||||
"rewritten"
|
||||
);
|
||||
|
||||
let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?;
|
||||
assert_eq!(hook_inputs.len(), 1);
|
||||
assert_eq!(hook_inputs[0]["tool_input"]["command"], original_command);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_shell_before_execution() -> Result<()> {
|
||||
assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::Shell).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_container_exec_before_execution() -> Result<()> {
|
||||
assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ContainerExec).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_local_shell_before_execution() -> Result<()> {
|
||||
assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::LocalShell).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_shell_command_before_execution() -> Result<()> {
|
||||
assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ShellCommand).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_exec_command_before_execution() -> Result<()> {
|
||||
assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ExecCommand).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_code_mode_nested_exec_command_before_execution() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "pretooluse-code-mode-rewrite";
|
||||
let original_marker = std::env::temp_dir().join("pretooluse-code-mode-original-marker");
|
||||
let rewritten_marker = std::env::temp_dir().join("pretooluse-code-mode-rewritten-marker");
|
||||
let original_command = format!("printf original > {}", original_marker.display());
|
||||
let rewritten_command = format!("printf rewritten > {}", rewritten_marker.display());
|
||||
let original_command_json =
|
||||
serde_json::to_string(&original_command).context("serialize original command")?;
|
||||
let code = format!(
|
||||
r#"
|
||||
const output = await tools.exec_command({{ cmd: {original_command_json} }});
|
||||
text(output.output);
|
||||
"#
|
||||
);
|
||||
let responses = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_custom_tool_call(call_id, "exec", &code),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "hook rewrote the nested command"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let updated_input = serde_json::json!({ "command": rewritten_command });
|
||||
let mut builder = test_codex()
|
||||
.with_model("test-gpt-5.1-codex")
|
||||
.with_pre_build_hook(move |home| {
|
||||
if let Err(error) = write_updating_pre_tool_use_hook(home, "^Bash$", &updated_input) {
|
||||
panic!("failed to write updating pre tool use hook fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(|config| {
|
||||
let _ = config.features.enable(Feature::CodeMode);
|
||||
trust_discovered_hooks(config);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
if original_marker.exists() {
|
||||
fs::remove_file(&original_marker).context("remove stale original pre tool marker")?;
|
||||
}
|
||||
if rewritten_marker.exists() {
|
||||
fs::remove_file(&rewritten_marker).context("remove stale rewritten pre tool marker")?;
|
||||
}
|
||||
|
||||
test.submit_turn_with_permission_profile(
|
||||
"run the rewritten shell command from code mode",
|
||||
PermissionProfile::Disabled,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = responses.requests();
|
||||
assert_eq!(requests.len(), 2);
|
||||
requests[1].custom_tool_call_output(call_id);
|
||||
assert!(
|
||||
!original_marker.exists(),
|
||||
"original nested shell command should not execute after rewrite"
|
||||
);
|
||||
assert_eq!(
|
||||
fs::read_to_string(&rewritten_marker)
|
||||
.context("read rewritten code mode pre tool marker")?,
|
||||
"rewritten"
|
||||
);
|
||||
|
||||
let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?;
|
||||
assert_eq!(hook_inputs.len(), 1);
|
||||
assert_eq!(hook_inputs[0]["tool_input"]["command"], original_command);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -2675,6 +2992,80 @@ async fn pre_tool_use_blocks_apply_patch_before_execution() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_rewrites_apply_patch_before_execution() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "pretooluse-apply-patch-rewrite";
|
||||
let original_file = "pre_tool_use_apply_patch_original.txt";
|
||||
let rewritten_file = "pre_tool_use_apply_patch_rewritten.txt";
|
||||
let original_patch = format!(
|
||||
r#"*** Begin Patch
|
||||
*** Add File: {original_file}
|
||||
+original
|
||||
*** End Patch"#
|
||||
);
|
||||
let rewritten_patch = format!(
|
||||
r#"*** Begin Patch
|
||||
*** Add File: {rewritten_file}
|
||||
+rewritten
|
||||
*** End Patch"#
|
||||
);
|
||||
let responses = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_custom_tool_call(call_id, &original_patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "apply_patch rewritten"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let updated_input = serde_json::json!({ "command": rewritten_patch });
|
||||
let mut builder = test_codex()
|
||||
.with_pre_build_hook(move |home| {
|
||||
if let Err(error) =
|
||||
write_updating_pre_tool_use_hook(home, "^apply_patch$", &updated_input)
|
||||
{
|
||||
panic!("failed to write updating pre tool use hook fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
trust_discovered_hooks(config);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("apply the rewritten patch").await?;
|
||||
|
||||
let requests = responses.requests();
|
||||
assert_eq!(requests.len(), 2);
|
||||
requests[1].custom_tool_call_output(call_id);
|
||||
assert!(
|
||||
!test.workspace_path(original_file).exists(),
|
||||
"original patch should not create its target file"
|
||||
);
|
||||
assert_eq!(
|
||||
fs::read_to_string(test.workspace_path(rewritten_file))
|
||||
.context("read rewritten apply_patch file")?,
|
||||
"rewritten\n"
|
||||
);
|
||||
|
||||
let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?;
|
||||
assert_eq!(hook_inputs.len(), 1);
|
||||
assert_eq!(hook_inputs[0]["tool_input"]["command"], original_patch);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_tool_use_blocks_apply_patch_with_write_alias() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -74,6 +74,50 @@ print(json.dumps({{
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_updating_pre_tool_use_hook(home: &Path, updated_message: &str) -> Result<()> {
|
||||
let script_path = home.join("pre_tool_use_hook.py");
|
||||
let log_path = home.join("pre_tool_use_hook_log.jsonl");
|
||||
let updated_message_json =
|
||||
serde_json::to_string(updated_message).context("serialize updated MCP message")?;
|
||||
let script = format!(
|
||||
r#"import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
payload = json.load(sys.stdin)
|
||||
|
||||
with Path(r"{log_path}").open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(payload) + "\n")
|
||||
|
||||
print(json.dumps({{
|
||||
"hookSpecificOutput": {{
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
"updatedInput": {{ "message": {updated_message_json} }}
|
||||
}}
|
||||
}}))
|
||||
"#,
|
||||
log_path = log_path.display(),
|
||||
updated_message_json = updated_message_json,
|
||||
);
|
||||
let hooks = serde_json::json!({
|
||||
"hooks": {
|
||||
"PreToolUse": [{
|
||||
"matcher": RMCP_HOOK_MATCHER,
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": format!("python3 {}", script_path.display()),
|
||||
"statusMessage": "rewriting MCP pre tool input",
|
||||
}]
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
fs::write(&script_path, script).context("write updating pre tool use hook script")?;
|
||||
fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_post_tool_use_hook(home: &Path, additional_context: &str) -> Result<()> {
|
||||
let script_path = home.join("post_tool_use_hook.py");
|
||||
let log_path = home.join("post_tool_use_hook_log.jsonl");
|
||||
@@ -249,6 +293,76 @@ async fn pre_tool_use_blocks_mcp_tool_before_execution() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pre_tool_use_rewrites_mcp_tool_before_execution() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "pretooluse-rmcp-echo-rewrite";
|
||||
let rewritten_message = "rewritten mcp hook input";
|
||||
let arguments = json!({ "message": RMCP_ECHO_MESSAGE }).to_string();
|
||||
let call_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call_with_namespace(call_id, RMCP_NAMESPACE, "echo", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let final_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "mcp pre hook rewrote it"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let test = test_codex()
|
||||
.with_pre_build_hook(move |home| {
|
||||
if let Err(error) = write_updating_pre_tool_use_hook(home, rewritten_message) {
|
||||
panic!("failed to write MCP updating pre tool use hook fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(move |config| {
|
||||
enable_hooks_and_rmcp_server(config, rmcp_test_server_bin, AppToolApproval::Approve);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
test.submit_turn("call the rmcp echo tool with the MCP pre hook rewrite")
|
||||
.await?;
|
||||
|
||||
let final_request = final_mock.single_request();
|
||||
let output_item = final_request.function_call_output(call_id);
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("MCP tool output string");
|
||||
assert!(
|
||||
output.contains(&format!("ECHOING: {rewritten_message}")),
|
||||
"MCP tool should execute the rewritten input",
|
||||
);
|
||||
assert!(
|
||||
!output.contains(RMCP_ECHO_MESSAGE),
|
||||
"MCP tool should not execute the original input",
|
||||
);
|
||||
|
||||
let hook_inputs = read_hook_inputs(test.codex_home_path(), "pre_tool_use_hook_log.jsonl")?;
|
||||
assert_eq!(hook_inputs.len(), 1);
|
||||
assert_eq!(
|
||||
hook_inputs[0]["tool_input"],
|
||||
json!({ "message": RMCP_ECHO_MESSAGE }),
|
||||
);
|
||||
|
||||
call_mock.single_request();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn post_tool_use_records_mcp_tool_payload_and_context() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::path::Path;
|
||||
|
||||
use futures::future::join_all;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
|
||||
use codex_protocol::protocol::HookCompletedEvent;
|
||||
use codex_protocol::protocol::HookEventName;
|
||||
@@ -20,6 +21,7 @@ use crate::events::common::matches_matcher;
|
||||
pub(crate) struct ParsedHandler<T> {
|
||||
pub completed: HookCompletedEvent,
|
||||
pub data: T,
|
||||
pub completion_order: usize,
|
||||
}
|
||||
|
||||
pub(crate) fn select_handlers(
|
||||
@@ -90,18 +92,25 @@ pub(crate) async fn execute_handlers<T>(
|
||||
turn_id: Option<String>,
|
||||
parse: fn(&ConfiguredHandler, CommandRunResult, Option<String>) -> ParsedHandler<T>,
|
||||
) -> Vec<ParsedHandler<T>> {
|
||||
let results = join_all(
|
||||
handlers
|
||||
.iter()
|
||||
.map(|handler| run_command(shell, handler, &input_json, cwd)),
|
||||
)
|
||||
.await;
|
||||
let mut pending = FuturesUnordered::new();
|
||||
for (configured_order, handler) in handlers.into_iter().enumerate() {
|
||||
let input_json = input_json.clone();
|
||||
let turn_id = turn_id.clone();
|
||||
pending.push(async move {
|
||||
let result = run_command(shell, &handler, &input_json, cwd).await;
|
||||
(configured_order, parse(&handler, result, turn_id))
|
||||
});
|
||||
}
|
||||
|
||||
handlers
|
||||
.into_iter()
|
||||
.zip(results)
|
||||
.map(|(handler, result)| parse(&handler, result, turn_id.clone()))
|
||||
.collect()
|
||||
let mut completed = Vec::new();
|
||||
let mut completion_order = 0;
|
||||
while let Some((configured_order, mut parsed)) = pending.next().await {
|
||||
parsed.completion_order = completion_order;
|
||||
completion_order += 1;
|
||||
completed.push((configured_order, parsed));
|
||||
}
|
||||
completed.sort_by_key(|(configured_order, _)| *configured_order);
|
||||
completed.into_iter().map(|(_, parsed)| parsed).collect()
|
||||
}
|
||||
|
||||
pub(crate) fn completed_summary(
|
||||
|
||||
@@ -17,6 +17,7 @@ pub(crate) struct PreToolUseOutput {
|
||||
pub universal: UniversalOutput,
|
||||
pub block_reason: Option<String>,
|
||||
pub additional_context: Option<String>,
|
||||
pub updated_input: Option<serde_json::Value>,
|
||||
pub invalid_reason: Option<String>,
|
||||
}
|
||||
|
||||
@@ -139,11 +140,24 @@ pub(crate) fn parse_pre_tool_use(stdout: &str) -> Option<PreToolUseOutput> {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let updated_input = if invalid_reason.is_none() {
|
||||
hook_specific_output.and_then(|output| {
|
||||
matches!(
|
||||
output.permission_decision,
|
||||
Some(PreToolUsePermissionDecisionWire::Allow)
|
||||
)
|
||||
.then(|| output.updated_input.clone())
|
||||
.flatten()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some(PreToolUseOutput {
|
||||
universal,
|
||||
block_reason,
|
||||
additional_context,
|
||||
updated_input,
|
||||
invalid_reason,
|
||||
})
|
||||
}
|
||||
@@ -377,12 +391,19 @@ fn unsupported_post_tool_use_hook_specific_output(
|
||||
fn unsupported_pre_tool_use_hook_specific_output(
|
||||
output: &crate::schema::PreToolUseHookSpecificOutputWire,
|
||||
) -> Option<String> {
|
||||
if output.updated_input.is_some() {
|
||||
Some("PreToolUse hook returned unsupported updatedInput".to_string())
|
||||
if output.updated_input.is_some()
|
||||
&& !matches!(
|
||||
output.permission_decision,
|
||||
Some(PreToolUsePermissionDecisionWire::Allow)
|
||||
)
|
||||
{
|
||||
Some("PreToolUse hook returned updatedInput without permissionDecision:allow".to_string())
|
||||
} else {
|
||||
match output.permission_decision {
|
||||
Some(PreToolUsePermissionDecisionWire::Allow) => {
|
||||
Some("PreToolUse hook returned unsupported permissionDecision:allow".to_string())
|
||||
output.updated_input.is_none().then(|| {
|
||||
"PreToolUse hook returned unsupported permissionDecision:allow".to_string()
|
||||
})
|
||||
}
|
||||
Some(PreToolUsePermissionDecisionWire::Ask) => {
|
||||
Some("PreToolUse hook returned unsupported permissionDecision:ask".to_string())
|
||||
|
||||
@@ -299,6 +299,7 @@ fn parse_pre_completed(
|
||||
should_stop,
|
||||
stop_reason,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,6 +402,7 @@ fn parse_completed(
|
||||
should_stop,
|
||||
stop_reason,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ fn parse_completed(
|
||||
dispatcher::ParsedHandler {
|
||||
completed,
|
||||
data: PermissionRequestHandlerData { decision },
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -298,6 +298,7 @@ fn parse_completed(
|
||||
additional_contexts_for_model,
|
||||
feedback_messages_for_model,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ pub struct PreToolUseOutcome {
|
||||
pub should_block: bool,
|
||||
pub block_reason: Option<String>,
|
||||
pub additional_contexts: Vec<String>,
|
||||
pub updated_input: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
@@ -45,6 +46,7 @@ struct PreToolUseHandlerData {
|
||||
should_block: bool,
|
||||
block_reason: Option<String>,
|
||||
additional_contexts_for_model: Vec<String>,
|
||||
updated_input: Option<Value>,
|
||||
}
|
||||
|
||||
pub(crate) fn preview(
|
||||
@@ -81,6 +83,7 @@ pub(crate) async fn run(
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts: Vec::new(),
|
||||
updated_input: None,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -116,6 +119,11 @@ pub(crate) async fn run(
|
||||
.iter()
|
||||
.map(|result| result.data.additional_contexts_for_model.as_slice()),
|
||||
);
|
||||
let updated_input = if should_block {
|
||||
None
|
||||
} else {
|
||||
latest_updated_input(&results)
|
||||
};
|
||||
|
||||
PreToolUseOutcome {
|
||||
hook_events: results
|
||||
@@ -127,9 +135,30 @@ pub(crate) async fn run(
|
||||
should_block,
|
||||
block_reason,
|
||||
additional_contexts,
|
||||
updated_input,
|
||||
}
|
||||
}
|
||||
|
||||
/// Chooses the rewrite from the hook that actually finished last.
|
||||
///
|
||||
/// Hook results stay in configured order for stable reporting, but the
|
||||
/// `PreToolUse` contract resolves competing rewrites by completion order.
|
||||
fn latest_updated_input(
|
||||
results: &[dispatcher::ParsedHandler<PreToolUseHandlerData>],
|
||||
) -> Option<Value> {
|
||||
results
|
||||
.iter()
|
||||
.filter_map(|result| {
|
||||
result
|
||||
.data
|
||||
.updated_input
|
||||
.clone()
|
||||
.map(|updated_input| (result.completion_order, updated_input))
|
||||
})
|
||||
.max_by_key(|(completion_order, _)| *completion_order)
|
||||
.map(|(_, updated_input)| updated_input)
|
||||
}
|
||||
|
||||
/// Serializes command stdin for a selected `PreToolUse` hook.
|
||||
///
|
||||
/// Handler selection may include internal matcher aliases, but hook stdin keeps
|
||||
@@ -161,6 +190,7 @@ fn parse_completed(
|
||||
let mut should_block = false;
|
||||
let mut block_reason = None;
|
||||
let mut additional_contexts_for_model = Vec::new();
|
||||
let mut updated_input = None;
|
||||
|
||||
match run_result.error.as_deref() {
|
||||
Some(error) => {
|
||||
@@ -204,6 +234,9 @@ fn parse_completed(
|
||||
text: reason,
|
||||
});
|
||||
}
|
||||
if !should_block {
|
||||
updated_input = parsed.updated_input;
|
||||
}
|
||||
}
|
||||
} else if output_parser::looks_like_json(&run_result.stdout) {
|
||||
status = HookRunStatus::Failed;
|
||||
@@ -258,7 +291,9 @@ fn parse_completed(
|
||||
should_block,
|
||||
block_reason,
|
||||
additional_contexts_for_model,
|
||||
updated_input,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,6 +303,7 @@ fn serialization_failure_outcome(hook_events: Vec<HookCompletedEvent>) -> PreToo
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,6 +320,7 @@ mod tests {
|
||||
|
||||
use super::PreToolUseHandlerData;
|
||||
use super::command_input_json;
|
||||
use super::latest_updated_input;
|
||||
use super::parse_completed;
|
||||
use super::preview;
|
||||
use crate::engine::ConfiguredHandler;
|
||||
@@ -320,6 +357,7 @@ mod tests {
|
||||
should_block: true,
|
||||
block_reason: Some("do not run that".to_string()),
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked);
|
||||
@@ -332,6 +370,91 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn permission_decision_allow_can_update_input() {
|
||||
let parsed = parse_completed(
|
||||
&handler(),
|
||||
run_result(
|
||||
Some(0),
|
||||
r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo rewritten"}}}"#,
|
||||
"",
|
||||
),
|
||||
Some("turn-1".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parsed.data,
|
||||
PreToolUseHandlerData {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: Some(serde_json::json!({ "command": "echo rewritten" })),
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Completed);
|
||||
assert_eq!(parsed.completed.run.entries, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_completed_updated_input_wins() {
|
||||
let mut later_configured = parse_completed(
|
||||
&handler(),
|
||||
run_result(
|
||||
Some(0),
|
||||
r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo configured later"}}}"#,
|
||||
"",
|
||||
),
|
||||
Some("turn-1".to_string()),
|
||||
);
|
||||
later_configured.completion_order = 0;
|
||||
let mut earlier_configured = parse_completed(
|
||||
&handler(),
|
||||
run_result(
|
||||
Some(0),
|
||||
r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo finished later"}}}"#,
|
||||
"",
|
||||
),
|
||||
Some("turn-1".to_string()),
|
||||
);
|
||||
earlier_configured.completion_order = 1;
|
||||
|
||||
assert_eq!(
|
||||
latest_updated_input(&[later_configured, earlier_configured]),
|
||||
Some(serde_json::json!({ "command": "echo finished later" }))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn permission_decision_allow_without_updated_input_fails_open() {
|
||||
let parsed = parse_completed(
|
||||
&handler(),
|
||||
run_result(
|
||||
Some(0),
|
||||
r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow"}}"#,
|
||||
"",
|
||||
),
|
||||
Some("turn-1".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parsed.data,
|
||||
PreToolUseHandlerData {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Failed);
|
||||
assert_eq!(
|
||||
parsed.completed.run.entries,
|
||||
vec![HookOutputEntry {
|
||||
kind: HookOutputEntryKind::Error,
|
||||
text: "PreToolUse hook returned unsupported permissionDecision:allow".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deprecated_block_decision_blocks_processing() {
|
||||
let parsed = parse_completed(
|
||||
@@ -350,6 +473,7 @@ mod tests {
|
||||
should_block: true,
|
||||
block_reason: Some("do not run that".to_string()),
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked);
|
||||
@@ -380,6 +504,7 @@ mod tests {
|
||||
should_block: true,
|
||||
block_reason: Some("do not run that".to_string()),
|
||||
additional_contexts_for_model: vec!["remember this".to_string()],
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked);
|
||||
@@ -416,6 +541,7 @@ mod tests {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Failed);
|
||||
@@ -442,6 +568,7 @@ mod tests {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Failed);
|
||||
@@ -472,6 +599,7 @@ mod tests {
|
||||
should_block: true,
|
||||
block_reason: Some("do not run that".to_string()),
|
||||
additional_contexts_for_model: vec!["nope".to_string()],
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked);
|
||||
@@ -504,6 +632,7 @@ mod tests {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Completed);
|
||||
@@ -524,6 +653,7 @@ mod tests {
|
||||
should_block: false,
|
||||
block_reason: None,
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Failed);
|
||||
@@ -550,6 +680,7 @@ mod tests {
|
||||
should_block: true,
|
||||
block_reason: Some("blocked by policy".to_string()),
|
||||
additional_contexts_for_model: Vec::new(),
|
||||
updated_input: None,
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked);
|
||||
|
||||
@@ -234,6 +234,7 @@ fn parse_completed(
|
||||
stop_reason,
|
||||
additional_contexts_for_model,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -259,6 +259,7 @@ fn parse_completed(
|
||||
block_reason,
|
||||
continuation_fragments,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -255,6 +255,7 @@ fn parse_completed(
|
||||
stop_reason,
|
||||
additional_contexts_for_model,
|
||||
},
|
||||
completion_order: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user