mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
[hooks] userpromptsubmit - hook before user's prompt is executed (#14626)
- this allows blocking the user's prompts from executing, and also
prevents them from entering history
- handles the edge case where you can both prevent the user's prompt AND
add n amount of additionalContexts
- refactors some old code into common.rs where hooks overlap
functionality
- refactors additionalContext being previously added to user messages,
instead we use developer messages for them
- handles queued messages correctly
Sample hook for testing - if you write "[block-user-submit]" this hook
will stop the thread:
example run
```
› sup
• Running UserPromptSubmit hook: reading the observatory notes
UserPromptSubmit hook (completed)
warning: wizard-tower UserPromptSubmit demo inspected: sup
hook context: Wizard Tower UserPromptSubmit demo fired. For this reply only, include the exact
phrase 'observatory lanterns lit' exactly once near the end.
• Just riding the cosmic wave and ready to help, my friend. What are we building today? observatory
lanterns lit
› and [block-user-submit]
• Running UserPromptSubmit hook: reading the observatory notes
UserPromptSubmit hook (stopped)
warning: wizard-tower UserPromptSubmit demo blocked the prompt on purpose.
stop: Wizard Tower demo block: remove [block-user-submit] to continue.
```
.codex/config.toml
```
[features]
codex_hooks = true
```
.codex/hooks.json
```
{
"hooks": {
"UserPromptSubmit": [
{
"hooks": [
{
"type": "command",
"command": "/usr/bin/python3 .codex/hooks/user_prompt_submit_demo.py",
"timeoutSec": 10,
"statusMessage": "reading the observatory notes"
}
]
}
]
}
}
```
.codex/hooks/user_prompt_submit_demo.py
```
#!/usr/bin/env python3
import json
import sys
from pathlib import Path
def prompt_from_payload(payload: dict) -> str:
prompt = payload.get("prompt")
if isinstance(prompt, str) and prompt.strip():
return prompt.strip()
event = payload.get("event")
if isinstance(event, dict):
user_prompt = event.get("user_prompt")
if isinstance(user_prompt, str):
return user_prompt.strip()
return ""
def main() -> int:
payload = json.load(sys.stdin)
prompt = prompt_from_payload(payload)
cwd = Path(payload.get("cwd", ".")).name or "wizard-tower"
if "[block-user-submit]" in prompt:
print(
json.dumps(
{
"systemMessage": (
f"{cwd} UserPromptSubmit demo blocked the prompt on purpose."
),
"decision": "block",
"reason": (
"Wizard Tower demo block: remove [block-user-submit] to continue."
),
}
)
)
return 0
prompt_preview = prompt or "(empty prompt)"
if len(prompt_preview) > 80:
prompt_preview = f"{prompt_preview[:77]}..."
print(
json.dumps(
{
"systemMessage": (
f"{cwd} UserPromptSubmit demo inspected: {prompt_preview}"
),
"hookSpecificOutput": {
"hookEventName": "UserPromptSubmit",
"additionalContext": (
"Wizard Tower UserPromptSubmit demo fired. "
"For this reply only, include the exact phrase "
"'observatory lanterns lit' exactly once near the end."
),
},
}
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
```
This commit is contained in:
@@ -6,21 +6,34 @@ use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_message_item_added;
|
||||
use core_test_support::responses::ev_output_text_delta;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::streaming_sse::StreamingSseChunk;
|
||||
use core_test_support::streaming_sse::start_streaming_sse_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::sleep;
|
||||
|
||||
const FIRST_CONTINUATION_PROMPT: &str = "Retry with exactly the phrase meow meow meow.";
|
||||
const SECOND_CONTINUATION_PROMPT: &str = "Now tighten it to just: meow.";
|
||||
const BLOCKED_PROMPT_CONTEXT: &str = "Remember the blocked lighthouse note.";
|
||||
|
||||
fn write_stop_hook(home: &Path, block_prompts: &[&str]) -> Result<()> {
|
||||
let script_path = home.join("stop_hook.py");
|
||||
@@ -69,6 +82,87 @@ else:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_user_prompt_submit_hook(
|
||||
home: &Path,
|
||||
blocked_prompt: &str,
|
||||
additional_context: &str,
|
||||
) -> Result<()> {
|
||||
let script_path = home.join("user_prompt_submit_hook.py");
|
||||
let blocked_prompt_json =
|
||||
serde_json::to_string(blocked_prompt).context("serialize blocked prompt for test")?;
|
||||
let additional_context_json = serde_json::to_string(additional_context)
|
||||
.context("serialize user prompt submit additional context for test")?;
|
||||
let script = format!(
|
||||
r#"import json
|
||||
import sys
|
||||
|
||||
payload = json.load(sys.stdin)
|
||||
|
||||
if payload.get("prompt") == {blocked_prompt_json}:
|
||||
print(json.dumps({{
|
||||
"decision": "block",
|
||||
"reason": "blocked by hook",
|
||||
"hookSpecificOutput": {{
|
||||
"hookEventName": "UserPromptSubmit",
|
||||
"additionalContext": {additional_context_json}
|
||||
}}
|
||||
}}))
|
||||
"#,
|
||||
);
|
||||
let hooks = serde_json::json!({
|
||||
"hooks": {
|
||||
"UserPromptSubmit": [{
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": format!("python3 {}", script_path.display()),
|
||||
"statusMessage": "running user prompt submit hook",
|
||||
}]
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
fs::write(&script_path, script).context("write user prompt submit hook script")?;
|
||||
fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_session_start_hook_recording_transcript(home: &Path) -> Result<()> {
|
||||
let script_path = home.join("session_start_hook.py");
|
||||
let log_path = home.join("session_start_hook_log.jsonl");
|
||||
let script = format!(
|
||||
r#"import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
payload = json.load(sys.stdin)
|
||||
transcript_path = payload.get("transcript_path")
|
||||
record = {{
|
||||
"transcript_path": transcript_path,
|
||||
"exists": Path(transcript_path).exists() if transcript_path else False,
|
||||
}}
|
||||
|
||||
with Path(r"{log_path}").open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(record) + "\n")
|
||||
"#,
|
||||
log_path = log_path.display(),
|
||||
);
|
||||
let hooks = serde_json::json!({
|
||||
"hooks": {
|
||||
"SessionStart": [{
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": format!("python3 {}", script_path.display()),
|
||||
"statusMessage": "running session start hook",
|
||||
}]
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
fs::write(&script_path, script).context("write session start hook script")?;
|
||||
fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rollout_developer_texts(text: &str) -> Result<Vec<String>> {
|
||||
let mut texts = Vec::new();
|
||||
for line in text.lines() {
|
||||
@@ -99,6 +193,49 @@ fn read_stop_hook_inputs(home: &Path) -> Result<Vec<serde_json::Value>> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn read_session_start_hook_inputs(home: &Path) -> Result<Vec<serde_json::Value>> {
|
||||
fs::read_to_string(home.join("session_start_hook_log.jsonl"))
|
||||
.context("read session start hook log")?
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.map(|line| serde_json::from_str(line).context("parse session start hook log line"))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn ev_message_item_done(id: &str, text: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"id": id,
|
||||
"content": [{"type": "output_text", "text": text}]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sse_event(event: Value) -> String {
|
||||
sse(vec![event])
|
||||
}
|
||||
|
||||
fn request_message_input_texts(body: &[u8], role: &str) -> Vec<String> {
|
||||
let body: Value = match serde_json::from_slice(body) {
|
||||
Ok(body) => body,
|
||||
Err(error) => panic!("parse request body: {error}"),
|
||||
};
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter(|item| item.get("type").and_then(Value::as_str) == Some("message"))
|
||||
.filter(|item| item.get("role").and_then(Value::as_str) == Some(role))
|
||||
.filter_map(|item| item.get("content").and_then(Value::as_array))
|
||||
.flatten()
|
||||
.filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text"))
|
||||
.filter_map(|span| span.get("text").and_then(Value::as_str).map(str::to_owned))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn stop_hook_can_block_multiple_times_in_same_turn() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -193,6 +330,51 @@ async fn stop_hook_can_block_multiple_times_in_same_turn() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn session_start_hook_sees_materialized_transcript_path() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let _response = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "hello from the reef"),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex()
|
||||
.with_pre_build_hook(|home| {
|
||||
if let Err(error) = write_session_start_hook_recording_transcript(home) {
|
||||
panic!("failed to write session start hook test fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::CodexHooks)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("hello").await?;
|
||||
|
||||
let hook_inputs = read_session_start_hook_inputs(test.codex_home_path())?;
|
||||
assert_eq!(hook_inputs.len(), 1);
|
||||
assert_eq!(
|
||||
hook_inputs[0]
|
||||
.get("transcript_path")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::is_empty),
|
||||
Some(false)
|
||||
);
|
||||
assert_eq!(hook_inputs[0].get("exists"), Some(&Value::Bool(true)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn resumed_thread_keeps_stop_continuation_prompt_in_history() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -269,3 +451,179 @@ async fn resumed_thread_keeps_stop_continuation_prompt_in_history() -> Result<()
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn blocked_user_prompt_submit_persists_additional_context_for_next_turn() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let response = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "second prompt handled"),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex()
|
||||
.with_pre_build_hook(|home| {
|
||||
if let Err(error) =
|
||||
write_user_prompt_submit_hook(home, "blocked first prompt", BLOCKED_PROMPT_CONTEXT)
|
||||
{
|
||||
panic!("failed to write user prompt submit hook test fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::CodexHooks)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("blocked first prompt").await?;
|
||||
test.submit_turn("second prompt").await?;
|
||||
|
||||
let request = response.single_request();
|
||||
assert!(
|
||||
request
|
||||
.message_input_texts("developer")
|
||||
.contains(&BLOCKED_PROMPT_CONTEXT.to_string()),
|
||||
"second request should include developer context persisted from the blocked prompt",
|
||||
);
|
||||
assert!(
|
||||
request
|
||||
.message_input_texts("user")
|
||||
.iter()
|
||||
.all(|text| !text.contains("blocked first prompt")),
|
||||
"blocked prompt should not be sent to the model",
|
||||
);
|
||||
assert!(
|
||||
request
|
||||
.message_input_texts("user")
|
||||
.iter()
|
||||
.any(|text| text.contains("second prompt")),
|
||||
"second request should include the accepted prompt",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn blocked_queued_prompt_does_not_strand_earlier_accepted_prompt() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let (gate_completed_tx, gate_completed_rx) = oneshot::channel();
|
||||
let first_chunks = vec![
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(ev_response_created("resp-1")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(ev_message_item_added("msg-1", "")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(ev_output_text_delta("first ")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(ev_message_item_done("msg-1", "first response")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: Some(gate_completed_rx),
|
||||
body: sse_event(ev_completed("resp-1")),
|
||||
},
|
||||
];
|
||||
let second_chunks = vec![StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-2", "accepted queued prompt handled"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
}];
|
||||
let (server, _completions) =
|
||||
start_streaming_sse_server(vec![first_chunks, second_chunks]).await;
|
||||
|
||||
let mut builder = test_codex()
|
||||
.with_model("gpt-5.1")
|
||||
.with_pre_build_hook(|home| {
|
||||
if let Err(error) =
|
||||
write_user_prompt_submit_hook(home, "blocked queued prompt", BLOCKED_PROMPT_CONTEXT)
|
||||
{
|
||||
panic!("failed to write user prompt submit hook test fixture: {error}");
|
||||
}
|
||||
})
|
||||
.with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::CodexHooks)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = builder.build_with_streaming_server(&server).await?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "initial prompt".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::AgentMessageContentDelta(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
for text in ["accepted queued prompt", "blocked queued prompt"] {
|
||||
test.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: text.to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
let _ = gate_completed_tx.send(());
|
||||
|
||||
let requests = tokio::time::timeout(Duration::from_secs(30), async {
|
||||
loop {
|
||||
let requests = server.requests().await;
|
||||
if requests.len() >= 2 {
|
||||
break requests;
|
||||
}
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("second request should arrive")
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
assert_eq!(requests.len(), 2);
|
||||
|
||||
let second_user_texts = request_message_input_texts(&requests[1], "user");
|
||||
assert!(
|
||||
second_user_texts.contains(&"accepted queued prompt".to_string()),
|
||||
"second request should include the accepted queued prompt",
|
||||
);
|
||||
assert!(
|
||||
!second_user_texts.contains(&"blocked queued prompt".to_string()),
|
||||
"second request should not include the blocked queued prompt",
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user