diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 4dcc5d4455..ed21f299cc 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1968,8 +1968,18 @@ impl Session { let mut active = self.active_turn.lock().await; match active.as_mut() { Some(at) => { + let pending_skills = input + .iter() + .filter_map(|item| match item { + UserInput::Skill { name, path } => Some((name.clone(), path.clone())), + _ => None, + }) + .collect::>(); let mut ts = at.turn_state.lock().await; - ts.push_pending_input(input.into()); + ts.push_pending_input(ResponseInputItem::from(input)); + if !pending_skills.is_empty() { + ts.push_pending_skill_mentions(pending_skills); + } Ok(()) } None => Err(input), @@ -3080,10 +3090,40 @@ pub(crate) async fn run_turn( .map(ResponseItem::from) .collect::>(); + let pending_skill_inputs = { + let mut active = sess.active_turn.lock().await; + if let Some(at) = active.as_mut() { + let mut ts = at.turn_state.lock().await; + ts.take_pending_skill_mentions() + .into_iter() + .map(|(name, path)| UserInput::Skill { name, path }) + .collect::>() + } else { + Vec::new() + } + }; + let SkillInjections { + items: pending_skill_items, + warnings: pending_skill_warnings, + } = build_skill_injections( + &pending_skill_inputs, + skills_outcome.as_ref(), + Some(&otel_manager), + ) + .await; + for message in pending_skill_warnings { + sess.send_event(&turn_context, EventMsg::Warning(WarningEvent { message })) + .await; + } + // Construct the input that we will send to the model. let sampling_request_input: Vec = { sess.record_conversation_items(&turn_context, &pending_input) .await; + if !pending_skill_items.is_empty() { + sess.record_conversation_items(&turn_context, &pending_skill_items) + .await; + } sess.clone_history().await.for_prompt() }; @@ -3672,6 +3712,7 @@ mod tests { use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; + use crate::state::ActiveTurn; use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; @@ -3685,7 +3726,9 @@ mod tests { use crate::turn_diff_tracker::TurnDiffTracker; use codex_app_server_protocol::AuthMode; use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; + use codex_protocol::user_input::UserInput; use std::path::Path; use std::time::Duration; use tokio::time::sleep; @@ -4672,6 +4715,40 @@ mod tests { } } + #[tokio::test] + async fn inject_input_preserves_skills_for_pending_turns() { + let (session, _turn_context) = make_session_and_context().await; + { + let mut active = session.active_turn.lock().await; + *active = Some(ActiveTurn::default()); + } + + let skill_name = String::from("test-skill"); + let skill_path = PathBuf::from("/tmp/test-skill/SKILL.md"); + let input = vec![UserInput::Skill { + name: skill_name.clone(), + path: skill_path.clone(), + }]; + let expected_pending = ResponseInputItem::from(input.clone()); + + session + .inject_input(input) + .await + .expect("inject into active turn"); + + let turn_state = { + let mut active = session.active_turn.lock().await; + let at = active.as_mut().expect("active turn present"); + Arc::clone(&at.turn_state) + }; + let mut ts = turn_state.lock().await; + assert_eq!( + vec![(skill_name, skill_path)], + ts.take_pending_skill_mentions() + ); + assert_eq!(vec![expected_pending], ts.take_pending_input()); + } + #[derive(Clone, Copy)] struct NeverEndingTask { kind: TaskKind, diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index ccc50d066b..9d36f43202 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -2,6 +2,7 @@ use indexmap::IndexMap; use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; use tokio::sync::Notify; @@ -73,6 +74,7 @@ pub(crate) struct TurnState { pending_user_input: HashMap>, pending_dynamic_tools: HashMap>, pending_input: Vec, + pending_skill_mentions: Vec<(String, PathBuf)>, } impl TurnState { @@ -96,6 +98,7 @@ impl TurnState { self.pending_user_input.clear(); self.pending_dynamic_tools.clear(); self.pending_input.clear(); + self.pending_skill_mentions.clear(); } pub(crate) fn insert_pending_user_input( @@ -142,6 +145,14 @@ impl TurnState { } } + pub(crate) fn push_pending_skill_mentions(&mut self, skills: Vec<(String, PathBuf)>) { + self.pending_skill_mentions.extend(skills); + } + + pub(crate) fn take_pending_skill_mentions(&mut self) -> Vec<(String, PathBuf)> { + std::mem::take(&mut self.pending_skill_mentions) + } + pub(crate) fn has_pending_input(&self) -> bool { !self.pending_input.is_empty() }