Files
codex/codex-rs/core/src/tools/registry.rs
Michael Bolin 3c7f013f97 core: cut codex-core compile time 63% with native async ToolHandler (#16630)
## Why

`ToolHandler` was still paying a large compile-time tax from
`#[async_trait]` on every concrete handler impl, even though the only
object-safe boundary the registry actually stores is the internal
`AnyToolHandler` adapter.

This PR removes that macro-generated async wrapper layer from concrete
`ToolHandler` impls while keeping the existing object-safe shim in
`AnyToolHandler`. In practice, that gets essentially the same
compile-time win as the larger type-erasure refactor in #16627, but with
a much smaller diff and without changing the public shape of
`ToolHandler<Output = T>`.

That tradeoff matters here because this is a broad `codex-core` hotspot
and reviewers should be able to judge the compile-time impact from hard
numbers, not vibes.

## Headline result

On a clean `codex-core` package rebuild (`cargo clean -p codex-core`
before each command), rustc `total` dropped from **187.15s to 68.98s**
versus the shared `0bd31dc382bd` baseline: **-63.1%**.

The biggest hot passes dropped by roughly **71-72%**:

| Metric | Baseline `0bd31dc382bd` | This PR `41f7ac0adeac` | Delta |
|---|---:|---:|---:|
| `total` | 187.15s | 68.98s | **-63.1%** |
| `generate_crate_metadata` | 84.53s | 24.49s | **-71.0%** |
| `MIR_borrow_checking` | 84.13s | 24.58s | **-70.8%** |
| `monomorphization_collector_graph_walk` | 79.74s | 22.19s | **-72.2%**
|
| `evaluate_obligation` self-time | 180.62s | 46.91s | **-74.0%** |

Important caveat: `-Z time-passes` timings are nested, so
`generate_crate_metadata` and `monomorphization_collector_graph_walk`
are mostly overlapping, not additive.

## Why this PR over #16627

#16627 already proved that the `ToolHandler` stack was the right
hotspot, but it got there by making `ToolHandler` object-safe and
changing every handler to return `BoxFuture<Result<AnyToolResult, _>>`
directly.

This PR keeps the lower-churn shape:

- `ToolHandler` remains generic over `type Output`.
- Concrete handlers use native RPITIT futures with explicit `Send`
bounds.
- `AnyToolHandler` remains the only object-safe adapter and still does
the boxing at the registry boundary, as before.
- The implementation diff is only **33 files, +28/-77**.

The measurements are at least comparable, and in this run this PR is
slightly faster than #16627 on the pass-level total:

| Metric | #16627 | This PR | Delta |
|---|---:|---:|---:|
| `total` | 79.90s | 68.98s | **-13.7%** |
| `generate_crate_metadata` | 25.88s | 24.49s | **-5.4%** |
| `monomorphization_collector_graph_walk` | 23.54s | 22.19s | **-5.7%**
|
| `evaluate_obligation` self-time | 43.29s | 46.91s | +8.4% |

## Profile data

### Crate-level timings

`cargo +nightly build -p codex-core --lib -Z unstable-options
--timings=json` after `cargo clean -p codex-core`.

Baseline data below is reused from the shared parent `0bd31dc382bd`
profile because this PR and #16627 are both one commit on top of that
same parent.

| Crate | Baseline `duration` | This PR `duration` | Delta | Baseline
`rmeta_time` | This PR `rmeta_time` | Delta |
|---|---:|---:|---:|---:|---:|---:|
| `codex_core` | 187.380776583s | 69.171113833s | **-63.1%** |
174.474507208s | 55.873015583s | **-68.0%** |
| `starlark` | 17.90s | 16.773824125s | -6.3% | n/a | 8.8999965s | n/a |

### Pass-level timings

`cargo +nightly rustc -p codex-core --lib -- -Z time-passes -Z
time-passes-format=json` after `cargo clean -p codex-core`.

| Pass | Baseline | This PR | Delta |
|---|---:|---:|---:|
| `total` | 187.150662083s | 68.978770375s | **-63.1%** |
| `generate_crate_metadata` | 84.531864625s | 24.487462958s | **-71.0%**
|
| `MIR_borrow_checking` | 84.131389375s | 24.575553875s | **-70.8%** |
| `monomorphization_collector_graph_walk` | 79.737515042s |
22.190207417s | **-72.2%** |
| `codegen_crate` | 12.362532292s | 12.695237625s | +2.7% |
| `type_check_crate` | 4.4765405s | 5.442019542s | +21.6% |
| `coherence_checking` | 3.311121208s | 4.239935292s | +28.0% |
| process `real` / `user` / `sys` | 187.70s / 201.87s / 4.99s | 69.52s /
85.90s / 2.92s | n/a |

### Self-profile query summary

`cargo +nightly rustc -p codex-core --lib -- -Z self-profile=... -Z
self-profile-events=default,query-keys,args,llvm,artifact-sizes` after
`cargo clean -p codex-core`, summarized with `measureme summarize -p
0.5`.

| Query / phase | Baseline self time | This PR self time | Delta |
Baseline total time | This PR total time | Baseline item count | This PR
item count | Baseline cache hits | This PR cache hits |
|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|
| `evaluate_obligation` | 180.62s | 46.91s | **-74.0%** | 182.08s |
48.37s | 572,234 | 388,659 | 1,130,998 | 1,058,553 |
| `mir_borrowck` | 1.42s | 1.49s | +4.9% | 93.77s | 29.59s | n/a | 6,184
| n/a | 15,298 |
| `typeck` | 1.84s | 1.87s | +1.6% | 2.38s | 2.44s | n/a | 9,367 | n/a |
79,247 |
| `LLVM_module_codegen_emit_obj` | n/a | 17.12s | n/a | 17.01s | 17.12s
| n/a | 256 | n/a | 0 |
| `LLVM_passes` | n/a | 13.07s | n/a | 12.95s | 13.07s | n/a | 1 | n/a |
0 |
| `codegen_module` | n/a | 12.33s | n/a | 12.22s | 13.64s | n/a | 256 |
n/a | 0 |
| `items_of_instance` | n/a | 676.00ms | n/a | n/a | 24.96s | n/a |
99,990 | n/a | 0 |
| `type_op_prove_predicate` | n/a | 660.79ms | n/a | n/a | 24.78s | n/a
| 78,762 | n/a | 235,877 |

| Summary | Baseline | This PR |
|---|---:|---:|
| `evaluate_obligation` % of total CPU | 70.821% | 38.880% |
| self-profile total CPU time | 255.042999997s | 120.661175956s |
| process `real` / `user` / `sys` | 220.96s / 235.02s / 7.09s | 86.35s /
103.66s / 3.54s |

### Artifact sizes

From the same `measureme summarize` output:

| Artifact | Baseline | This PR | Delta |
|---|---:|---:|---:|
| `crate_metadata` | 26,534,471 bytes | 26,545,248 bytes | +10,777 |
| `dep_graph` | 253,181,425 bytes | 239,240,806 bytes | -13,940,619 |
| `linked_artifact` | 565,366,624 bytes | 562,673,176 bytes | -2,693,448
|
| `object_file` | 513,127,264 bytes | 510,464,096 bytes | -2,663,168 |
| `query_cache` | 137,440,945 bytes | 136,982,566 bytes | -458,379 |
| `cgu_instructions` | 3,586,307 bytes | 3,575,121 bytes | -11,186 |
| `codegen_unit_size_estimate` | 2,084,846 bytes | 2,078,773 bytes |
-6,073 |
| `work_product_index` | 19,565 bytes | 19,565 bytes | 0 |

### Baseline hotspots before this change

These are the top normalized obligation buckets from the shared baseline
profile:

| Obligation bucket | Samples | Duration |
|---|---:|---:|
| `outlives:tasks::review::ReviewTask` | 1,067 | 6.33s |
| `outlives:tools::handlers::unified_exec::UnifiedExecHandler` | 896 |
5.63s |
| `trait:T as tools::registry::ToolHandler` | 876 | 5.45s |
| `outlives:tools::handlers::shell::ShellHandler` | 888 | 5.37s |
| `outlives:tools::handlers::shell::ShellCommandHandler` | 870 | 5.29s |
|
`outlives:tools::runtimes::shell::unix_escalation::CoreShellActionProvider`
| 637 | 3.73s |
| `outlives:tools::handlers::mcp::McpHandler` | 695 | 3.61s |
| `outlives:tasks::regular::RegularTask` | 726 | 3.57s |

Top `items_of_instance` entries before this change were mostly concrete
async handler/task impls:

| Instance | Duration |
|---|---:|
| `tasks::regular::{impl#2}::run` | 3.79s |
| `tools::handlers::mcp::{impl#0}::handle` | 3.27s |
| `tools::runtimes::shell::unix_escalation::{impl#2}::determine_action`
| 3.09s |
| `tools::handlers::agent_jobs::{impl#11}::handle` | 3.07s |
| `tools::handlers::multi_agents::spawn::{impl#1}::handle` | 2.84s |
| `tasks::review::{impl#4}::run` | 2.82s |
| `tools::handlers::multi_agents_v2::spawn::{impl#2}::handle` | 2.80s |
| `tools::handlers::multi_agents::resume_agent::{impl#1}::handle` |
2.73s |
| `tools::handlers::unified_exec::{impl#2}::handle` | 2.54s |
| `tasks::compact::{impl#4}::run` | 2.45s |

## What changed

Relevant pre-change registry shape:
[`codex-rs/core/src/tools/registry.rs`](0bd31dc382/codex-rs/core/src/tools/registry.rs (L38-L219))

Current registry shape in this PR:
[`codex-rs/core/src/tools/registry.rs`](41f7ac0ade/codex-rs/core/src/tools/registry.rs (L38-L203))

- `ToolHandler::{is_mutating, handle}` now return native `impl Future +
Send` futures instead of using `#[async_trait]`.
- `AnyToolHandler` remains the object-safe adapter and boxes those
futures at the registry boundary with explicit lifetimes.
- Concrete handlers and the registry test handler drop `#[async_trait]`
but otherwise keep their async method bodies intact.
- Representative examples:
[`codex-rs/core/src/tools/handlers/shell.rs`](41f7ac0ade/codex-rs/core/src/tools/handlers/shell.rs (L223-L379)),
[`codex-rs/core/src/tools/handlers/unified_exec.rs`](41f7ac0ade/codex-rs/core/src/tools/handlers/unified_exec.rs),
[`codex-rs/core/src/tools/registry_tests.rs`](41f7ac0ade/codex-rs/core/src/tools/registry_tests.rs)

## Tradeoff

This is intentionally less invasive than #16627: it does **not** move
result boxing into every concrete handler and does **not** change
`ToolHandler` into an object-safe trait.

Instead, it keeps the existing registry-level type-erasure boundary and
only removes the macro-generated async wrapper layer from concrete
impls. So the runtime boxing story stays basically the same as before,
while the compile-time savings are still large.

## Verification

Existing verification for this branch still applies:

- Ran `cargo test -p codex-core`; this change compiled and the suite
reached the known unrelated `config::tests::*guardian*` failures, with
no local diff under `codex-rs/core/src/config/`.

Profiling commands used for the tables above:

- `cargo clean -p codex-core`
- `cargo +nightly build -p codex-core --lib -Z unstable-options
--timings=json`
- `cargo +nightly rustc -p codex-core --lib -- -Z time-passes -Z
time-passes-format=json`
- `cargo +nightly rustc -p codex-core --lib -- -Z self-profile=... -Z
self-profile-events=default,query-keys,args,llvm,artifact-sizes`
- `measureme summarize -p 0.5`
2026-04-02 16:03:52 -07:00

658 lines
22 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use crate::function_tool::FunctionCallError;
use crate::hook_runtime::record_additional_contexts;
use crate::hook_runtime::run_post_tool_use_hooks;
use crate::hook_runtime::run_pre_tool_use_hooks;
use crate::memories::usage::emit_metric_for_tool_read;
use crate::sandbox_tags::sandbox_tag;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use codex_hooks::HookEvent;
use codex_hooks::HookEventAfterToolUse;
use codex_hooks::HookPayload;
use codex_hooks::HookResult;
use codex_hooks::HookToolInput;
use codex_hooks::HookToolInputLocalShell;
use codex_hooks::HookToolKind;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::protocol::SandboxPolicy;
use codex_tools::ConfiguredToolSpec;
use codex_tools::ToolSpec;
use codex_utils_readiness::Readiness;
use futures::future::BoxFuture;
use serde_json::Value;
use tracing::warn;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ToolKind {
Function,
Mcp,
}
pub trait ToolHandler: Send + Sync {
type Output: ToolOutput + 'static;
fn kind(&self) -> ToolKind;
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
(self.kind(), payload),
(ToolKind::Function, ToolPayload::Function { .. })
| (ToolKind::Function, ToolPayload::ToolSearch { .. })
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
)
}
/// Returns `true` if the [ToolInvocation] *might* mutate the environment of the
/// user (through file system, OS operations, ...).
/// This function must remains defensive and return `true` if a doubt exist on the
/// exact effect of a ToolInvocation.
fn is_mutating(
&self,
_invocation: &ToolInvocation,
) -> impl std::future::Future<Output = bool> + Send {
async { false }
}
fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
None
}
fn post_tool_use_payload(
&self,
_call_id: &str,
_payload: &ToolPayload,
_result: &dyn ToolOutput,
) -> Option<PostToolUsePayload> {
None
}
/// Perform the actual [ToolInvocation] and returns a [ToolOutput] containing
/// the final output to return to the model.
fn handle(
&self,
invocation: ToolInvocation,
) -> impl std::future::Future<Output = Result<Self::Output, FunctionCallError>> + Send;
}
pub(crate) struct AnyToolResult {
pub(crate) call_id: String,
pub(crate) payload: ToolPayload,
pub(crate) result: Box<dyn ToolOutput>,
}
impl AnyToolResult {
pub(crate) fn into_response(self) -> ResponseInputItem {
let Self {
call_id,
payload,
result,
..
} = self;
result.to_response_item(&call_id, &payload)
}
pub(crate) fn code_mode_result(self) -> serde_json::Value {
let Self {
payload, result, ..
} = self;
result.code_mode_result(&payload)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct PreToolUsePayload {
pub(crate) command: String,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct PostToolUsePayload {
pub(crate) command: String,
pub(crate) tool_response: Value,
}
trait AnyToolHandler: Send + Sync {
fn matches_kind(&self, payload: &ToolPayload) -> bool;
fn is_mutating<'a>(&'a self, invocation: &'a ToolInvocation) -> BoxFuture<'a, bool>;
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload>;
fn post_tool_use_payload(
&self,
call_id: &str,
payload: &ToolPayload,
result: &dyn ToolOutput,
) -> Option<PostToolUsePayload>;
fn handle_any<'a>(
&'a self,
invocation: ToolInvocation,
) -> BoxFuture<'a, Result<AnyToolResult, FunctionCallError>>;
}
impl<T> AnyToolHandler for T
where
T: ToolHandler,
{
fn matches_kind(&self, payload: &ToolPayload) -> bool {
ToolHandler::matches_kind(self, payload)
}
fn is_mutating<'a>(&'a self, invocation: &'a ToolInvocation) -> BoxFuture<'a, bool> {
Box::pin(ToolHandler::is_mutating(self, invocation))
}
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
ToolHandler::pre_tool_use_payload(self, invocation)
}
fn post_tool_use_payload(
&self,
call_id: &str,
payload: &ToolPayload,
result: &dyn ToolOutput,
) -> Option<PostToolUsePayload> {
ToolHandler::post_tool_use_payload(self, call_id, payload, result)
}
fn handle_any<'a>(
&'a self,
invocation: ToolInvocation,
) -> BoxFuture<'a, Result<AnyToolResult, FunctionCallError>> {
Box::pin(async move {
let call_id = invocation.call_id.clone();
let payload = invocation.payload.clone();
let output = self.handle(invocation).await?;
Ok(AnyToolResult {
call_id,
payload,
result: Box::new(output),
})
})
}
}
pub(crate) fn tool_handler_key(tool_name: &str, namespace: Option<&str>) -> String {
if let Some(namespace) = namespace {
format!("{namespace}:{tool_name}")
} else {
tool_name.to_string()
}
}
pub struct ToolRegistry {
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
}
impl ToolRegistry {
fn new(handlers: HashMap<String, Arc<dyn AnyToolHandler>>) -> Self {
Self { handlers }
}
fn handler(&self, name: &str, namespace: Option<&str>) -> Option<Arc<dyn AnyToolHandler>> {
self.handlers
.get(&tool_handler_key(name, namespace))
.map(Arc::clone)
}
#[cfg(test)]
pub(crate) fn has_handler(&self, name: &str, namespace: Option<&str>) -> bool {
self.handler(name, namespace).is_some()
}
// TODO(jif) for dynamic tools.
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
// let name = name.into();
// if self.handlers.insert(name.clone(), handler).is_some() {
// warn!("overwriting handler for tool {name}");
// }
// }
pub(crate) async fn dispatch_any(
&self,
invocation: ToolInvocation,
) -> Result<AnyToolResult, FunctionCallError> {
let tool_name = invocation.tool_name.clone();
let tool_namespace = invocation.tool_namespace.clone();
let call_id_owned = invocation.call_id.clone();
let otel = invocation.turn.session_telemetry.clone();
let payload_for_response = invocation.payload.clone();
let log_payload = payload_for_response.log_payload();
let metric_tags = [
(
"sandbox",
sandbox_tag(
&invocation.turn.sandbox_policy,
invocation.turn.windows_sandbox_level,
),
),
(
"sandbox_policy",
sandbox_policy_tag(&invocation.turn.sandbox_policy),
),
];
let (mcp_server, mcp_server_origin) = match &invocation.payload {
ToolPayload::Mcp { server, .. } => {
let manager = invocation
.session
.services
.mcp_connection_manager
.read()
.await;
let origin = manager.server_origin(server).map(str::to_owned);
(Some(server.clone()), origin)
}
_ => (None, None),
};
let mcp_server_ref = mcp_server.as_deref();
let mcp_server_origin_ref = mcp_server_origin.as_deref();
{
let mut active = invocation.session.active_turn.lock().await;
if let Some(active_turn) = active.as_mut() {
let mut turn_state = active_turn.turn_state.lock().await;
turn_state.tool_calls = turn_state.tool_calls.saturating_add(1);
}
}
let handler = match self.handler(tool_name.as_ref(), tool_namespace.as_deref()) {
Some(handler) => handler,
None => {
let message = unsupported_tool_call_message(
&invocation.payload,
tool_name.as_ref(),
tool_namespace.as_deref(),
);
otel.tool_result_with_tags(
tool_name.as_ref(),
&call_id_owned,
log_payload.as_ref(),
Duration::ZERO,
/*success*/ false,
&message,
&metric_tags,
mcp_server_ref,
mcp_server_origin_ref,
);
return Err(FunctionCallError::RespondToModel(message));
}
};
if !handler.matches_kind(&invocation.payload) {
let message = format!("tool {tool_name} invoked with incompatible payload");
otel.tool_result_with_tags(
tool_name.as_ref(),
&call_id_owned,
log_payload.as_ref(),
Duration::ZERO,
/*success*/ false,
&message,
&metric_tags,
mcp_server_ref,
mcp_server_origin_ref,
);
return Err(FunctionCallError::Fatal(message));
}
if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation)
&& let Some(reason) = run_pre_tool_use_hooks(
&invocation.session,
&invocation.turn,
invocation.call_id.clone(),
pre_tool_use_payload.command.clone(),
)
.await
{
return Err(FunctionCallError::RespondToModel(format!(
"Command blocked by PreToolUse hook: {reason}. Command: {}",
pre_tool_use_payload.command
)));
}
let is_mutating = handler.is_mutating(&invocation).await;
let response_cell = tokio::sync::Mutex::new(None);
let invocation_for_tool = invocation.clone();
let started = Instant::now();
let result = otel
.log_tool_result_with_tags(
tool_name.as_ref(),
&call_id_owned,
log_payload.as_ref(),
&metric_tags,
mcp_server_ref,
mcp_server_origin_ref,
|| {
let handler = handler.clone();
let response_cell = &response_cell;
async move {
if is_mutating {
tracing::trace!("waiting for tool gate");
invocation_for_tool.turn.tool_call_gate.wait_ready().await;
tracing::trace!("tool gate released");
}
match handler.handle_any(invocation_for_tool).await {
Ok(result) => {
let preview = result.result.log_preview();
let success = result.result.success_for_logging();
let mut guard = response_cell.lock().await;
*guard = Some(result);
Ok((preview, success))
}
Err(err) => Err(err),
}
}
},
)
.await;
let duration = started.elapsed();
let (output_preview, success) = match &result {
Ok((preview, success)) => (preview.clone(), *success),
Err(err) => (err.to_string(), false),
};
emit_metric_for_tool_read(&invocation, success).await;
let post_tool_use_payload = if success {
let guard = response_cell.lock().await;
guard.as_ref().and_then(|result| {
handler.post_tool_use_payload(
&result.call_id,
&result.payload,
result.result.as_ref(),
)
})
} else {
None
};
let post_tool_use_outcome = if let Some(post_tool_use_payload) = post_tool_use_payload {
Some(
run_post_tool_use_hooks(
&invocation.session,
&invocation.turn,
invocation.call_id.clone(),
post_tool_use_payload.command,
post_tool_use_payload.tool_response,
)
.await,
)
} else {
None
};
// Deprecated: this is the legacy AfterToolUse hook. Prefer the new PostToolUse
let hook_abort_error = dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
invocation: &invocation,
output_preview,
success,
executed: true,
duration,
mutating: is_mutating,
})
.await;
if let Some(err) = hook_abort_error {
return Err(err);
}
if let Some(outcome) = &post_tool_use_outcome {
record_additional_contexts(
&invocation.session,
&invocation.turn,
outcome.additional_contexts.clone(),
)
.await;
let replacement_text = if outcome.should_stop {
Some(
outcome
.feedback_message
.clone()
.or_else(|| outcome.stop_reason.clone())
.unwrap_or_else(|| "PostToolUse hook stopped execution".to_string()),
)
} else {
outcome.feedback_message.clone()
};
if let Some(replacement_text) = replacement_text {
let mut guard = response_cell.lock().await;
if let Some(result) = guard.as_mut() {
result.result = Box::new(FunctionToolOutput::from_text(
replacement_text,
/*success*/ None,
));
}
}
}
match result {
Ok(_) => {
let mut guard = response_cell.lock().await;
let result = guard.take().ok_or_else(|| {
FunctionCallError::Fatal("tool produced no output".to_string())
})?;
Ok(result)
}
Err(err) => Err(err),
}
}
}
pub struct ToolRegistryBuilder {
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
specs: Vec<ConfiguredToolSpec>,
}
impl ToolRegistryBuilder {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
specs: Vec::new(),
}
}
pub fn push_spec(&mut self, spec: ToolSpec) {
self.push_spec_with_parallel_support(spec, /*supports_parallel_tool_calls*/ false);
}
pub fn push_spec_with_parallel_support(
&mut self,
spec: ToolSpec,
supports_parallel_tool_calls: bool,
) {
self.specs
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
}
pub fn register_handler<H>(&mut self, name: impl Into<String>, handler: Arc<H>)
where
H: ToolHandler + 'static,
{
let name = name.into();
let handler: Arc<dyn AnyToolHandler> = handler;
if self
.handlers
.insert(name.clone(), handler.clone())
.is_some()
{
warn!("overwriting handler for tool {name}");
}
}
// TODO(jif) for dynamic tools.
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
// where
// I: IntoIterator,
// I::Item: Into<String>,
// {
// for name in names {
// let name = name.into();
// if self
// .handlers
// .insert(name.clone(), handler.clone())
// .is_some()
// {
// warn!("overwriting handler for tool {name}");
// }
// }
// }
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
let registry = ToolRegistry::new(self.handlers);
(self.specs, registry)
}
}
fn unsupported_tool_call_message(
payload: &ToolPayload,
tool_name: &str,
namespace: Option<&str>,
) -> String {
let tool_name = tool_handler_key(tool_name, namespace);
match payload {
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
_ => format!("unsupported call: {tool_name}"),
}
}
fn sandbox_policy_tag(policy: &SandboxPolicy) -> &'static str {
match policy {
SandboxPolicy::ReadOnly { .. } => "read-only",
SandboxPolicy::WorkspaceWrite { .. } => "workspace-write",
SandboxPolicy::DangerFullAccess => "danger-full-access",
SandboxPolicy::ExternalSandbox { .. } => "external-sandbox",
}
}
// Hooks use a separate wire-facing input type so hook payload JSON stays stable
// and decoupled from core's internal tool runtime representation.
impl From<&ToolPayload> for HookToolInput {
fn from(payload: &ToolPayload) -> Self {
match payload {
ToolPayload::Function { arguments } => HookToolInput::Function {
arguments: arguments.clone(),
},
ToolPayload::ToolSearch { arguments } => HookToolInput::Function {
arguments: serde_json::json!({
"query": arguments.query,
"limit": arguments.limit,
})
.to_string(),
},
ToolPayload::Custom { input } => HookToolInput::Custom {
input: input.clone(),
},
ToolPayload::LocalShell { params } => HookToolInput::LocalShell {
params: HookToolInputLocalShell {
command: params.command.clone(),
workdir: params.workdir.clone(),
timeout_ms: params.timeout_ms,
sandbox_permissions: params.sandbox_permissions,
prefix_rule: params.prefix_rule.clone(),
justification: params.justification.clone(),
},
},
ToolPayload::Mcp {
server,
tool,
raw_arguments,
} => HookToolInput::Mcp {
server: server.clone(),
tool: tool.clone(),
arguments: raw_arguments.clone(),
},
}
}
}
fn hook_tool_kind(tool_input: &HookToolInput) -> HookToolKind {
match tool_input {
HookToolInput::Function { .. } => HookToolKind::Function,
HookToolInput::Custom { .. } => HookToolKind::Custom,
HookToolInput::LocalShell { .. } => HookToolKind::LocalShell,
HookToolInput::Mcp { .. } => HookToolKind::Mcp,
}
}
struct AfterToolUseHookDispatch<'a> {
invocation: &'a ToolInvocation,
output_preview: String,
success: bool,
executed: bool,
duration: Duration,
mutating: bool,
}
async fn dispatch_after_tool_use_hook(
dispatch: AfterToolUseHookDispatch<'_>,
) -> Option<FunctionCallError> {
let AfterToolUseHookDispatch { invocation, .. } = dispatch;
let session = invocation.session.as_ref();
let turn = invocation.turn.as_ref();
let tool_input = HookToolInput::from(&invocation.payload);
let hook_outcomes = session
.hooks()
.dispatch(HookPayload {
session_id: session.conversation_id,
cwd: turn.cwd.to_path_buf(),
client: turn.app_server_client_name.clone(),
triggered_at: chrono::Utc::now(),
hook_event: HookEvent::AfterToolUse {
event: HookEventAfterToolUse {
turn_id: turn.sub_id.clone(),
call_id: invocation.call_id.clone(),
tool_name: invocation.tool_name.clone(),
tool_kind: hook_tool_kind(&tool_input),
tool_input,
executed: dispatch.executed,
success: dispatch.success,
duration_ms: u64::try_from(dispatch.duration.as_millis()).unwrap_or(u64::MAX),
mutating: dispatch.mutating,
sandbox: sandbox_tag(&turn.sandbox_policy, turn.windows_sandbox_level)
.to_string(),
sandbox_policy: sandbox_policy_tag(&turn.sandbox_policy).to_string(),
output_preview: dispatch.output_preview.clone(),
},
},
})
.await;
for hook_outcome in hook_outcomes {
let hook_name = hook_outcome.hook_name;
match hook_outcome.result {
HookResult::Success => {}
HookResult::FailedContinue(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = %invocation.tool_name,
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; continuing"
);
}
HookResult::FailedAbort(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = %invocation.tool_name,
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; aborting operation"
);
return Some(FunctionCallError::Fatal(format!(
"after_tool_use hook '{hook_name}' failed and aborted operation: {error}"
)));
}
}
}
None
}
#[cfg(test)]
#[path = "registry_tests.rs"]
mod tests;