From 3bb8e69dd33fee1022825154cacc81fb40278750 Mon Sep 17 00:00:00 2001 From: Matthew Zeng Date: Tue, 27 Jan 2026 19:02:45 -0800 Subject: [PATCH] [skills] Auto install MCP dependencies when running skils with dependency specs. (#9982) Auto install MCP dependencies when running skils with dependency specs. --- .../app-server-protocol/src/protocol/v2.rs | 62 ++- .../app-server/src/codex_message_processor.rs | 16 + codex-rs/cli/src/mcp_cmd.rs | 47 +- codex-rs/core/config.schema.json | 6 + codex-rs/core/src/codex.rs | 125 ++++- codex-rs/core/src/default_client.rs | 16 + codex-rs/core/src/features.rs | 8 + codex-rs/core/src/mcp/auth.rs | 41 ++ codex-rs/core/src/mcp/mod.rs | 4 + codex-rs/core/src/mcp/skill_dependencies.rs | 518 ++++++++++++++++++ codex-rs/core/src/skills/injection.rs | 501 ++++++++++++++++- codex-rs/core/src/skills/loader.rs | 267 ++++++++- codex-rs/core/src/skills/mod.rs | 1 + codex-rs/core/src/skills/model.rs | 16 + codex-rs/core/src/state/session.rs | 14 + codex-rs/protocol/src/protocol.rs | 30 +- codex-rs/tui/src/bottom_pane/mod.rs | 1 + codex-rs/tui/src/chatwidget/skills.rs | 19 + 18 files changed, 1591 insertions(+), 101 deletions(-) create mode 100644 codex-rs/core/src/mcp/skill_dependencies.rs diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index e4b3b90589..0116bd21e9 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -27,10 +27,12 @@ use codex_protocol::protocol::NetworkAccess as CoreNetworkAccess; use codex_protocol::protocol::RateLimitSnapshot as CoreRateLimitSnapshot; use codex_protocol::protocol::RateLimitWindow as CoreRateLimitWindow; use codex_protocol::protocol::SessionSource as CoreSessionSource; +use codex_protocol::protocol::SkillDependencies as CoreSkillDependencies; use codex_protocol::protocol::SkillErrorInfo as CoreSkillErrorInfo; use codex_protocol::protocol::SkillInterface as CoreSkillInterface; use codex_protocol::protocol::SkillMetadata as CoreSkillMetadata; use codex_protocol::protocol::SkillScope as CoreSkillScope; +use codex_protocol::protocol::SkillToolDependency as CoreSkillToolDependency; use codex_protocol::protocol::SubAgentSource as CoreSubAgentSource; use codex_protocol::protocol::TokenUsage as CoreTokenUsage; use codex_protocol::protocol::TokenUsageInfo as CoreTokenUsageInfo; @@ -1395,11 +1397,14 @@ pub struct SkillMetadata { pub description: String, #[serde(default, skip_serializing_if = "Option::is_none")] #[ts(optional)] - /// Legacy short_description from SKILL.md. Prefer SKILL.toml interface.short_description. + /// Legacy short_description from SKILL.md. Prefer SKILL.json interface.short_description. pub short_description: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[ts(optional)] pub interface: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub dependencies: Option, pub path: PathBuf, pub scope: SkillScope, pub enabled: bool, @@ -1423,6 +1428,35 @@ pub struct SkillInterface { pub default_prompt: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct SkillDependencies { + pub tools: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct SkillToolDependency { + #[serde(rename = "type")] + #[ts(rename = "type")] + pub r#type: String, + pub value: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub transport: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub command: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub url: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1462,6 +1496,7 @@ impl From for SkillMetadata { description: value.description, short_description: value.short_description, interface: value.interface.map(SkillInterface::from), + dependencies: value.dependencies.map(SkillDependencies::from), path: value.path, scope: value.scope.into(), enabled: true, @@ -1482,6 +1517,31 @@ impl From for SkillInterface { } } +impl From for SkillDependencies { + fn from(value: CoreSkillDependencies) -> Self { + Self { + tools: value + .tools + .into_iter() + .map(SkillToolDependency::from) + .collect(), + } + } +} + +impl From for SkillToolDependency { + fn from(value: CoreSkillToolDependency) -> Self { + Self { + r#type: value.r#type, + value: value.value, + description: value.description, + transport: value.transport, + command: value.command, + url: value.url, + } + } +} + impl From for SkillScope { fn from(value: CoreSkillScope) -> Self { match value { diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 589b88d9a2..e7acf3d290 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -4522,6 +4522,22 @@ fn skills_to_info( default_prompt: interface.default_prompt, } }), + dependencies: skill.dependencies.clone().map(|dependencies| { + codex_app_server_protocol::SkillDependencies { + tools: dependencies + .tools + .into_iter() + .map(|tool| codex_app_server_protocol::SkillToolDependency { + r#type: tool.r#type, + value: tool.value, + description: tool.description, + transport: tool.transport, + command: tool.command, + url: tool.url, + }) + .collect(), + } + }), path: skill.path.clone(), scope: skill.scope.into(), enabled, diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 7cc42c4c49..83de37e027 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -13,11 +13,12 @@ use codex_core::config::find_codex_home; use codex_core::config::load_global_mcp_servers; use codex_core::config::types::McpServerConfig; use codex_core::config::types::McpServerTransportConfig; +use codex_core::mcp::auth::McpOAuthLoginSupport; use codex_core::mcp::auth::compute_auth_statuses; +use codex_core::mcp::auth::oauth_login_support; use codex_core::protocol::McpAuthStatus; use codex_rmcp_client::delete_oauth_tokens; use codex_rmcp_client::perform_oauth_login; -use codex_rmcp_client::supports_oauth_login; /// Subcommands: /// - `list` — list configured servers (with `--json`) @@ -260,33 +261,25 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re println!("Added global MCP server '{name}'."); - if let McpServerTransportConfig::StreamableHttp { - url, - bearer_token_env_var: None, - http_headers, - env_http_headers, - } = transport - { - match supports_oauth_login(&url).await { - Ok(true) => { - println!("Detected OAuth support. Starting OAuth flow…"); - perform_oauth_login( - &name, - &url, - config.mcp_oauth_credentials_store_mode, - http_headers.clone(), - env_http_headers.clone(), - &Vec::new(), - config.mcp_oauth_callback_port, - ) - .await?; - println!("Successfully logged in."); - } - Ok(false) => {} - Err(_) => println!( - "MCP server may or may not require login. Run `codex mcp login {name}` to login." - ), + match oauth_login_support(&transport).await { + McpOAuthLoginSupport::Supported(oauth_config) => { + println!("Detected OAuth support. Starting OAuth flow…"); + perform_oauth_login( + &name, + &oauth_config.url, + config.mcp_oauth_credentials_store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &Vec::new(), + config.mcp_oauth_callback_port, + ) + .await?; + println!("Successfully logged in."); } + McpOAuthLoginSupport::Unsupported => {} + McpOAuthLoginSupport::Unknown(_) => println!( + "MCP server may or may not require login. Run `codex mcp login {name}` to login." + ), } Ok(()) diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 1e164b24cb..1fe5f9e75d 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -198,6 +198,9 @@ "shell_tool": { "type": "boolean" }, + "skill_mcp_dependency_install": { + "type": "boolean" + }, "steer": { "type": "boolean" }, @@ -1190,6 +1193,9 @@ "shell_tool": { "type": "boolean" }, + "skill_mcp_dependency_install": { + "type": "boolean" + }, "steer": { "type": "boolean" }, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index c2c25967d2..d4f2822a73 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -115,6 +115,7 @@ use crate::instructions::UserInstructions; use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; use crate::mcp::auth::compute_auth_statuses; use crate::mcp::effective_mcp_servers; +use crate::mcp::maybe_prompt_and_install_mcp_dependencies; use crate::mcp::with_codex_apps_mcp; use crate::mcp_connection_manager::McpConnectionManager; use crate::model_provider_info::CHAT_WIRE_API_DEPRECATION_SUMMARY; @@ -138,9 +139,11 @@ use crate::protocol::RequestUserInputEvent; use crate::protocol::ReviewDecision; use crate::protocol::SandboxPolicy; use crate::protocol::SessionConfiguredEvent; +use crate::protocol::SkillDependencies as ProtocolSkillDependencies; use crate::protocol::SkillErrorInfo; use crate::protocol::SkillInterface as ProtocolSkillInterface; use crate::protocol::SkillMetadata as ProtocolSkillMetadata; +use crate::protocol::SkillToolDependency as ProtocolSkillToolDependency; use crate::protocol::StreamErrorEvent; use crate::protocol::Submission; use crate::protocol::TokenCountEvent; @@ -158,6 +161,7 @@ use crate::skills::SkillInjections; use crate::skills::SkillMetadata; use crate::skills::SkillsManager; use crate::skills::build_skill_injections; +use crate::skills::collect_explicit_skill_mentions; use crate::state::ActiveTurn; use crate::state::SessionServices; use crate::state::SessionState; @@ -1857,6 +1861,19 @@ impl Session { self.send_token_count_event(turn_context).await; } + pub(crate) async fn mcp_dependency_prompted(&self) -> HashSet { + let state = self.state.lock().await; + state.mcp_dependency_prompted() + } + + pub(crate) async fn record_mcp_dependency_prompted(&self, names: I) + where + I: IntoIterator, + { + let mut state = self.state.lock().await; + state.record_mcp_dependency_prompted(names); + } + pub(crate) async fn set_server_reasoning_included(&self, included: bool) { let mut state = self.state.lock().await; state.set_server_reasoning_included(included); @@ -2101,35 +2118,12 @@ impl Session { Arc::clone(&self.services.user_shell) } - async fn refresh_mcp_servers_if_requested(&self, turn_context: &TurnContext) { - let refresh_config = { self.pending_mcp_server_refresh_config.lock().await.take() }; - let Some(refresh_config) = refresh_config else { - return; - }; - - let McpServerRefreshConfig { - mcp_servers, - mcp_oauth_credentials_store_mode, - } = refresh_config; - - let mcp_servers = - match serde_json::from_value::>(mcp_servers) { - Ok(servers) => servers, - Err(err) => { - warn!("failed to parse MCP server refresh config: {err}"); - return; - } - }; - let store_mode = match serde_json::from_value::( - mcp_oauth_credentials_store_mode, - ) { - Ok(mode) => mode, - Err(err) => { - warn!("failed to parse MCP OAuth refresh config: {err}"); - return; - } - }; - + async fn refresh_mcp_servers_inner( + &self, + turn_context: &TurnContext, + mcp_servers: HashMap, + store_mode: OAuthCredentialsStoreMode, + ) { let auth = self.services.auth_manager.auth().await; let config = self.get_config().await; let mcp_servers = with_codex_apps_mcp( @@ -2162,6 +2156,49 @@ impl Session { *manager = refreshed_manager; } + async fn refresh_mcp_servers_if_requested(&self, turn_context: &TurnContext) { + let refresh_config = { self.pending_mcp_server_refresh_config.lock().await.take() }; + let Some(refresh_config) = refresh_config else { + return; + }; + + let McpServerRefreshConfig { + mcp_servers, + mcp_oauth_credentials_store_mode, + } = refresh_config; + + let mcp_servers = + match serde_json::from_value::>(mcp_servers) { + Ok(servers) => servers, + Err(err) => { + warn!("failed to parse MCP server refresh config: {err}"); + return; + } + }; + let store_mode = match serde_json::from_value::( + mcp_oauth_credentials_store_mode, + ) { + Ok(mode) => mode, + Err(err) => { + warn!("failed to parse MCP OAuth refresh config: {err}"); + return; + } + }; + + self.refresh_mcp_servers_inner(turn_context, mcp_servers, store_mode) + .await; + } + + pub(crate) async fn refresh_mcp_servers_now( + &self, + turn_context: &TurnContext, + mcp_servers: HashMap, + store_mode: OAuthCredentialsStoreMode, + ) { + self.refresh_mcp_servers_inner(turn_context, mcp_servers, store_mode) + .await; + } + async fn mcp_startup_cancellation_token(&self) -> CancellationToken { self.services .mcp_startup_cancellation_token @@ -2985,6 +3022,22 @@ fn skills_to_info( brand_color: interface.brand_color, default_prompt: interface.default_prompt, }), + dependencies: skill.dependencies.clone().map(|dependencies| { + ProtocolSkillDependencies { + tools: dependencies + .tools + .into_iter() + .map(|tool| ProtocolSkillToolDependency { + r#type: tool.r#type, + value: tool.value, + description: tool.description, + transport: tool.transport, + command: tool.command, + url: tool.url, + }) + .collect(), + } + }), path: skill.path.clone(), scope: skill.scope, enabled: !disabled_paths.contains(&skill.path), @@ -3044,11 +3097,23 @@ pub(crate) async fn run_turn( .await, ); + let mentioned_skills = skills_outcome.as_ref().map_or_else(Vec::new, |outcome| { + collect_explicit_skill_mentions(&input, &outcome.skills, &outcome.disabled_paths) + }); + + maybe_prompt_and_install_mcp_dependencies( + sess.as_ref(), + turn_context.as_ref(), + &cancellation_token, + &mentioned_skills, + ) + .await; + let otel_manager = turn_context.client.get_otel_manager(); let SkillInjections { items: skill_items, warnings: skill_warnings, - } = build_skill_injections(&input, skills_outcome.as_ref(), Some(&otel_manager)).await; + } = build_skill_injections(&mentioned_skills, Some(&otel_manager)).await; for message in skill_warnings { sess.send_event(&turn_context, EventMsg::Warning(WarningEvent { message })) diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 4ded10a3d9..67a5daf1b1 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -95,6 +95,12 @@ pub fn originator() -> Originator { get_originator_value(None) } +pub fn is_first_party_originator(originator_value: &str) -> bool { + originator_value == DEFAULT_ORIGINATOR + || originator_value == "codex_vscode" + || originator_value.starts_with("Codex ") +} + pub fn get_codex_user_agent() -> String { let build_version = env!("CARGO_PKG_VERSION"); let os_info = os_info::get(); @@ -185,6 +191,7 @@ fn is_sandboxed() -> bool { mod tests { use super::*; use core_test_support::skip_if_no_network; + use pretty_assertions::assert_eq; #[test] fn test_get_codex_user_agent() { @@ -194,6 +201,15 @@ mod tests { assert!(user_agent.starts_with(&prefix)); } + #[test] + fn is_first_party_originator_matches_known_values() { + assert_eq!(is_first_party_originator(DEFAULT_ORIGINATOR), true); + assert_eq!(is_first_party_originator("codex_vscode"), true); + assert_eq!(is_first_party_originator("Codex Something Else"), true); + assert_eq!(is_first_party_originator("codex_cli"), false); + assert_eq!(is_first_party_originator("Other"), false); + } + #[tokio::test] async fn test_create_client_sets_default_headers() { skip_if_no_network!(); diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 7fde52687c..cc56970276 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -109,6 +109,8 @@ pub enum Feature { Collab, /// Enable connectors (apps). Connectors, + /// Allow prompting and installing missing MCP dependencies. + SkillMcpDependencyInstall, /// Steer feature flag - when enabled, Enter submits immediately instead of queuing. Steer, /// Enable collaboration modes (Plan, Code, Pair Programming, Execute). @@ -449,6 +451,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::UnderDevelopment, default_enabled: false, }, + FeatureSpec { + id: Feature::SkillMcpDependencyInstall, + key: "skill_mcp_dependency_install", + stage: Stage::Stable, + default_enabled: true, + }, FeatureSpec { id: Feature::Steer, key: "steer", diff --git a/codex-rs/core/src/mcp/auth.rs b/codex-rs/core/src/mcp/auth.rs index e321a857bb..f095c930dc 100644 --- a/codex-rs/core/src/mcp/auth.rs +++ b/codex-rs/core/src/mcp/auth.rs @@ -4,12 +4,53 @@ use anyhow::Result; use codex_protocol::protocol::McpAuthStatus; use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::determine_streamable_http_auth_status; +use codex_rmcp_client::supports_oauth_login; use futures::future::join_all; use tracing::warn; use crate::config::types::McpServerConfig; use crate::config::types::McpServerTransportConfig; +#[derive(Debug, Clone)] +pub struct McpOAuthLoginConfig { + pub url: String, + pub http_headers: Option>, + pub env_http_headers: Option>, +} + +#[derive(Debug)] +pub enum McpOAuthLoginSupport { + Supported(McpOAuthLoginConfig), + Unsupported, + Unknown(anyhow::Error), +} + +pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport { + let McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var, + http_headers, + env_http_headers, + } = transport + else { + return McpOAuthLoginSupport::Unsupported; + }; + + if bearer_token_env_var.is_some() { + return McpOAuthLoginSupport::Unsupported; + } + + match supports_oauth_login(url).await { + Ok(true) => McpOAuthLoginSupport::Supported(McpOAuthLoginConfig { + url: url.clone(), + http_headers: http_headers.clone(), + env_http_headers: env_http_headers.clone(), + }), + Ok(false) => McpOAuthLoginSupport::Unsupported, + Err(err) => McpOAuthLoginSupport::Unknown(err), + } +} + #[derive(Debug, Clone)] pub struct McpAuthStatusEntry { pub config: McpServerConfig, diff --git a/codex-rs/core/src/mcp/mod.rs b/codex-rs/core/src/mcp/mod.rs index 1bb08a75be..fa1de0f045 100644 --- a/codex-rs/core/src/mcp/mod.rs +++ b/codex-rs/core/src/mcp/mod.rs @@ -1,4 +1,8 @@ pub mod auth; +mod skill_dependencies; + +pub(crate) use skill_dependencies::maybe_prompt_and_install_mcp_dependencies; + use std::collections::HashMap; use std::env; use std::path::PathBuf; diff --git a/codex-rs/core/src/mcp/skill_dependencies.rs b/codex-rs/core/src/mcp/skill_dependencies.rs new file mode 100644 index 0000000000..c295a96863 --- /dev/null +++ b/codex-rs/core/src/mcp/skill_dependencies.rs @@ -0,0 +1,518 @@ +use std::collections::HashMap; +use std::collections::HashSet; + +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::request_user_input::RequestUserInputArgs; +use codex_protocol::request_user_input::RequestUserInputQuestion; +use codex_protocol::request_user_input::RequestUserInputQuestionOption; +use codex_protocol::request_user_input::RequestUserInputResponse; +use codex_rmcp_client::perform_oauth_login; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +use super::auth::McpOAuthLoginSupport; +use super::auth::oauth_login_support; +use super::effective_mcp_servers; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::config::Config; +use crate::config::edit::ConfigEditsBuilder; +use crate::config::load_global_mcp_servers; +use crate::config::types::McpServerConfig; +use crate::config::types::McpServerTransportConfig; +use crate::default_client::is_first_party_originator; +use crate::default_client::originator; +use crate::features::Feature; +use crate::skills::SkillMetadata; +use crate::skills::model::SkillToolDependency; + +const SKILL_MCP_DEPENDENCY_PROMPT_ID: &str = "skill_mcp_dependency_install"; +const MCP_DEPENDENCY_OPTION_INSTALL: &str = "Install"; +const MCP_DEPENDENCY_OPTION_SKIP: &str = "Continue anyway"; + +fn is_full_access_mode(turn_context: &TurnContext) -> bool { + matches!(turn_context.approval_policy, AskForApproval::Never) + && matches!( + turn_context.sandbox_policy, + SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. } + ) +} + +fn format_missing_mcp_dependencies(missing: &HashMap) -> String { + let mut names = missing.keys().cloned().collect::>(); + names.sort(); + names.join(", ") +} + +async fn filter_prompted_mcp_dependencies( + sess: &Session, + missing: &HashMap, +) -> HashMap { + let prompted = sess.mcp_dependency_prompted().await; + if prompted.is_empty() { + return missing.clone(); + } + + missing + .iter() + .filter(|(name, config)| !prompted.contains(&canonical_mcp_server_key(name, config))) + .map(|(name, config)| (name.clone(), config.clone())) + .collect() +} + +async fn should_install_mcp_dependencies( + sess: &Session, + turn_context: &TurnContext, + missing: &HashMap, + cancellation_token: &CancellationToken, +) -> bool { + if is_full_access_mode(turn_context) { + return true; + } + + let server_list = format_missing_mcp_dependencies(missing); + let question = RequestUserInputQuestion { + id: SKILL_MCP_DEPENDENCY_PROMPT_ID.to_string(), + header: "Install MCP servers?".to_string(), + question: format!( + "The following MCP servers are required by the selected skills but are not installed yet: {server_list}. Install them now?" + ), + is_other: false, + options: Some(vec![ + RequestUserInputQuestionOption { + label: MCP_DEPENDENCY_OPTION_INSTALL.to_string(), + description: + "Install and enable the missing MCP servers in your global config." + .to_string(), + }, + RequestUserInputQuestionOption { + label: MCP_DEPENDENCY_OPTION_SKIP.to_string(), + description: "Skip installation for now and do not show again for these MCP servers in this session." + .to_string(), + }, + ]), + }; + let args = RequestUserInputArgs { + questions: vec![question], + }; + let sub_id = &turn_context.sub_id; + let call_id = format!("mcp-deps-{sub_id}"); + let response_fut = sess.request_user_input(turn_context, call_id, args); + let response = tokio::select! { + biased; + _ = cancellation_token.cancelled() => { + let empty = RequestUserInputResponse { + answers: HashMap::new(), + }; + sess.notify_user_input_response(sub_id, empty.clone()).await; + empty + } + response = response_fut => response.unwrap_or_else(|| RequestUserInputResponse { + answers: HashMap::new(), + }), + }; + + let install = response + .answers + .get(SKILL_MCP_DEPENDENCY_PROMPT_ID) + .is_some_and(|answer| { + answer + .answers + .iter() + .any(|entry| entry == MCP_DEPENDENCY_OPTION_INSTALL) + }); + + let prompted_keys = missing + .iter() + .map(|(name, config)| canonical_mcp_server_key(name, config)); + sess.record_mcp_dependency_prompted(prompted_keys).await; + + install +} + +pub(crate) async fn maybe_prompt_and_install_mcp_dependencies( + sess: &Session, + turn_context: &TurnContext, + cancellation_token: &CancellationToken, + mentioned_skills: &[SkillMetadata], +) { + let originator_value = originator().value; + if !is_first_party_originator(originator_value.as_str()) { + // Only support first-party clients for now. + return; + } + + let config = turn_context.client.config(); + if mentioned_skills.is_empty() || !config.features.enabled(Feature::SkillMcpDependencyInstall) { + return; + } + + let installed = config.mcp_servers.get().clone(); + let missing = collect_missing_mcp_dependencies(mentioned_skills, &installed); + if missing.is_empty() { + return; + } + + let unprompted_missing = filter_prompted_mcp_dependencies(sess, &missing).await; + if unprompted_missing.is_empty() { + return; + } + + if should_install_mcp_dependencies(sess, turn_context, &unprompted_missing, cancellation_token) + .await + { + maybe_install_mcp_dependencies(sess, turn_context, config.as_ref(), mentioned_skills).await; + } +} + +pub(crate) async fn maybe_install_mcp_dependencies( + sess: &Session, + turn_context: &TurnContext, + config: &Config, + mentioned_skills: &[SkillMetadata], +) { + if mentioned_skills.is_empty() || !config.features.enabled(Feature::SkillMcpDependencyInstall) { + return; + } + + let codex_home = config.codex_home.clone(); + let installed = config.mcp_servers.get().clone(); + let missing = collect_missing_mcp_dependencies(mentioned_skills, &installed); + if missing.is_empty() { + return; + } + + let mut servers = match load_global_mcp_servers(&codex_home).await { + Ok(servers) => servers, + Err(err) => { + warn!("failed to load MCP servers while installing skill dependencies: {err}"); + return; + } + }; + + let mut updated = false; + let mut added = Vec::new(); + for (name, config) in missing { + if servers.contains_key(&name) { + continue; + } + servers.insert(name.clone(), config.clone()); + added.push((name, config)); + updated = true; + } + + if !updated { + return; + } + + if let Err(err) = ConfigEditsBuilder::new(&codex_home) + .replace_mcp_servers(&servers) + .apply() + .await + { + warn!("failed to persist MCP dependencies for mentioned skills: {err}"); + return; + } + + for (name, server_config) in added { + let oauth_config = match oauth_login_support(&server_config.transport).await { + McpOAuthLoginSupport::Supported(config) => config, + McpOAuthLoginSupport::Unsupported => continue, + McpOAuthLoginSupport::Unknown(err) => { + warn!("MCP server may or may not require login for dependency {name}: {err}"); + continue; + } + }; + + sess.notify_background_event( + turn_context, + format!( + "Authenticating MCP {name}... Follow instructions in your browser if prompted." + ), + ) + .await; + + if let Err(err) = perform_oauth_login( + &name, + &oauth_config.url, + config.mcp_oauth_credentials_store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + config.mcp_oauth_callback_port, + ) + .await + { + warn!("failed to login to MCP dependency {name}: {err}"); + } + } + + // Refresh from the effective merged MCP map (global + repo + managed) and + // overlay the updated global servers so we don't drop repo-scoped servers. + let auth = sess.services.auth_manager.auth().await; + let mut refresh_servers = effective_mcp_servers(config, auth.as_ref()); + for (name, server_config) in &servers { + refresh_servers + .entry(name.clone()) + .or_insert_with(|| server_config.clone()); + } + sess.refresh_mcp_servers_now( + turn_context, + refresh_servers, + config.mcp_oauth_credentials_store_mode, + ) + .await; +} + +fn canonical_mcp_key(transport: &str, identifier: &str, fallback: &str) -> String { + let identifier = identifier.trim(); + if identifier.is_empty() { + fallback.to_string() + } else { + format!("mcp__{transport}__{identifier}") + } +} + +fn canonical_mcp_server_key(name: &str, config: &McpServerConfig) -> String { + match &config.transport { + McpServerTransportConfig::Stdio { command, .. } => { + canonical_mcp_key("stdio", command, name) + } + McpServerTransportConfig::StreamableHttp { url, .. } => { + canonical_mcp_key("streamable_http", url, name) + } + } +} + +fn canonical_mcp_dependency_key(dependency: &SkillToolDependency) -> Result { + let transport = dependency.transport.as_deref().unwrap_or("streamable_http"); + if transport.eq_ignore_ascii_case("streamable_http") { + let url = dependency + .url + .as_ref() + .ok_or_else(|| "missing url for streamable_http dependency".to_string())?; + return Ok(canonical_mcp_key("streamable_http", url, &dependency.value)); + } + if transport.eq_ignore_ascii_case("stdio") { + let command = dependency + .command + .as_ref() + .ok_or_else(|| "missing command for stdio dependency".to_string())?; + return Ok(canonical_mcp_key("stdio", command, &dependency.value)); + } + Err(format!("unsupported transport {transport}")) +} + +pub(crate) fn collect_missing_mcp_dependencies( + mentioned_skills: &[SkillMetadata], + installed: &HashMap, +) -> HashMap { + let mut missing = HashMap::new(); + let installed_keys: HashSet = installed + .iter() + .map(|(name, config)| canonical_mcp_server_key(name, config)) + .collect(); + let mut seen_canonical_keys = HashSet::new(); + + for skill in mentioned_skills { + let Some(dependencies) = skill.dependencies.as_ref() else { + continue; + }; + + for tool in &dependencies.tools { + if !tool.r#type.eq_ignore_ascii_case("mcp") { + continue; + } + let dependency_key = match canonical_mcp_dependency_key(tool) { + Ok(key) => key, + Err(err) => { + let dependency = tool.value.as_str(); + let skill_name = skill.name.as_str(); + warn!( + "unable to auto-install MCP dependency {dependency} for skill {skill_name}: {err}", + ); + continue; + } + }; + if installed_keys.contains(&dependency_key) + || seen_canonical_keys.contains(&dependency_key) + { + continue; + } + + let config = match mcp_dependency_to_server_config(tool) { + Ok(config) => config, + Err(err) => { + let dependency = dependency_key.as_str(); + let skill_name = skill.name.as_str(); + warn!( + "unable to auto-install MCP dependency {dependency} for skill {skill_name}: {err}", + ); + continue; + } + }; + + missing.insert(tool.value.clone(), config); + seen_canonical_keys.insert(dependency_key); + } + } + + missing +} + +fn mcp_dependency_to_server_config( + dependency: &SkillToolDependency, +) -> Result { + let transport = dependency.transport.as_deref().unwrap_or("streamable_http"); + if transport.eq_ignore_ascii_case("streamable_http") { + let url = dependency + .url + .as_ref() + .ok_or_else(|| "missing url for streamable_http dependency".to_string())?; + return Ok(McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: url.clone(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + }); + } + + if transport.eq_ignore_ascii_case("stdio") { + let command = dependency + .command + .as_ref() + .ok_or_else(|| "missing command for stdio dependency".to_string())?; + return Ok(McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: command.clone(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + }); + } + + Err(format!("unsupported transport {transport}")) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::skills::model::SkillDependencies; + use codex_protocol::protocol::SkillScope; + use pretty_assertions::assert_eq; + use std::path::PathBuf; + + fn skill_with_tools(tools: Vec) -> SkillMetadata { + SkillMetadata { + name: "skill".to_string(), + description: "skill".to_string(), + short_description: None, + interface: None, + dependencies: Some(SkillDependencies { tools }), + path: PathBuf::from("skill"), + scope: SkillScope::User, + } + } + + #[test] + fn collect_missing_respects_canonical_installed_key() { + let url = "https://example.com/mcp".to_string(); + let skills = vec![skill_with_tools(vec![SkillToolDependency { + r#type: "mcp".to_string(), + value: "github".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }])]; + let installed = HashMap::from([( + "alias".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + }, + )]); + + assert_eq!( + collect_missing_mcp_dependencies(&skills, &installed), + HashMap::new() + ); + } + + #[test] + fn collect_missing_dedupes_by_canonical_key_but_preserves_original_name() { + let url = "https://example.com/one".to_string(); + let skills = vec![skill_with_tools(vec![ + SkillToolDependency { + r#type: "mcp".to_string(), + value: "alias-one".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "alias-two".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }, + ])]; + + let expected = HashMap::from([( + "alias-one".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + }, + )]); + + assert_eq!( + collect_missing_mcp_dependencies(&skills, &HashMap::new()), + expected + ); + } +} diff --git a/codex-rs/core/src/skills/injection.rs b/codex-rs/core/src/skills/injection.rs index 9aa12d775c..65ec4dd510 100644 --- a/codex-rs/core/src/skills/injection.rs +++ b/codex-rs/core/src/skills/injection.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use std::path::PathBuf; use crate::instructions::SkillInstructions; -use crate::skills::SkillLoadOutcome; use crate::skills::SkillMetadata; use codex_otel::OtelManager; use codex_protocol::models::ResponseItem; @@ -16,20 +15,9 @@ pub(crate) struct SkillInjections { } pub(crate) async fn build_skill_injections( - inputs: &[UserInput], - skills: Option<&SkillLoadOutcome>, + mentioned_skills: &[SkillMetadata], otel: Option<&OtelManager>, ) -> SkillInjections { - if inputs.is_empty() { - return SkillInjections::default(); - } - - let Some(outcome) = skills else { - return SkillInjections::default(); - }; - - let mentioned_skills = - collect_explicit_skill_mentions(inputs, &outcome.skills, &outcome.disabled_paths); if mentioned_skills.is_empty() { return SkillInjections::default(); } @@ -42,15 +30,15 @@ pub(crate) async fn build_skill_injections( for skill in mentioned_skills { match fs::read_to_string(&skill.path).await { Ok(contents) => { - emit_skill_injected_metric(otel, &skill, "ok"); + emit_skill_injected_metric(otel, skill, "ok"); result.items.push(ResponseItem::from(SkillInstructions { - name: skill.name, + name: skill.name.clone(), path: skill.path.to_string_lossy().into_owned(), contents, })); } Err(err) => { - emit_skill_injected_metric(otel, &skill, "error"); + emit_skill_injected_metric(otel, skill, "error"); let message = format!( "Failed to load skill {name} at {path}: {err:#}", name = skill.name, @@ -76,23 +64,488 @@ fn emit_skill_injected_metric(otel: Option<&OtelManager>, skill: &SkillMetadata, ); } -fn collect_explicit_skill_mentions( +/// Collect explicitly mentioned skills from `$name` text mentions. +/// +/// Text inputs are scanned once to extract `$skill-name` tokens, then we iterate `skills` +/// in their existing order to preserve prior ordering semantics. +/// +/// Complexity: `O(S + T + N_t * S)` time, `O(S)` space, where: +/// `S` = number of skills, `T` = total text length, `N_t` = number of text inputs. +pub(crate) fn collect_explicit_skill_mentions( inputs: &[UserInput], skills: &[SkillMetadata], disabled_paths: &HashSet, ) -> Vec { let mut selected: Vec = Vec::new(); - let mut seen: HashSet = HashSet::new(); + let mut seen_names: HashSet = HashSet::new(); + let mut seen_paths: HashSet = HashSet::new(); for input in inputs { - if let UserInput::Skill { name, path } = input - && seen.insert(name.clone()) - && let Some(skill) = skills.iter().find(|s| s.name == *name && s.path == *path) - && !disabled_paths.contains(&skill.path) - { - selected.push(skill.clone()); + if let UserInput::Text { text, .. } = input { + let mentioned_names = extract_skill_mentions(text); + select_skills_from_mentions( + skills, + disabled_paths, + &mentioned_names, + &mut seen_names, + &mut seen_paths, + &mut selected, + ); } } selected } + +struct SkillMentions<'a> { + names: HashSet<&'a str>, + paths: HashSet<&'a str>, +} + +impl<'a> SkillMentions<'a> { + fn is_empty(&self) -> bool { + self.names.is_empty() && self.paths.is_empty() + } +} + +/// Extract `$skill-name` mentions from a single text input. +/// +/// Supports explicit resource links in the form `[$skill-name](resource path)`. When a +/// resource path is present, it is captured for exact path matching while also tracking +/// the name for fallback matching. +fn extract_skill_mentions(text: &str) -> SkillMentions<'_> { + let text_bytes = text.as_bytes(); + let mut mentioned_names: HashSet<&str> = HashSet::new(); + let mut mentioned_paths: HashSet<&str> = HashSet::new(); + + let mut index = 0; + while index < text_bytes.len() { + let byte = text_bytes[index]; + if byte == b'[' + && let Some((name, path, end_index)) = + parse_linked_skill_mention(text, text_bytes, index) + { + if !is_common_env_var(name) { + mentioned_names.insert(name); + mentioned_paths.insert(path); + } + index = end_index; + continue; + } + + if byte != b'$' { + index += 1; + continue; + } + + let name_start = index + 1; + let Some(first_name_byte) = text_bytes.get(name_start) else { + index += 1; + continue; + }; + if !is_skill_name_char(*first_name_byte) { + index += 1; + continue; + } + + let mut name_end = name_start + 1; + while let Some(next_byte) = text_bytes.get(name_end) + && is_skill_name_char(*next_byte) + { + name_end += 1; + } + + let name = &text[name_start..name_end]; + if !is_common_env_var(name) { + mentioned_names.insert(name); + } + index = name_end; + } + + SkillMentions { + names: mentioned_names, + paths: mentioned_paths, + } +} + +/// Select mentioned skills while preserving the order of `skills`. +fn select_skills_from_mentions( + skills: &[SkillMetadata], + disabled_paths: &HashSet, + mentions: &SkillMentions<'_>, + seen_names: &mut HashSet, + seen_paths: &mut HashSet, + selected: &mut Vec, +) { + if mentions.is_empty() { + return; + } + + for skill in skills { + if disabled_paths.contains(&skill.path) || seen_paths.contains(&skill.path) { + continue; + } + + let path_str = skill.path.to_string_lossy(); + if mentions.paths.contains(path_str.as_ref()) { + seen_paths.insert(skill.path.clone()); + seen_names.insert(skill.name.clone()); + selected.push(skill.clone()); + } + } + + for skill in skills { + if disabled_paths.contains(&skill.path) || seen_paths.contains(&skill.path) { + continue; + } + + if mentions.names.contains(skill.name.as_str()) && seen_names.insert(skill.name.clone()) { + seen_paths.insert(skill.path.clone()); + selected.push(skill.clone()); + } + } +} + +fn parse_linked_skill_mention<'a>( + text: &'a str, + text_bytes: &[u8], + start: usize, +) -> Option<(&'a str, &'a str, usize)> { + let dollar_index = start + 1; + if text_bytes.get(dollar_index) != Some(&b'$') { + return None; + } + + let name_start = dollar_index + 1; + let first_name_byte = text_bytes.get(name_start)?; + if !is_skill_name_char(*first_name_byte) { + return None; + } + + let mut name_end = name_start + 1; + while let Some(next_byte) = text_bytes.get(name_end) + && is_skill_name_char(*next_byte) + { + name_end += 1; + } + + if text_bytes.get(name_end) != Some(&b']') { + return None; + } + + let mut path_start = name_end + 1; + while let Some(next_byte) = text_bytes.get(path_start) + && next_byte.is_ascii_whitespace() + { + path_start += 1; + } + if text_bytes.get(path_start) != Some(&b'(') { + return None; + } + + let mut path_end = path_start + 1; + while let Some(next_byte) = text_bytes.get(path_end) + && *next_byte != b')' + { + path_end += 1; + } + if text_bytes.get(path_end) != Some(&b')') { + return None; + } + + let path = text[path_start + 1..path_end].trim(); + if path.is_empty() { + return None; + } + + let name = &text[name_start..name_end]; + Some((name, path, path_end + 1)) +} + +fn is_common_env_var(name: &str) -> bool { + let upper = name.to_ascii_uppercase(); + matches!( + upper.as_str(), + "PATH" + | "HOME" + | "USER" + | "SHELL" + | "PWD" + | "TMPDIR" + | "TEMP" + | "TMP" + | "LANG" + | "TERM" + | "XDG_CONFIG_HOME" + ) +} + +#[cfg(test)] +fn text_mentions_skill(text: &str, skill_name: &str) -> bool { + if skill_name.is_empty() { + return false; + } + + let text_bytes = text.as_bytes(); + let skill_bytes = skill_name.as_bytes(); + + for (index, byte) in text_bytes.iter().copied().enumerate() { + if byte != b'$' { + continue; + } + + let name_start = index + 1; + let Some(rest) = text_bytes.get(name_start..) else { + continue; + }; + if !rest.starts_with(skill_bytes) { + continue; + } + + let after_index = name_start + skill_bytes.len(); + let after = text_bytes.get(after_index).copied(); + if after.is_none_or(|b| !is_skill_name_char(b)) { + return true; + } + } + + false +} + +fn is_skill_name_char(byte: u8) -> bool { + matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'-') +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::collections::HashSet; + + fn make_skill(name: &str, path: &str) -> SkillMetadata { + SkillMetadata { + name: name.to_string(), + description: format!("{name} skill"), + short_description: None, + interface: None, + dependencies: None, + path: PathBuf::from(path), + scope: codex_protocol::protocol::SkillScope::User, + } + } + + fn set<'a>(items: &'a [&'a str]) -> HashSet<&'a str> { + items.iter().copied().collect() + } + + fn assert_mentions(text: &str, expected_names: &[&str], expected_paths: &[&str]) { + let mentions = extract_skill_mentions(text); + assert_eq!(mentions.names, set(expected_names)); + assert_eq!(mentions.paths, set(expected_paths)); + } + + #[test] + fn text_mentions_skill_requires_exact_boundary() { + assert_eq!( + true, + text_mentions_skill("use $notion-research-doc please", "notion-research-doc") + ); + assert_eq!( + true, + text_mentions_skill("($notion-research-doc)", "notion-research-doc") + ); + assert_eq!( + true, + text_mentions_skill("$notion-research-doc.", "notion-research-doc") + ); + assert_eq!( + false, + text_mentions_skill("$notion-research-docs", "notion-research-doc") + ); + assert_eq!( + false, + text_mentions_skill("$notion-research-doc_extra", "notion-research-doc") + ); + } + + #[test] + fn text_mentions_skill_handles_end_boundary_and_near_misses() { + assert_eq!(true, text_mentions_skill("$alpha-skill", "alpha-skill")); + assert_eq!(false, text_mentions_skill("$alpha-skillx", "alpha-skill")); + assert_eq!( + true, + text_mentions_skill("$alpha-skillx and later $alpha-skill ", "alpha-skill") + ); + } + + #[test] + fn text_mentions_skill_handles_many_dollars_without_looping() { + let prefix = "$".repeat(256); + let text = format!("{prefix} not-a-mention"); + assert_eq!(false, text_mentions_skill(&text, "alpha-skill")); + } + + #[test] + fn extract_skill_mentions_handles_plain_and_linked_mentions() { + assert_mentions( + "use $alpha and [$beta](/tmp/beta)", + &["alpha", "beta"], + &["/tmp/beta"], + ); + } + + #[test] + fn extract_skill_mentions_skips_common_env_vars() { + assert_mentions("use $PATH and $alpha", &["alpha"], &[]); + assert_mentions("use [$HOME](/tmp/skill)", &[], &[]); + assert_mentions("use $XDG_CONFIG_HOME and $beta", &["beta"], &[]); + } + + #[test] + fn extract_skill_mentions_requires_link_syntax() { + assert_mentions("[beta](/tmp/beta)", &[], &[]); + assert_mentions("[$beta] /tmp/beta", &["beta"], &[]); + assert_mentions("[$beta]()", &["beta"], &[]); + } + + #[test] + fn extract_skill_mentions_trims_linked_paths_and_allows_spacing() { + assert_mentions("use [$beta] ( /tmp/beta )", &["beta"], &["/tmp/beta"]); + } + + #[test] + fn extract_skill_mentions_stops_at_non_name_chars() { + assert_mentions( + "use $alpha.skill and $beta_extra", + &["alpha", "beta_extra"], + &[], + ); + } + + #[test] + fn collect_explicit_skill_mentions_text_respects_skill_order() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let beta = make_skill("beta-skill", "/tmp/beta"); + let skills = vec![beta.clone(), alpha.clone()]; + let inputs = vec![UserInput::Text { + text: "first $alpha-skill then $beta-skill".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + // Text scanning should not change the previous selection ordering semantics. + assert_eq!(selected, vec![beta, alpha]); + } + + #[test] + fn collect_explicit_skill_mentions_ignores_structured_inputs() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let beta = make_skill("beta-skill", "/tmp/beta"); + let skills = vec![alpha.clone(), beta]; + let inputs = vec![ + UserInput::Text { + text: "please run $alpha-skill".to_string(), + text_elements: Vec::new(), + }, + UserInput::Skill { + name: "beta-skill".to_string(), + path: PathBuf::from("/tmp/beta"), + }, + ]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![alpha]); + } + + #[test] + fn collect_explicit_skill_mentions_dedupes_by_path() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$alpha-skill](/tmp/alpha) and [$alpha-skill](/tmp/alpha)".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![alpha]); + } + + #[test] + fn collect_explicit_skill_mentions_dedupes_by_name() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha.clone(), beta]; + let inputs = vec![UserInput::Text { + text: "use $demo-skill and again $demo-skill".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![alpha]); + } + + #[test] + fn collect_explicit_skill_mentions_prefers_linked_path_over_name() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta.clone()]; + let inputs = vec![UserInput::Text { + text: "use $demo-skill and [$demo-skill](/tmp/beta)".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![beta]); + } + + #[test] + fn collect_explicit_skill_mentions_falls_back_when_linked_path_disabled() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/alpha)".to_string(), + text_elements: Vec::new(), + }]; + let disabled = HashSet::from([PathBuf::from("/tmp/alpha")]); + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &disabled); + + assert_eq!(selected, vec![beta]); + } + + #[test] + fn collect_explicit_skill_mentions_prefers_resource_path() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/beta)".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![beta]); + } + + #[test] + fn collect_explicit_skill_mentions_falls_back_to_name_when_path_missing() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha.clone(), beta]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/missing)".to_string(), + text_elements: Vec::new(), + }]; + + let selected = collect_explicit_skill_mentions(&inputs, &skills, &HashSet::new()); + + assert_eq!(selected, vec![alpha]); + } +} diff --git a/codex-rs/core/src/skills/loader.rs b/codex-rs/core/src/skills/loader.rs index a411dd8bc3..581f7d0c1f 100644 --- a/codex-rs/core/src/skills/loader.rs +++ b/codex-rs/core/src/skills/loader.rs @@ -1,10 +1,12 @@ use crate::config::Config; use crate::config_loader::ConfigLayerStack; use crate::config_loader::ConfigLayerStackOrdering; +use crate::skills::model::SkillDependencies; use crate::skills::model::SkillError; use crate::skills::model::SkillInterface; use crate::skills::model::SkillLoadOutcome; use crate::skills::model::SkillMetadata; +use crate::skills::model::SkillToolDependency; use crate::skills::system::system_cache_root_dir; use codex_app_server_protocol::ConfigLayerSource; use codex_protocol::protocol::SkillScope; @@ -38,6 +40,8 @@ struct SkillFrontmatterMetadata { struct SkillMetadataFile { #[serde(default)] interface: Option, + #[serde(default)] + dependencies: Option, } #[derive(Debug, Default, Deserialize)] @@ -50,6 +54,23 @@ struct Interface { default_prompt: Option, } +#[derive(Debug, Default, Deserialize)] +struct Dependencies { + #[serde(default)] + tools: Vec, +} + +#[derive(Debug, Default, Deserialize)] +struct DependencyTool { + #[serde(rename = "type")] + kind: Option, + value: Option, + description: Option, + transport: Option, + command: Option, + url: Option, +} + const SKILLS_FILENAME: &str = "SKILL.md"; const SKILLS_JSON_FILENAME: &str = "SKILL.json"; const SKILLS_DIR_NAME: &str = "skills"; @@ -57,6 +78,12 @@ const MAX_NAME_LEN: usize = 64; const MAX_DESCRIPTION_LEN: usize = 1024; const MAX_SHORT_DESCRIPTION_LEN: usize = MAX_DESCRIPTION_LEN; const MAX_DEFAULT_PROMPT_LEN: usize = MAX_DESCRIPTION_LEN; +const MAX_DEPENDENCY_TYPE_LEN: usize = MAX_NAME_LEN; +const MAX_DEPENDENCY_TRANSPORT_LEN: usize = MAX_NAME_LEN; +const MAX_DEPENDENCY_VALUE_LEN: usize = MAX_DESCRIPTION_LEN; +const MAX_DEPENDENCY_DESCRIPTION_LEN: usize = MAX_DESCRIPTION_LEN; +const MAX_DEPENDENCY_COMMAND_LEN: usize = MAX_DESCRIPTION_LEN; +const MAX_DEPENDENCY_URL_LEN: usize = MAX_DESCRIPTION_LEN; // Traversal depth from the skills root. const MAX_SCAN_DEPTH: usize = 6; const MAX_SKILLS_DIRS_PER_ROOT: usize = 2000; @@ -345,7 +372,7 @@ fn parse_skill_file(path: &Path, scope: SkillScope) -> Result Result Option { - // Fail open: optional interface metadata should not block loading SKILL.md. - let skill_dir = skill_path.parent()?; - let interface_path = skill_dir.join(SKILLS_JSON_FILENAME); - if !interface_path.exists() { - return None; +fn load_skill_metadata(skill_path: &Path) -> (Option, Option) { + // Fail open: optional metadata should not block loading SKILL.md. + let Some(skill_dir) = skill_path.parent() else { + return (None, None); + }; + let metadata_path = skill_dir.join(SKILLS_JSON_FILENAME); + if !metadata_path.exists() { + return (None, None); } - let contents = match fs::read_to_string(&interface_path) { + let contents = match fs::read_to_string(&metadata_path) { Ok(contents) => contents, Err(error) => { tracing::warn!( - "ignoring {path}: failed to read SKILL.json: {error}", - path = interface_path.display() + "ignoring {path}: failed to read {label}: {error}", + path = metadata_path.display(), + label = SKILLS_JSON_FILENAME ); - return None; + return (None, None); } }; + let parsed: SkillMetadataFile = match serde_json::from_str(&contents) { Ok(parsed) => parsed, Err(error) => { tracing::warn!( - "ignoring {path}: invalid JSON: {error}", - path = interface_path.display() + "ignoring {path}: invalid {label}: {error}", + path = metadata_path.display(), + label = SKILLS_JSON_FILENAME ); - return None; + return (None, None); } }; - let interface = parsed.interface?; + ( + resolve_interface(parsed.interface, skill_dir), + resolve_dependencies(parsed.dependencies), + ) +} + +fn resolve_interface(interface: Option, skill_dir: &Path) -> Option { + let interface = interface?; let interface = SkillInterface { display_name: resolve_str( interface.display_name, @@ -428,6 +468,58 @@ fn load_skill_interface(skill_path: &Path) -> Option { if has_fields { Some(interface) } else { None } } +fn resolve_dependencies(dependencies: Option) -> Option { + let dependencies = dependencies?; + let tools: Vec = dependencies + .tools + .into_iter() + .filter_map(resolve_dependency_tool) + .collect(); + if tools.is_empty() { + None + } else { + Some(SkillDependencies { tools }) + } +} + +fn resolve_dependency_tool(tool: DependencyTool) -> Option { + let r#type = resolve_required_str( + tool.kind, + MAX_DEPENDENCY_TYPE_LEN, + "dependencies.tools.type", + )?; + let value = resolve_required_str( + tool.value, + MAX_DEPENDENCY_VALUE_LEN, + "dependencies.tools.value", + )?; + let description = resolve_str( + tool.description, + MAX_DEPENDENCY_DESCRIPTION_LEN, + "dependencies.tools.description", + ); + let transport = resolve_str( + tool.transport, + MAX_DEPENDENCY_TRANSPORT_LEN, + "dependencies.tools.transport", + ); + let command = resolve_str( + tool.command, + MAX_DEPENDENCY_COMMAND_LEN, + "dependencies.tools.command", + ); + let url = resolve_str(tool.url, MAX_DEPENDENCY_URL_LEN, "dependencies.tools.url"); + + Some(SkillToolDependency { + r#type, + value, + description, + transport, + command, + url, + }) +} + fn resolve_asset_path( skill_dir: &Path, field: &'static str, @@ -511,6 +603,18 @@ fn resolve_str(value: Option, max_len: usize, field: &'static str) -> Op Some(value) } +fn resolve_required_str( + value: Option, + max_len: usize, + field: &'static str, +) -> Option { + let Some(value) = value else { + tracing::warn!("ignoring {field}: value is missing"); + return None; + }; + resolve_str(Some(value), max_len, field) +} + fn resolve_color_str(value: Option, field: &'static str) -> Option { let value = value?; let value = value.trim(); @@ -755,14 +859,118 @@ mod tests { path } - fn write_skill_interface_at(skill_dir: &Path, contents: &str) -> PathBuf { - let path = skill_dir.join(SKILLS_JSON_FILENAME); + fn write_skill_metadata_at(skill_dir: &Path, filename: &str, contents: &str) -> PathBuf { + let path = skill_dir.join(filename); fs::write(&path, contents).unwrap(); path } + fn write_skill_interface_at(skill_dir: &Path, contents: &str) -> PathBuf { + write_skill_metadata_at(skill_dir, SKILLS_JSON_FILENAME, contents) + } + #[tokio::test] - async fn loads_skill_interface_metadata_happy_path() { + async fn loads_skill_dependencies_metadata_from_json() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "dep-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + SKILLS_JSON_FILENAME, + r#" +{ + "dependencies": { + "tools": [ + { + "type": "env_var", + "value": "GITHUB_TOKEN", + "description": "GitHub API token with repo scopes" + }, + { + "type": "mcp", + "value": "github", + "description": "GitHub MCP server", + "transport": "streamable_http", + "url": "https://example.com/mcp" + }, + { + "type": "cli", + "value": "gh", + "description": "GitHub CLI" + }, + { + "type": "mcp", + "value": "local-gh", + "description": "Local GH MCP server", + "transport": "stdio", + "command": "gh-mcp" + } + ] + } +} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "dep-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: None, + dependencies: Some(SkillDependencies { + tools: vec![ + SkillToolDependency { + r#type: "env_var".to_string(), + value: "GITHUB_TOKEN".to_string(), + description: Some("GitHub API token with repo scopes".to_string()), + transport: None, + command: None, + url: None, + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "github".to_string(), + description: Some("GitHub MCP server".to_string()), + transport: Some("streamable_http".to_string()), + command: None, + url: Some("https://example.com/mcp".to_string()), + }, + SkillToolDependency { + r#type: "cli".to_string(), + value: "gh".to_string(), + description: Some("GitHub CLI".to_string()), + transport: None, + command: None, + url: None, + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "local-gh".to_string(), + description: Some("Local GH MCP server".to_string()), + transport: Some("stdio".to_string()), + command: Some("gh-mcp".to_string()), + url: None, + }, + ], + }), + path: normalized(&skill_path), + scope: SkillScope::User, + }] + ); + } + + #[tokio::test] + async fn loads_skill_interface_metadata_from_json() { let codex_home = tempfile::tempdir().expect("tempdir"); let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); let skill_dir = skill_path.parent().expect("skill dir"); @@ -806,6 +1014,7 @@ mod tests { brand_color: Some("#3B82F6".to_string()), default_prompt: Some("default prompt".to_string()), }), + dependencies: None, path: normalized(skill_path.as_path()), scope: SkillScope::User, }] @@ -854,6 +1063,7 @@ mod tests { brand_color: None, default_prompt: None, }), + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -892,6 +1102,7 @@ mod tests { description: "from json".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -943,6 +1154,7 @@ mod tests { brand_color: None, default_prompt: None, }), + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -982,6 +1194,7 @@ mod tests { description: "from json".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -1024,6 +1237,7 @@ mod tests { description: "from link".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&shared_skill_path), scope: SkillScope::User, }] @@ -1082,6 +1296,7 @@ mod tests { description: "still loads".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -1116,6 +1331,7 @@ mod tests { description: "from link".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&shared_skill_path), scope: SkillScope::Admin, }] @@ -1154,6 +1370,7 @@ mod tests { description: "from link".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&linked_skill_path), scope: SkillScope::Repo, }] @@ -1215,6 +1432,7 @@ mod tests { description: "loads".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&within_depth_path), scope: SkillScope::User, }] @@ -1240,6 +1458,7 @@ mod tests { description: "does things carefully".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -1269,6 +1488,7 @@ mod tests { description: "long description".to_string(), short_description: Some("short summary".to_string()), interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::User, }] @@ -1379,6 +1599,7 @@ mod tests { description: "from repo".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::Repo, }] @@ -1430,6 +1651,7 @@ mod tests { description: "from nested".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&nested_skill_path), scope: SkillScope::Repo, }, @@ -1438,6 +1660,7 @@ mod tests { description: "from root".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&root_skill_path), scope: SkillScope::Repo, }, @@ -1475,6 +1698,7 @@ mod tests { description: "from cwd".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::Repo, }] @@ -1510,6 +1734,7 @@ mod tests { description: "from repo".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::Repo, }] @@ -1549,6 +1774,7 @@ mod tests { description: "from repo".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&repo_skill_path), scope: SkillScope::Repo, }, @@ -1557,6 +1783,7 @@ mod tests { description: "from user".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&user_skill_path), scope: SkillScope::User, }, @@ -1619,6 +1846,7 @@ mod tests { description: first_description.to_string(), short_description: None, interface: None, + dependencies: None, path: first_path, scope: SkillScope::Repo, }, @@ -1627,6 +1855,7 @@ mod tests { description: second_description.to_string(), short_description: None, interface: None, + dependencies: None, path: second_path, scope: SkillScope::Repo, }, @@ -1696,6 +1925,7 @@ mod tests { description: "from repo".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::Repo, }] @@ -1752,6 +1982,7 @@ mod tests { description: "from system".to_string(), short_description: None, interface: None, + dependencies: None, path: normalized(&skill_path), scope: SkillScope::System, }] diff --git a/codex-rs/core/src/skills/mod.rs b/codex-rs/core/src/skills/mod.rs index cf7c180502..ae4ffa2920 100644 --- a/codex-rs/core/src/skills/mod.rs +++ b/codex-rs/core/src/skills/mod.rs @@ -7,6 +7,7 @@ pub mod system; pub(crate) use injection::SkillInjections; pub(crate) use injection::build_skill_injections; +pub(crate) use injection::collect_explicit_skill_mentions; pub use loader::load_skills; pub use manager::SkillsManager; pub use model::SkillError; diff --git a/codex-rs/core/src/skills/model.rs b/codex-rs/core/src/skills/model.rs index fe3357f9d9..92ecbd84b9 100644 --- a/codex-rs/core/src/skills/model.rs +++ b/codex-rs/core/src/skills/model.rs @@ -9,6 +9,7 @@ pub struct SkillMetadata { pub description: String, pub short_description: Option, pub interface: Option, + pub dependencies: Option, pub path: PathBuf, pub scope: SkillScope, } @@ -23,6 +24,21 @@ pub struct SkillInterface { pub default_prompt: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SkillDependencies { + pub tools: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SkillToolDependency { + pub r#type: String, + pub value: String, + pub description: Option, + pub transport: Option, + pub command: Option, + pub url: Option, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SkillError { pub path: PathBuf, diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 067f7b3378..9d3d96db99 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -1,6 +1,7 @@ //! Session-wide mutable state. use codex_protocol::models::ResponseItem; +use std::collections::HashSet; use crate::codex::SessionConfiguration; use crate::context_manager::ContextManager; @@ -15,6 +16,7 @@ pub(crate) struct SessionState { pub(crate) history: ContextManager, pub(crate) latest_rate_limits: Option, pub(crate) server_reasoning_included: bool, + pub(crate) mcp_dependency_prompted: HashSet, /// Whether the session's initial context has been seeded into history. /// /// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at @@ -31,6 +33,7 @@ impl SessionState { history, latest_rate_limits: None, server_reasoning_included: false, + mcp_dependency_prompted: HashSet::new(), initial_context_seeded: false, } } @@ -98,6 +101,17 @@ impl SessionState { pub(crate) fn server_reasoning_included(&self) -> bool { self.server_reasoning_included } + + pub(crate) fn record_mcp_dependency_prompted(&mut self, names: I) + where + I: IntoIterator, + { + self.mcp_dependency_prompted.extend(names); + } + + pub(crate) fn mcp_dependency_prompted(&self) -> HashSet { + self.mcp_dependency_prompted.clone() + } } // Sometimes new snapshots don't include credits or plan information. diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index c2d282e995..b9a4f2120b 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -2093,11 +2093,14 @@ pub struct SkillMetadata { pub description: String, #[serde(default, skip_serializing_if = "Option::is_none")] #[ts(optional)] - /// Legacy short_description from SKILL.md. Prefer SKILL.toml interface.short_description. + /// Legacy short_description from SKILL.md. Prefer SKILL.json interface.short_description. pub short_description: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[ts(optional)] pub interface: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub dependencies: Option, pub path: PathBuf, pub scope: SkillScope, pub enabled: bool, @@ -2119,6 +2122,31 @@ pub struct SkillInterface { pub default_prompt: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS, PartialEq, Eq)] +pub struct SkillDependencies { + pub tools: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS, PartialEq, Eq)] +pub struct SkillToolDependency { + #[serde(rename = "type")] + #[ts(rename = "type")] + pub r#type: String, + pub value: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub transport: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub command: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub url: Option, +} + #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct SkillErrorInfo { pub path: PathBuf, diff --git a/codex-rs/tui/src/bottom_pane/mod.rs b/codex-rs/tui/src/bottom_pane/mod.rs index bbd59c58bd..e0a7319a42 100644 --- a/codex-rs/tui/src/bottom_pane/mod.rs +++ b/codex-rs/tui/src/bottom_pane/mod.rs @@ -1154,6 +1154,7 @@ mod tests { description: "test skill".to_string(), short_description: None, interface: None, + dependencies: None, path: PathBuf::from("test-skill"), scope: SkillScope::User, }]), diff --git a/codex-rs/tui/src/chatwidget/skills.rs b/codex-rs/tui/src/chatwidget/skills.rs index d72ca60455..d6893a9e3c 100644 --- a/codex-rs/tui/src/chatwidget/skills.rs +++ b/codex-rs/tui/src/chatwidget/skills.rs @@ -15,8 +15,10 @@ use crate::skills_helpers::skill_display_name; use codex_core::protocol::ListSkillsResponseEvent; use codex_core::protocol::SkillMetadata as ProtocolSkillMetadata; use codex_core::protocol::SkillsListEntry; +use codex_core::skills::model::SkillDependencies; use codex_core::skills::model::SkillInterface; use codex_core::skills::model::SkillMetadata; +use codex_core::skills::model::SkillToolDependency; impl ChatWidget { pub(crate) fn open_skills_list(&mut self) { @@ -168,6 +170,23 @@ fn protocol_skill_to_core(skill: &ProtocolSkillMetadata) -> SkillMetadata { brand_color: interface.brand_color, default_prompt: interface.default_prompt, }), + dependencies: skill + .dependencies + .clone() + .map(|dependencies| SkillDependencies { + tools: dependencies + .tools + .into_iter() + .map(|tool| SkillToolDependency { + r#type: tool.r#type, + value: tool.value, + description: tool.description, + transport: tool.transport, + command: tool.command, + url: tool.url, + }) + .collect(), + }), path: skill.path.clone(), scope: skill.scope, }