diff --git a/codex-rs/core/src/session/handlers.rs b/codex-rs/core/src/session/handlers.rs index 3d57bd7074..8798db200c 100644 --- a/codex-rs/core/src/session/handlers.rs +++ b/codex-rs/core/src/session/handlers.rs @@ -234,6 +234,14 @@ pub(super) async fn user_input_or_turn_inner( Some(items) } Err(SteerInputError::NoActiveTurn(items)) => { + if sess.has_terminal_active_turn().await { + sess.send_event_raw(Event { + id: sub_id, + msg: EventMsg::Error(SteerInputError::NoActiveTurn(items).to_error_event()), + }) + .await; + return; + } if let Some(responsesapi_client_metadata) = responsesapi_client_metadata { current_context .turn_metadata_state diff --git a/codex-rs/core/src/session/input_queue.rs b/codex-rs/core/src/session/input_queue.rs index 5f92322c8d..561604386d 100644 --- a/codex-rs/core/src/session/input_queue.rs +++ b/codex-rs/core/src/session/input_queue.rs @@ -116,6 +116,16 @@ impl InputQueue { turn_state.pending_input.items.clear(); } + pub(crate) async fn mark_terminal_and_clear_pending_for_turn_state( + &self, + turn_state: &Mutex, + ) { + let mut turn_state = turn_state.lock().await; + turn_state.mark_terminal(); + turn_state.clear_pending_waiters(); + turn_state.pending_input.items.clear(); + } + pub(crate) async fn defer_mailbox_delivery_to_next_turn( &self, active_turn: &Mutex>, @@ -159,10 +169,14 @@ impl InputQueue { &self, turn_state: &Mutex, input: TurnInput, - ) { + ) -> Result<(), TurnInput> { let mut turn_state = turn_state.lock().await; + if turn_state.is_terminal() { + return Err(input); + } turn_state.pending_input.items.push(input); turn_state.accept_mailbox_delivery_for_current_turn(); + Ok(()) } pub(crate) async fn extend_pending_input_for_turn_state( @@ -192,14 +206,14 @@ impl InputQueue { let mut active = active_turn.lock().await; match active.as_mut() { Some(active_turn) => { - self.extend_pending_input_for_turn_state( - active_turn.turn_state.as_ref(), - input - .into_iter() - .map(TurnInput::ResponseInputItem) - .collect(), - ) - .await; + let mut turn_state = active_turn.turn_state.lock().await; + if turn_state.is_terminal() { + return Err(input); + } + turn_state + .pending_input + .items + .extend(input.into_iter().map(TurnInput::ResponseInputItem)); Ok(()) } None => Err(input), @@ -219,6 +233,9 @@ impl InputQueue { match active.as_mut() { Some(active_turn) => { let mut turn_state = active_turn.turn_state.lock().await; + if turn_state.is_terminal() { + return Vec::new(); + } ( turn_state.pending_input.items.split_off(0), turn_state.accepts_mailbox_delivery_for_current_turn(), @@ -254,6 +271,9 @@ impl InputQueue { match active.as_ref() { Some(active_turn) => { let turn_state = active_turn.turn_state.lock().await; + if turn_state.is_terminal() { + return false; + } ( !turn_state.pending_input.items.is_empty(), turn_state.accepts_mailbox_delivery_for_current_turn(), diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index 675e794aee..be08958ce4 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -3191,6 +3191,17 @@ impl Session { return Err(SteerInputError::EmptyInput); } + self.input_queue + .push_pending_input_and_accept_mailbox_delivery_for_turn_state( + active_turn.turn_state.as_ref(), + TurnInput::UserInput(input), + ) + .await + .map_err(|input| match input { + TurnInput::UserInput(input) => SteerInputError::NoActiveTurn(input), + TurnInput::ResponseInputItem(_) => unreachable!("steer input must be user input"), + })?; + if let Some(responsesapi_client_metadata) = responsesapi_client_metadata && let Some((_, active_task)) = active_turn.tasks.first() { @@ -3200,15 +3211,22 @@ impl Session { .set_responsesapi_client_metadata(responsesapi_client_metadata); } - self.input_queue - .push_pending_input_and_accept_mailbox_delivery_for_turn_state( - active_turn.turn_state.as_ref(), - TurnInput::UserInput(input), - ) - .await; Ok(active_turn_id.clone()) } + pub(crate) async fn has_terminal_active_turn(&self) -> bool { + let turn_state = { + let active = self.active_turn.lock().await; + active + .as_ref() + .map(|active_turn| Arc::clone(&active_turn.turn_state)) + }; + let Some(turn_state) = turn_state else { + return false; + }; + turn_state.lock().await.is_terminal() + } + /// Returns the input if there was no task running to inject into. pub async fn inject_response_items( &self, diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index d4d0c28660..a373e9d790 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -8120,6 +8120,62 @@ async fn steer_input_requires_active_turn() { assert!(matches!(err, SteerInputError::NoActiveTurn(_))); } +#[tokio::test] +async fn terminal_active_turn_rejects_new_pending_input() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + let turn_state = { + let active_turn = sess.active_turn.lock().await; + Arc::clone(&active_turn.as_ref().expect("active turn").turn_state) + }; + sess.input_queue + .mark_terminal_and_clear_pending_for_turn_state(turn_state.as_ref()) + .await; + + let steer_input = vec![UserInput::Text { + text: "late steer".to_string(), + text_elements: Vec::new(), + }]; + let err = sess + .steer_input( + steer_input, + Some(&tc.sub_id), + /*responsesapi_client_metadata*/ None, + ) + .await + .expect_err("terminal turn should reject steering"); + assert!(matches!(err, SteerInputError::NoActiveTurn(_))); + + let injected_item = ResponseInputItem::Message { + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late injected input".to_string(), + }], + phase: None, + }; + let rejected_items = sess + .inject_response_items(vec![injected_item.clone()]) + .await + .expect_err("terminal turn should reject injected input"); + assert_eq!(rejected_items, vec![injected_item]); + assert!( + !sess.input_queue.has_pending_input(&sess.active_turn).await, + "terminal turn should not accept new pending input" + ); +} + #[tokio::test] async fn steer_input_enforces_expected_turn_id() { let (sess, tc, _rx) = make_session_and_context_with_rx().await; diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index 86438ad56a..4fe22494e5 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -117,6 +117,7 @@ pub(crate) struct TurnState { pending_dynamic_tools: HashMap>, pub(crate) pending_input: TurnInputQueue, mailbox_delivery_phase: MailboxDeliveryPhase, + terminal: bool, granted_permissions: Option, strict_auto_review_enabled: bool, pub(crate) tool_calls: u64, @@ -220,17 +221,29 @@ impl TurnState { } pub(crate) fn accept_mailbox_delivery_for_current_turn(&mut self) { + if self.terminal { + return; + } self.set_mailbox_delivery_phase(MailboxDeliveryPhase::CurrentTurn); } pub(crate) fn accepts_mailbox_delivery_for_current_turn(&self) -> bool { - self.mailbox_delivery_phase == MailboxDeliveryPhase::CurrentTurn + !self.terminal && self.mailbox_delivery_phase == MailboxDeliveryPhase::CurrentTurn } pub(crate) fn set_mailbox_delivery_phase(&mut self, phase: MailboxDeliveryPhase) { self.mailbox_delivery_phase = phase; } + pub(crate) fn mark_terminal(&mut self) { + self.terminal = true; + self.set_mailbox_delivery_phase(MailboxDeliveryPhase::NextTurn); + } + + pub(crate) fn is_terminal(&self) -> bool { + self.terminal + } + pub(crate) fn record_granted_permissions(&mut self, permissions: AdditionalPermissionProfile) { self.granted_permissions = merge_permission_profiles(self.granted_permissions.as_ref(), Some(&permissions)); diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 1258f7ea53..3f6383e81c 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -98,12 +98,9 @@ impl SessionTask for RegularTask { .map(|active_turn| Arc::clone(&active_turn.turn_state)) }; if let Some(turn_state) = turn_state { - turn_state.lock().await.clear_pending_waiters(); - drop( - sess.input_queue - .take_pending_input_for_turn_state(turn_state.as_ref()) - .await, - ); + sess.input_queue + .mark_terminal_and_clear_pending_for_turn_state(turn_state.as_ref()) + .await; } return None; }