mirror of
https://github.com/openai/codex.git
synced 2026-04-30 17:36:40 +00:00
285 lines
9.1 KiB
Rust
285 lines
9.1 KiB
Rust
use crate::codex::Session;
|
|
use crate::codex::TurnContext;
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::sandboxing::SandboxPermissions;
|
|
use crate::tools::context::SharedTurnDiffTracker;
|
|
use crate::tools::context::ToolInvocation;
|
|
use crate::tools::context::ToolPayload;
|
|
use crate::tools::registry::AnyToolResult;
|
|
use crate::tools::registry::ToolRegistry;
|
|
use crate::tools::spec::build_specs_with_discoverable_tools;
|
|
use codex_mcp::ToolInfo;
|
|
use codex_protocol::dynamic_tools::DynamicToolSpec;
|
|
use codex_protocol::models::LocalShellAction;
|
|
use codex_protocol::models::ResponseItem;
|
|
use codex_protocol::models::SearchToolCallParams;
|
|
use codex_protocol::models::ShellToolCallParams;
|
|
use codex_tools::ConfiguredToolSpec;
|
|
use codex_tools::DiscoverableTool;
|
|
use codex_tools::ToolNamespace;
|
|
use codex_tools::ToolSpec;
|
|
use codex_tools::ToolsConfig;
|
|
use rmcp::model::Tool;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tracing::instrument;
|
|
|
|
pub use crate::tools::context::ToolCallSource;
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct ToolCall {
|
|
pub tool_name: String,
|
|
pub tool_namespace: Option<String>,
|
|
pub call_id: String,
|
|
pub payload: ToolPayload,
|
|
}
|
|
|
|
pub struct ToolRouter {
|
|
registry: ToolRegistry,
|
|
specs: Vec<ConfiguredToolSpec>,
|
|
model_visible_specs: Vec<ToolSpec>,
|
|
}
|
|
|
|
pub(crate) struct ToolRouterParams<'a> {
|
|
pub(crate) mcp_tools: Option<HashMap<String, Tool>>,
|
|
pub(crate) tool_namespaces: Option<HashMap<String, ToolNamespace>>,
|
|
pub(crate) app_tools: Option<HashMap<String, ToolInfo>>,
|
|
pub(crate) discoverable_tools: Option<Vec<DiscoverableTool>>,
|
|
pub(crate) dynamic_tools: &'a [DynamicToolSpec],
|
|
}
|
|
|
|
pub(crate) struct McpToolRouterInputs {
|
|
pub(crate) mcp_tools: HashMap<String, Tool>,
|
|
pub(crate) tool_namespaces: HashMap<String, ToolNamespace>,
|
|
}
|
|
|
|
pub(crate) fn map_mcp_tool_infos(mcp_tools: &HashMap<String, ToolInfo>) -> McpToolRouterInputs {
|
|
McpToolRouterInputs {
|
|
mcp_tools: mcp_tools
|
|
.iter()
|
|
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
|
|
.collect(),
|
|
tool_namespaces: mcp_tools
|
|
.iter()
|
|
.map(|(name, tool)| {
|
|
(
|
|
name.clone(),
|
|
ToolNamespace {
|
|
name: tool.tool_namespace.clone(),
|
|
description: tool.server_instructions.clone(),
|
|
},
|
|
)
|
|
})
|
|
.collect(),
|
|
}
|
|
}
|
|
|
|
impl ToolRouter {
|
|
pub fn from_config(config: &ToolsConfig, params: ToolRouterParams<'_>) -> Self {
|
|
let ToolRouterParams {
|
|
mcp_tools,
|
|
tool_namespaces,
|
|
app_tools,
|
|
discoverable_tools,
|
|
dynamic_tools,
|
|
} = params;
|
|
let builder = build_specs_with_discoverable_tools(
|
|
config,
|
|
mcp_tools,
|
|
app_tools,
|
|
tool_namespaces,
|
|
discoverable_tools,
|
|
dynamic_tools,
|
|
);
|
|
let (specs, registry) = builder.build();
|
|
let model_visible_specs = if config.code_mode_only_enabled {
|
|
specs
|
|
.iter()
|
|
.filter_map(|configured_tool| {
|
|
if !codex_code_mode::is_code_mode_nested_tool(configured_tool.name()) {
|
|
Some(configured_tool.spec.clone())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect()
|
|
} else {
|
|
specs
|
|
.iter()
|
|
.map(|configured_tool| configured_tool.spec.clone())
|
|
.collect()
|
|
};
|
|
|
|
Self {
|
|
registry,
|
|
specs,
|
|
model_visible_specs,
|
|
}
|
|
}
|
|
|
|
pub fn specs(&self) -> Vec<ToolSpec> {
|
|
self.specs
|
|
.iter()
|
|
.map(|config| config.spec.clone())
|
|
.collect()
|
|
}
|
|
|
|
pub fn model_visible_specs(&self) -> Vec<ToolSpec> {
|
|
self.model_visible_specs.clone()
|
|
}
|
|
|
|
pub fn find_spec(&self, tool_name: &str) -> Option<ToolSpec> {
|
|
self.specs
|
|
.iter()
|
|
.find(|config| config.name() == tool_name)
|
|
.map(|config| config.spec.clone())
|
|
}
|
|
|
|
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
|
self.specs
|
|
.iter()
|
|
.filter(|config| config.supports_parallel_tool_calls)
|
|
.any(|config| config.name() == tool_name)
|
|
}
|
|
|
|
#[instrument(level = "trace", skip_all, err)]
|
|
pub async fn build_tool_call(
|
|
session: &Session,
|
|
item: ResponseItem,
|
|
) -> Result<Option<ToolCall>, FunctionCallError> {
|
|
match item {
|
|
ResponseItem::FunctionCall {
|
|
name,
|
|
namespace,
|
|
arguments,
|
|
call_id,
|
|
..
|
|
} => {
|
|
if let Some((server, tool)) = session.parse_mcp_tool_name(&name, &namespace).await {
|
|
Ok(Some(ToolCall {
|
|
tool_name: name,
|
|
tool_namespace: namespace,
|
|
call_id,
|
|
payload: ToolPayload::Mcp {
|
|
server,
|
|
tool,
|
|
raw_arguments: arguments,
|
|
},
|
|
}))
|
|
} else {
|
|
Ok(Some(ToolCall {
|
|
tool_name: name,
|
|
tool_namespace: namespace,
|
|
call_id,
|
|
payload: ToolPayload::Function { arguments },
|
|
}))
|
|
}
|
|
}
|
|
ResponseItem::ToolSearchCall {
|
|
call_id: Some(call_id),
|
|
execution,
|
|
arguments,
|
|
..
|
|
} if execution == "client" => {
|
|
let arguments: SearchToolCallParams =
|
|
serde_json::from_value(arguments).map_err(|err| {
|
|
FunctionCallError::RespondToModel(format!(
|
|
"failed to parse tool_search arguments: {err}"
|
|
))
|
|
})?;
|
|
Ok(Some(ToolCall {
|
|
tool_name: "tool_search".to_string(),
|
|
tool_namespace: None,
|
|
call_id,
|
|
payload: ToolPayload::ToolSearch { arguments },
|
|
}))
|
|
}
|
|
ResponseItem::ToolSearchCall { .. } => Ok(None),
|
|
ResponseItem::CustomToolCall {
|
|
name,
|
|
input,
|
|
call_id,
|
|
..
|
|
} => Ok(Some(ToolCall {
|
|
tool_name: name,
|
|
tool_namespace: None,
|
|
call_id,
|
|
payload: ToolPayload::Custom { input },
|
|
})),
|
|
ResponseItem::LocalShellCall {
|
|
id,
|
|
call_id,
|
|
action,
|
|
..
|
|
} => {
|
|
let call_id = call_id
|
|
.or(id)
|
|
.ok_or(FunctionCallError::MissingLocalShellCallId)?;
|
|
|
|
match action {
|
|
LocalShellAction::Exec(exec) => {
|
|
let params = ShellToolCallParams {
|
|
command: exec.command,
|
|
workdir: exec.working_directory,
|
|
timeout_ms: exec.timeout_ms,
|
|
sandbox_permissions: Some(SandboxPermissions::UseDefault),
|
|
additional_permissions: None,
|
|
prefix_rule: None,
|
|
justification: None,
|
|
};
|
|
Ok(Some(ToolCall {
|
|
tool_name: "local_shell".to_string(),
|
|
tool_namespace: None,
|
|
call_id,
|
|
payload: ToolPayload::LocalShell { params },
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
_ => Ok(None),
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "trace", skip_all, err)]
|
|
pub async fn dispatch_tool_call_with_code_mode_result(
|
|
&self,
|
|
session: Arc<Session>,
|
|
turn: Arc<TurnContext>,
|
|
tracker: SharedTurnDiffTracker,
|
|
call: ToolCall,
|
|
source: ToolCallSource,
|
|
) -> Result<AnyToolResult, FunctionCallError> {
|
|
let ToolCall {
|
|
tool_name,
|
|
tool_namespace,
|
|
call_id,
|
|
payload,
|
|
} = call;
|
|
|
|
if source == ToolCallSource::Direct
|
|
&& turn.tools_config.js_repl_tools_only
|
|
&& !matches!(tool_name.as_str(), "js_repl" | "js_repl_reset")
|
|
{
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"direct tool calls are disabled; use js_repl and codex.tool(...) instead"
|
|
.to_string(),
|
|
));
|
|
}
|
|
|
|
let invocation = ToolInvocation {
|
|
session,
|
|
turn,
|
|
tracker,
|
|
call_id,
|
|
tool_name,
|
|
tool_namespace,
|
|
payload,
|
|
};
|
|
|
|
self.registry.dispatch_any(invocation).await
|
|
}
|
|
}
|
|
#[cfg(test)]
|
|
#[path = "router_tests.rs"]
|
|
mod tests;
|