From 96ad7154da5f7bb4aee42648b164ebcdd6dcdc41 Mon Sep 17 00:00:00 2001 From: Abhinav Vedmala Date: Fri, 22 May 2026 12:35:28 -0700 Subject: [PATCH] simplify post-tool rewrite projection --- codex-rs/core/src/tools/context.rs | 14 ++---- codex-rs/core/src/tools/context_tests.rs | 49 +++++++++++++++++++ codex-rs/core/src/tools/registry_tests.rs | 57 ---------------------- codex-rs/hooks/src/events/post_tool_use.rs | 27 +++++----- 4 files changed, 69 insertions(+), 78 deletions(-) diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index e885b861eb..916a268215 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -293,13 +293,13 @@ impl ToolOutput for FunctionToolOutput { /// response item. pub(crate) struct ModelVisibleRewriteOutput { original_tool_output: Box, - updated_tool_output: JsonValue, + updated_tool_output: String, } impl ModelVisibleRewriteOutput { pub(crate) fn new( original_tool_output: Box, - updated_tool_output: JsonValue, + updated_tool_output: String, ) -> Self { Self { original_tool_output, @@ -318,14 +318,8 @@ impl ToolOutput for ModelVisibleRewriteOutput { } fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem { - FunctionToolOutput::from_text( - match &self.updated_tool_output { - JsonValue::String(text) => text.clone(), - _ => self.updated_tool_output.to_string(), - }, - Some(true), - ) - .to_response_item(call_id, payload) + FunctionToolOutput::from_text(self.updated_tool_output.clone(), Some(true)) + .to_response_item(call_id, payload) } fn code_mode_result(&self, payload: &ToolPayload) -> JsonValue { diff --git a/codex-rs/core/src/tools/context_tests.rs b/codex-rs/core/src/tools/context_tests.rs index 6dc4248313..98c95dfdeb 100644 --- a/codex-rs/core/src/tools/context_tests.rs +++ b/codex-rs/core/src/tools/context_tests.rs @@ -230,6 +230,55 @@ fn mcp_tool_output_response_item_preserves_content_items() { } } +#[test] +fn model_visible_rewrite_output_keeps_original_mcp_code_mode_result() { + let large_content = "large structured value ".repeat(1_000); + let output = ModelVisibleRewriteOutput::new( + Box::new(McpToolOutput { + result: CallToolResult { + content: vec![serde_json::json!({ + "type": "text", + "text": "ignored", + })], + structured_content: Some(serde_json::json!({ + "content": large_content, + })), + is_error: Some(false), + meta: None, + }, + tool_input: json!({}), + wall_time: std::time::Duration::from_millis(1250), + original_image_detail_supported: false, + truncation_policy: TruncationPolicy::Bytes(64), + }), + "rewritten".to_string(), + ); + let payload = ToolPayload::Function { + arguments: "{}".to_string(), + }; + + match output.to_response_item("mcp-call-1", &payload) { + ResponseInputItem::FunctionCallOutput { call_id, output } => { + assert_eq!(call_id, "mcp-call-1"); + assert_eq!(output.body.to_text().as_deref(), Some("rewritten")); + } + other => panic!("expected FunctionCallOutput, got {other:?}"), + } + assert_eq!( + output.code_mode_result(&payload), + serde_json::json!({ + "content": [{ + "type": "text", + "text": "ignored", + }], + "structuredContent": { + "content": "large structured value ".repeat(1_000), + }, + "isError": false, + }) + ); +} + #[test] fn mcp_tool_output_code_mode_result_stays_raw_call_tool_result() { let large_content = "large structured value ".repeat(1_000); diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index 1d9808748e..dc744321d1 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -1,13 +1,8 @@ use super::*; -use crate::tools::context::McpToolOutput; -use crate::tools::context::ModelVisibleRewriteOutput; use crate::tools::handlers::GetGoalHandler; use crate::tools::handlers::goal_spec::GET_GOAL_TOOL_NAME; use crate::tools::handlers::goal_spec::create_get_goal_tool; -use codex_protocol::mcp::CallToolResult; use pretty_assertions::assert_eq; -use serde_json::json; -use std::time::Duration; struct TestHandler { tool_name: codex_tools::ToolName, @@ -67,58 +62,6 @@ fn handler_looks_up_namespaced_aliases_explicitly() { ); } -#[test] -fn model_visible_rewrite_preserves_code_mode_result() { - let result = mcp_result_with_model_visible_rewrite(); - - match result.into_response() { - ResponseInputItem::FunctionCallOutput { call_id, output } => { - assert_eq!(call_id, "mcp-call-1"); - assert_eq!( - output.body.to_text().as_deref(), - Some(r#"{"echo":"rewritten"}"#) - ); - } - other => panic!("expected FunctionCallOutput, got {other:?}"), - } - - assert_eq!( - mcp_result_with_model_visible_rewrite().code_mode_result(), - json!({ - "content": [], - "structuredContent": { - "echo": "original", - }, - "isError": false, - }) - ); -} - -fn mcp_result_with_model_visible_rewrite() -> AnyToolResult { - AnyToolResult { - call_id: "mcp-call-1".to_string(), - payload: ToolPayload::Function { - arguments: "{}".to_string(), - }, - result: Box::new(ModelVisibleRewriteOutput::new( - Box::new(McpToolOutput { - result: CallToolResult { - content: Vec::new(), - structured_content: Some(json!({ "echo": "original" })), - is_error: Some(false), - meta: None, - }, - tool_input: json!({}), - wall_time: Duration::ZERO, - original_image_detail_supported: false, - truncation_policy: codex_utils_output_truncation::TruncationPolicy::Bytes(1024), - }), - json!({ "echo": "rewritten" }), - )), - post_tool_use_payload: None, - } -} - #[test] fn register_handler_adds_handler_and_spec() { let mut builder = ToolRegistryBuilder::new(); diff --git a/codex-rs/hooks/src/events/post_tool_use.rs b/codex-rs/hooks/src/events/post_tool_use.rs index d0e9d22f4a..b851d81a25 100644 --- a/codex-rs/hooks/src/events/post_tool_use.rs +++ b/codex-rs/hooks/src/events/post_tool_use.rs @@ -40,7 +40,7 @@ pub struct PostToolUseOutcome { pub stop_reason: Option, pub additional_contexts: Vec, pub feedback_message: Option, - pub updated_tool_output: Option, + pub updated_tool_output: Option, } #[derive(Debug, Default, PartialEq, Eq)] @@ -255,6 +255,7 @@ fn parse_completed( } } let can_rewrite_output = parsed.universal.continue_processing + && !parsed.should_block && parsed.invalid_reason.is_none() && parsed.invalid_block_reason.is_none(); if can_rewrite_output { @@ -341,15 +342,17 @@ fn select_updated_tool_output( results: &mut [dispatcher::ParsedHandler], tool_name: &str, original_tool_response: &Value, -) -> Option { +) -> Option { let is_mcp_tool = tool_name.starts_with("mcp__"); + let original_tool_response_kind = json_kind_name(original_tool_response); let mut selected = None; for result in results { - let candidate = if let Some(updated_tool_output) = result.data.updated_tool_output.take() { + let candidate = if let Some(updated_tool_output) = result.data.updated_tool_output.as_ref() + { Some(updated_tool_output) } else if is_mcp_tool { - result.data.updated_mcp_tool_output.take() + result.data.updated_mcp_tool_output.as_ref() } else if result.data.updated_mcp_tool_output.is_some() { result.completed.run.entries.push(HookOutputEntry { kind: HookOutputEntryKind::Warning, @@ -364,14 +367,16 @@ fn select_updated_tool_output( continue; }; - if is_mcp_tool || json_kind_name(original_tool_response) == json_kind_name(&candidate) { - selected = Some(candidate); + if is_mcp_tool || original_tool_response_kind == json_kind_name(candidate) { + selected = Some(match candidate { + Value::String(text) => text.clone(), + _ => candidate.to_string(), + }); } else { result.completed.run.entries.push(HookOutputEntry { kind: HookOutputEntryKind::Warning, text: format!( - "ignored updatedToolOutput: expected {} to match tool_response shape", - json_kind_name(original_tool_response) + "ignored updatedToolOutput: expected {original_tool_response_kind} to match tool_response shape" ), }); } @@ -651,7 +656,7 @@ mod tests { assert_eq!( super::select_updated_tool_output(&mut results, "Bash", &json!("old")), - Some(json!("second")) + Some("second".to_string()) ); } @@ -695,7 +700,7 @@ mod tests { assert_eq!( super::select_updated_tool_output(&mut results, "mcp__memory__lookup", &json!({})), - Some(json!({"source": "generic"})) + Some(r#"{"source":"generic"}"#.to_string()) ); } @@ -713,7 +718,7 @@ mod tests { assert_eq!( super::select_updated_tool_output(&mut results, "mcp__memory__lookup", &json!({})), - Some(json!({"source": "mcp"})) + Some(r#"{"source":"mcp"}"#.to_string()) ); }