Compare commits

...

6 Commits

Author SHA1 Message Date
Abhinav Vedmala
898518e352 Merge origin/main into hook context developer bundling 2026-05-19 12:38:48 -07:00
Abhinav Vedmala
d1a763f6cd Preserve sticky SessionStart hook context 2026-05-19 12:28:01 -07:00
Abhinav Vedmala
596bc555e7 Merge remote-tracking branch 'origin/main' into abhinav/hook-context-developer-bundling
# Conflicts:
#	codex-rs/core/tests/suite/hooks.rs
2026-05-13 16:05:18 -07:00
Abhinav Vedmala
b48f5f1f91 Keep hook context out of rollback trimming 2026-05-13 16:01:12 -07:00
Abhinav Vedmala
6218bef50e codex: fix CI failure on PR #22536 2026-05-13 11:54:44 -07:00
Abhinav Vedmala
4ad8ea55bf merge hook developer context messages 2026-05-13 11:41:42 -07:00
7 changed files with 210 additions and 39 deletions

View File

@@ -5,12 +5,23 @@ pub(crate) struct HookAdditionalContext {
text: String,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct StickyHookAdditionalContext {
text: String,
}
impl HookAdditionalContext {
pub(crate) fn new(text: impl Into<String>) -> Self {
Self { text: text.into() }
}
}
impl StickyHookAdditionalContext {
pub(crate) fn new(text: impl Into<String>) -> Self {
Self { text: text.into() }
}
}
impl ContextualUserFragment for HookAdditionalContext {
fn role() -> &'static str {
"developer"
@@ -21,7 +32,25 @@ impl ContextualUserFragment for HookAdditionalContext {
}
fn type_markers() -> (&'static str, &'static str) {
("", "")
("<hook_context>", "</hook_context>")
}
fn body(&self) -> String {
self.text.clone()
}
}
impl ContextualUserFragment for StickyHookAdditionalContext {
fn role() -> &'static str {
"developer"
}
fn markers(&self) -> (&'static str, &'static str) {
Self::type_markers()
}
fn type_markers() -> (&'static str, &'static str) {
("<hook_context_sticky>", "</hook_context_sticky>")
}
fn body(&self) -> String {

View File

@@ -43,6 +43,7 @@ pub(crate) use fragment::FragmentRegistrationProxy;
pub(crate) use goal_context::GoalContext;
pub(crate) use guardian_followup_review_reminder::GuardianFollowupReviewReminder;
pub(crate) use hook_additional_context::HookAdditionalContext;
pub(crate) use hook_additional_context::StickyHookAdditionalContext;
pub(crate) use image_generation_instructions::ImageGenerationInstructions;
pub(crate) use legacy_apply_patch_exec_command_warning::LegacyApplyPatchExecCommandWarning;
pub(crate) use legacy_model_mismatch_warning::LegacyModelMismatchWarning;

View File

@@ -27,6 +27,7 @@ use crate::web_search::web_search_action_detail;
const CONTEXTUAL_DEVELOPER_PREFIXES: &[&str] = &[
"<permissions instructions>",
"<model_switch>",
"<hook_context>",
COLLABORATION_MODE_OPEN_TAG,
REALTIME_CONVERSATION_OPEN_TAG,
"<personality_spec>",

View File

@@ -1,3 +1,5 @@
use super::has_non_contextual_dev_message_content;
use super::is_contextual_dev_message_content;
use super::parse_turn_item;
use crate::context::ContextualUserFragment;
use crate::context::GoalContext;
@@ -316,6 +318,21 @@ fn parses_hook_prompt_and_hides_other_contextual_fragments() {
}
}
#[test]
fn hook_context_fragments_distinguish_thread_scoped_from_sticky_state() {
let trimmable_context = vec![ContentItem::InputText {
text: "<hook_context>thread scoped note</hook_context>".to_string(),
}];
assert!(is_contextual_dev_message_content(&trimmable_context));
assert!(!has_non_contextual_dev_message_content(&trimmable_context));
let sticky_context = vec![ContentItem::InputText {
text: "<hook_context_sticky>session scoped note</hook_context_sticky>".to_string(),
}];
assert!(!is_contextual_dev_message_content(&sticky_context));
assert!(has_non_contextual_dev_message_content(&sticky_context));
}
#[test]
fn goal_context_does_not_parse_as_visible_turn_item() {
let item = ResponseItem::Message {

View File

@@ -34,6 +34,8 @@ use serde_json::Value;
use crate::context::ContextualUserFragment;
use crate::context::HookAdditionalContext;
use crate::context::StickyHookAdditionalContext;
use crate::context_manager::updates::build_developer_update_item;
use crate::event_mapping::parse_turn_item;
use crate::session::TurnInput;
use crate::session::session::Session;
@@ -111,15 +113,16 @@ pub(crate) async fn run_pending_session_start_hooks(
};
let hooks = sess.hooks();
let preview_runs = hooks.preview_session_start(&request);
run_context_injecting_hook(
let outcome = run_context_injecting_hook(
sess,
turn_context,
preview_runs,
hooks.run_session_start(request, Some(turn_context.sub_id.clone())),
)
.await
.record_additional_contexts(sess, turn_context)
.await
.await;
// SessionStart context is durable developer state; rollback should not trim it.
record_sticky_additional_contexts(sess, turn_context, outcome.additional_contexts).await;
outcome.should_stop
}
/// Runs matching `PreToolUse` hooks before a tool executes.
@@ -484,18 +487,6 @@ where
outcome.outcome
}
impl HookRuntimeOutcome {
async fn record_additional_contexts(
self,
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
) -> bool {
record_additional_contexts(sess, turn_context, self.additional_contexts).await;
self.should_stop
}
}
pub(crate) async fn record_additional_contexts(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
@@ -510,12 +501,39 @@ pub(crate) async fn record_additional_contexts(
.await;
}
async fn record_sticky_additional_contexts(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
additional_contexts: Vec<String>,
) {
// Use this for hook context that should survive rollback as persistent developer state.
let developer_messages = sticky_additional_context_messages(additional_contexts);
if developer_messages.is_empty() {
return;
}
sess.record_conversation_items(turn_context, developer_messages.as_slice())
.await;
}
fn additional_context_messages(additional_contexts: Vec<String>) -> Vec<ResponseItem> {
additional_contexts
let sections = additional_contexts
.into_iter()
.map(HookAdditionalContext::new)
.map(ContextualUserFragment::into)
.collect()
.map(|context| context.render())
.collect();
build_developer_update_item(sections).into_iter().collect()
}
fn sticky_additional_context_messages(additional_contexts: Vec<String>) -> Vec<ResponseItem> {
let sections = additional_contexts
.into_iter()
.map(StickyHookAdditionalContext::new)
.map(|context| context.render())
.collect();
build_developer_update_item(sections).into_iter().collect()
}
async fn emit_hook_started_events(
@@ -675,36 +693,39 @@ mod tests {
use codex_utils_absolute_path::test_support::test_path_buf;
#[test]
fn additional_context_messages_stay_separate_and_ordered() {
fn additional_context_messages_merge_and_preserve_order() {
let messages = additional_context_messages(vec![
"first tide note".to_string(),
"second tide note".to_string(),
]);
assert_eq!(messages.len(), 2);
assert_eq!(messages.len(), 1);
assert_eq!(
messages
.iter()
.map(|message| match message {
codex_protocol::models::ResponseItem::Message { role, content, .. } => {
let text = content
let texts = content
.iter()
.map(|item| match item {
ContentItem::InputText { text } => text.as_str(),
ContentItem::InputText { text } => text.clone(),
ContentItem::InputImage { .. } | ContentItem::OutputText { .. } => {
panic!("expected input text content, got {item:?}")
}
})
.collect::<String>();
(role.as_str(), text)
.collect::<Vec<_>>();
(role.as_str(), texts)
}
other => panic!("expected developer message, got {other:?}"),
})
.collect::<Vec<_>>(),
vec![
("developer", "first tide note".to_string()),
("developer", "second tide note".to_string()),
],
vec![(
"developer",
vec![
"<hook_context>first tide note</hook_context>".to_string(),
"<hook_context>second tide note</hook_context>".to_string(),
],
)],
);
}

View File

@@ -180,6 +180,50 @@ else:
Ok(())
}
fn write_parallel_session_start_hooks(home: &Path, contexts: &[&str]) -> Result<()> {
let hook_entries = contexts
.iter()
.enumerate()
.map(|(index, context)| {
let script_path = home.join(format!("session_start_hook_{index}.py"));
let context_json =
serde_json::to_string(context).context("serialize session start context")?;
let script = format!(
r#"import json
print(json.dumps({{
"hookSpecificOutput": {{
"hookEventName": "SessionStart",
"additionalContext": {context_json}
}}
}}))
"#
);
fs::write(&script_path, script).with_context(|| {
format!(
"write session start hook script fixture at {}",
script_path.display()
)
})?;
Ok(serde_json::json!({
"type": "command",
"command": format!("python3 {}", script_path.display()),
}))
})
.collect::<Result<Vec<_>>>()?;
let hooks = serde_json::json!({
"hooks": {
"SessionStart": [{
"hooks": hook_entries,
}]
}
});
fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?;
Ok(())
}
fn write_user_prompt_submit_hook(
home: &Path,
blocked_prompt: &str,
@@ -741,6 +785,11 @@ fn request_hook_prompt_texts(
fn spilled_hook_output_path(text: &str) -> Option<&str> {
text.lines()
.find_map(|line| line.strip_prefix("Full hook output saved to: "))
.map(|path| {
path.strip_suffix("</hook_context>")
.or_else(|| path.strip_suffix("</hook_context_sticky>"))
.unwrap_or(path)
})
}
fn read_stop_hook_inputs(home: &Path) -> Result<Vec<serde_json::Value>> {
@@ -1058,12 +1107,65 @@ async fn session_start_hook_spills_large_additional_context() -> Result<()> {
.find(|message| spilled_hook_output_path(message).is_some())
.context("spilled developer hook message")?;
assert!(developer_message.contains("tokens truncated"));
assert!(developer_message.contains("<hook_context_sticky>"));
let path = spilled_hook_output_path(developer_message).context("spill path")?;
assert_eq!(fs::read_to_string(path)?, additional_context);
Ok(())
}
#[tokio::test]
async fn parallel_session_start_additional_contexts_share_one_developer_message() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let response = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "merged hook context observed"),
ev_completed("resp-1"),
]),
)
.await;
let contexts = ["first tide note", "second tide note"];
let mut builder = test_codex()
.with_pre_build_hook(move |home| {
if let Err(error) = write_parallel_session_start_hooks(home, &contexts) {
panic!("failed to write parallel session start hook fixtures: {error}");
}
})
.with_config(trust_discovered_hooks);
let test = builder.build(&server).await?;
test.submit_turn("hello").await?;
let request = response.single_request();
let merged_developer_messages = request
.input()
.iter()
.filter(|item| item.get("role").and_then(Value::as_str) == Some("developer"))
.filter(|item| {
let joined = item["content"]
.as_array()
.into_iter()
.flatten()
.filter_map(|content| content.get("text").and_then(Value::as_str))
.collect::<Vec<_>>()
.join("\n");
contexts.iter().all(|context| {
joined.contains(&format!(
"<hook_context_sticky>{context}</hook_context_sticky>"
))
})
})
.count();
assert_eq!(merged_developer_messages, 1);
Ok(())
}
#[tokio::test]
async fn pre_tool_use_hook_spills_large_additional_context() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -1334,9 +1436,9 @@ async fn blocked_user_prompt_submit_persists_additional_context_for_next_turn()
let request = response.single_request();
assert!(
request
.message_input_texts("developer")
.contains(&BLOCKED_PROMPT_CONTEXT.to_string()),
request.message_input_texts("developer").contains(&format!(
"<hook_context>{BLOCKED_PROMPT_CONTEXT}</hook_context>"
)),
"second request should include developer context persisted from the blocked prompt",
);
assert!(
@@ -2103,7 +2205,7 @@ async fn pre_tool_use_records_additional_context_for_shell_command() -> Result<(
assert!(
requests[1]
.message_input_texts("developer")
.contains(&pre_context.to_string()),
.contains(&format!("<hook_context>{pre_context}</hook_context>")),
"follow-up request should include pre tool use additional context",
);
let output_item = requests[1].function_call_output(call_id);
@@ -2176,7 +2278,7 @@ async fn blocked_pre_tool_use_records_additional_context_for_shell_command() ->
assert!(
requests[1]
.message_input_texts("developer")
.contains(&pre_context.to_string()),
.contains(&format!("<hook_context>{pre_context}</hook_context>")),
"follow-up request should include blocked pre tool use additional context",
);
let output_item = requests[1].function_call_output(call_id);
@@ -3172,7 +3274,7 @@ async fn post_tool_use_records_additional_context_for_shell_command() -> Result<
assert!(
requests[1]
.message_input_texts("developer")
.contains(&post_context.to_string()),
.contains(&format!("<hook_context>{post_context}</hook_context>")),
"follow-up request should include post tool use additional context",
);
let output_item = requests[1].function_call_output(call_id);
@@ -3636,7 +3738,7 @@ async fn post_tool_use_records_additional_context_for_apply_patch() -> Result<()
assert!(
requests[1]
.message_input_texts("developer")
.contains(&post_context.to_string()),
.contains(&format!("<hook_context>{post_context}</hook_context>")),
"follow-up request should include apply_patch post tool use context",
);
assert!(
@@ -3711,7 +3813,7 @@ async fn post_tool_use_records_apply_patch_context_with_edit_alias() -> Result<(
assert!(
requests[1]
.message_input_texts("developer")
.contains(&post_context.to_string()),
.contains(&format!("<hook_context>{post_context}</hook_context>")),
"follow-up request should include apply_patch post tool use context",
);
assert!(

View File

@@ -411,7 +411,7 @@ async fn post_tool_use_records_mcp_tool_payload_and_context() -> Result<()> {
assert!(
final_request
.message_input_texts("developer")
.contains(&post_context.to_string()),
.contains(&format!("<hook_context>{post_context}</hook_context>")),
"follow-up request should include MCP post tool use additional context",
);
let output_item = final_request.function_call_output(call_id);