diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 4dcc5d4455..15099dc6f2 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -429,6 +429,7 @@ pub(crate) struct Session { features: Features, pending_mcp_server_refresh_config: Mutex>, pub(crate) active_turn: Mutex>, + pending_follow_up: Mutex>, pub(crate) services: SessionServices, next_internal_sub_id: AtomicU64, } @@ -842,6 +843,7 @@ impl Session { features: config.features.clone(), pending_mcp_server_refresh_config: Mutex::new(None), active_turn: Mutex::new(None), + pending_follow_up: Mutex::new(Vec::new()), services, next_internal_sub_id: AtomicU64::new(0), }); @@ -1994,18 +1996,33 @@ impl Session { } } - pub async fn get_pending_input(&self) -> Vec { - let mut active = self.active_turn.lock().await; - match active.as_mut() { - Some(at) => { - let mut ts = at.turn_state.lock().await; - ts.take_pending_input() - } - None => Vec::with_capacity(0), + pub(crate) async fn push_follow_up_items(&self, mut items: Vec) { + if items.is_empty() { + return; } + let mut pending_follow_up = self.pending_follow_up.lock().await; + pending_follow_up.append(&mut items); + } + + pub(crate) async fn take_follow_up_items(&self) -> Vec { + let mut pending_follow_up = self.pending_follow_up.lock().await; + std::mem::take(&mut *pending_follow_up) + } + + pub async fn get_pending_input(&self) -> Vec { + let mut items = self.take_follow_up_items().await; + let mut active = self.active_turn.lock().await; + if let Some(at) = active.as_mut() { + let mut ts = at.turn_state.lock().await; + items.extend(ts.take_pending_input()); + } + items } pub async fn has_pending_input(&self) -> bool { + if !self.pending_follow_up.lock().await.is_empty() { + return true; + } let active = self.active_turn.lock().await; match active.as_ref() { Some(at) => { @@ -3672,6 +3689,8 @@ mod tests { use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; + use crate::state::ActiveTurn; + use crate::state::RunningTask; use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; @@ -3685,10 +3704,15 @@ mod tests { use crate::turn_diff_tracker::TurnDiffTracker; use codex_app_server_protocol::AuthMode; use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; + use codex_protocol::user_input::UserInput; use std::path::Path; use std::time::Duration; + use tokio::sync::Notify; use tokio::time::sleep; + use tokio_util::sync::CancellationToken; + use tokio_util::task::AbortOnDropHandle; use mcp_types::ContentBlock; use mcp_types::TextContent; @@ -4480,6 +4504,7 @@ mod tests { features: config.features.clone(), pending_mcp_server_refresh_config: Mutex::new(None), active_turn: Mutex::new(None), + pending_follow_up: Mutex::new(Vec::new()), services, next_internal_sub_id: AtomicU64::new(0), }; @@ -4591,6 +4616,7 @@ mod tests { features: config.features.clone(), pending_mcp_server_refresh_config: Mutex::new(None), active_turn: Mutex::new(None), + pending_follow_up: Mutex::new(Vec::new()), services, next_internal_sub_id: AtomicU64::new(0), }); @@ -4672,6 +4698,46 @@ mod tests { } } + #[tokio::test] + async fn drains_pending_input_on_finish_into_follow_up_buffer() { + let (sess, turn_context, _rx) = make_session_and_context_with_rx().await; + let pending_input = ResponseInputItem::from(vec![UserInput::Text { + text: "follow up".to_string(), + text_elements: Vec::new(), + }]); + let done = Arc::new(Notify::new()); + let cancellation_token = CancellationToken::new(); + let handle = tokio::spawn(async {}); + let running_task = RunningTask { + done, + kind: TaskKind::Regular, + task: Arc::new(NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }), + cancellation_token, + handle: Arc::new(AbortOnDropHandle::new(handle)), + turn_context: Arc::clone(&turn_context), + _timer: None, + }; + let mut turn = ActiveTurn::default(); + turn.add_task(running_task); + { + let mut ts = turn.turn_state.lock().await; + ts.push_pending_input(pending_input.clone()); + } + { + let mut active = sess.active_turn.lock().await; + *active = Some(turn); + } + + sess.on_task_finished(Arc::clone(&turn_context), None).await; + + assert_eq!(true, sess.has_pending_input().await); + assert_eq!(vec![pending_input], sess.get_pending_input().await); + assert_eq!(false, sess.has_pending_input().await); + } + #[derive(Clone, Copy)] struct NeverEndingTask { kind: TaskKind, diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 08d23f7987..0456ac2f58 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -182,15 +182,22 @@ impl Session { last_agent_message: Option, ) { let mut active = self.active_turn.lock().await; - let should_close_processes = if let Some(at) = active.as_mut() + let (should_close_processes, pending_follow_up) = if let Some(at) = active.as_mut() && at.remove_task(&turn_context.sub_id) { + let pending_follow_up = { + let mut ts = at.turn_state.lock().await; + ts.take_pending_input() + }; *active = None; - true + (true, pending_follow_up) } else { - false + (false, Vec::new()) }; drop(active); + if should_close_processes && !pending_follow_up.is_empty() { + self.push_follow_up_items(pending_follow_up).await; + } if should_close_processes { self.close_unified_exec_processes().await; }