Extract tool building (#11337)

Make it clear what input go into building tools and allow for easy reuse
for pre-warm request
This commit is contained in:
pakrym-oai
2026-02-10 11:45:23 -08:00
committed by GitHub
parent 9c4656000f
commit e4b5384539
3 changed files with 177 additions and 104 deletions

View File

@@ -152,7 +152,7 @@ use crate::mcp_connection_manager::filter_mcp_tools_by_name;
use crate::memories;
use crate::mentions::build_connector_slug_counts;
use crate::mentions::build_skill_name_counts;
use crate::mentions::collect_explicit_app_paths;
use crate::mentions::collect_explicit_app_ids;
use crate::mentions::collect_tool_mentions_from_messages;
use crate::project_doc::get_user_instructions;
use crate::proposed_plan_parser::ProposedPlanParser;
@@ -201,6 +201,7 @@ use crate::shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::skills::SkillError;
use crate::skills::SkillInjections;
use crate::skills::SkillLoadOutcome;
use crate::skills::SkillMetadata;
use crate::skills::SkillsManager;
use crate::skills::build_skill_injections;
@@ -3882,10 +3883,6 @@ pub(crate) async fn run_turn(
.await,
);
let (skill_name_counts, skill_name_counts_lower) = skills_outcome.as_ref().map_or_else(
|| (HashMap::new(), HashMap::new()),
|outcome| build_skill_name_counts(&outcome.skills, &outcome.disabled_paths),
);
let connector_slug_counts = if turn_context.config.features.enabled(Feature::Apps) {
let mcp_tools = match sess
.services
@@ -3909,12 +3906,10 @@ pub(crate) async fn run_turn(
&input,
&outcome.skills,
&outcome.disabled_paths,
&skill_name_counts,
&connector_slug_counts,
)
});
let explicit_app_paths = collect_explicit_app_paths(&input);
let explicitly_enabled_connectors = collect_explicit_app_ids(&input);
let config = turn_context.config.clone();
if config
.features
@@ -4016,10 +4011,6 @@ pub(crate) async fn run_turn(
})
.map(|user_message| user_message.message())
.collect::<Vec<String>>();
let tool_selection = SamplingRequestToolSelection {
explicit_app_paths: &explicit_app_paths,
skill_name_counts_lower: &skill_name_counts_lower,
};
match run_sampling_request(
Arc::clone(&sess),
Arc::clone(&turn_context),
@@ -4027,7 +4018,8 @@ pub(crate) async fn run_turn(
&mut client_session,
turn_metadata_header.as_deref(),
sampling_request_input,
tool_selection,
&explicitly_enabled_connectors,
skills_outcome.as_ref(),
cancellation_token.child_token(),
)
.await
@@ -4127,11 +4119,11 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
fn filter_connectors_for_input(
connectors: Vec<connectors::AppInfo>,
input: &[ResponseItem],
explicit_app_paths: &[String],
explicitly_enabled_connectors: &HashSet<String>,
skill_name_counts_lower: &HashMap<String, usize>,
) -> Vec<connectors::AppInfo> {
let user_messages = collect_user_messages(input);
if user_messages.is_empty() && explicit_app_paths.is_empty() {
if user_messages.is_empty() && explicitly_enabled_connectors.is_empty() {
return Vec::new();
}
@@ -4143,10 +4135,10 @@ fn filter_connectors_for_input(
.collect::<HashSet<String>>();
let connector_slug_counts = build_connector_slug_counts(&connectors);
let mut allowed_connector_ids: HashSet<String> = HashSet::new();
for path in explicit_app_paths
let mut allowed_connector_ids = explicitly_enabled_connectors.clone();
for path in mentions
.paths
.iter()
.chain(mentions.paths.iter())
.filter(|path| tool_kind_for_path(path) == ToolMentionKind::App)
{
if let Some(connector_id) = app_id_from_path(path) {
@@ -4217,11 +4209,6 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op
tool.connector_id.as_deref()
}
struct SamplingRequestToolSelection<'a> {
explicit_app_paths: &'a [String],
skill_name_counts_lower: &'a HashMap<String, usize>,
}
#[allow(clippy::too_many_arguments)]
#[instrument(level = "trace",
skip_all,
@@ -4238,58 +4225,19 @@ async fn run_sampling_request(
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
input: Vec<ResponseItem>,
tool_selection: SamplingRequestToolSelection<'_>,
explicitly_enabled_connectors: &HashSet<String>,
skills_outcome: Option<&SkillLoadOutcome>,
cancellation_token: CancellationToken,
) -> CodexResult<SamplingRequestResult> {
let mut mcp_tools = sess
.services
.mcp_connection_manager
.read()
.await
.list_all_tools()
.or_cancel(&cancellation_token)
.await?;
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
Some(filter_connectors_for_input(
connectors,
&input,
tool_selection.explicit_app_paths,
tool_selection.skill_name_counts_lower,
))
} else {
None
};
if turn_context.config.features.enabled(Feature::SearchTool) {
let mut selected_mcp_tools =
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
} else {
HashMap::new()
};
if let Some(connectors) = connectors_for_tools.as_ref() {
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
selected_mcp_tools.extend(apps_mcp_tools);
}
mcp_tools = selected_mcp_tools;
} else if let Some(connectors) = connectors_for_tools.as_ref() {
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
}
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(
mcp_tools
.into_iter()
.map(|(name, tool)| (name, tool.tool))
.collect(),
),
turn_context.dynamic_tools.as_slice(),
));
let router = built_tools(
sess.as_ref(),
turn_context.as_ref(),
&input,
explicitly_enabled_connectors,
skills_outcome,
&cancellation_token,
)
.await?;
let model_supports_parallel = turn_context.model_info.supports_parallel_tool_calls;
@@ -4383,6 +4331,68 @@ async fn run_sampling_request(
}
}
async fn built_tools(
sess: &Session,
turn_context: &TurnContext,
input: &[ResponseItem],
explicitly_enabled_connectors: &HashSet<String>,
skills_outcome: Option<&SkillLoadOutcome>,
cancellation_token: &CancellationToken,
) -> CodexResult<Arc<ToolRouter>> {
let mut mcp_tools = sess
.services
.mcp_connection_manager
.read()
.await
.list_all_tools()
.or_cancel(cancellation_token)
.await?;
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
let skill_name_counts_lower = skills_outcome.map_or_else(HashMap::new, |outcome| {
build_skill_name_counts(&outcome.skills, &outcome.disabled_paths).1
});
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
Some(filter_connectors_for_input(
connectors,
input,
explicitly_enabled_connectors,
&skill_name_counts_lower,
))
} else {
None
};
if turn_context.config.features.enabled(Feature::SearchTool) {
let mut selected_mcp_tools =
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
} else {
HashMap::new()
};
if let Some(connectors) = connectors_for_tools.as_ref() {
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
selected_mcp_tools.extend(apps_mcp_tools);
}
mcp_tools = selected_mcp_tools;
} else if let Some(connectors) = connectors_for_tools.as_ref() {
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
}
Ok(Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(
mcp_tools
.into_iter()
.map(|(name, tool)| (name, tool.tool))
.collect(),
),
turn_context.dynamic_tools.as_slice(),
)))
}
#[derive(Debug)]
struct SamplingRequestResult {
needs_follow_up: bool,
@@ -5284,13 +5294,13 @@ mod tests {
make_connector("two", "Foo-Bar"),
];
let input = vec![user_message("use $foo-bar")];
let explicit_app_paths = Vec::new();
let explicitly_enabled_connectors = HashSet::new();
let skill_name_counts_lower = HashMap::new();
let selected = filter_connectors_for_input(
connectors,
&input,
&explicit_app_paths,
&explicitly_enabled_connectors,
&skill_name_counts_lower,
);
@@ -5301,13 +5311,13 @@ mod tests {
fn filter_connectors_for_input_skips_when_skill_name_conflicts() {
let connectors = vec![make_connector("one", "Todoist")];
let input = vec![user_message("use $todoist")];
let explicit_app_paths = Vec::new();
let explicitly_enabled_connectors = HashSet::new();
let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]);
let selected = filter_connectors_for_input(
connectors,
&input,
&explicit_app_paths,
&explicitly_enabled_connectors,
&skill_name_counts_lower,
);
@@ -5339,10 +5349,11 @@ mod tests {
let mut selected_mcp_tools =
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
let explicitly_enabled_connectors = HashSet::new();
let connectors = filter_connectors_for_input(
connectors,
&[user_message("run the selected tools")],
&[],
&explicitly_enabled_connectors,
&HashMap::new(),
);
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
@@ -5381,10 +5392,11 @@ mod tests {
let mut selected_mcp_tools =
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
let explicitly_enabled_connectors = HashSet::new();
let connectors = filter_connectors_for_input(
connectors,
&[user_message("use $calendar and then echo the response")],
&[],
&explicitly_enabled_connectors,
&HashMap::new(),
);
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);

View File

@@ -6,7 +6,10 @@ use codex_protocol::user_input::UserInput;
use crate::connectors;
use crate::skills::SkillMetadata;
use crate::skills::injection::ToolMentionKind;
use crate::skills::injection::app_id_from_path;
use crate::skills::injection::extract_tool_mentions;
use crate::skills::injection::tool_kind_for_path;
pub(crate) struct CollectedToolMentions {
pub(crate) plain_names: HashSet<String>,
@@ -24,13 +27,24 @@ pub(crate) fn collect_tool_mentions_from_messages(messages: &[String]) -> Collec
CollectedToolMentions { plain_names, paths }
}
pub(crate) fn collect_explicit_app_paths(input: &[UserInput]) -> Vec<String> {
pub(crate) fn collect_explicit_app_ids(input: &[UserInput]) -> HashSet<String> {
let messages = input
.iter()
.filter_map(|item| match item {
UserInput::Text { text, .. } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>();
input
.iter()
.filter_map(|item| match item {
UserInput::Mention { path, .. } => Some(path.clone()),
_ => None,
})
.chain(collect_tool_mentions_from_messages(&messages).paths)
.filter(|path| tool_kind_for_path(path.as_str()) == ToolMentionKind::App)
.filter_map(|path| app_id_from_path(path.as_str()).map(str::to_string))
.collect()
}
@@ -62,3 +76,69 @@ pub(crate) fn build_connector_slug_counts(
}
counts
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use codex_protocol::user_input::UserInput;
use pretty_assertions::assert_eq;
use super::collect_explicit_app_ids;
fn text_input(text: &str) -> UserInput {
UserInput::Text {
text: text.to_string(),
text_elements: Vec::new(),
}
}
#[test]
fn collect_explicit_app_ids_from_linked_text_mentions() {
let input = vec![text_input("use [$calendar](app://calendar)")];
let app_ids = collect_explicit_app_ids(&input);
assert_eq!(app_ids, HashSet::from(["calendar".to_string()]));
}
#[test]
fn collect_explicit_app_ids_dedupes_structured_and_linked_mentions() {
let input = vec![
text_input("use [$calendar](app://calendar)"),
UserInput::Mention {
name: "calendar".to_string(),
path: "app://calendar".to_string(),
},
];
let app_ids = collect_explicit_app_ids(&input);
assert_eq!(app_ids, HashSet::from(["calendar".to_string()]));
}
#[test]
fn collect_explicit_app_ids_ignores_non_app_paths() {
let input = vec![
text_input(
"use [$docs](mcp://docs) and [$skill](skill://team/skill) and [$file](/tmp/file.txt)",
),
UserInput::Mention {
name: "docs".to_string(),
path: "mcp://docs".to_string(),
},
UserInput::Mention {
name: "skill".to_string(),
path: "skill://team/skill".to_string(),
},
UserInput::Mention {
name: "file".to_string(),
path: "/tmp/file.txt".to_string(),
},
];
let app_ids = collect_explicit_app_ids(&input);
assert_eq!(app_ids, HashSet::<String>::new());
}
}

View File

@@ -6,6 +6,7 @@ use crate::analytics_client::AnalyticsEventsClient;
use crate::analytics_client::SkillInvocation;
use crate::analytics_client::TrackEventsContext;
use crate::instructions::SkillInstructions;
use crate::mentions::build_skill_name_counts;
use crate::skills::SkillMetadata;
use codex_otel::OtelManager;
use codex_protocol::models::ResponseItem;
@@ -93,13 +94,14 @@ pub(crate) fn collect_explicit_skill_mentions(
inputs: &[UserInput],
skills: &[SkillMetadata],
disabled_paths: &HashSet<PathBuf>,
skill_name_counts: &HashMap<String, usize>,
connector_slug_counts: &HashMap<String, usize>,
) -> Vec<SkillMetadata> {
let skill_name_counts = build_skill_name_counts(skills, disabled_paths).0;
let selection_context = SkillSelectionContext {
skills,
disabled_paths,
skill_name_counts,
skill_name_counts: &skill_name_counts,
connector_slug_counts,
};
let mut selected: Vec<SkillMetadata> = Vec::new();
@@ -489,34 +491,13 @@ mod tests {
assert_eq!(mentions.paths, set(expected_paths));
}
fn build_skill_name_counts(
skills: &[SkillMetadata],
disabled_paths: &HashSet<PathBuf>,
) -> HashMap<String, usize> {
let mut counts = HashMap::new();
for skill in skills {
if disabled_paths.contains(&skill.path) {
continue;
}
*counts.entry(skill.name.clone()).or_insert(0) += 1;
}
counts
}
fn collect_mentions(
inputs: &[UserInput],
skills: &[SkillMetadata],
disabled_paths: &HashSet<PathBuf>,
connector_slug_counts: &HashMap<String, usize>,
) -> Vec<SkillMetadata> {
let skill_name_counts = build_skill_name_counts(skills, disabled_paths);
collect_explicit_skill_mentions(
inputs,
skills,
disabled_paths,
&skill_name_counts,
connector_slug_counts,
)
collect_explicit_skill_mentions(inputs, skills, disabled_paths, connector_slug_counts)
}
#[test]