From b2aa8613e9d10b79d451da1b96d7468a599dcd3c Mon Sep 17 00:00:00 2001 From: Abhinav Vedmala Date: Sat, 2 May 2026 12:03:52 -0700 Subject: [PATCH] add hook trust integration coverage --- .../app-server/tests/suite/v2/hooks_list.rs | 248 ++++++++++++++++++ codex-rs/hooks/src/engine/mod_tests.rs | 101 ------- 2 files changed, 248 insertions(+), 101 deletions(-) diff --git a/codex-rs/app-server/tests/suite/v2/hooks_list.rs b/codex-rs/app-server/tests/suite/v2/hooks_list.rs index 49430451da..ca7bd5d492 100644 --- a/codex-rs/app-server/tests/suite/v2/hooks_list.rs +++ b/codex-rs/app-server/tests/suite/v2/hooks_list.rs @@ -476,6 +476,254 @@ async fn config_batch_write_toggles_user_hook() -> Result<()> { Ok(()) } +#[tokio::test] +async fn config_batch_write_updates_hook_trust_for_loaded_session() -> Result<()> { + skip_if_windows!(Ok(())); + + let responses = vec![ + create_final_assistant_message_sse_response("Warmup")?, + create_final_assistant_message_sse_response("Untrusted turn")?, + create_final_assistant_message_sse_response("Trusted turn")?, + create_final_assistant_message_sse_response("Modified turn")?, + ]; + let server = create_mock_responses_server_sequence_unchecked(responses).await; + let codex_home = TempDir::new()?; + let hook_script_path = codex_home.path().join("user_prompt_submit_hook.py"); + let hook_log_path = codex_home.path().join("user_prompt_submit_hook_log.jsonl"); + std::fs::write( + &hook_script_path, + format!( + r#"import json +from pathlib import Path +import sys + +payload = json.load(sys.stdin) +with Path(r"{hook_log_path}").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload) + "\n") +"#, + hook_log_path = hook_log_path.display(), + ), + )?; + std::fs::write( + codex_home.path().join("config.toml"), + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 + +[hooks] + +[[hooks.UserPromptSubmit]] + +[[hooks.UserPromptSubmit.hooks]] +type = "command" +command = "python3 {hook_script_path}" +"#, + server_uri = server.uri(), + hook_script_path = hook_script_path.display(), + ), + )?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_TIMEOUT, mcp.initialize()).await??; + + let hook_list_id = mcp + .send_hooks_list_request(HooksListParams { + cwds: vec![codex_home.path().to_path_buf()], + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(hook_list_id)), + ) + .await??; + let HooksListResponse { data } = to_response(response)?; + let hook = data[0].hooks[0].clone(); + assert_eq!(hook.trust_status, HookTrustStatus::Untrusted); + + let thread_start_id = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + ..Default::default() + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_start_id)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response(response)?; + + let first_turn_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "first turn".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(first_turn_id)), + ) + .await??; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + assert!(!std::fs::exists(&hook_log_path)?); + + let write_id = mcp + .send_config_batch_write_request(ConfigBatchWriteParams { + edits: vec![ConfigEdit { + key_path: "hooks.state".to_string(), + value: serde_json::json!({ + hook.key.clone(): { + "trusted_hash": hook.current_hash.clone() + } + }), + merge_strategy: MergeStrategy::Upsert, + }], + file_path: None, + expected_version: None, + reload_user_config: true, + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(write_id)), + ) + .await??; + let _: codex_app_server_protocol::ConfigWriteResponse = to_response(response)?; + + let hook_list_id = mcp + .send_hooks_list_request(HooksListParams { + cwds: vec![codex_home.path().to_path_buf()], + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(hook_list_id)), + ) + .await??; + let HooksListResponse { data } = to_response(response)?; + let trusted_hook = &data[0].hooks[0]; + assert_eq!(trusted_hook.key, hook.key); + assert_eq!(trusted_hook.current_hash, hook.current_hash); + assert_eq!(trusted_hook.trust_status, HookTrustStatus::Trusted); + + let second_turn_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "second turn".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(second_turn_id)), + ) + .await??; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + assert_eq!( + std::fs::read_to_string(&hook_log_path)? + .lines() + .filter(|line| !line.is_empty()) + .count(), + 1 + ); + + let write_id = mcp + .send_config_batch_write_request(ConfigBatchWriteParams { + edits: vec![ConfigEdit { + key_path: "hooks.UserPromptSubmit".to_string(), + value: serde_json::json!([{ + "hooks": [{ + "type": "command", + "command": format!("python3 {}", hook_script_path.display()), + "statusMessage": "modified hook", + }], + }]), + merge_strategy: MergeStrategy::Replace, + }], + file_path: None, + expected_version: None, + reload_user_config: true, + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(write_id)), + ) + .await??; + let _: codex_app_server_protocol::ConfigWriteResponse = to_response(response)?; + + let hook_list_id = mcp + .send_hooks_list_request(HooksListParams { + cwds: vec![codex_home.path().to_path_buf()], + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(hook_list_id)), + ) + .await??; + let HooksListResponse { data } = to_response(response)?; + let modified_hook = &data[0].hooks[0]; + assert_eq!(modified_hook.key, hook.key); + assert_ne!(modified_hook.current_hash, hook.current_hash); + assert_eq!(modified_hook.trust_status, HookTrustStatus::Modified); + + let third_turn_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id, + input: vec![V2UserInput::Text { + text: "third turn".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(third_turn_id)), + ) + .await??; + timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + assert_eq!( + std::fs::read_to_string(&hook_log_path)? + .lines() + .filter(|line| !line.is_empty()) + .count(), + 1 + ); + Ok(()) +} + #[tokio::test] async fn config_batch_write_disables_hook_for_loaded_session() -> Result<()> { skip_if_windows!(Ok(())); diff --git a/codex-rs/hooks/src/engine/mod_tests.rs b/codex-rs/hooks/src/engine/mod_tests.rs index ea6bbe6a22..32739165f1 100644 --- a/codex-rs/hooks/src/engine/mod_tests.rs +++ b/codex-rs/hooks/src/engine/mod_tests.rs @@ -12,7 +12,6 @@ use codex_config::Constrained; use codex_config::ConstrainedWithSource; use codex_config::HookEventsToml; use codex_config::HookHandlerConfig; -use codex_config::HookStateToml; use codex_config::ManagedHooksRequirementsToml; use codex_config::MatcherGroup; use codex_config::RequirementSource; @@ -365,85 +364,6 @@ fn user_disablement_does_not_filter_managed_layer_hooks() { ); } -#[test] -fn unmanaged_hook_trust_status_tracks_stored_hash() { - let temp = tempdir().expect("create temp dir"); - let config_path = - AbsolutePathBuf::try_from(temp.path().join("config.toml")).expect("absolute path"); - let key = format!("{}:pre_tool_use:0:0", config_path.display()); - - let untrusted_stack = ConfigLayerStack::new( - vec![ConfigLayerEntry::new( - ConfigLayerSource::User { - file: config_path.clone(), - }, - config_with_pre_tool_use_hook("python3 /tmp/user.py"), - )], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("config layer stack"); - let untrusted = - super::discovery::discover_handlers(Some(&untrusted_stack), Vec::new(), Vec::new()); - assert_eq!(untrusted.hook_entries.len(), 1); - assert_eq!( - untrusted.hook_entries[0].trust_status, - HookTrustStatus::Untrusted - ); - assert_eq!(untrusted.handlers, Vec::new()); - - let current_hash = untrusted.hook_entries[0].current_hash.clone(); - let trusted_stack = ConfigLayerStack::new( - vec![ConfigLayerEntry::new( - ConfigLayerSource::User { file: config_path }, - config_with_pre_tool_use_hook_and_state( - "python3 /tmp/user.py", - &key, - HookStateToml { - enabled: None, - trusted_hash: Some(current_hash), - }, - ), - )], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("config layer stack"); - let trusted = super::discovery::discover_handlers(Some(&trusted_stack), Vec::new(), Vec::new()); - assert_eq!(trusted.hook_entries.len(), 1); - assert_eq!( - trusted.hook_entries[0].trust_status, - HookTrustStatus::Trusted - ); - assert_eq!(trusted.handlers.len(), 1); - - let changed_stack = ConfigLayerStack::new( - vec![ConfigLayerEntry::new( - ConfigLayerSource::User { - file: trusted.hook_entries[0].source_path.clone(), - }, - config_with_pre_tool_use_hook_and_state( - "python3 /tmp/user.py", - &key, - HookStateToml { - enabled: None, - trusted_hash: Some("sha256:old".to_string()), - }, - ), - )], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("config layer stack"); - let changed = super::discovery::discover_handlers(Some(&changed_stack), Vec::new(), Vec::new()); - assert_eq!(changed.hook_entries.len(), 1); - assert_eq!( - changed.hook_entries[0].trust_status, - HookTrustStatus::Modified - ); - assert_eq!(changed.handlers, Vec::new()); -} - fn config_with_hook_state(key: &str, enabled: bool) -> TomlValue { serde_json::from_value(serde_json::json!({ "hooks": { @@ -457,27 +377,6 @@ fn config_with_hook_state(key: &str, enabled: bool) -> TomlValue { .expect("config TOML should deserialize") } -fn config_with_pre_tool_use_hook_and_state( - command: &str, - key: &str, - state: HookStateToml, -) -> TomlValue { - serde_json::from_value(serde_json::json!({ - "hooks": { - "state": { - (key): serde_json::to_value(state).expect("hook state should serialize"), - }, - "PreToolUse": [{ - "hooks": [{ - "type": "command", - "command": command, - }], - }], - }, - })) - .expect("config TOML should deserialize") -} - fn config_with_pre_tool_use_hook_and_states( command: &str, disabled_keys: [&str; N],