mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Extract tool building (#11337)
Make it clear what input go into building tools and allow for easy reuse for pre-warm request
This commit is contained in:
@@ -152,7 +152,7 @@ use crate::mcp_connection_manager::filter_mcp_tools_by_name;
|
||||
use crate::memories;
|
||||
use crate::mentions::build_connector_slug_counts;
|
||||
use crate::mentions::build_skill_name_counts;
|
||||
use crate::mentions::collect_explicit_app_paths;
|
||||
use crate::mentions::collect_explicit_app_ids;
|
||||
use crate::mentions::collect_tool_mentions_from_messages;
|
||||
use crate::project_doc::get_user_instructions;
|
||||
use crate::proposed_plan_parser::ProposedPlanParser;
|
||||
@@ -201,6 +201,7 @@ use crate::shell;
|
||||
use crate::shell_snapshot::ShellSnapshot;
|
||||
use crate::skills::SkillError;
|
||||
use crate::skills::SkillInjections;
|
||||
use crate::skills::SkillLoadOutcome;
|
||||
use crate::skills::SkillMetadata;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::skills::build_skill_injections;
|
||||
@@ -3882,10 +3883,6 @@ pub(crate) async fn run_turn(
|
||||
.await,
|
||||
);
|
||||
|
||||
let (skill_name_counts, skill_name_counts_lower) = skills_outcome.as_ref().map_or_else(
|
||||
|| (HashMap::new(), HashMap::new()),
|
||||
|outcome| build_skill_name_counts(&outcome.skills, &outcome.disabled_paths),
|
||||
);
|
||||
let connector_slug_counts = if turn_context.config.features.enabled(Feature::Apps) {
|
||||
let mcp_tools = match sess
|
||||
.services
|
||||
@@ -3909,12 +3906,10 @@ pub(crate) async fn run_turn(
|
||||
&input,
|
||||
&outcome.skills,
|
||||
&outcome.disabled_paths,
|
||||
&skill_name_counts,
|
||||
&connector_slug_counts,
|
||||
)
|
||||
});
|
||||
let explicit_app_paths = collect_explicit_app_paths(&input);
|
||||
|
||||
let explicitly_enabled_connectors = collect_explicit_app_ids(&input);
|
||||
let config = turn_context.config.clone();
|
||||
if config
|
||||
.features
|
||||
@@ -4016,10 +4011,6 @@ pub(crate) async fn run_turn(
|
||||
})
|
||||
.map(|user_message| user_message.message())
|
||||
.collect::<Vec<String>>();
|
||||
let tool_selection = SamplingRequestToolSelection {
|
||||
explicit_app_paths: &explicit_app_paths,
|
||||
skill_name_counts_lower: &skill_name_counts_lower,
|
||||
};
|
||||
match run_sampling_request(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
@@ -4027,7 +4018,8 @@ pub(crate) async fn run_turn(
|
||||
&mut client_session,
|
||||
turn_metadata_header.as_deref(),
|
||||
sampling_request_input,
|
||||
tool_selection,
|
||||
&explicitly_enabled_connectors,
|
||||
skills_outcome.as_ref(),
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
@@ -4127,11 +4119,11 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
|
||||
fn filter_connectors_for_input(
|
||||
connectors: Vec<connectors::AppInfo>,
|
||||
input: &[ResponseItem],
|
||||
explicit_app_paths: &[String],
|
||||
explicitly_enabled_connectors: &HashSet<String>,
|
||||
skill_name_counts_lower: &HashMap<String, usize>,
|
||||
) -> Vec<connectors::AppInfo> {
|
||||
let user_messages = collect_user_messages(input);
|
||||
if user_messages.is_empty() && explicit_app_paths.is_empty() {
|
||||
if user_messages.is_empty() && explicitly_enabled_connectors.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
@@ -4143,10 +4135,10 @@ fn filter_connectors_for_input(
|
||||
.collect::<HashSet<String>>();
|
||||
|
||||
let connector_slug_counts = build_connector_slug_counts(&connectors);
|
||||
let mut allowed_connector_ids: HashSet<String> = HashSet::new();
|
||||
for path in explicit_app_paths
|
||||
let mut allowed_connector_ids = explicitly_enabled_connectors.clone();
|
||||
for path in mentions
|
||||
.paths
|
||||
.iter()
|
||||
.chain(mentions.paths.iter())
|
||||
.filter(|path| tool_kind_for_path(path) == ToolMentionKind::App)
|
||||
{
|
||||
if let Some(connector_id) = app_id_from_path(path) {
|
||||
@@ -4217,11 +4209,6 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op
|
||||
tool.connector_id.as_deref()
|
||||
}
|
||||
|
||||
struct SamplingRequestToolSelection<'a> {
|
||||
explicit_app_paths: &'a [String],
|
||||
skill_name_counts_lower: &'a HashMap<String, usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(level = "trace",
|
||||
skip_all,
|
||||
@@ -4238,58 +4225,19 @@ async fn run_sampling_request(
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_metadata_header: Option<&str>,
|
||||
input: Vec<ResponseItem>,
|
||||
tool_selection: SamplingRequestToolSelection<'_>,
|
||||
explicitly_enabled_connectors: &HashSet<String>,
|
||||
skills_outcome: Option<&SkillLoadOutcome>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<SamplingRequestResult> {
|
||||
let mut mcp_tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await?;
|
||||
|
||||
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
Some(filter_connectors_for_input(
|
||||
connectors,
|
||||
&input,
|
||||
tool_selection.explicit_app_paths,
|
||||
tool_selection.skill_name_counts_lower,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if turn_context.config.features.enabled(Feature::SearchTool) {
|
||||
let mut selected_mcp_tools =
|
||||
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
|
||||
selected_mcp_tools.extend(apps_mcp_tools);
|
||||
}
|
||||
|
||||
mcp_tools = selected_mcp_tools;
|
||||
} else if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
|
||||
}
|
||||
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
turn_context.dynamic_tools.as_slice(),
|
||||
));
|
||||
let router = built_tools(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&input,
|
||||
explicitly_enabled_connectors,
|
||||
skills_outcome,
|
||||
&cancellation_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let model_supports_parallel = turn_context.model_info.supports_parallel_tool_calls;
|
||||
|
||||
@@ -4383,6 +4331,68 @@ async fn run_sampling_request(
|
||||
}
|
||||
}
|
||||
|
||||
async fn built_tools(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
input: &[ResponseItem],
|
||||
explicitly_enabled_connectors: &HashSet<String>,
|
||||
skills_outcome: Option<&SkillLoadOutcome>,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> CodexResult<Arc<ToolRouter>> {
|
||||
let mut mcp_tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools()
|
||||
.or_cancel(cancellation_token)
|
||||
.await?;
|
||||
|
||||
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
|
||||
let skill_name_counts_lower = skills_outcome.map_or_else(HashMap::new, |outcome| {
|
||||
build_skill_name_counts(&outcome.skills, &outcome.disabled_paths).1
|
||||
});
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
Some(filter_connectors_for_input(
|
||||
connectors,
|
||||
input,
|
||||
explicitly_enabled_connectors,
|
||||
&skill_name_counts_lower,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if turn_context.config.features.enabled(Feature::SearchTool) {
|
||||
let mut selected_mcp_tools =
|
||||
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
|
||||
selected_mcp_tools.extend(apps_mcp_tools);
|
||||
}
|
||||
|
||||
mcp_tools = selected_mcp_tools;
|
||||
} else if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
|
||||
}
|
||||
|
||||
Ok(Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
turn_context.dynamic_tools.as_slice(),
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SamplingRequestResult {
|
||||
needs_follow_up: bool,
|
||||
@@ -5284,13 +5294,13 @@ mod tests {
|
||||
make_connector("two", "Foo-Bar"),
|
||||
];
|
||||
let input = vec![user_message("use $foo-bar")];
|
||||
let explicit_app_paths = Vec::new();
|
||||
let explicitly_enabled_connectors = HashSet::new();
|
||||
let skill_name_counts_lower = HashMap::new();
|
||||
|
||||
let selected = filter_connectors_for_input(
|
||||
connectors,
|
||||
&input,
|
||||
&explicit_app_paths,
|
||||
&explicitly_enabled_connectors,
|
||||
&skill_name_counts_lower,
|
||||
);
|
||||
|
||||
@@ -5301,13 +5311,13 @@ mod tests {
|
||||
fn filter_connectors_for_input_skips_when_skill_name_conflicts() {
|
||||
let connectors = vec![make_connector("one", "Todoist")];
|
||||
let input = vec![user_message("use $todoist")];
|
||||
let explicit_app_paths = Vec::new();
|
||||
let explicitly_enabled_connectors = HashSet::new();
|
||||
let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]);
|
||||
|
||||
let selected = filter_connectors_for_input(
|
||||
connectors,
|
||||
&input,
|
||||
&explicit_app_paths,
|
||||
&explicitly_enabled_connectors,
|
||||
&skill_name_counts_lower,
|
||||
);
|
||||
|
||||
@@ -5339,10 +5349,11 @@ mod tests {
|
||||
let mut selected_mcp_tools =
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
let explicitly_enabled_connectors = HashSet::new();
|
||||
let connectors = filter_connectors_for_input(
|
||||
connectors,
|
||||
&[user_message("run the selected tools")],
|
||||
&[],
|
||||
&explicitly_enabled_connectors,
|
||||
&HashMap::new(),
|
||||
);
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
|
||||
@@ -5381,10 +5392,11 @@ mod tests {
|
||||
let mut selected_mcp_tools =
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
let explicitly_enabled_connectors = HashSet::new();
|
||||
let connectors = filter_connectors_for_input(
|
||||
connectors,
|
||||
&[user_message("use $calendar and then echo the response")],
|
||||
&[],
|
||||
&explicitly_enabled_connectors,
|
||||
&HashMap::new(),
|
||||
);
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
|
||||
|
||||
@@ -6,7 +6,10 @@ use codex_protocol::user_input::UserInput;
|
||||
|
||||
use crate::connectors;
|
||||
use crate::skills::SkillMetadata;
|
||||
use crate::skills::injection::ToolMentionKind;
|
||||
use crate::skills::injection::app_id_from_path;
|
||||
use crate::skills::injection::extract_tool_mentions;
|
||||
use crate::skills::injection::tool_kind_for_path;
|
||||
|
||||
pub(crate) struct CollectedToolMentions {
|
||||
pub(crate) plain_names: HashSet<String>,
|
||||
@@ -24,13 +27,24 @@ pub(crate) fn collect_tool_mentions_from_messages(messages: &[String]) -> Collec
|
||||
CollectedToolMentions { plain_names, paths }
|
||||
}
|
||||
|
||||
pub(crate) fn collect_explicit_app_paths(input: &[UserInput]) -> Vec<String> {
|
||||
pub(crate) fn collect_explicit_app_ids(input: &[UserInput]) -> HashSet<String> {
|
||||
let messages = input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
UserInput::Text { text, .. } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
UserInput::Mention { path, .. } => Some(path.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.chain(collect_tool_mentions_from_messages(&messages).paths)
|
||||
.filter(|path| tool_kind_for_path(path.as_str()) == ToolMentionKind::App)
|
||||
.filter_map(|path| app_id_from_path(path.as_str()).map(str::to_string))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -62,3 +76,69 @@ pub(crate) fn build_connector_slug_counts(
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::collect_explicit_app_ids;
|
||||
|
||||
fn text_input(text: &str) -> UserInput {
|
||||
UserInput::Text {
|
||||
text: text.to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_explicit_app_ids_from_linked_text_mentions() {
|
||||
let input = vec")];
|
||||
|
||||
let app_ids = collect_explicit_app_ids(&input);
|
||||
|
||||
assert_eq!(app_ids, HashSet::from(["calendar".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_explicit_app_ids_dedupes_structured_and_linked_mentions() {
|
||||
let input = vec"),
|
||||
UserInput::Mention {
|
||||
name: "calendar".to_string(),
|
||||
path: "app://calendar".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let app_ids = collect_explicit_app_ids(&input);
|
||||
|
||||
assert_eq!(app_ids, HashSet::from(["calendar".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_explicit_app_ids_ignores_non_app_paths() {
|
||||
let input = vec and [$skill](skill://team/skill) and [$file](/tmp/file.txt)",
|
||||
),
|
||||
UserInput::Mention {
|
||||
name: "docs".to_string(),
|
||||
path: "mcp://docs".to_string(),
|
||||
},
|
||||
UserInput::Mention {
|
||||
name: "skill".to_string(),
|
||||
path: "skill://team/skill".to_string(),
|
||||
},
|
||||
UserInput::Mention {
|
||||
name: "file".to_string(),
|
||||
path: "/tmp/file.txt".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let app_ids = collect_explicit_app_ids(&input);
|
||||
|
||||
assert_eq!(app_ids, HashSet::<String>::new());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::analytics_client::AnalyticsEventsClient;
|
||||
use crate::analytics_client::SkillInvocation;
|
||||
use crate::analytics_client::TrackEventsContext;
|
||||
use crate::instructions::SkillInstructions;
|
||||
use crate::mentions::build_skill_name_counts;
|
||||
use crate::skills::SkillMetadata;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -93,13 +94,14 @@ pub(crate) fn collect_explicit_skill_mentions(
|
||||
inputs: &[UserInput],
|
||||
skills: &[SkillMetadata],
|
||||
disabled_paths: &HashSet<PathBuf>,
|
||||
skill_name_counts: &HashMap<String, usize>,
|
||||
connector_slug_counts: &HashMap<String, usize>,
|
||||
) -> Vec<SkillMetadata> {
|
||||
let skill_name_counts = build_skill_name_counts(skills, disabled_paths).0;
|
||||
|
||||
let selection_context = SkillSelectionContext {
|
||||
skills,
|
||||
disabled_paths,
|
||||
skill_name_counts,
|
||||
skill_name_counts: &skill_name_counts,
|
||||
connector_slug_counts,
|
||||
};
|
||||
let mut selected: Vec<SkillMetadata> = Vec::new();
|
||||
@@ -489,34 +491,13 @@ mod tests {
|
||||
assert_eq!(mentions.paths, set(expected_paths));
|
||||
}
|
||||
|
||||
fn build_skill_name_counts(
|
||||
skills: &[SkillMetadata],
|
||||
disabled_paths: &HashSet<PathBuf>,
|
||||
) -> HashMap<String, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for skill in skills {
|
||||
if disabled_paths.contains(&skill.path) {
|
||||
continue;
|
||||
}
|
||||
*counts.entry(skill.name.clone()).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
fn collect_mentions(
|
||||
inputs: &[UserInput],
|
||||
skills: &[SkillMetadata],
|
||||
disabled_paths: &HashSet<PathBuf>,
|
||||
connector_slug_counts: &HashMap<String, usize>,
|
||||
) -> Vec<SkillMetadata> {
|
||||
let skill_name_counts = build_skill_name_counts(skills, disabled_paths);
|
||||
collect_explicit_skill_mentions(
|
||||
inputs,
|
||||
skills,
|
||||
disabled_paths,
|
||||
&skill_name_counts,
|
||||
connector_slug_counts,
|
||||
)
|
||||
collect_explicit_skill_mentions(inputs, skills, disabled_paths, connector_slug_counts)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user