diff --git a/codex-rs/core/src/hook_runtime.rs b/codex-rs/core/src/hook_runtime.rs index 2bea729339..175813a544 100644 --- a/codex-rs/core/src/hook_runtime.rs +++ b/codex-rs/core/src/hook_runtime.rs @@ -44,6 +44,11 @@ pub(crate) struct HookRuntimeOutcome { pub additional_contexts: Vec, } +pub(crate) enum PreToolUseHookResult { + Continue { updated_input: Option }, + Blocked(String), +} + pub(crate) enum PendingInputHookDisposition { Accepted(Box), Blocked { additional_contexts: Vec }, @@ -141,7 +146,7 @@ pub(crate) async fn run_pre_tool_use_hooks( tool_use_id: String, tool_name: &HookToolName, tool_input: &Value, -) -> Option { +) -> PreToolUseHookResult { let request = PreToolUseRequest { session_id: sess.conversation_id, turn_id: turn_context.sub_id.clone(), @@ -163,25 +168,32 @@ pub(crate) async fn run_pre_tool_use_hooks( should_block, block_reason, additional_contexts, + updated_input, } = hooks.run_pre_tool_use(request).await; emit_hook_completed_events(sess, turn_context, hook_events).await; record_additional_contexts(sess, turn_context, additional_contexts).await; - if should_block { - block_reason.map(|reason| { - if (tool_name.name() == "Bash" || tool_name.name() == "apply_patch") - && let Some(command) = tool_input.get("command").and_then(Value::as_str) - { - format!("Command blocked by PreToolUse hook: {reason}. Command: {command}") - } else { - format!( - "Tool call blocked by PreToolUse hook: {reason}. Tool: {}", - tool_name.name() - ) - } - }) + if !should_block { + return PreToolUseHookResult::Continue { updated_input }; + } + + let Some(reason) = block_reason else { + return PreToolUseHookResult::Continue { + updated_input: None, + }; + }; + + if (tool_name.name() == "Bash" || tool_name.name() == "apply_patch") + && let Some(command) = tool_input.get("command").and_then(Value::as_str) + { + PreToolUseHookResult::Blocked(format!( + "Command blocked by PreToolUse hook: {reason}. Command: {command}" + )) } else { - None + PreToolUseHookResult::Blocked(format!( + "Tool call blocked by PreToolUse hook: {reason}. Tool: {}", + tool_name.name() + )) } } diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 211862aa2d..81498c8bea 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -24,6 +24,7 @@ use crate::tools::events::ToolEventCtx; use crate::tools::handlers::apply_granted_turn_permissions; use crate::tools::handlers::apply_patch_spec::create_apply_patch_freeform_tool; use crate::tools::handlers::resolve_tool_environment; +use crate::tools::handlers::updated_hook_command; use crate::tools::hook_names::HookToolName; use crate::tools::orchestrator::ToolOrchestrator; use crate::tools::registry::PostToolUsePayload; @@ -325,6 +326,21 @@ impl ToolHandler for ApplyPatchHandler { }) } + fn with_updated_hook_input( + &self, + mut invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + let patch = updated_hook_command(&updated_input)?; + invocation.payload = match invocation.payload { + ToolPayload::Custom { .. } => ToolPayload::Custom { + input: patch.to_string(), + }, + payload => payload, + }; + Ok(invocation) + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs index c9907f06c4..8bb8337337 100644 --- a/codex-rs/core/src/tools/handlers/mcp.rs +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -15,6 +15,7 @@ use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolTelemetryTags; use codex_mcp::ToolInfo; use codex_tools::ToolName; +use serde_json::Map; use serde_json::Value; pub struct McpHandler { @@ -57,6 +58,28 @@ impl ToolHandler for McpHandler { }) } + fn with_updated_hook_input( + &self, + mut invocation: ToolInvocation, + updated_input: Value, + ) -> Result { + invocation.payload = match invocation.payload { + ToolPayload::Function { .. } => ToolPayload::Function { + arguments: serde_json::to_string(&updated_input).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to serialize rewritten MCP arguments: {err}" + )) + })?, + }, + payload => { + return Err(FunctionCallError::RespondToModel(format!( + "tool {} does not support hook input rewriting for payload {payload:?}", + self.tool_name() + ))); + } + }; + Ok(invocation) + } fn post_tool_use_payload( &self, invocation: &ToolInvocation, @@ -118,7 +141,7 @@ impl ToolHandler for McpHandler { fn mcp_hook_tool_input(raw_arguments: &str) -> Value { if raw_arguments.trim().is_empty() { - return Value::Object(serde_json::Map::new()); + return Value::Object(Map::new()); } serde_json::from_str(raw_arguments).unwrap_or_else(|_| Value::String(raw_arguments.to_string())) @@ -148,7 +171,6 @@ mod tests { }; let (session, turn) = make_session_and_context().await; let handler = McpHandler::new(tool_info("memory", "mcp__memory__", "create_entities")); - assert_eq!( handler.pre_tool_use_payload(&ToolInvocation { session: session.into(), @@ -172,6 +194,62 @@ mod tests { ); } + #[tokio::test] + async fn mcp_pre_tool_use_payload_keeps_builtin_like_tool_names_namespaced() { + let payload = ToolPayload::Function { + arguments: json!({ "message": "hello" }).to_string(), + }; + let (session, turn) = make_session_and_context().await; + let handler = McpHandler::new(tool_info("foo", "mcp__foo__", "exec_command")); + + assert_eq!( + handler.pre_tool_use_payload(&ToolInvocation { + session: session.into(), + turn: turn.into(), + cancellation_token: tokio_util::sync::CancellationToken::new(), + tracker: Arc::new(Mutex::new(TurnDiffTracker::new())), + call_id: "call-mcp-pre-builtin-like".to_string(), + tool_name: codex_tools::ToolName::namespaced("mcp__foo__", "exec_command"), + source: ToolCallSource::Direct, + payload, + }), + Some(PreToolUsePayload { + tool_name: HookToolName::new("mcp__foo__exec_command"), + tool_input: json!({ "message": "hello" }), + }) + ); + } + + #[tokio::test] + async fn mcp_updated_input_rewrites_builtin_like_tool_names_as_mcp() { + let payload = ToolPayload::Function { + arguments: json!({ "message": "hello" }).to_string(), + }; + let (session, turn) = make_session_and_context().await; + let handler = McpHandler::new(tool_info("foo", "mcp__foo__", "exec_command")); + + let invocation = handler + .with_updated_hook_input( + ToolInvocation { + session: session.into(), + turn: turn.into(), + cancellation_token: tokio_util::sync::CancellationToken::new(), + tracker: Arc::new(Mutex::new(TurnDiffTracker::new())), + call_id: "call-mcp-rewrite-builtin-like".to_string(), + tool_name: codex_tools::ToolName::namespaced("mcp__foo__", "exec_command"), + source: ToolCallSource::Direct, + payload, + }, + json!({ "message": "rewritten" }), + ) + .expect("MCP rewrite should succeed"); + + let ToolPayload::Function { arguments } = invocation.payload else { + panic!("builtin-like MCP tool should stay function-shaped"); + }; + assert_eq!(arguments, json!({ "message": "rewritten" }).to_string()); + } + #[tokio::test] async fn mcp_post_tool_use_payload_uses_model_tool_name_args_and_result() { let payload = ToolPayload::Function { diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index 169a513e67..d689b0cf35 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -37,6 +37,7 @@ use codex_sandboxing::policy_transforms::normalize_additional_permissions; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_absolute_path::AbsolutePathBufGuard; use serde::Deserialize; +use serde_json::Map; use serde_json::Value; use std::path::Path; @@ -76,7 +77,7 @@ pub(crate) use unified_exec::ExecCommandHandlerOptions; pub use unified_exec::WriteStdinHandler; pub use view_image::ViewImageHandler; -fn parse_arguments(arguments: &str) -> Result +pub(crate) fn parse_arguments(arguments: &str) -> Result where T: for<'de> Deserialize<'de>, { @@ -85,6 +86,47 @@ where }) } +fn updated_hook_command(updated_input: &Value) -> Result<&str, FunctionCallError> { + updated_input + .get("command") + .and_then(Value::as_str) + .ok_or_else(|| { + FunctionCallError::RespondToModel( + "hook returned updatedInput without string field `command`".to_string(), + ) + }) +} + +fn rewrite_function_arguments( + arguments: &str, + tool_name: &str, + rewrite: impl FnOnce(&mut Map), +) -> Result { + let mut arguments: Value = parse_arguments(arguments)?; + let Value::Object(arguments) = &mut arguments else { + return Err(FunctionCallError::RespondToModel(format!( + "{tool_name} arguments must be an object" + ))); + }; + rewrite(arguments); + serde_json::to_string(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to serialize rewritten {tool_name} arguments: {err}" + )) + }) +} + +fn rewrite_function_string_argument( + arguments: &str, + tool_name: &str, + field_name: &str, + value: &str, +) -> Result { + rewrite_function_arguments(arguments, tool_name, |arguments| { + arguments.insert(field_name.to_string(), Value::String(value.to_string())); + }) +} + fn parse_arguments_with_base_path( arguments: &str, base_path: &AbsolutePathBuf, diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 8bffe13827..6f8ba4d632 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -19,6 +19,8 @@ use crate::tools::handlers::apply_patch::intercept_apply_patch; use crate::tools::handlers::implicit_granted_permissions; use crate::tools::handlers::normalize_and_validate_additional_permissions; use crate::tools::handlers::parse_arguments; +use crate::tools::handlers::rewrite_function_arguments; +use crate::tools::handlers::updated_hook_command; use crate::tools::hook_names::HookToolName; use crate::tools::orchestrator::ToolOrchestrator; use crate::tools::registry::PostToolUsePayload; @@ -93,6 +95,32 @@ fn shell_function_pre_tool_use_payload(invocation: &ToolInvocation) -> Option Result { + let ToolPayload::Function { arguments } = invocation.payload else { + return Err(FunctionCallError::RespondToModel(format!( + "hook input rewrite received unsupported {tool_name} payload" + ))); + }; + let command = shlex::split(updated_hook_command(&updated_input)?).ok_or_else(|| { + FunctionCallError::RespondToModel( + "hook returned shell input with an invalid command string".to_string(), + ) + })?; + invocation.payload = ToolPayload::Function { + arguments: rewrite_function_arguments(&arguments, tool_name, |arguments| { + arguments.insert( + "command".to_string(), + JsonValue::Array(command.into_iter().map(JsonValue::String).collect()), + ); + })?, + }; + Ok(invocation) +} + fn shell_function_post_tool_use_payload( invocation: &ToolInvocation, result: &FunctionToolOutput, diff --git a/codex-rs/core/src/tools/handlers/shell/container_exec.rs b/codex-rs/core/src/tools/handlers/shell/container_exec.rs index 969b34af1b..cb09a51501 100644 --- a/codex-rs/core/src/tools/handlers/shell/container_exec.rs +++ b/codex-rs/core/src/tools/handlers/shell/container_exec.rs @@ -14,6 +14,7 @@ use crate::tools::registry::ToolHandler; use crate::tools::runtimes::shell::ShellRuntimeBackend; use super::RunExecLikeArgs; +use super::rewrite_shell_function_updated_hook_input; use super::run_exec_like; use super::shell_function_post_tool_use_payload; use super::shell_function_pre_tool_use_payload; @@ -46,6 +47,14 @@ impl ToolHandler for ContainerExecHandler { shell_function_pre_tool_use_payload(invocation) } + fn with_updated_hook_input( + &self, + invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + rewrite_shell_function_updated_hook_input(invocation, updated_input, "container.exec") + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/handlers/shell/local_shell.rs b/codex-rs/core/src/tools/handlers/shell/local_shell.rs index c8b44951f1..6265be4285 100644 --- a/codex-rs/core/src/tools/handlers/shell/local_shell.rs +++ b/codex-rs/core/src/tools/handlers/shell/local_shell.rs @@ -6,6 +6,7 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; +use crate::tools::handlers::updated_hook_command; use crate::tools::hook_names::HookToolName; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; @@ -64,6 +65,26 @@ impl ToolHandler for LocalShellHandler { }) } + fn with_updated_hook_input( + &self, + mut invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + let command = updated_hook_command(&updated_input)?; + invocation.payload = match invocation.payload { + ToolPayload::LocalShell { mut params } => { + params.command = shlex::split(command).ok_or_else(|| { + FunctionCallError::RespondToModel( + "hook returned shell input with an invalid command string".to_string(), + ) + })?; + ToolPayload::LocalShell { params } + } + payload => payload, + }; + Ok(invocation) + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/handlers/shell/shell_command.rs b/codex-rs/core/src/tools/handlers/shell/shell_command.rs index 91238b40d9..f5ecd80664 100644 --- a/codex-rs/core/src/tools/handlers/shell/shell_command.rs +++ b/codex-rs/core/src/tools/handlers/shell/shell_command.rs @@ -17,6 +17,8 @@ use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments_with_base_path; use crate::tools::handlers::resolve_workdir_base_path; +use crate::tools::handlers::rewrite_function_string_argument; +use crate::tools::handlers::updated_hook_command; use crate::tools::hook_names::HookToolName; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; @@ -175,6 +177,27 @@ impl ToolHandler for ShellCommandHandler { }) } + fn with_updated_hook_input( + &self, + mut invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + let ToolPayload::Function { arguments } = invocation.payload else { + return Err(FunctionCallError::RespondToModel( + "hook input rewrite received unsupported shell_command payload".to_string(), + )); + }; + invocation.payload = ToolPayload::Function { + arguments: rewrite_function_string_argument( + &arguments, + "shell_command", + "command", + updated_hook_command(&updated_input)?, + )?, + }; + Ok(invocation) + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/handlers/shell/shell_handler.rs b/codex-rs/core/src/tools/handlers/shell/shell_handler.rs index 112fd22c39..39422734f5 100644 --- a/codex-rs/core/src/tools/handlers/shell/shell_handler.rs +++ b/codex-rs/core/src/tools/handlers/shell/shell_handler.rs @@ -22,6 +22,7 @@ use codex_tools::ToolSpec; use super::super::shell_spec::ShellToolOptions; use super::super::shell_spec::create_shell_tool; use super::RunExecLikeArgs; +use super::rewrite_shell_function_updated_hook_input; use super::run_exec_like; use super::shell_function_post_tool_use_payload; use super::shell_function_pre_tool_use_payload; @@ -95,6 +96,14 @@ impl ToolHandler for ShellHandler { shell_function_pre_tool_use_payload(invocation) } + fn with_updated_hook_input( + &self, + invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + rewrite_shell_function_updated_hook_input(invocation, updated_input, "shell") + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs b/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs index d18764a934..88f23762d6 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs @@ -12,6 +12,8 @@ use crate::tools::handlers::normalize_and_validate_additional_permissions; use crate::tools::handlers::parse_arguments; use crate::tools::handlers::parse_arguments_with_base_path; use crate::tools::handlers::resolve_tool_environment; +use crate::tools::handlers::rewrite_function_string_argument; +use crate::tools::handlers::updated_hook_command; use crate::tools::hook_names::HookToolName; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; @@ -128,6 +130,27 @@ impl ToolHandler for ExecCommandHandler { }) } + fn with_updated_hook_input( + &self, + mut invocation: ToolInvocation, + updated_input: serde_json::Value, + ) -> Result { + let ToolPayload::Function { arguments } = invocation.payload else { + return Err(FunctionCallError::RespondToModel( + "hook input rewrite received unsupported exec_command payload".to_string(), + )); + }; + invocation.payload = ToolPayload::Function { + arguments: rewrite_function_string_argument( + &arguments, + "exec_command", + "cmd", + updated_hook_command(&updated_input)?, + )?, + }; + Ok(invocation) + } + fn post_tool_use_payload( &self, invocation: &ToolInvocation, diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index f1f9db0dd2..25f2d11e9a 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -4,6 +4,7 @@ use std::time::Duration; use crate::function_tool::FunctionCallError; use crate::goals::GoalRuntimeEvent; +use crate::hook_runtime::PreToolUseHookResult; use crate::hook_runtime::record_additional_contexts; use crate::hook_runtime::run_post_tool_use_hooks; use crate::hook_runtime::run_pre_tool_use_hooks; @@ -73,10 +74,6 @@ pub trait ToolHandler: Send + Sync { async { false } } - fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option { - None - } - fn post_tool_use_payload( &self, _invocation: &ToolInvocation, @@ -85,6 +82,24 @@ pub trait ToolHandler: Send + Sync { None } + fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option { + None + } + + /// Rebuilds a tool invocation from hook-facing `tool_input`. + /// + /// Tools that opt into input-rewriting hooks should invert the same stable + /// hook contract they expose from `pre_tool_use_payload`. + fn with_updated_hook_input( + &self, + _invocation: ToolInvocation, + _updated_input: Value, + ) -> Result { + Err(FunctionCallError::RespondToModel( + "tool does not support hook input rewriting".to_string(), + )) + } + /// Creates an optional consumer for streamed tool argument diffs. fn create_diff_consumer(&self) -> Option> { None @@ -175,6 +190,12 @@ trait AnyToolHandler: Send + Sync { fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option; + fn with_updated_hook_input( + &self, + invocation: ToolInvocation, + updated_input: Value, + ) -> Result; + fn telemetry_tags<'a>( &'a self, invocation: &'a ToolInvocation, @@ -207,6 +228,14 @@ where ToolHandler::pre_tool_use_payload(self, invocation) } + fn with_updated_hook_input( + &self, + invocation: ToolInvocation, + updated_input: Value, + ) -> Result { + ToolHandler::with_updated_hook_input(self, invocation, updated_input) + } + fn telemetry_tags<'a>( &'a self, invocation: &'a ToolInvocation, @@ -286,14 +315,12 @@ impl ToolRegistry { )] pub(crate) async fn dispatch_any( &self, - invocation: ToolInvocation, + mut invocation: ToolInvocation, ) -> Result { let tool_name = invocation.tool_name.clone(); let tool_name_flat = flat_tool_name(&tool_name); let call_id_owned = invocation.call_id.clone(); let otel = invocation.turn.session_telemetry.clone(); - let payload_for_response = invocation.payload.clone(); - let log_payload = payload_for_response.log_payload(); let base_tool_result_tags = [ ( "sandbox", @@ -325,6 +352,7 @@ impl ToolRegistry { Some(handler) => handler, None => { let message = unsupported_tool_call_message(&invocation.payload, &tool_name); + let log_payload = invocation.payload.log_payload(); otel.tool_result_with_tags( tool_name_flat.as_ref(), &call_id_owned, @@ -353,9 +381,9 @@ impl ToolRegistry { tool_result_tags.push((*key, value.as_str())); } } - if !handler.matches_kind(&invocation.payload) { let message = format!("tool {tool_name} invoked with incompatible payload"); + let log_payload = invocation.payload.log_payload(); otel.tool_result_with_tags( tool_name_flat.as_ref(), &call_id_owned, @@ -371,8 +399,8 @@ impl ToolRegistry { return Err(err); } - if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation) - && let Some(message) = run_pre_tool_use_hooks( + if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation) { + match run_pre_tool_use_hooks( &invocation.session, &invocation.turn, invocation.call_id.clone(), @@ -380,15 +408,27 @@ impl ToolRegistry { &pre_tool_use_payload.tool_input, ) .await - { - let err = FunctionCallError::RespondToModel(message); - dispatch_trace.record_failed(&err); - return Err(err); + { + PreToolUseHookResult::Blocked(message) => { + let err = FunctionCallError::RespondToModel(message); + dispatch_trace.record_failed(&err); + return Err(err); + } + PreToolUseHookResult::Continue { + updated_input: Some(updated_input), + } => { + invocation = handler.with_updated_hook_input(invocation, updated_input)?; + } + PreToolUseHookResult::Continue { + updated_input: None, + } => {} + } } let is_mutating = handler.is_mutating(&invocation).await; let response_cell = tokio::sync::Mutex::new(None); let invocation_for_tool = invocation.clone(); + let log_payload = invocation.payload.log_payload(); let result = otel .log_tool_result_with_tags( diff --git a/codex-rs/core/tests/suite/hooks.rs b/codex-rs/core/tests/suite/hooks.rs index caf4642b0c..563719cec2 100644 --- a/codex-rs/core/tests/suite/hooks.rs +++ b/codex-rs/core/tests/suite/hooks.rs @@ -26,6 +26,7 @@ use core_test_support::managed_network_requirements_loader; use core_test_support::responses::ev_apply_patch_custom_tool_call; use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_custom_tool_call; use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_message_item_added; use core_test_support::responses::ev_output_text_delta; @@ -307,6 +308,54 @@ elif mode == "exit_2": Ok(()) } +fn write_updating_pre_tool_use_hook( + home: &Path, + matcher: &str, + updated_input: &Value, +) -> Result<()> { + let script_path = home.join("pre_tool_use_hook.py"); + let log_path = home.join("pre_tool_use_hook_log.jsonl"); + let updated_input_json = + serde_json::to_string(updated_input).context("serialize updated pre tool input")?; + let script = format!( + r#"import json +from pathlib import Path +import sys + +payload = json.load(sys.stdin) + +with Path(r"{log_path}").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload) + "\n") + +print(json.dumps({{ + "hookSpecificOutput": {{ + "hookEventName": "PreToolUse", + "permissionDecision": "allow", + "updatedInput": {updated_input_json} + }} +}})) +"#, + log_path = log_path.display(), + updated_input_json = updated_input_json, + ); + let hooks = serde_json::json!({ + "hooks": { + "PreToolUse": [{ + "matcher": matcher, + "hooks": [{ + "type": "command", + "command": format!("python3 {}", script_path.display()), + "statusMessage": "rewriting pre tool input", + }] + }] + } + }); + + fs::write(&script_path, script).context("write updating pre tool use hook script")?; + fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?; + Ok(()) +} + fn write_pre_tool_use_hook_toml( home: &Path, script_name: &str, @@ -2081,6 +2130,274 @@ async fn blocked_pre_tool_use_records_additional_context_for_shell_command() -> !marker.exists(), "blocked command should not create marker file" ); + Ok(()) +} + +#[derive(Clone, Copy)] +enum BashRewriteSurface { + ContainerExec, + ExecCommand, + LocalShell, + Shell, + ShellCommand, +} + +impl BashRewriteSurface { + fn slug(self) -> &'static str { + match self { + BashRewriteSurface::ContainerExec => "container-exec", + BashRewriteSurface::ExecCommand => "exec-command", + BashRewriteSurface::LocalShell => "local-shell", + BashRewriteSurface::Shell => "shell", + BashRewriteSurface::ShellCommand => "shell-command", + } + } + + fn tool_call(self, call_id: &str, command: &[String], command_text: &str) -> Result { + match self { + BashRewriteSurface::ContainerExec => Ok(ev_function_call( + call_id, + "container.exec", + &serde_json::to_string(&serde_json::json!({ "command": command }))?, + )), + BashRewriteSurface::ExecCommand => Ok(ev_function_call( + call_id, + "exec_command", + &serde_json::to_string(&serde_json::json!({ "cmd": command_text }))?, + )), + BashRewriteSurface::LocalShell => { + Ok(core_test_support::responses::ev_local_shell_call( + call_id, + "completed", + command.iter().map(String::as_str).collect(), + )) + } + BashRewriteSurface::Shell => Ok(ev_function_call( + call_id, + "shell", + &serde_json::to_string(&serde_json::json!({ "command": command }))?, + )), + BashRewriteSurface::ShellCommand => Ok(ev_function_call( + call_id, + "shell_command", + &serde_json::to_string(&serde_json::json!({ "command": command_text }))?, + )), + } + } + + fn original_command(self, marker: &Path) -> (Vec, String) { + let command_text = format!("printf original > {}", marker.display()); + match self { + BashRewriteSurface::ContainerExec + | BashRewriteSurface::LocalShell + | BashRewriteSurface::Shell => { + let command = vec!["/bin/sh".to_string(), "-c".to_string(), command_text]; + let command_text = codex_shell_command::parse_command::shlex_join(&command); + (command, command_text) + } + BashRewriteSurface::ExecCommand | BashRewriteSurface::ShellCommand => { + (Vec::new(), command_text) + } + } + } + + fn rewritten_command(self, marker: &Path) -> String { + let command_text = format!("printf rewritten > {}", marker.display()); + match self { + BashRewriteSurface::ContainerExec + | BashRewriteSurface::LocalShell + | BashRewriteSurface::Shell => codex_shell_command::parse_command::shlex_join(&[ + "/bin/sh".to_string(), + "-c".to_string(), + command_text, + ]), + BashRewriteSurface::ExecCommand | BashRewriteSurface::ShellCommand => command_text, + } + } + + fn configure(self, config: &mut Config) { + trust_discovered_hooks(config); + if matches!(self, BashRewriteSurface::ExecCommand) { + config.use_experimental_unified_exec_tool = true; + if let Err(error) = config.features.enable(Feature::UnifiedExec) { + panic!("test config should allow feature update: {error}"); + } + } + } +} + +async fn assert_pre_tool_use_rewrites_bash_surface(surface: BashRewriteSurface) -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let slug = surface.slug(); + let call_id = format!("pretooluse-{slug}-rewrite"); + let original_marker = std::env::temp_dir().join(format!("pretooluse-{slug}-original-marker")); + let rewritten_marker = std::env::temp_dir().join(format!("pretooluse-{slug}-rewritten-marker")); + let (tool_command, original_command) = surface.original_command(&original_marker); + let rewritten_command = surface.rewritten_command(&rewritten_marker); + let responses = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + surface.tool_call(&call_id, &tool_command, &original_command)?, + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "hook rewrote it"), + ev_completed("resp-2"), + ]), + ], + ) + .await; + + let updated_input = serde_json::json!({ "command": rewritten_command }); + let mut builder = test_codex() + .with_pre_build_hook(move |home| { + if let Err(error) = write_updating_pre_tool_use_hook(home, "^Bash$", &updated_input) { + panic!("failed to write updating pre tool use hook fixture: {error}"); + } + }) + .with_config(move |config| surface.configure(config)); + let test = builder.build(&server).await?; + + if original_marker.exists() { + fs::remove_file(&original_marker).context("remove stale original pre tool marker")?; + } + if rewritten_marker.exists() { + fs::remove_file(&rewritten_marker).context("remove stale rewritten pre tool marker")?; + } + + test.submit_turn_with_permission_profile( + &format!("run the rewritten {slug} command"), + PermissionProfile::Disabled, + ) + .await?; + + let requests = responses.requests(); + assert_eq!(requests.len(), 2); + requests[1].function_call_output(&call_id); + assert!( + !original_marker.exists(), + "original {slug} command should not execute after rewrite" + ); + assert_eq!( + fs::read_to_string(&rewritten_marker).context("read rewritten pre tool marker")?, + "rewritten" + ); + + let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?; + assert_eq!(hook_inputs.len(), 1); + assert_eq!(hook_inputs[0]["tool_input"]["command"], original_command); + + Ok(()) +} + +#[tokio::test] +async fn pre_tool_use_rewrites_shell_before_execution() -> Result<()> { + assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::Shell).await +} + +#[tokio::test] +async fn pre_tool_use_rewrites_container_exec_before_execution() -> Result<()> { + assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ContainerExec).await +} + +#[tokio::test] +async fn pre_tool_use_rewrites_local_shell_before_execution() -> Result<()> { + assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::LocalShell).await +} + +#[tokio::test] +async fn pre_tool_use_rewrites_shell_command_before_execution() -> Result<()> { + assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ShellCommand).await +} + +#[tokio::test] +async fn pre_tool_use_rewrites_exec_command_before_execution() -> Result<()> { + assert_pre_tool_use_rewrites_bash_surface(BashRewriteSurface::ExecCommand).await +} + +#[tokio::test] +async fn pre_tool_use_rewrites_code_mode_nested_exec_command_before_execution() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let call_id = "pretooluse-code-mode-rewrite"; + let original_marker = std::env::temp_dir().join("pretooluse-code-mode-original-marker"); + let rewritten_marker = std::env::temp_dir().join("pretooluse-code-mode-rewritten-marker"); + let original_command = format!("printf original > {}", original_marker.display()); + let rewritten_command = format!("printf rewritten > {}", rewritten_marker.display()); + let original_command_json = + serde_json::to_string(&original_command).context("serialize original command")?; + let code = format!( + r#" +const output = await tools.exec_command({{ cmd: {original_command_json} }}); +text(output.output); +"# + ); + let responses = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call(call_id, "exec", &code), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "hook rewrote the nested command"), + ev_completed("resp-2"), + ]), + ], + ) + .await; + + let updated_input = serde_json::json!({ "command": rewritten_command }); + let mut builder = test_codex() + .with_model("test-gpt-5.1-codex") + .with_pre_build_hook(move |home| { + if let Err(error) = write_updating_pre_tool_use_hook(home, "^Bash$", &updated_input) { + panic!("failed to write updating pre tool use hook fixture: {error}"); + } + }) + .with_config(|config| { + let _ = config.features.enable(Feature::CodeMode); + trust_discovered_hooks(config); + }); + let test = builder.build(&server).await?; + + if original_marker.exists() { + fs::remove_file(&original_marker).context("remove stale original pre tool marker")?; + } + if rewritten_marker.exists() { + fs::remove_file(&rewritten_marker).context("remove stale rewritten pre tool marker")?; + } + + test.submit_turn_with_permission_profile( + "run the rewritten shell command from code mode", + PermissionProfile::Disabled, + ) + .await?; + + let requests = responses.requests(); + assert_eq!(requests.len(), 2); + requests[1].custom_tool_call_output(call_id); + assert!( + !original_marker.exists(), + "original nested shell command should not execute after rewrite" + ); + assert_eq!( + fs::read_to_string(&rewritten_marker) + .context("read rewritten code mode pre tool marker")?, + "rewritten" + ); + + let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?; + assert_eq!(hook_inputs.len(), 1); + assert_eq!(hook_inputs[0]["tool_input"]["command"], original_command); Ok(()) } @@ -2675,6 +2992,80 @@ async fn pre_tool_use_blocks_apply_patch_before_execution() -> Result<()> { Ok(()) } +#[tokio::test] +async fn pre_tool_use_rewrites_apply_patch_before_execution() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let call_id = "pretooluse-apply-patch-rewrite"; + let original_file = "pre_tool_use_apply_patch_original.txt"; + let rewritten_file = "pre_tool_use_apply_patch_rewritten.txt"; + let original_patch = format!( + r#"*** Begin Patch +*** Add File: {original_file} ++original +*** End Patch"# + ); + let rewritten_patch = format!( + r#"*** Begin Patch +*** Add File: {rewritten_file} ++rewritten +*** End Patch"# + ); + let responses = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_apply_patch_custom_tool_call(call_id, &original_patch), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "apply_patch rewritten"), + ev_completed("resp-2"), + ]), + ], + ) + .await; + + let updated_input = serde_json::json!({ "command": rewritten_patch }); + let mut builder = test_codex() + .with_pre_build_hook(move |home| { + if let Err(error) = + write_updating_pre_tool_use_hook(home, "^apply_patch$", &updated_input) + { + panic!("failed to write updating pre tool use hook fixture: {error}"); + } + }) + .with_config(|config| { + config.include_apply_patch_tool = true; + trust_discovered_hooks(config); + }); + let test = builder.build(&server).await?; + + test.submit_turn("apply the rewritten patch").await?; + + let requests = responses.requests(); + assert_eq!(requests.len(), 2); + requests[1].custom_tool_call_output(call_id); + assert!( + !test.workspace_path(original_file).exists(), + "original patch should not create its target file" + ); + assert_eq!( + fs::read_to_string(test.workspace_path(rewritten_file)) + .context("read rewritten apply_patch file")?, + "rewritten\n" + ); + + let hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?; + assert_eq!(hook_inputs.len(), 1); + assert_eq!(hook_inputs[0]["tool_input"]["command"], original_patch); + + Ok(()) +} + #[tokio::test] async fn pre_tool_use_blocks_apply_patch_with_write_alias() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/core/tests/suite/hooks_mcp.rs b/codex-rs/core/tests/suite/hooks_mcp.rs index 26e3053189..96ed732695 100644 --- a/codex-rs/core/tests/suite/hooks_mcp.rs +++ b/codex-rs/core/tests/suite/hooks_mcp.rs @@ -74,6 +74,50 @@ print(json.dumps({{ Ok(()) } +fn write_updating_pre_tool_use_hook(home: &Path, updated_message: &str) -> Result<()> { + let script_path = home.join("pre_tool_use_hook.py"); + let log_path = home.join("pre_tool_use_hook_log.jsonl"); + let updated_message_json = + serde_json::to_string(updated_message).context("serialize updated MCP message")?; + let script = format!( + r#"import json +from pathlib import Path +import sys + +payload = json.load(sys.stdin) + +with Path(r"{log_path}").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload) + "\n") + +print(json.dumps({{ + "hookSpecificOutput": {{ + "hookEventName": "PreToolUse", + "permissionDecision": "allow", + "updatedInput": {{ "message": {updated_message_json} }} + }} +}})) +"#, + log_path = log_path.display(), + updated_message_json = updated_message_json, + ); + let hooks = serde_json::json!({ + "hooks": { + "PreToolUse": [{ + "matcher": RMCP_HOOK_MATCHER, + "hooks": [{ + "type": "command", + "command": format!("python3 {}", script_path.display()), + "statusMessage": "rewriting MCP pre tool input", + }] + }] + } + }); + + fs::write(&script_path, script).context("write updating pre tool use hook script")?; + fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?; + Ok(()) +} + fn write_post_tool_use_hook(home: &Path, additional_context: &str) -> Result<()> { let script_path = home.join("post_tool_use_hook.py"); let log_path = home.join("post_tool_use_hook_log.jsonl"); @@ -249,6 +293,76 @@ async fn pre_tool_use_blocks_mcp_tool_before_execution() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pre_tool_use_rewrites_mcp_tool_before_execution() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let call_id = "pretooluse-rmcp-echo-rewrite"; + let rewritten_message = "rewritten mcp hook input"; + let arguments = json!({ "message": RMCP_ECHO_MESSAGE }).to_string(); + let call_mock = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_function_call_with_namespace(call_id, RMCP_NAMESPACE, "echo", &arguments), + ev_completed("resp-1"), + ]), + ) + .await; + let final_mock = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "mcp pre hook rewrote it"), + ev_completed("resp-2"), + ]), + ) + .await; + + let rmcp_test_server_bin = stdio_server_bin()?; + let test = test_codex() + .with_pre_build_hook(move |home| { + if let Err(error) = write_updating_pre_tool_use_hook(home, rewritten_message) { + panic!("failed to write MCP updating pre tool use hook fixture: {error}"); + } + }) + .with_config(move |config| { + enable_hooks_and_rmcp_server(config, rmcp_test_server_bin, AppToolApproval::Approve); + }) + .build(&server) + .await?; + + test.submit_turn("call the rmcp echo tool with the MCP pre hook rewrite") + .await?; + + let final_request = final_mock.single_request(); + let output_item = final_request.function_call_output(call_id); + let output = output_item + .get("output") + .and_then(Value::as_str) + .expect("MCP tool output string"); + assert!( + output.contains(&format!("ECHOING: {rewritten_message}")), + "MCP tool should execute the rewritten input", + ); + assert!( + !output.contains(RMCP_ECHO_MESSAGE), + "MCP tool should not execute the original input", + ); + + let hook_inputs = read_hook_inputs(test.codex_home_path(), "pre_tool_use_hook_log.jsonl")?; + assert_eq!(hook_inputs.len(), 1); + assert_eq!( + hook_inputs[0]["tool_input"], + json!({ "message": RMCP_ECHO_MESSAGE }), + ); + + call_mock.single_request(); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn post_tool_use_records_mcp_tool_payload_and_context() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/hooks/src/engine/dispatcher.rs b/codex-rs/hooks/src/engine/dispatcher.rs index 9c71bf49b5..7243fc6eab 100644 --- a/codex-rs/hooks/src/engine/dispatcher.rs +++ b/codex-rs/hooks/src/engine/dispatcher.rs @@ -1,6 +1,7 @@ use std::path::Path; -use futures::future::join_all; +use futures::StreamExt; +use futures::stream::FuturesUnordered; use codex_protocol::protocol::HookCompletedEvent; use codex_protocol::protocol::HookEventName; @@ -20,6 +21,7 @@ use crate::events::common::matches_matcher; pub(crate) struct ParsedHandler { pub completed: HookCompletedEvent, pub data: T, + pub completion_order: usize, } pub(crate) fn select_handlers( @@ -90,18 +92,25 @@ pub(crate) async fn execute_handlers( turn_id: Option, parse: fn(&ConfiguredHandler, CommandRunResult, Option) -> ParsedHandler, ) -> Vec> { - let results = join_all( - handlers - .iter() - .map(|handler| run_command(shell, handler, &input_json, cwd)), - ) - .await; + let mut pending = FuturesUnordered::new(); + for (configured_order, handler) in handlers.into_iter().enumerate() { + let input_json = input_json.clone(); + let turn_id = turn_id.clone(); + pending.push(async move { + let result = run_command(shell, &handler, &input_json, cwd).await; + (configured_order, parse(&handler, result, turn_id)) + }); + } - handlers - .into_iter() - .zip(results) - .map(|(handler, result)| parse(&handler, result, turn_id.clone())) - .collect() + let mut completed = Vec::new(); + let mut completion_order = 0; + while let Some((configured_order, mut parsed)) = pending.next().await { + parsed.completion_order = completion_order; + completion_order += 1; + completed.push((configured_order, parsed)); + } + completed.sort_by_key(|(configured_order, _)| *configured_order); + completed.into_iter().map(|(_, parsed)| parsed).collect() } pub(crate) fn completed_summary( diff --git a/codex-rs/hooks/src/engine/output_parser.rs b/codex-rs/hooks/src/engine/output_parser.rs index 3bccb10127..48ad1aff2d 100644 --- a/codex-rs/hooks/src/engine/output_parser.rs +++ b/codex-rs/hooks/src/engine/output_parser.rs @@ -17,6 +17,7 @@ pub(crate) struct PreToolUseOutput { pub universal: UniversalOutput, pub block_reason: Option, pub additional_context: Option, + pub updated_input: Option, pub invalid_reason: Option, } @@ -139,11 +140,24 @@ pub(crate) fn parse_pre_tool_use(stdout: &str) -> Option { } else { None }; + let updated_input = if invalid_reason.is_none() { + hook_specific_output.and_then(|output| { + matches!( + output.permission_decision, + Some(PreToolUsePermissionDecisionWire::Allow) + ) + .then(|| output.updated_input.clone()) + .flatten() + }) + } else { + None + }; Some(PreToolUseOutput { universal, block_reason, additional_context, + updated_input, invalid_reason, }) } @@ -377,12 +391,19 @@ fn unsupported_post_tool_use_hook_specific_output( fn unsupported_pre_tool_use_hook_specific_output( output: &crate::schema::PreToolUseHookSpecificOutputWire, ) -> Option { - if output.updated_input.is_some() { - Some("PreToolUse hook returned unsupported updatedInput".to_string()) + if output.updated_input.is_some() + && !matches!( + output.permission_decision, + Some(PreToolUsePermissionDecisionWire::Allow) + ) + { + Some("PreToolUse hook returned updatedInput without permissionDecision:allow".to_string()) } else { match output.permission_decision { Some(PreToolUsePermissionDecisionWire::Allow) => { - Some("PreToolUse hook returned unsupported permissionDecision:allow".to_string()) + output.updated_input.is_none().then(|| { + "PreToolUse hook returned unsupported permissionDecision:allow".to_string() + }) } Some(PreToolUsePermissionDecisionWire::Ask) => { Some("PreToolUse hook returned unsupported permissionDecision:ask".to_string()) diff --git a/codex-rs/hooks/src/events/compact.rs b/codex-rs/hooks/src/events/compact.rs index 67c13c34eb..469fdda232 100644 --- a/codex-rs/hooks/src/events/compact.rs +++ b/codex-rs/hooks/src/events/compact.rs @@ -299,6 +299,7 @@ fn parse_pre_completed( should_stop, stop_reason, }, + completion_order: 0, } } @@ -401,6 +402,7 @@ fn parse_completed( should_stop, stop_reason, }, + completion_order: 0, } } diff --git a/codex-rs/hooks/src/events/permission_request.rs b/codex-rs/hooks/src/events/permission_request.rs index 11ab4d2e47..79d0608236 100644 --- a/codex-rs/hooks/src/events/permission_request.rs +++ b/codex-rs/hooks/src/events/permission_request.rs @@ -281,6 +281,7 @@ fn parse_completed( dispatcher::ParsedHandler { completed, data: PermissionRequestHandlerData { decision }, + completion_order: 0, } } diff --git a/codex-rs/hooks/src/events/post_tool_use.rs b/codex-rs/hooks/src/events/post_tool_use.rs index 223efa7260..801c5f09e9 100644 --- a/codex-rs/hooks/src/events/post_tool_use.rs +++ b/codex-rs/hooks/src/events/post_tool_use.rs @@ -298,6 +298,7 @@ fn parse_completed( additional_contexts_for_model, feedback_messages_for_model, }, + completion_order: 0, } } diff --git a/codex-rs/hooks/src/events/pre_tool_use.rs b/codex-rs/hooks/src/events/pre_tool_use.rs index 77e6d3f3fa..b21daf063b 100644 --- a/codex-rs/hooks/src/events/pre_tool_use.rs +++ b/codex-rs/hooks/src/events/pre_tool_use.rs @@ -38,6 +38,7 @@ pub struct PreToolUseOutcome { pub should_block: bool, pub block_reason: Option, pub additional_contexts: Vec, + pub updated_input: Option, } #[derive(Debug, Default, PartialEq, Eq)] @@ -45,6 +46,7 @@ struct PreToolUseHandlerData { should_block: bool, block_reason: Option, additional_contexts_for_model: Vec, + updated_input: Option, } pub(crate) fn preview( @@ -81,6 +83,7 @@ pub(crate) async fn run( should_block: false, block_reason: None, additional_contexts: Vec::new(), + updated_input: None, }; } @@ -116,6 +119,11 @@ pub(crate) async fn run( .iter() .map(|result| result.data.additional_contexts_for_model.as_slice()), ); + let updated_input = if should_block { + None + } else { + latest_updated_input(&results) + }; PreToolUseOutcome { hook_events: results @@ -127,9 +135,30 @@ pub(crate) async fn run( should_block, block_reason, additional_contexts, + updated_input, } } +/// Chooses the rewrite from the hook that actually finished last. +/// +/// Hook results stay in configured order for stable reporting, but the +/// `PreToolUse` contract resolves competing rewrites by completion order. +fn latest_updated_input( + results: &[dispatcher::ParsedHandler], +) -> Option { + results + .iter() + .filter_map(|result| { + result + .data + .updated_input + .clone() + .map(|updated_input| (result.completion_order, updated_input)) + }) + .max_by_key(|(completion_order, _)| *completion_order) + .map(|(_, updated_input)| updated_input) +} + /// Serializes command stdin for a selected `PreToolUse` hook. /// /// Handler selection may include internal matcher aliases, but hook stdin keeps @@ -161,6 +190,7 @@ fn parse_completed( let mut should_block = false; let mut block_reason = None; let mut additional_contexts_for_model = Vec::new(); + let mut updated_input = None; match run_result.error.as_deref() { Some(error) => { @@ -204,6 +234,9 @@ fn parse_completed( text: reason, }); } + if !should_block { + updated_input = parsed.updated_input; + } } } else if output_parser::looks_like_json(&run_result.stdout) { status = HookRunStatus::Failed; @@ -258,7 +291,9 @@ fn parse_completed( should_block, block_reason, additional_contexts_for_model, + updated_input, }, + completion_order: 0, } } @@ -268,6 +303,7 @@ fn serialization_failure_outcome(hook_events: Vec) -> PreToo should_block: false, block_reason: None, additional_contexts: Vec::new(), + updated_input: None, } } @@ -284,6 +320,7 @@ mod tests { use super::PreToolUseHandlerData; use super::command_input_json; + use super::latest_updated_input; use super::parse_completed; use super::preview; use crate::engine::ConfiguredHandler; @@ -320,6 +357,7 @@ mod tests { should_block: true, block_reason: Some("do not run that".to_string()), additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked); @@ -332,6 +370,91 @@ mod tests { ); } + #[test] + fn permission_decision_allow_can_update_input() { + let parsed = parse_completed( + &handler(), + run_result( + Some(0), + r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo rewritten"}}}"#, + "", + ), + Some("turn-1".to_string()), + ); + + assert_eq!( + parsed.data, + PreToolUseHandlerData { + should_block: false, + block_reason: None, + additional_contexts_for_model: Vec::new(), + updated_input: Some(serde_json::json!({ "command": "echo rewritten" })), + } + ); + assert_eq!(parsed.completed.run.status, HookRunStatus::Completed); + assert_eq!(parsed.completed.run.entries, vec![]); + } + + #[test] + fn last_completed_updated_input_wins() { + let mut later_configured = parse_completed( + &handler(), + run_result( + Some(0), + r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo configured later"}}}"#, + "", + ), + Some("turn-1".to_string()), + ); + later_configured.completion_order = 0; + let mut earlier_configured = parse_completed( + &handler(), + run_result( + Some(0), + r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow","updatedInput":{"command":"echo finished later"}}}"#, + "", + ), + Some("turn-1".to_string()), + ); + earlier_configured.completion_order = 1; + + assert_eq!( + latest_updated_input(&[later_configured, earlier_configured]), + Some(serde_json::json!({ "command": "echo finished later" })) + ); + } + + #[test] + fn permission_decision_allow_without_updated_input_fails_open() { + let parsed = parse_completed( + &handler(), + run_result( + Some(0), + r#"{"hookSpecificOutput":{"hookEventName":"PreToolUse","permissionDecision":"allow"}}"#, + "", + ), + Some("turn-1".to_string()), + ); + + assert_eq!( + parsed.data, + PreToolUseHandlerData { + should_block: false, + block_reason: None, + additional_contexts_for_model: Vec::new(), + updated_input: None, + } + ); + assert_eq!(parsed.completed.run.status, HookRunStatus::Failed); + assert_eq!( + parsed.completed.run.entries, + vec![HookOutputEntry { + kind: HookOutputEntryKind::Error, + text: "PreToolUse hook returned unsupported permissionDecision:allow".to_string(), + }] + ); + } + #[test] fn deprecated_block_decision_blocks_processing() { let parsed = parse_completed( @@ -350,6 +473,7 @@ mod tests { should_block: true, block_reason: Some("do not run that".to_string()), additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked); @@ -380,6 +504,7 @@ mod tests { should_block: true, block_reason: Some("do not run that".to_string()), additional_contexts_for_model: vec!["remember this".to_string()], + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked); @@ -416,6 +541,7 @@ mod tests { should_block: false, block_reason: None, additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Failed); @@ -442,6 +568,7 @@ mod tests { should_block: false, block_reason: None, additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Failed); @@ -472,6 +599,7 @@ mod tests { should_block: true, block_reason: Some("do not run that".to_string()), additional_contexts_for_model: vec!["nope".to_string()], + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked); @@ -504,6 +632,7 @@ mod tests { should_block: false, block_reason: None, additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Completed); @@ -524,6 +653,7 @@ mod tests { should_block: false, block_reason: None, additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Failed); @@ -550,6 +680,7 @@ mod tests { should_block: true, block_reason: Some("blocked by policy".to_string()), additional_contexts_for_model: Vec::new(), + updated_input: None, } ); assert_eq!(parsed.completed.run.status, HookRunStatus::Blocked); diff --git a/codex-rs/hooks/src/events/session_start.rs b/codex-rs/hooks/src/events/session_start.rs index 195bb11257..88bd9c00a0 100644 --- a/codex-rs/hooks/src/events/session_start.rs +++ b/codex-rs/hooks/src/events/session_start.rs @@ -234,6 +234,7 @@ fn parse_completed( stop_reason, additional_contexts_for_model, }, + completion_order: 0, } } diff --git a/codex-rs/hooks/src/events/stop.rs b/codex-rs/hooks/src/events/stop.rs index 8fc176a474..3cd2d44270 100644 --- a/codex-rs/hooks/src/events/stop.rs +++ b/codex-rs/hooks/src/events/stop.rs @@ -259,6 +259,7 @@ fn parse_completed( block_reason, continuation_fragments, }, + completion_order: 0, } } diff --git a/codex-rs/hooks/src/events/user_prompt_submit.rs b/codex-rs/hooks/src/events/user_prompt_submit.rs index a10798ea62..eb152a1f48 100644 --- a/codex-rs/hooks/src/events/user_prompt_submit.rs +++ b/codex-rs/hooks/src/events/user_prompt_submit.rs @@ -255,6 +255,7 @@ fn parse_completed( stop_reason, additional_contexts_for_model, }, + completion_order: 0, } }