Compare commits

...

8 Commits

Author SHA1 Message Date
Joe Gershenson
0e570952f4 test: preserve abort event coverage 2026-04-21 01:25:41 -07:00
Joe Gershenson
bbc664079a test: cover blocked session-start interrupt recovery 2026-04-21 01:20:56 -07:00
Joe Gershenson
0e331a685d codex: harden turn-start interrupt recovery 2026-04-21 01:14:29 -07:00
Joe Gershenson
5e7458e515 codex: document interrupt hook ordering tradeoff 2026-04-21 01:01:25 -07:00
Joe Gershenson
c0e4efdfb5 codex: serialize interrupt turn start recovery 2026-04-21 00:58:06 -07:00
Joe Gershenson
dd7aaab825 codex: harden interrupt prompt drain 2026-04-21 00:40:08 -07:00
Joe Gershenson
051accdb6a codex: address PR review feedback (#18817) 2026-04-21 00:20:45 -07:00
Joe Gershenson
553376ffcb Fix interrupt prompt history race 2026-04-20 23:49:29 -07:00
9 changed files with 393 additions and 70 deletions

View File

@@ -17,6 +17,7 @@ use codex_hooks::UserPromptSubmitRequest;
use codex_otel::HOOK_RUN_DURATION_METRIC;
use codex_otel::HOOK_RUN_METRIC;
use codex_protocol::items::TurnItem;
use codex_protocol::items::UserMessageItem;
use codex_protocol::models::DeveloperInstructions;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
@@ -30,8 +31,10 @@ use codex_protocol::protocol::HookSource;
use codex_protocol::protocol::HookStartedEvent;
use codex_protocol::user_input::UserInput;
use serde_json::Value;
use tokio::sync::OwnedSemaphorePermit;
use crate::event_mapping::parse_turn_item;
use crate::session::PreviousTurnSettings;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::tools::sandboxing::PermissionRequestPayload;
@@ -57,6 +60,11 @@ pub(crate) enum PendingInputRecord {
},
}
pub(crate) enum TurnStartTranscriptDrainMode {
RegularTurn,
InterruptRecovery,
}
struct ContextInjectingHookOutcome {
hook_events: Vec<HookCompletedEvent>,
outcome: HookRuntimeOutcome,
@@ -268,6 +276,107 @@ pub(crate) async fn inspect_pending_input(
}
}
pub(crate) async fn drain_turn_start_transcript_inputs(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
mode: TurnStartTranscriptDrainMode,
) -> bool {
let Ok(permit) = turn_context
.transcript_serialization_lock
.clone()
.acquire_owned()
.await
else {
return false;
};
drain_turn_start_transcript_inputs_with_permit(sess, turn_context, mode, permit).await
}
pub(crate) async fn drain_turn_start_transcript_inputs_with_permit(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
mode: TurnStartTranscriptDrainMode,
_permit: OwnedSemaphorePermit,
) -> bool {
let has_queued_start_input = !turn_context.lock_turn_start_transcript_inputs().is_empty();
if !has_queued_start_input && matches!(mode, TurnStartTranscriptDrainMode::InterruptRecovery) {
return true;
}
// Keep the normal turn-start ordering in one serialized region: context first,
// then the user prompt, then hook-provided context and previous-turn settings.
sess.record_context_updates_and_set_reference_context_item(turn_context.as_ref())
.await;
if run_pending_session_start_hooks(sess, turn_context).await {
turn_context.lock_turn_start_transcript_inputs().clear();
return false;
}
let mut recorded_start_input = false;
loop {
let input = {
let inputs = turn_context.lock_turn_start_transcript_inputs();
inputs.first().map(|queued| queued.input.clone())
};
let Some(input) = input else {
break;
};
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input.clone());
let response_item: ResponseItem = initial_input_for_turn.into();
let user_prompt_submit_outcome = run_user_prompt_submit_hooks(
sess,
turn_context,
UserMessageItem::new(&input).message(),
)
.await;
if user_prompt_submit_outcome.should_stop {
record_additional_contexts(
sess,
turn_context,
user_prompt_submit_outcome.additional_contexts,
)
.await;
let mut inputs = turn_context.lock_turn_start_transcript_inputs();
if inputs.first().is_some_and(|queued| queued.input == input) {
inputs.remove(0);
}
return false;
}
sess.record_user_prompt_and_emit_turn_item(
turn_context.as_ref(),
input.as_slice(),
response_item,
)
.await;
record_additional_contexts(
sess,
turn_context,
user_prompt_submit_outcome.additional_contexts,
)
.await;
let mut inputs = turn_context.lock_turn_start_transcript_inputs();
if inputs.first().is_some_and(|queued| queued.input == input) {
inputs.remove(0);
}
recorded_start_input = true;
}
if recorded_start_input {
sess.set_previous_turn_settings(Some(PreviousTurnSettings {
model: turn_context.model_info.slug.clone(),
realtime_active: Some(turn_context.realtime_active),
}))
.await;
}
true
}
pub(crate) async fn record_pending_input(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,

View File

@@ -128,6 +128,7 @@ use rmcp::model::RequestId;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::Semaphore;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;

View File

@@ -139,6 +139,8 @@ pub(super) async fn spawn_review_thread(
turn_metadata_state,
turn_skills: TurnSkillsContext::new(parent_turn_context.turn_skills.outcome.clone()),
turn_timing_state: Arc::new(TurnTimingState::default()),
turn_start_transcript_inputs: Arc::new(std::sync::Mutex::new(Vec::new())),
transcript_serialization_lock: Arc::new(Semaphore::new(1)),
};
// Seed the child task with the review prompt as the initial user message.

View File

@@ -144,7 +144,6 @@ use sha2::Sha512;
use std::path::Path;
use std::time::Duration;
use tokio::sync::Semaphore;
use tokio::time::sleep;
use tokio::time::timeout;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use wiremock::Mock;
@@ -306,12 +305,23 @@ async fn interrupting_regular_turn_waiting_on_startup_prewarm_emits_turn_aborted
),
)
.await;
sess.spawn_task(
Arc::clone(&tc),
Vec::new(),
crate::tasks::RegularTask::new(),
)
.await;
let input = vec![UserInput::Text {
text: "hello before prewarm".to_string(),
text_elements: Vec::new(),
}];
let mut expected_history = sess.build_initial_context(tc.as_ref()).await;
expected_history.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello before prewarm".to_string(),
}],
end_turn: None,
phase: None,
});
expected_history.push(crate::tasks::interrupted_turn_history_marker());
sess.spawn_task(Arc::clone(&tc), input, crate::tasks::RegularTask::new())
.await;
let first = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv())
.await
@@ -324,23 +334,112 @@ async fn interrupting_regular_turn_waiting_on_startup_prewarm_emits_turn_aborted
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
let (turn_id, reason, completed_at, duration_ms) =
tokio::time::timeout(std::time::Duration::from_secs(2), async {
loop {
let event = rx.recv().await.expect("channel open");
if let EventMsg::TurnAborted(TurnAbortedEvent {
turn_id,
reason,
completed_at,
duration_ms,
}) = event.msg
{
return (turn_id, reason, completed_at, duration_ms);
}
}
})
.await
.expect("expected turn aborted event")
.expect("channel open");
let EventMsg::TurnAborted(TurnAbortedEvent {
turn_id,
reason,
completed_at,
duration_ms,
}) = second.msg
else {
panic!("expected turn aborted event");
};
.expect("expected turn aborted event");
assert_eq!(turn_id, Some(tc.sub_id.clone()));
assert_eq!(reason, TurnAbortReason::Interrupted);
assert!(completed_at.is_some());
assert!(duration_ms.is_some());
let history = sess.clone_history().await;
assert_eq!(history.raw_items(), expected_history.as_slice());
assert!(tc.lock_turn_start_transcript_inputs().is_empty());
assert_eq!(
sess.previous_turn_settings().await,
Some(PreviousTurnSettings {
model: tc.model_info.slug.clone(),
realtime_active: Some(tc.realtime_active),
})
);
}
#[tokio::test]
async fn session_start_stop_clears_queued_start_input() -> anyhow::Result<()> {
let (mut sess, tc) = make_session_and_context().await;
std::fs::create_dir_all(&tc.config.codex_home)?;
let hook_script = tc.config.codex_home.join("block_session_start_hook.py");
std::fs::write(
&hook_script,
r#"import json
import sys
json.load(sys.stdin)
print(json.dumps({"continue": False, "stopReason": "blocked for test"}))
"#,
)?;
std::fs::write(
tc.config.codex_home.join("hooks.json"),
json!({
"hooks": {
"SessionStart": [{
"hooks": [{
"type": "command",
"command": format!("python3 {}", hook_script.display()),
}]
}]
}
})
.to_string(),
)?;
sess.services.hooks = Hooks::new(HooksConfig {
feature_enabled: true,
config_layer_stack: Some(tc.config.config_layer_stack.clone()),
..HooksConfig::default()
});
let sess = Arc::new(sess);
let tc = Arc::new(tc);
let input = vec![UserInput::Text {
text: "prompt blocked by session start".to_string(),
text_elements: Vec::new(),
}];
sess.state
.lock()
.await
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
tc.lock_turn_start_transcript_inputs()
.push(crate::session::turn_context::TurnStartTranscriptInput { input });
assert!(
!crate::hook_runtime::drain_turn_start_transcript_inputs(
&sess,
&tc,
crate::hook_runtime::TurnStartTranscriptDrainMode::RegularTurn,
)
.await
);
assert!(tc.lock_turn_start_transcript_inputs().is_empty());
assert!(
crate::hook_runtime::drain_turn_start_transcript_inputs(
&sess,
&tc,
crate::hook_runtime::TurnStartTranscriptDrainMode::InterruptRecovery,
)
.await
);
let history = sess.clone_history().await;
assert!(
user_input_texts(history.raw_items())
.iter()
.all(|text| !text.contains("prompt blocked by session start"))
);
Ok(())
}
fn test_model_client_session() -> crate::client::ModelClientSession {
@@ -5669,9 +5768,7 @@ impl SessionTask for NeverEndingTask {
cancellation_token.cancelled().await;
return None;
}
loop {
sleep(Duration::from_secs(60)).await;
}
std::future::pending::<Option<String>>().await
}
}
@@ -5742,6 +5839,62 @@ async fn abort_gracefully_emits_turn_aborted_only() {
assert!(rx.try_recv().is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_non_regular_task_keeps_pending_session_start_source() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
sess.state
.lock()
.await
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
let input = vec![UserInput::Text {
text: "review this".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Review,
listen_to_cancellation_token: true,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
assert!(matches!(
sess.take_pending_session_start_source().await,
Some(codex_hooks::SessionStartSource::Startup)
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_regular_task_without_start_input_keeps_pending_session_start_source() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
sess.state
.lock()
.await
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
sess.spawn_task(
Arc::clone(&tc),
Vec::new(),
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
assert!(matches!(
sess.take_pending_session_start_source().await,
Some(codex_hooks::SessionStartSource::Startup)
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;

View File

@@ -18,12 +18,13 @@ use crate::compact_remote::run_inline_remote_auto_compact_task;
use crate::connectors;
use crate::feedback_tags;
use crate::hook_runtime::PendingInputHookDisposition;
use crate::hook_runtime::TurnStartTranscriptDrainMode;
use crate::hook_runtime::drain_turn_start_transcript_inputs;
use crate::hook_runtime::emit_hook_completed_events;
use crate::hook_runtime::inspect_pending_input;
use crate::hook_runtime::record_additional_contexts;
use crate::hook_runtime::record_pending_input;
use crate::hook_runtime::run_pending_session_start_hooks;
use crate::hook_runtime::run_user_prompt_submit_hooks;
use crate::injection::ToolMentionKind;
use crate::injection::app_id_from_path;
use crate::injection::tool_kind_for_path;
@@ -37,7 +38,6 @@ use crate::mentions::collect_tool_mentions_from_messages;
use crate::parse_turn_item;
use crate::plugins::build_plugin_injections;
use crate::resolve_skill_dependencies_for_turn;
use crate::session::PreviousTurnSettings;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::stream_events_utils::HandleOutputCtx;
@@ -74,7 +74,6 @@ use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CodexResult;
use codex_protocol::items::PlanItem;
use codex_protocol::items::TurnItem;
use codex_protocol::items::UserMessageItem;
use codex_protocol::items::build_hook_prompt_message;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
@@ -158,9 +157,6 @@ pub(crate) async fn run_turn(
let skills_outcome = Some(turn_context.turn_skills.outcome.as_ref());
sess.record_context_updates_and_set_reference_context_item(turn_context.as_ref())
.await;
let loaded_plugins = sess
.services
.plugins_manager
@@ -285,33 +281,15 @@ pub(crate) async fn run_turn(
})
.collect::<Vec<_>>();
if run_pending_session_start_hooks(&sess, &turn_context).await {
if !drain_turn_start_transcript_inputs(
&sess,
&turn_context,
TurnStartTranscriptDrainMode::RegularTurn,
)
.await
{
return None;
}
let additional_contexts = if input.is_empty() {
Vec::new()
} else {
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input.clone());
let response_item: ResponseItem = initial_input_for_turn.clone().into();
let user_prompt_submit_outcome = run_user_prompt_submit_hooks(
&sess,
&turn_context,
UserMessageItem::new(&input).message(),
)
.await;
if user_prompt_submit_outcome.should_stop {
record_additional_contexts(
&sess,
&turn_context,
user_prompt_submit_outcome.additional_contexts,
)
.await;
return None;
}
sess.record_user_prompt_and_emit_turn_item(turn_context.as_ref(), &input, response_item)
.await;
user_prompt_submit_outcome.additional_contexts
};
sess.services
.analytics_events_client
.track_app_mentioned(tracking.clone(), mentioned_app_invocations);
@@ -322,17 +300,6 @@ pub(crate) async fn run_turn(
}
sess.merge_connector_selection(explicitly_enabled_connectors.clone())
.await;
record_additional_contexts(&sess, &turn_context, additional_contexts).await;
if !input.is_empty() {
// Track the previous-turn baseline from the regular user-turn path only so
// standalone tasks (compact/shell/review/undo) cannot suppress future
// model/realtime injections.
sess.set_previous_turn_settings(Some(PreviousTurnSettings {
model: turn_context.model_info.slug.clone(),
realtime_active: Some(turn_context.realtime_active),
}))
.await;
}
let agent_task = match sess.ensure_agent_task_registered().await {
Ok(agent_task) => agent_task,
Err(error) => {

View File

@@ -24,6 +24,11 @@ impl TurnSkillsContext {
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct TurnStartTranscriptInput {
pub(crate) input: Vec<UserInput>,
}
/// The context needed for a single turn of the thread.
#[derive(Debug)]
pub(crate) struct TurnContext {
@@ -71,8 +76,18 @@ pub(crate) struct TurnContext {
pub(crate) turn_metadata_state: Arc<TurnMetadataState>,
pub(crate) turn_skills: TurnSkillsContext,
pub(crate) turn_timing_state: Arc<TurnTimingState>,
pub(crate) turn_start_transcript_inputs: Arc<std::sync::Mutex<Vec<TurnStartTranscriptInput>>>,
pub(crate) transcript_serialization_lock: Arc<Semaphore>,
}
impl TurnContext {
pub(crate) fn lock_turn_start_transcript_inputs(
&self,
) -> std::sync::MutexGuard<'_, Vec<TurnStartTranscriptInput>> {
self.turn_start_transcript_inputs
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(crate) fn model_context_window(&self) -> Option<i64> {
let effective_context_window_percent = self.model_info.effective_context_window_percent;
self.model_info
@@ -197,6 +212,8 @@ impl TurnContext {
turn_metadata_state: self.turn_metadata_state.clone(),
turn_skills: self.turn_skills.clone(),
turn_timing_state: Arc::clone(&self.turn_timing_state),
turn_start_transcript_inputs: Arc::clone(&self.turn_start_transcript_inputs),
transcript_serialization_lock: Arc::clone(&self.transcript_serialization_lock),
}
}
@@ -443,6 +460,8 @@ impl Session {
turn_metadata_state,
turn_skills: TurnSkillsContext::new(skills_outcome),
turn_timing_state: Arc::new(TurnTimingState::default()),
turn_start_transcript_inputs: Arc::new(std::sync::Mutex::new(Vec::new())),
transcript_serialization_lock: Arc::new(Semaphore::new(1)),
}
}

View File

@@ -70,6 +70,7 @@ pub(crate) struct RunningTask {
pub(crate) done: Arc<Notify>,
pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn AnySessionTask>,
pub(crate) records_turn_start_transcript: bool,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
pub(crate) turn_context: Arc<TurnContext>,

View File

@@ -22,11 +22,15 @@ use tracing::warn;
use crate::contextual_user_message::TURN_ABORTED_CLOSE_TAG;
use crate::contextual_user_message::TURN_ABORTED_OPEN_TAG;
use crate::hook_runtime::PendingInputHookDisposition;
use crate::hook_runtime::TurnStartTranscriptDrainMode;
use crate::hook_runtime::drain_turn_start_transcript_inputs;
use crate::hook_runtime::drain_turn_start_transcript_inputs_with_permit;
use crate::hook_runtime::inspect_pending_input;
use crate::hook_runtime::record_additional_contexts;
use crate::hook_runtime::record_pending_input;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::session::turn_context::TurnStartTranscriptInput;
use crate::state::ActiveTurn;
use crate::state::RunningTask;
use crate::state::TaskKind;
@@ -160,6 +164,11 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
/// Returns the tracing name for a spawned task span.
fn span_name(&self) -> &'static str;
/// Whether this task owns the normal user-turn transcript start sequence.
fn records_turn_start_transcript(&self) -> bool {
false
}
/// Executes the task until completion or cancellation.
///
/// Implementations typically stream protocol events using `session` and
@@ -197,6 +206,8 @@ pub(crate) trait AnySessionTask: Send + Sync + 'static {
fn span_name(&self) -> &'static str;
fn records_turn_start_transcript(&self) -> bool;
fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
@@ -224,6 +235,10 @@ where
SessionTask::span_name(self)
}
fn records_turn_start_transcript(&self) -> bool {
SessionTask::records_turn_start_transcript(self)
}
fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
@@ -270,6 +285,7 @@ impl Session {
let task: Arc<dyn AnySessionTask> = Arc::new(task);
let task_kind = task.kind();
let span_name = task.span_name();
let records_turn_start_transcript = task.records_turn_start_transcript();
let started_at = Instant::now();
turn_context
.turn_timing_state
@@ -299,6 +315,13 @@ impl Session {
}
}
if records_turn_start_transcript && !input.is_empty() {
turn_context
.lock_turn_start_transcript_inputs()
.push(TurnStartTranscriptInput {
input: input.clone(),
});
}
let mut active = self.active_turn.lock().await;
let turn = active.get_or_insert_with(ActiveTurn::default);
debug_assert!(turn.tasks.is_empty());
@@ -358,6 +381,7 @@ impl Session {
handle: Arc::new(AbortOnDropHandle::new(handle)),
kind: task_kind,
task,
records_turn_start_transcript,
cancellation_token,
turn_context: Arc::clone(&turn_context),
_timer: timer,
@@ -618,15 +642,58 @@ impl Session {
.cancel_git_enrichment_task();
let session_task = task.task;
select! {
_ = task.done.notified() => {
},
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
let mut handle_aborted = false;
// The startup prompt queue is serialized by the regular turn itself. If the serializer is
// idle, grab it and abort the task before recovering the queued prompt; if the serializer
// is already active, wait for it to finish before force-aborting so history cannot be left
// half-written. This can wait on turn-start hooks, but it keeps hook policy and hook-added
// context ahead of the model-visible interrupt marker.
if reason == TurnAbortReason::Interrupted && task.records_turn_start_transcript {
let has_turn_start_transcript_input = {
let inputs = task.turn_context.lock_turn_start_transcript_inputs();
!inputs.is_empty()
};
if has_turn_start_transcript_input {
match task
.turn_context
.transcript_serialization_lock
.clone()
.try_acquire_owned()
{
Ok(permit) => {
task.handle.abort();
handle_aborted = true;
let _ = drain_turn_start_transcript_inputs_with_permit(
self,
&task.turn_context,
TurnStartTranscriptDrainMode::InterruptRecovery,
permit,
)
.await;
}
Err(_) => {
let _ = drain_turn_start_transcript_inputs(
self,
&task.turn_context,
TurnStartTranscriptDrainMode::InterruptRecovery,
)
.await;
}
}
}
}
task.handle.abort();
if !handle_aborted {
select! {
_ = task.done.notified() => {
},
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
}
}
task.handle.abort();
}
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task

View File

@@ -33,6 +33,10 @@ impl SessionTask for RegularTask {
"session_task.turn"
}
fn records_turn_start_transcript(&self) -> bool {
true
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,