refactor tool namespaced identity usage

This commit is contained in:
Sayan Sisodiya
2026-04-10 19:19:31 -07:00
parent c9d21ff2d3
commit df30d6c467
20 changed files with 205 additions and 126 deletions

View File

@@ -1191,11 +1191,14 @@ impl McpConnectionManager {
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
}
pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
self.list_all_tools()
.await
.get(tool_name)
.map(|tool| (tool.server_name.clone(), tool.tool.name.to_string()))
pub async fn resolve_tool_info(&self, name: &str, namespace: Option<&str>) -> Option<ToolInfo> {
let qualified_name = match namespace {
Some(namespace) if name.starts_with(namespace) => name.to_string(),
Some(namespace) => format!("{namespace}{name}"),
None => name.to_string(),
};
self.list_all_tools().await.get(&qualified_name).cloned()
}
pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {

View File

@@ -77,6 +77,7 @@ use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_login::default_client::originator;
use codex_mcp::McpConnectionManager;
use codex_mcp::SandboxState;
use codex_mcp::ToolInfo;
use codex_mcp::codex_apps_tools_cache_key;
#[cfg(test)]
use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig;
@@ -4412,25 +4413,16 @@ impl Session {
.await
}
pub(crate) async fn parse_mcp_tool_name(
pub(crate) async fn resolve_mcp_tool_info(
&self,
name: &str,
namespace: &Option<String>,
) -> Option<(String, String)> {
let tool_name = if let Some(namespace) = namespace {
if name.starts_with(namespace.as_str()) {
name
} else {
&format!("{namespace}{name}")
}
} else {
name
};
namespace: Option<&str>,
) -> Option<ToolInfo> {
self.services
.mcp_connection_manager
.read()
.await
.parse_tool_name(tool_name)
.resolve_tool_info(name, namespace)
.await
}

View File

@@ -38,13 +38,14 @@ pub(crate) async fn emit_metric_for_tool_read(invocation: &ToolInvocation, succe
}
let success = if success { "true" } else { "false" };
let tool_name = invocation.tool_name.display();
for kind in kinds {
invocation.turn.session_telemetry.counter(
MEMORIES_USAGE_METRIC,
/*inc*/ 1,
&[
("kind", kind.as_tag()),
("tool", invocation.tool_name.name()),
("tool", &tool_name),
("success", success),
],
);
@@ -77,8 +78,11 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
return None;
};
match invocation.tool_name.name() {
"shell" => serde_json::from_str::<ShellToolCallParams>(arguments)
match (
invocation.tool_name.namespace.as_deref(),
invocation.tool_name.name.as_str(),
) {
(None, "shell") => serde_json::from_str::<ShellToolCallParams>(arguments)
.ok()
.map(|params| {
(
@@ -86,7 +90,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
invocation.turn.resolve_path(params.workdir).to_path_buf(),
)
}),
"shell_command" => serde_json::from_str::<ShellCommandToolCallParams>(arguments)
(None, "shell_command") => serde_json::from_str::<ShellCommandToolCallParams>(arguments)
.ok()
.map(|params| {
if !invocation.turn.tools_config.allow_login_shell && params.login == Some(true) {
@@ -107,7 +111,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
invocation.turn.resolve_path(params.workdir).to_path_buf(),
)
}),
"exec_command" => serde_json::from_str::<ExecCommandArgs>(arguments)
(None, "exec_command") => serde_json::from_str::<ExecCommandArgs>(arguments)
.ok()
.and_then(|params| {
let command = crate::tools::handlers::unified_exec::get_command(
@@ -122,7 +126,7 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
invocation.turn.resolve_path(params.workdir).to_path_buf(),
))
}),
_ => None,
(Some(_), _) | (None, _) => None,
}
}

View File

@@ -221,7 +221,7 @@ pub(crate) async fn handle_output_item_done(
tracing::info!(
thread_id = %ctx.sess.conversation_id,
"ToolCall: {} {}",
call.tool_name.name(),
call.tool_name.display(),
payload_preview
);

View File

@@ -73,7 +73,9 @@ impl ToolHandler for CodeModeExecuteHandler {
} = invocation;
match payload {
ToolPayload::Custom { input } if tool_name.name() == PUBLIC_TOOL_NAME => {
ToolPayload::Custom { input }
if tool_name.namespace.is_none() && tool_name.name.as_str() == PUBLIC_TOOL_NAME =>
{
self.execute(session, turn, call_id, input).await
}
_ => Err(FunctionCallError::RespondToModel(format!(

View File

@@ -284,22 +284,25 @@ async fn call_nested_tool(
)));
}
let payload =
if let Some((server, tool)) = exec.session.parse_mcp_tool_name(&tool_name, &None).await {
match serialize_function_tool_arguments(&tool_name, input) {
Ok(raw_arguments) => ToolPayload::Mcp {
server,
tool,
raw_arguments,
},
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
}
} else {
match build_nested_tool_payload(tool_runtime.find_spec(&tool_name), &tool_name, input) {
Ok(payload) => payload,
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
}
};
let payload = if let Some(tool_info) = exec
.session
.resolve_mcp_tool_info(&tool_name, /*namespace*/ None)
.await
{
match serialize_function_tool_arguments(&tool_name, input) {
Ok(raw_arguments) => ToolPayload::Mcp {
server: tool_info.server_name,
tool: tool_info.tool.name.to_string(),
raw_arguments,
},
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
}
} else {
match build_nested_tool_payload(tool_runtime.find_spec(&tool_name), &tool_name, input) {
Ok(payload) => payload,
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
}
};
let call = ToolCall {
tool_name: ToolName::plain(tool_name.clone()),

View File

@@ -55,7 +55,9 @@ impl ToolHandler for CodeModeWaitHandler {
} = invocation;
match payload {
ToolPayload::Function { arguments } if tool_name.name() == WAIT_TOOL_NAME => {
ToolPayload::Function { arguments }
if tool_name.namespace.is_none() && tool_name.name.as_str() == WAIT_TOOL_NAME =>
{
let args: ExecWaitArgs = parse_arguments(&arguments)?;
let exec = ExecContext { session, turn };
let started_at = std::time::Instant::now();

View File

@@ -206,7 +206,7 @@ impl ToolHandler for BatchJobHandler {
}
};
match tool_name.name() {
match tool_name.name.as_str() {
"spawn_agents_on_csv" => spawn_agents_on_csv::handle(session, turn, arguments).await,
"report_agent_job_result" => report_agent_job_result::handle(session, arguments).await,
other => Err(FunctionCallError::RespondToModel(format!(

View File

@@ -218,7 +218,7 @@ impl ToolHandler for ApplyPatchHandler {
session: session.clone(),
turn: turn.clone(),
call_id: call_id.clone(),
tool_name: tool_name.name().to_string(),
tool_name: tool_name.display(),
};
let out = orchestrator
.run(

View File

@@ -50,19 +50,14 @@ impl ToolHandler for DynamicToolHandler {
};
let args: Value = parse_arguments(&arguments)?;
let response = request_dynamic_tool(
&session,
turn.as_ref(),
call_id,
tool_name.name().to_string(),
args,
)
.await
.ok_or_else(|| {
FunctionCallError::RespondToModel(
"dynamic tool call was cancelled before receiving a response".to_string(),
)
})?;
let response =
request_dynamic_tool(&session, turn.as_ref(), call_id, tool_name.display(), args)
.await
.ok_or_else(|| {
FunctionCallError::RespondToModel(
"dynamic tool call was cancelled before receiving a response".to_string(),
)
})?;
let DynamicToolResponse {
content_items,

View File

@@ -205,7 +205,7 @@ impl ToolHandler for McpResourceHandler {
let arguments_value = parse_arguments(arguments.as_str())?;
match tool_name.name() {
match tool_name.name.as_str() {
"list_mcp_resources" => {
handle_list_resources(
Arc::clone(&session),

View File

@@ -239,7 +239,7 @@ impl ToolHandler for ShellHandler {
let exec_params =
Self::to_exec_params(&params, turn.as_ref(), session.conversation_id);
Self::run_exec_like(RunExecLikeArgs {
tool_name: tool_name.name().to_string(),
tool_name: tool_name.display(),
exec_params,
additional_permissions: params.additional_permissions.clone(),
prefix_rule,
@@ -256,7 +256,7 @@ impl ToolHandler for ShellHandler {
let exec_params =
Self::to_exec_params(&params, turn.as_ref(), session.conversation_id);
Self::run_exec_like(RunExecLikeArgs {
tool_name: tool_name.name().to_string(),
tool_name: tool_name.display(),
exec_params,
additional_permissions: None,
prefix_rule: None,
@@ -271,7 +271,7 @@ impl ToolHandler for ShellHandler {
}
_ => Err(FunctionCallError::RespondToModel(format!(
"unsupported payload for shell handler: {}",
tool_name.name()
tool_name.display()
))),
}
}
@@ -341,7 +341,7 @@ impl ToolHandler for ShellCommandHandler {
let ToolPayload::Function { arguments } = payload else {
return Err(FunctionCallError::RespondToModel(format!(
"unsupported payload for shell_command handler: {}",
tool_name.name()
tool_name.display()
)));
};
@@ -364,7 +364,7 @@ impl ToolHandler for ShellCommandHandler {
turn.tools_config.allow_login_shell,
)?;
ShellHandler::run_exec_like(RunExecLikeArgs {
tool_name: tool_name.name().to_string(),
tool_name: tool_name.display(),
exec_params,
additional_permissions: params.additional_permissions.clone(),
prefix_rule,

View File

@@ -121,7 +121,9 @@ impl ToolHandler for UnifiedExecHandler {
}
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
if invocation.tool_name.name() != "exec_command" {
if invocation.tool_name.namespace.is_some()
|| invocation.tool_name.name.as_str() != "exec_command"
{
return None;
}
@@ -186,7 +188,7 @@ impl ToolHandler for UnifiedExecHandler {
let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager;
let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone());
let response = match tool_name.name() {
let response = match tool_name.name.as_str() {
"exec_command" => {
let cwd = resolve_workdir_base_path(&arguments, &context.turn.cwd)?;
let args: ExecCommandArgs = parse_arguments_with_base_path(&arguments, &cwd)?;
@@ -289,7 +291,7 @@ impl ToolHandler for UnifiedExecHandler {
context.turn.clone(),
Some(&tracker),
&context.call_id,
tool_name.name(),
&tool_name.name,
)
.await?
{

View File

@@ -1572,14 +1572,14 @@ impl JsReplManager {
},
);
let payload = if let Some((server, tool)) = exec
let payload = if let Some(tool_info) = exec
.session
.parse_mcp_tool_name(&req.tool_name, &None)
.resolve_mcp_tool_info(&req.tool_name, /*namespace*/ None)
.await
{
crate::tools::context::ToolPayload::Mcp {
server,
tool,
server: tool_info.server_name,
tool: tool_info.tool.name.to_string(),
raw_arguments: req.arguments.clone(),
}
} else if is_freeform_tool(&router.specs(), &req.tool_name) {

View File

@@ -78,18 +78,19 @@ impl ToolCallRuntime {
source: ToolCallSource,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<AnyToolResult, FunctionCallError>> {
let supports_parallel = self.router.tool_supports_parallel(call.tool_name.name());
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution);
let started = Instant::now();
let display_name = call.tool_name.display();
let dispatch_span = trace_span!(
"dispatch_tool_call_with_code_mode_result",
otel.name = call.tool_name.name(),
tool_name = call.tool_name.name(),
otel.name = display_name.as_str(),
tool_name = display_name.as_str(),
call_id = call.call_id.as_str(),
aborted = false,
);
@@ -171,11 +172,15 @@ impl ToolCallRuntime {
}
fn abort_message(call: &ToolCall, secs: f32) -> String {
match call.tool_name.name() {
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec" => {
format!("Wall time: {secs:.1} seconds\naborted by user")
}
_ => format!("aborted by user after {secs:.1}s"),
if call.tool_name.namespace.is_none()
&& matches!(
call.tool_name.name.as_str(),
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec"
)
{
format!("Wall time: {secs:.1} seconds\naborted by user")
} else {
format!("aborted by user after {secs:.1}s")
}
}
}

View File

@@ -180,25 +180,17 @@ where
}
}
fn tool_handler_key(tool_name: &ToolName) -> String {
if let Some(namespace) = tool_name.namespace() {
format!("{namespace}:{}", tool_name.name())
} else {
tool_name.name().to_string()
}
}
pub struct ToolRegistry {
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>,
}
impl ToolRegistry {
fn new(handlers: HashMap<String, Arc<dyn AnyToolHandler>>) -> Self {
fn new(handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>) -> Self {
Self { handlers }
}
fn handler(&self, name: &ToolName) -> Option<Arc<dyn AnyToolHandler>> {
self.handlers.get(&tool_handler_key(name)).map(Arc::clone)
self.handlers.get(name).map(Arc::clone)
}
#[cfg(test)]
@@ -219,6 +211,7 @@ impl ToolRegistry {
invocation: ToolInvocation,
) -> Result<AnyToolResult, FunctionCallError> {
let tool_name = invocation.tool_name.clone();
let display_name = tool_name.display();
let call_id_owned = invocation.call_id.clone();
let otel = invocation.turn.session_telemetry.clone();
let payload_for_response = invocation.payload.clone();
@@ -265,7 +258,7 @@ impl ToolRegistry {
None => {
let message = unsupported_tool_call_message(&invocation.payload, &tool_name);
otel.tool_result_with_tags(
tool_name.name(),
&display_name,
&call_id_owned,
log_payload.as_ref(),
Duration::ZERO,
@@ -280,10 +273,9 @@ impl ToolRegistry {
};
if !handler.matches_kind(&invocation.payload) {
let key = tool_handler_key(&tool_name);
let message = format!("tool {key} invoked with incompatible payload");
let message = format!("tool {display_name} invoked with incompatible payload");
otel.tool_result_with_tags(
tool_name.name(),
&display_name,
&call_id_owned,
log_payload.as_ref(),
Duration::ZERO,
@@ -318,7 +310,7 @@ impl ToolRegistry {
let started = Instant::now();
let result = otel
.log_tool_result_with_tags(
tool_name.name(),
&display_name,
&call_id_owned,
log_payload.as_ref(),
&metric_tags,
@@ -438,7 +430,7 @@ impl ToolRegistry {
}
pub struct ToolRegistryBuilder {
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
handlers: HashMap<ToolName, Arc<dyn AnyToolHandler>>,
specs: Vec<ConfiguredToolSpec>,
}
@@ -468,10 +460,10 @@ impl ToolRegistryBuilder {
H: ToolHandler + 'static,
{
let name = name.into();
let key = tool_handler_key(&name);
let display_name = name.display();
let handler: Arc<dyn AnyToolHandler> = handler;
if self.handlers.insert(key.clone(), handler.clone()).is_some() {
warn!("overwriting handler for tool {key}");
if self.handlers.insert(name, handler).is_some() {
warn!("overwriting handler for tool {display_name}");
}
}
@@ -500,7 +492,7 @@ impl ToolRegistryBuilder {
}
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &ToolName) -> String {
let tool_name = tool_handler_key(tool_name);
let tool_name = tool_name.display();
match payload {
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
_ => format!("unsupported call: {tool_name}"),
@@ -593,7 +585,7 @@ async fn dispatch_after_tool_use_hook(
event: HookEventAfterToolUse {
turn_id: turn.sub_id.clone(),
call_id: invocation.call_id.clone(),
tool_name: invocation.tool_name.name().to_string(),
tool_name: invocation.tool_name.display(),
tool_kind: hook_tool_kind(&tool_input),
tool_input,
executed: dispatch.executed,
@@ -616,7 +608,7 @@ async fn dispatch_after_tool_use_hook(
HookResult::FailedContinue(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = invocation.tool_name.name(),
tool_name = %invocation.tool_name.display(),
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; continuing"
@@ -625,7 +617,7 @@ async fn dispatch_after_tool_use_hook(
HookResult::FailedAbort(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = invocation.tool_name.name(),
tool_name = %invocation.tool_name.display(),
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; aborting operation"

View File

@@ -24,11 +24,8 @@ fn handler_looks_up_namespaced_aliases_explicitly() {
let plain_name = codex_tools::ToolName::plain(tool_name);
let namespaced_name = codex_tools::ToolName::namespaced(namespace, tool_name);
let registry = ToolRegistry::new(HashMap::from([
(tool_handler_key(&plain_name), Arc::clone(&plain_handler)),
(
tool_handler_key(&namespaced_name),
Arc::clone(&namespaced_handler),
),
(plain_name.clone(), Arc::clone(&plain_handler)),
(namespaced_name.clone(), Arc::clone(&namespaced_handler)),
]));
let plain = registry.handler(&plain_name);

View File

@@ -104,11 +104,13 @@ impl ToolRouter {
.map(|config| config.spec.clone())
}
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
self.specs
.iter()
.filter(|config| config.supports_parallel_tool_calls)
.any(|config| config.name() == tool_name)
pub fn tool_supports_parallel(&self, tool_name: &ToolName) -> bool {
tool_name.namespace.is_none()
&& self
.specs
.iter()
.filter(|config| config.supports_parallel_tool_calls)
.any(|config| config.name() == tool_name.name.as_str())
}
#[instrument(level = "trace", skip_all, err)]
@@ -124,18 +126,20 @@ impl ToolRouter {
call_id,
..
} => {
let mcp_tool = session.parse_mcp_tool_name(&name, &namespace).await;
let mcp_tool = session
.resolve_mcp_tool_info(&name, namespace.as_deref())
.await;
let tool_name = match namespace {
Some(namespace) => ToolName::namespaced(namespace, name),
None => ToolName::plain(name),
};
if let Some((server, tool)) = mcp_tool {
if let Some(tool_info) = mcp_tool {
Ok(Some(ToolCall {
tool_name,
call_id,
payload: ToolPayload::Mcp {
server,
tool,
server: tool_info.server_name,
tool: tool_info.tool.name.to_string(),
raw_arguments: arguments,
},
}))
@@ -224,7 +228,8 @@ impl ToolRouter {
payload,
} = call;
let direct_js_repl_call = matches!(tool_name.name(), "js_repl" | "js_repl_reset");
let direct_js_repl_call = tool_name.namespace.is_none()
&& matches!(tool_name.name.as_str(), "js_repl" | "js_repl_reset");
if source == ToolCallSource::Direct
&& turn.tools_config.js_repl_tools_only
&& !direct_js_repl_call

View File

@@ -117,6 +117,84 @@ async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()>
Ok(())
}
#[tokio::test]
async fn js_repl_tools_only_blocks_namespaced_js_repl_tool() -> anyhow::Result<()> {
let (session, mut turn) = make_session_and_context().await;
turn.tools_config.js_repl_tools_only = true;
let session = Arc::new(session);
let turn = Arc::new(turn);
let router = ToolRouter::from_config(
&turn.tools_config,
ToolRouterParams {
deferred_mcp_tools: None,
mcp_tools: None,
discoverable_tools: None,
dynamic_tools: turn.dynamic_tools.as_slice(),
},
);
let call = ToolCall {
tool_name: ToolName::namespaced("mcp__server__", "js_repl"),
call_id: "call-namespaced-js-repl".to_string(),
payload: ToolPayload::Mcp {
server: "server".to_string(),
tool: "js_repl".to_string(),
raw_arguments: "{}".to_string(),
},
};
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let err = router
.dispatch_tool_call_with_code_mode_result(
session,
turn,
tracker,
call,
ToolCallSource::Direct,
)
.await
.err()
.expect("namespaced js_repl calls should be blocked");
let FunctionCallError::RespondToModel(message) = err else {
panic!("expected RespondToModel, got {err:?}");
};
assert!(message.contains("direct tool calls are disabled"));
Ok(())
}
#[tokio::test]
async fn parallel_support_does_not_match_namespaced_local_tool_names() -> anyhow::Result<()> {
let (session, turn) = make_session_and_context().await;
let mcp_tools = session
.services
.mcp_connection_manager
.read()
.await
.list_all_tools()
.await;
let router = ToolRouter::from_config(
&turn.tools_config,
ToolRouterParams {
deferred_mcp_tools: None,
mcp_tools: Some(mcp_tools),
discoverable_tools: None,
dynamic_tools: turn.dynamic_tools.as_slice(),
},
);
let parallel_tool_name = ["shell", "local_shell", "exec_command", "shell_command"]
.into_iter()
.find(|name| router.tool_supports_parallel(&ToolName::plain(*name)))
.expect("test session should expose a parallel shell-like tool");
assert!(
!router.tool_supports_parallel(&ToolName::namespaced("mcp__server__", parallel_tool_name))
);
Ok(())
}
#[tokio::test]
async fn build_tool_call_uses_namespace_for_registry_name() -> anyhow::Result<()> {
let (session, _) = make_session_and_context().await;

View File

@@ -2,8 +2,8 @@
/// provides one.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ToolName {
name: String,
namespace: Option<String>,
pub name: String,
pub namespace: Option<String>,
}
impl ToolName {
@@ -21,12 +21,11 @@ impl ToolName {
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref()
pub fn display(&self) -> String {
match &self.namespace {
Some(namespace) => format!("{namespace}{}", self.name),
None => self.name.clone(),
}
}
}