mark terminal turns unsteerable

This commit is contained in:
Roy Han
2026-05-22 13:12:04 -07:00
parent 738d424509
commit 3b4cbc4fc0
6 changed files with 134 additions and 22 deletions

View File

@@ -234,6 +234,14 @@ pub(super) async fn user_input_or_turn_inner(
Some(items)
}
Err(SteerInputError::NoActiveTurn(items)) => {
if sess.has_terminal_active_turn().await {
sess.send_event_raw(Event {
id: sub_id,
msg: EventMsg::Error(SteerInputError::NoActiveTurn(items).to_error_event()),
})
.await;
return;
}
if let Some(responsesapi_client_metadata) = responsesapi_client_metadata {
current_context
.turn_metadata_state

View File

@@ -116,6 +116,16 @@ impl InputQueue {
turn_state.pending_input.items.clear();
}
pub(crate) async fn mark_terminal_and_clear_pending_for_turn_state(
&self,
turn_state: &Mutex<TurnState>,
) {
let mut turn_state = turn_state.lock().await;
turn_state.mark_terminal();
turn_state.clear_pending_waiters();
turn_state.pending_input.items.clear();
}
pub(crate) async fn defer_mailbox_delivery_to_next_turn(
&self,
active_turn: &Mutex<Option<ActiveTurn>>,
@@ -159,10 +169,14 @@ impl InputQueue {
&self,
turn_state: &Mutex<TurnState>,
input: TurnInput,
) {
) -> Result<(), TurnInput> {
let mut turn_state = turn_state.lock().await;
if turn_state.is_terminal() {
return Err(input);
}
turn_state.pending_input.items.push(input);
turn_state.accept_mailbox_delivery_for_current_turn();
Ok(())
}
pub(crate) async fn extend_pending_input_for_turn_state(
@@ -192,14 +206,14 @@ impl InputQueue {
let mut active = active_turn.lock().await;
match active.as_mut() {
Some(active_turn) => {
self.extend_pending_input_for_turn_state(
active_turn.turn_state.as_ref(),
input
.into_iter()
.map(TurnInput::ResponseInputItem)
.collect(),
)
.await;
let mut turn_state = active_turn.turn_state.lock().await;
if turn_state.is_terminal() {
return Err(input);
}
turn_state
.pending_input
.items
.extend(input.into_iter().map(TurnInput::ResponseInputItem));
Ok(())
}
None => Err(input),
@@ -219,6 +233,9 @@ impl InputQueue {
match active.as_mut() {
Some(active_turn) => {
let mut turn_state = active_turn.turn_state.lock().await;
if turn_state.is_terminal() {
return Vec::new();
}
(
turn_state.pending_input.items.split_off(0),
turn_state.accepts_mailbox_delivery_for_current_turn(),
@@ -254,6 +271,9 @@ impl InputQueue {
match active.as_ref() {
Some(active_turn) => {
let turn_state = active_turn.turn_state.lock().await;
if turn_state.is_terminal() {
return false;
}
(
!turn_state.pending_input.items.is_empty(),
turn_state.accepts_mailbox_delivery_for_current_turn(),

View File

@@ -3191,6 +3191,17 @@ impl Session {
return Err(SteerInputError::EmptyInput);
}
self.input_queue
.push_pending_input_and_accept_mailbox_delivery_for_turn_state(
active_turn.turn_state.as_ref(),
TurnInput::UserInput(input),
)
.await
.map_err(|input| match input {
TurnInput::UserInput(input) => SteerInputError::NoActiveTurn(input),
TurnInput::ResponseInputItem(_) => unreachable!("steer input must be user input"),
})?;
if let Some(responsesapi_client_metadata) = responsesapi_client_metadata
&& let Some((_, active_task)) = active_turn.tasks.first()
{
@@ -3200,15 +3211,22 @@ impl Session {
.set_responsesapi_client_metadata(responsesapi_client_metadata);
}
self.input_queue
.push_pending_input_and_accept_mailbox_delivery_for_turn_state(
active_turn.turn_state.as_ref(),
TurnInput::UserInput(input),
)
.await;
Ok(active_turn_id.clone())
}
pub(crate) async fn has_terminal_active_turn(&self) -> bool {
let turn_state = {
let active = self.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
let Some(turn_state) = turn_state else {
return false;
};
turn_state.lock().await.is_terminal()
}
/// Returns the input if there was no task running to inject into.
pub async fn inject_response_items(
&self,

View File

@@ -8120,6 +8120,62 @@ async fn steer_input_requires_active_turn() {
assert!(matches!(err, SteerInputError::NoActiveTurn(_)));
}
#[tokio::test]
async fn terminal_active_turn_rejects_new_pending_input() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
let turn_state = {
let active_turn = sess.active_turn.lock().await;
Arc::clone(&active_turn.as_ref().expect("active turn").turn_state)
};
sess.input_queue
.mark_terminal_and_clear_pending_for_turn_state(turn_state.as_ref())
.await;
let steer_input = vec![UserInput::Text {
text: "late steer".to_string(),
text_elements: Vec::new(),
}];
let err = sess
.steer_input(
steer_input,
Some(&tc.sub_id),
/*responsesapi_client_metadata*/ None,
)
.await
.expect_err("terminal turn should reject steering");
assert!(matches!(err, SteerInputError::NoActiveTurn(_)));
let injected_item = ResponseInputItem::Message {
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "late injected input".to_string(),
}],
phase: None,
};
let rejected_items = sess
.inject_response_items(vec![injected_item.clone()])
.await
.expect_err("terminal turn should reject injected input");
assert_eq!(rejected_items, vec![injected_item]);
assert!(
!sess.input_queue.has_pending_input(&sess.active_turn).await,
"terminal turn should not accept new pending input"
);
}
#[tokio::test]
async fn steer_input_enforces_expected_turn_id() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;

View File

@@ -117,6 +117,7 @@ pub(crate) struct TurnState {
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pub(crate) pending_input: TurnInputQueue,
mailbox_delivery_phase: MailboxDeliveryPhase,
terminal: bool,
granted_permissions: Option<AdditionalPermissionProfile>,
strict_auto_review_enabled: bool,
pub(crate) tool_calls: u64,
@@ -220,17 +221,29 @@ impl TurnState {
}
pub(crate) fn accept_mailbox_delivery_for_current_turn(&mut self) {
if self.terminal {
return;
}
self.set_mailbox_delivery_phase(MailboxDeliveryPhase::CurrentTurn);
}
pub(crate) fn accepts_mailbox_delivery_for_current_turn(&self) -> bool {
self.mailbox_delivery_phase == MailboxDeliveryPhase::CurrentTurn
!self.terminal && self.mailbox_delivery_phase == MailboxDeliveryPhase::CurrentTurn
}
pub(crate) fn set_mailbox_delivery_phase(&mut self, phase: MailboxDeliveryPhase) {
self.mailbox_delivery_phase = phase;
}
pub(crate) fn mark_terminal(&mut self) {
self.terminal = true;
self.set_mailbox_delivery_phase(MailboxDeliveryPhase::NextTurn);
}
pub(crate) fn is_terminal(&self) -> bool {
self.terminal
}
pub(crate) fn record_granted_permissions(&mut self, permissions: AdditionalPermissionProfile) {
self.granted_permissions =
merge_permission_profiles(self.granted_permissions.as_ref(), Some(&permissions));

View File

@@ -98,12 +98,9 @@ impl SessionTask for RegularTask {
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
turn_state.lock().await.clear_pending_waiters();
drop(
sess.input_queue
.take_pending_input_for_turn_state(turn_state.as_ref())
.await,
);
sess.input_queue
.mark_terminal_and_clear_pending_for_turn_state(turn_state.as_ref())
.await;
}
return None;
}