mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
refactor tool namespaced identity usage
This commit is contained in:
@@ -1191,11 +1191,14 @@ impl McpConnectionManager {
|
||||
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
|
||||
}
|
||||
|
||||
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.list_all_tools()
|
||||
.await
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool.name.to_string()))
|
||||
pub async fn resolve_tool_info(&self, name: &str, namespace: Option<&str>) -> Option<ToolInfo> {
|
||||
let qualified_name = match namespace {
|
||||
Some(namespace) if name.starts_with(namespace) => name.to_string(),
|
||||
Some(namespace) => format!("{namespace}{name}"),
|
||||
None => name.to_string(),
|
||||
};
|
||||
|
||||
self.list_all_tools().await.get(&qualified_name).cloned()
|
||||
}
|
||||
|
||||
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
|
||||
|
||||
@@ -77,6 +77,7 @@ use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
|
||||
use codex_login::default_client::originator;
|
||||
use codex_mcp::McpConnectionManager;
|
||||
use codex_mcp::SandboxState;
|
||||
use codex_mcp::ToolInfo;
|
||||
use codex_mcp::codex_apps_tools_cache_key;
|
||||
#[cfg(test)]
|
||||
use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig;
|
||||
@@ -4412,25 +4413,16 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(
|
||||
pub(crate) async fn resolve_mcp_tool_info(
|
||||
&self,
|
||||
name: &str,
|
||||
namespace: &Option<String>,
|
||||
) -> Option<(String, String)> {
|
||||
let tool_name = if let Some(namespace) = namespace {
|
||||
if name.starts_with(namespace.as_str()) {
|
||||
name
|
||||
} else {
|
||||
&format!("{namespace}{name}")
|
||||
}
|
||||
} else {
|
||||
name
|
||||
};
|
||||
namespace: Option<&str>,
|
||||
) -> Option<ToolInfo> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.parse_tool_name(tool_name)
|
||||
.resolve_tool_info(name, namespace)
|
||||
.await
|
||||
}
|
||||
|
||||
|
||||
@@ -38,13 +38,14 @@ pub(crate) async fn emit_metric_for_tool_read(invocation: &ToolInvocation, succe
|
||||
}
|
||||
|
||||
let success = if success { "true" } else { "false" };
|
||||
let tool_name = invocation.tool_name.display();
|
||||
for kind in kinds {
|
||||
invocation.turn.session_telemetry.counter(
|
||||
MEMORIES_USAGE_METRIC,
|
||||
/*inc*/ 1,
|
||||
&[
|
||||
("kind", kind.as_tag()),
|
||||
("tool", invocation.tool_name.name()),
|
||||
("tool", &tool_name),
|
||||
("success", success),
|
||||
],
|
||||
);
|
||||
@@ -77,8 +78,11 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
|
||||
return None;
|
||||
};
|
||||
|
||||
match invocation.tool_name.name() {
|
||||
"shell" => serde_json::from_str::<ShellToolCallParams>(arguments)
|
||||
match (
|
||||
invocation.tool_name.namespace.as_deref(),
|
||||
invocation.tool_name.name.as_str(),
|
||||
) {
|
||||
(None, "shell") => serde_json::from_str::<ShellToolCallParams>(arguments)
|
||||
.ok()
|
||||
.map(|params| {
|
||||
(
|
||||
@@ -86,7 +90,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
|
||||
invocation.turn.resolve_path(params.workdir).to_path_buf(),
|
||||
)
|
||||
}),
|
||||
"shell_command" => serde_json::from_str::<ShellCommandToolCallParams>(arguments)
|
||||
(None, "shell_command") => serde_json::from_str::<ShellCommandToolCallParams>(arguments)
|
||||
.ok()
|
||||
.map(|params| {
|
||||
if !invocation.turn.tools_config.allow_login_shell && params.login == Some(true) {
|
||||
@@ -107,7 +111,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
|
||||
invocation.turn.resolve_path(params.workdir).to_path_buf(),
|
||||
)
|
||||
}),
|
||||
"exec_command" => serde_json::from_str::<ExecCommandArgs>(arguments)
|
||||
(None, "exec_command") => serde_json::from_str::<ExecCommandArgs>(arguments)
|
||||
.ok()
|
||||
.and_then(|params| {
|
||||
let command = crate::tools::handlers::unified_exec::get_command(
|
||||
@@ -122,7 +126,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
|
||||
invocation.turn.resolve_path(params.workdir).to_path_buf(),
|
||||
))
|
||||
}),
|
||||
_ => None,
|
||||
(Some(_), _) | (None, _) => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ pub(crate) async fn handle_output_item_done(
|
||||
tracing::info!(
|
||||
thread_id = %ctx.sess.conversation_id,
|
||||
"ToolCall: {} {}",
|
||||
call.tool_name.name(),
|
||||
call.tool_name.display(),
|
||||
payload_preview
|
||||
);
|
||||
|
||||
|
||||
@@ -73,7 +73,9 @@ impl ToolHandler for CodeModeExecuteHandler {
|
||||
} = invocation;
|
||||
|
||||
match payload {
|
||||
ToolPayload::Custom { input } if tool_name.name() == PUBLIC_TOOL_NAME => {
|
||||
ToolPayload::Custom { input }
|
||||
if tool_name.namespace.is_none() && tool_name.name.as_str() == PUBLIC_TOOL_NAME =>
|
||||
{
|
||||
self.execute(session, turn, call_id, input).await
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||
|
||||
@@ -284,22 +284,25 @@ async fn call_nested_tool(
|
||||
)));
|
||||
}
|
||||
|
||||
let payload =
|
||||
if let Some((server, tool)) = exec.session.parse_mcp_tool_name(&tool_name, &None).await {
|
||||
match serialize_function_tool_arguments(&tool_name, input) {
|
||||
Ok(raw_arguments) => ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
raw_arguments,
|
||||
},
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
}
|
||||
} else {
|
||||
match build_nested_tool_payload(tool_runtime.find_spec(&tool_name), &tool_name, input) {
|
||||
Ok(payload) => payload,
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
}
|
||||
};
|
||||
let payload = if let Some(tool_info) = exec
|
||||
.session
|
||||
.resolve_mcp_tool_info(&tool_name, /*namespace*/ None)
|
||||
.await
|
||||
{
|
||||
match serialize_function_tool_arguments(&tool_name, input) {
|
||||
Ok(raw_arguments) => ToolPayload::Mcp {
|
||||
server: tool_info.server_name,
|
||||
tool: tool_info.tool.name.to_string(),
|
||||
raw_arguments,
|
||||
},
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
}
|
||||
} else {
|
||||
match build_nested_tool_payload(tool_runtime.find_spec(&tool_name), &tool_name, input) {
|
||||
Ok(payload) => payload,
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
}
|
||||
};
|
||||
|
||||
let call = ToolCall {
|
||||
tool_name: ToolName::plain(tool_name.clone()),
|
||||
|
||||
@@ -55,7 +55,9 @@ impl ToolHandler for CodeModeWaitHandler {
|
||||
} = invocation;
|
||||
|
||||
match payload {
|
||||
ToolPayload::Function { arguments } if tool_name.name() == WAIT_TOOL_NAME => {
|
||||
ToolPayload::Function { arguments }
|
||||
if tool_name.namespace.is_none() && tool_name.name.as_str() == WAIT_TOOL_NAME =>
|
||||
{
|
||||
let args: ExecWaitArgs = parse_arguments(&arguments)?;
|
||||
let exec = ExecContext { session, turn };
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
@@ -206,7 +206,7 @@ impl ToolHandler for BatchJobHandler {
|
||||
}
|
||||
};
|
||||
|
||||
match tool_name.name() {
|
||||
match tool_name.name.as_str() {
|
||||
"spawn_agents_on_csv" => spawn_agents_on_csv::handle(session, turn, arguments).await,
|
||||
"report_agent_job_result" => report_agent_job_result::handle(session, arguments).await,
|
||||
other => Err(FunctionCallError::RespondToModel(format!(
|
||||
|
||||
@@ -218,7 +218,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
session: session.clone(),
|
||||
turn: turn.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.name().to_string(),
|
||||
tool_name: tool_name.display(),
|
||||
};
|
||||
let out = orchestrator
|
||||
.run(
|
||||
|
||||
@@ -50,19 +50,14 @@ impl ToolHandler for DynamicToolHandler {
|
||||
};
|
||||
|
||||
let args: Value = parse_arguments(&arguments)?;
|
||||
let response = request_dynamic_tool(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
call_id,
|
||||
tool_name.name().to_string(),
|
||||
args,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"dynamic tool call was cancelled before receiving a response".to_string(),
|
||||
)
|
||||
})?;
|
||||
let response =
|
||||
request_dynamic_tool(&session, turn.as_ref(), call_id, tool_name.display(), args)
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"dynamic tool call was cancelled before receiving a response".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let DynamicToolResponse {
|
||||
content_items,
|
||||
|
||||
@@ -205,7 +205,7 @@ impl ToolHandler for McpResourceHandler {
|
||||
|
||||
let arguments_value = parse_arguments(arguments.as_str())?;
|
||||
|
||||
match tool_name.name() {
|
||||
match tool_name.name.as_str() {
|
||||
"list_mcp_resources" => {
|
||||
handle_list_resources(
|
||||
Arc::clone(&session),
|
||||
|
||||
@@ -239,7 +239,7 @@ impl ToolHandler for ShellHandler {
|
||||
let exec_params =
|
||||
Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id);
|
||||
Self::run_exec_like(RunExecLikeArgs {
|
||||
tool_name: tool_name.name().to_string(),
|
||||
tool_name: tool_name.display(),
|
||||
exec_params,
|
||||
additional_permissions: params.additional_permissions.clone(),
|
||||
prefix_rule,
|
||||
@@ -256,7 +256,7 @@ impl ToolHandler for ShellHandler {
|
||||
let exec_params =
|
||||
Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id);
|
||||
Self::run_exec_like(RunExecLikeArgs {
|
||||
tool_name: tool_name.name().to_string(),
|
||||
tool_name: tool_name.display(),
|
||||
exec_params,
|
||||
additional_permissions: None,
|
||||
prefix_rule: None,
|
||||
@@ -271,7 +271,7 @@ impl ToolHandler for ShellHandler {
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported payload for shell handler: {}",
|
||||
tool_name.name()
|
||||
tool_name.display()
|
||||
))),
|
||||
}
|
||||
}
|
||||
@@ -341,7 +341,7 @@ impl ToolHandler for ShellCommandHandler {
|
||||
let ToolPayload::Function { arguments } = payload else {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported payload for shell_command handler: {}",
|
||||
tool_name.name()
|
||||
tool_name.display()
|
||||
)));
|
||||
};
|
||||
|
||||
@@ -364,7 +364,7 @@ impl ToolHandler for ShellCommandHandler {
|
||||
turn.tools_config.allow_login_shell,
|
||||
)?;
|
||||
ShellHandler::run_exec_like(RunExecLikeArgs {
|
||||
tool_name: tool_name.name().to_string(),
|
||||
tool_name: tool_name.display(),
|
||||
exec_params,
|
||||
additional_permissions: params.additional_permissions.clone(),
|
||||
prefix_rule,
|
||||
|
||||
@@ -121,7 +121,9 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
}
|
||||
|
||||
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
||||
if invocation.tool_name.name() != "exec_command" {
|
||||
if invocation.tool_name.namespace.is_some()
|
||||
|| invocation.tool_name.name.as_str() != "exec_command"
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -186,7 +188,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager;
|
||||
let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone());
|
||||
|
||||
let response = match tool_name.name() {
|
||||
let response = match tool_name.name.as_str() {
|
||||
"exec_command" => {
|
||||
let cwd = resolve_workdir_base_path(&arguments, &context.turn.cwd)?;
|
||||
let args: ExecCommandArgs = parse_arguments_with_base_path(&arguments, &cwd)?;
|
||||
@@ -289,7 +291,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
context.turn.clone(),
|
||||
Some(&tracker),
|
||||
&context.call_id,
|
||||
tool_name.name(),
|
||||
&tool_name.name,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
|
||||
@@ -1572,14 +1572,14 @@ impl JsReplManager {
|
||||
},
|
||||
);
|
||||
|
||||
let payload = if let Some((server, tool)) = exec
|
||||
let payload = if let Some(tool_info) = exec
|
||||
.session
|
||||
.parse_mcp_tool_name(&req.tool_name, &None)
|
||||
.resolve_mcp_tool_info(&req.tool_name, /*namespace*/ None)
|
||||
.await
|
||||
{
|
||||
crate::tools::context::ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
server: tool_info.server_name,
|
||||
tool: tool_info.tool.name.to_string(),
|
||||
raw_arguments: req.arguments.clone(),
|
||||
}
|
||||
} else if is_freeform_tool(&router.specs(), &req.tool_name) {
|
||||
|
||||
@@ -78,18 +78,19 @@ impl ToolCallRuntime {
|
||||
source: ToolCallSource,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<AnyToolResult, FunctionCallError>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(call.tool_name.name());
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let started = Instant::now();
|
||||
let display_name = call.tool_name.display();
|
||||
|
||||
let dispatch_span = trace_span!(
|
||||
"dispatch_tool_call_with_code_mode_result",
|
||||
otel.name = call.tool_name.name(),
|
||||
tool_name = call.tool_name.name(),
|
||||
otel.name = display_name.as_str(),
|
||||
tool_name = display_name.as_str(),
|
||||
call_id = call.call_id.as_str(),
|
||||
aborted = false,
|
||||
);
|
||||
@@ -171,11 +172,15 @@ impl ToolCallRuntime {
|
||||
}
|
||||
|
||||
fn abort_message(call: &ToolCall, secs: f32) -> String {
|
||||
match call.tool_name.name() {
|
||||
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec" => {
|
||||
format!("Wall time: {secs:.1} seconds\naborted by user")
|
||||
}
|
||||
_ => format!("aborted by user after {secs:.1}s"),
|
||||
if call.tool_name.namespace.is_none()
|
||||
&& matches!(
|
||||
call.tool_name.name.as_str(),
|
||||
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec"
|
||||
)
|
||||
{
|
||||
format!("Wall time: {secs:.1} seconds\naborted by user")
|
||||
} else {
|
||||
format!("aborted by user after {secs:.1}s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,25 +180,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_handler_key(tool_name: &ToolName) -> String {
|
||||
if let Some(namespace) = tool_name.namespace() {
|
||||
format!("{namespace}:{}", tool_name.name())
|
||||
} else {
|
||||
tool_name.name().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
|
||||
handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
fn new(handlers: HashMap<String, Arc<dyn AnyToolHandler>>) -> Self {
|
||||
fn new(handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>) -> Self {
|
||||
Self { handlers }
|
||||
}
|
||||
|
||||
fn handler(&self, name: &ToolName) -> Option<Arc<dyn AnyToolHandler>> {
|
||||
self.handlers.get(&tool_handler_key(name)).map(Arc::clone)
|
||||
self.handlers.get(name).map(Arc::clone)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -219,6 +211,7 @@ impl ToolRegistry {
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let display_name = tool_name.display();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.session_telemetry.clone();
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
@@ -265,7 +258,7 @@ impl ToolRegistry {
|
||||
None => {
|
||||
let message = unsupported_tool_call_message(&invocation.payload, &tool_name);
|
||||
otel.tool_result_with_tags(
|
||||
tool_name.name(),
|
||||
&display_name,
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
@@ -280,10 +273,9 @@ impl ToolRegistry {
|
||||
};
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let key = tool_handler_key(&tool_name);
|
||||
let message = format!("tool {key} invoked with incompatible payload");
|
||||
let message = format!("tool {display_name} invoked with incompatible payload");
|
||||
otel.tool_result_with_tags(
|
||||
tool_name.name(),
|
||||
&display_name,
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
@@ -318,7 +310,7 @@ impl ToolRegistry {
|
||||
let started = Instant::now();
|
||||
let result = otel
|
||||
.log_tool_result_with_tags(
|
||||
tool_name.name(),
|
||||
&display_name,
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
&metric_tags,
|
||||
@@ -438,7 +430,7 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
|
||||
handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
@@ -468,10 +460,10 @@ impl ToolRegistryBuilder {
|
||||
H: ToolHandler + 'static,
|
||||
{
|
||||
let name = name.into();
|
||||
let key = tool_handler_key(&name);
|
||||
let display_name = name.display();
|
||||
let handler: Arc<dyn AnyToolHandler> = handler;
|
||||
if self.handlers.insert(key.clone(), handler.clone()).is_some() {
|
||||
warn!("overwriting handler for tool {key}");
|
||||
if self.handlers.insert(name, handler).is_some() {
|
||||
warn!("overwriting handler for tool {display_name}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,7 +492,7 @@ impl ToolRegistryBuilder {
|
||||
}
|
||||
|
||||
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &ToolName) -> String {
|
||||
let tool_name = tool_handler_key(tool_name);
|
||||
let tool_name = tool_name.display();
|
||||
match payload {
|
||||
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||||
_ => format!("unsupported call: {tool_name}"),
|
||||
@@ -593,7 +585,7 @@ async fn dispatch_after_tool_use_hook(
|
||||
event: HookEventAfterToolUse {
|
||||
turn_id: turn.sub_id.clone(),
|
||||
call_id: invocation.call_id.clone(),
|
||||
tool_name: invocation.tool_name.name().to_string(),
|
||||
tool_name: invocation.tool_name.display(),
|
||||
tool_kind: hook_tool_kind(&tool_input),
|
||||
tool_input,
|
||||
executed: dispatch.executed,
|
||||
@@ -616,7 +608,7 @@ async fn dispatch_after_tool_use_hook(
|
||||
HookResult::FailedContinue(error) => {
|
||||
warn!(
|
||||
call_id = %invocation.call_id,
|
||||
tool_name = invocation.tool_name.name(),
|
||||
tool_name = %invocation.tool_name.display(),
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_tool_use hook failed; continuing"
|
||||
@@ -625,7 +617,7 @@ async fn dispatch_after_tool_use_hook(
|
||||
HookResult::FailedAbort(error) => {
|
||||
warn!(
|
||||
call_id = %invocation.call_id,
|
||||
tool_name = invocation.tool_name.name(),
|
||||
tool_name = %invocation.tool_name.display(),
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_tool_use hook failed; aborting operation"
|
||||
|
||||
@@ -24,11 +24,8 @@ fn handler_looks_up_namespaced_aliases_explicitly() {
|
||||
let plain_name = codex_tools::ToolName::plain(tool_name);
|
||||
let namespaced_name = codex_tools::ToolName::namespaced(namespace, tool_name);
|
||||
let registry = ToolRegistry::new(HashMap::from([
|
||||
(tool_handler_key(&plain_name), Arc::clone(&plain_handler)),
|
||||
(
|
||||
tool_handler_key(&namespaced_name),
|
||||
Arc::clone(&namespaced_handler),
|
||||
),
|
||||
(plain_name.clone(), Arc::clone(&plain_handler)),
|
||||
(namespaced_name.clone(), Arc::clone(&namespaced_handler)),
|
||||
]));
|
||||
|
||||
let plain = registry.handler(&plain_name);
|
||||
|
||||
@@ -104,11 +104,13 @@ impl ToolRouter {
|
||||
.map(|config| config.spec.clone())
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.name() == tool_name)
|
||||
pub fn tool_supports_parallel(&self, tool_name: &ToolName) -> bool {
|
||||
tool_name.namespace.is_none()
|
||||
&& self
|
||||
.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.name() == tool_name.name.as_str())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
@@ -124,18 +126,20 @@ impl ToolRouter {
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mcp_tool = session.parse_mcp_tool_name(&name, &namespace).await;
|
||||
let mcp_tool = session
|
||||
.resolve_mcp_tool_info(&name, namespace.as_deref())
|
||||
.await;
|
||||
let tool_name = match namespace {
|
||||
Some(namespace) => ToolName::namespaced(namespace, name),
|
||||
None => ToolName::plain(name),
|
||||
};
|
||||
if let Some((server, tool)) = mcp_tool {
|
||||
if let Some(tool_info) = mcp_tool {
|
||||
Ok(Some(ToolCall {
|
||||
tool_name,
|
||||
call_id,
|
||||
payload: ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
server: tool_info.server_name,
|
||||
tool: tool_info.tool.name.to_string(),
|
||||
raw_arguments: arguments,
|
||||
},
|
||||
}))
|
||||
@@ -224,7 +228,8 @@ impl ToolRouter {
|
||||
payload,
|
||||
} = call;
|
||||
|
||||
let direct_js_repl_call = matches!(tool_name.name(), "js_repl" | "js_repl_reset");
|
||||
let direct_js_repl_call = tool_name.namespace.is_none()
|
||||
&& matches!(tool_name.name.as_str(), "js_repl" | "js_repl_reset");
|
||||
if source == ToolCallSource::Direct
|
||||
&& turn.tools_config.js_repl_tools_only
|
||||
&& !direct_js_repl_call
|
||||
|
||||
@@ -117,6 +117,84 @@ async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()>
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn js_repl_tools_only_blocks_namespaced_js_repl_tool() -> anyhow::Result<()> {
|
||||
let (session, mut turn) = make_session_and_context().await;
|
||||
turn.tools_config.js_repl_tools_only = true;
|
||||
|
||||
let session = Arc::new(session);
|
||||
let turn = Arc::new(turn);
|
||||
let router = ToolRouter::from_config(
|
||||
&turn.tools_config,
|
||||
ToolRouterParams {
|
||||
deferred_mcp_tools: None,
|
||||
mcp_tools: None,
|
||||
discoverable_tools: None,
|
||||
dynamic_tools: turn.dynamic_tools.as_slice(),
|
||||
},
|
||||
);
|
||||
|
||||
let call = ToolCall {
|
||||
tool_name: ToolName::namespaced("mcp__server__", "js_repl"),
|
||||
call_id: "call-namespaced-js-repl".to_string(),
|
||||
payload: ToolPayload::Mcp {
|
||||
server: "server".to_string(),
|
||||
tool: "js_repl".to_string(),
|
||||
raw_arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let err = router
|
||||
.dispatch_tool_call_with_code_mode_result(
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call,
|
||||
ToolCallSource::Direct,
|
||||
)
|
||||
.await
|
||||
.err()
|
||||
.expect("namespaced js_repl calls should be blocked");
|
||||
let FunctionCallError::RespondToModel(message) = err else {
|
||||
panic!("expected RespondToModel, got {err:?}");
|
||||
};
|
||||
assert!(message.contains("direct tool calls are disabled"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_support_does_not_match_namespaced_local_tool_names() -> anyhow::Result<()> {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
let mcp_tools = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools()
|
||||
.await;
|
||||
let router = ToolRouter::from_config(
|
||||
&turn.tools_config,
|
||||
ToolRouterParams {
|
||||
deferred_mcp_tools: None,
|
||||
mcp_tools: Some(mcp_tools),
|
||||
discoverable_tools: None,
|
||||
dynamic_tools: turn.dynamic_tools.as_slice(),
|
||||
},
|
||||
);
|
||||
|
||||
let parallel_tool_name = ["shell", "local_shell", "exec_command", "shell_command"]
|
||||
.into_iter()
|
||||
.find(|name| router.tool_supports_parallel(&ToolName::plain(*name)))
|
||||
.expect("test session should expose a parallel shell-like tool");
|
||||
|
||||
assert!(
|
||||
!router.tool_supports_parallel(&ToolName::namespaced("mcp__server__", parallel_tool_name))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_tool_call_uses_namespace_for_registry_name() -> anyhow::Result<()> {
|
||||
let (session, _) = make_session_and_context().await;
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
/// provides one.
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub struct ToolName {
|
||||
name: String,
|
||||
namespace: Option<String>,
|
||||
pub name: String,
|
||||
pub namespace: Option<String>,
|
||||
}
|
||||
|
||||
impl ToolName {
|
||||
@@ -21,12 +21,11 @@ impl ToolName {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn namespace(&self) -> Option<&str> {
|
||||
self.namespace.as_deref()
|
||||
pub fn display(&self) -> String {
|
||||
match &self.namespace {
|
||||
Some(namespace) => format!("{namespace}{}", self.name),
|
||||
None => self.name.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user