mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
## Why `ToolHandler` was still paying a large compile-time tax from `#[async_trait]` on every concrete handler impl, even though the only object-safe boundary the registry actually stores is the internal `AnyToolHandler` adapter. This PR removes that macro-generated async wrapper layer from concrete `ToolHandler` impls while keeping the existing object-safe shim in `AnyToolHandler`. In practice, that gets essentially the same compile-time win as the larger type-erasure refactor in #16627, but with a much smaller diff and without changing the public shape of `ToolHandler<Output = T>`. That tradeoff matters here because this is a broad `codex-core` hotspot and reviewers should be able to judge the compile-time impact from hard numbers, not vibes. ## Headline result On a clean `codex-core` package rebuild (`cargo clean -p codex-core` before each command), rustc `total` dropped from **187.15s to 68.98s** versus the shared `0bd31dc382bd` baseline: **-63.1%**. The biggest hot passes dropped by roughly **71-72%**: | Metric | Baseline `0bd31dc382bd` | This PR `41f7ac0adeac` | Delta | |---|---:|---:|---:| | `total` | 187.15s | 68.98s | **-63.1%** | | `generate_crate_metadata` | 84.53s | 24.49s | **-71.0%** | | `MIR_borrow_checking` | 84.13s | 24.58s | **-70.8%** | | `monomorphization_collector_graph_walk` | 79.74s | 22.19s | **-72.2%** | | `evaluate_obligation` self-time | 180.62s | 46.91s | **-74.0%** | Important caveat: `-Z time-passes` timings are nested, so `generate_crate_metadata` and `monomorphization_collector_graph_walk` are mostly overlapping, not additive. ## Why this PR over #16627 #16627 already proved that the `ToolHandler` stack was the right hotspot, but it got there by making `ToolHandler` object-safe and changing every handler to return `BoxFuture<Result<AnyToolResult, _>>` directly. This PR keeps the lower-churn shape: - `ToolHandler` remains generic over `type Output`. - Concrete handlers use native RPITIT futures with explicit `Send` bounds. - `AnyToolHandler` remains the only object-safe adapter and still does the boxing at the registry boundary, as before. - The implementation diff is only **33 files, +28/-77**. The measurements are at least comparable, and in this run this PR is slightly faster than #16627 on the pass-level total: | Metric | #16627 | This PR | Delta | |---|---:|---:|---:| | `total` | 79.90s | 68.98s | **-13.7%** | | `generate_crate_metadata` | 25.88s | 24.49s | **-5.4%** | | `monomorphization_collector_graph_walk` | 23.54s | 22.19s | **-5.7%** | | `evaluate_obligation` self-time | 43.29s | 46.91s | +8.4% | ## Profile data ### Crate-level timings `cargo +nightly build -p codex-core --lib -Z unstable-options --timings=json` after `cargo clean -p codex-core`. Baseline data below is reused from the shared parent `0bd31dc382bd` profile because this PR and #16627 are both one commit on top of that same parent. | Crate | Baseline `duration` | This PR `duration` | Delta | Baseline `rmeta_time` | This PR `rmeta_time` | Delta | |---|---:|---:|---:|---:|---:|---:| | `codex_core` | 187.380776583s | 69.171113833s | **-63.1%** | 174.474507208s | 55.873015583s | **-68.0%** | | `starlark` | 17.90s | 16.773824125s | -6.3% | n/a | 8.8999965s | n/a | ### Pass-level timings `cargo +nightly rustc -p codex-core --lib -- -Z time-passes -Z time-passes-format=json` after `cargo clean -p codex-core`. | Pass | Baseline | This PR | Delta | |---|---:|---:|---:| | `total` | 187.150662083s | 68.978770375s | **-63.1%** | | `generate_crate_metadata` | 84.531864625s | 24.487462958s | **-71.0%** | | `MIR_borrow_checking` | 84.131389375s | 24.575553875s | **-70.8%** | | `monomorphization_collector_graph_walk` | 79.737515042s | 22.190207417s | **-72.2%** | | `codegen_crate` | 12.362532292s | 12.695237625s | +2.7% | | `type_check_crate` | 4.4765405s | 5.442019542s | +21.6% | | `coherence_checking` | 3.311121208s | 4.239935292s | +28.0% | | process `real` / `user` / `sys` | 187.70s / 201.87s / 4.99s | 69.52s / 85.90s / 2.92s | n/a | ### Self-profile query summary `cargo +nightly rustc -p codex-core --lib -- -Z self-profile=... -Z self-profile-events=default,query-keys,args,llvm,artifact-sizes` after `cargo clean -p codex-core`, summarized with `measureme summarize -p 0.5`. | Query / phase | Baseline self time | This PR self time | Delta | Baseline total time | This PR total time | Baseline item count | This PR item count | Baseline cache hits | This PR cache hits | |---|---:|---:|---:|---:|---:|---:|---:|---:|---:| | `evaluate_obligation` | 180.62s | 46.91s | **-74.0%** | 182.08s | 48.37s | 572,234 | 388,659 | 1,130,998 | 1,058,553 | | `mir_borrowck` | 1.42s | 1.49s | +4.9% | 93.77s | 29.59s | n/a | 6,184 | n/a | 15,298 | | `typeck` | 1.84s | 1.87s | +1.6% | 2.38s | 2.44s | n/a | 9,367 | n/a | 79,247 | | `LLVM_module_codegen_emit_obj` | n/a | 17.12s | n/a | 17.01s | 17.12s | n/a | 256 | n/a | 0 | | `LLVM_passes` | n/a | 13.07s | n/a | 12.95s | 13.07s | n/a | 1 | n/a | 0 | | `codegen_module` | n/a | 12.33s | n/a | 12.22s | 13.64s | n/a | 256 | n/a | 0 | | `items_of_instance` | n/a | 676.00ms | n/a | n/a | 24.96s | n/a | 99,990 | n/a | 0 | | `type_op_prove_predicate` | n/a | 660.79ms | n/a | n/a | 24.78s | n/a | 78,762 | n/a | 235,877 | | Summary | Baseline | This PR | |---|---:|---:| | `evaluate_obligation` % of total CPU | 70.821% | 38.880% | | self-profile total CPU time | 255.042999997s | 120.661175956s | | process `real` / `user` / `sys` | 220.96s / 235.02s / 7.09s | 86.35s / 103.66s / 3.54s | ### Artifact sizes From the same `measureme summarize` output: | Artifact | Baseline | This PR | Delta | |---|---:|---:|---:| | `crate_metadata` | 26,534,471 bytes | 26,545,248 bytes | +10,777 | | `dep_graph` | 253,181,425 bytes | 239,240,806 bytes | -13,940,619 | | `linked_artifact` | 565,366,624 bytes | 562,673,176 bytes | -2,693,448 | | `object_file` | 513,127,264 bytes | 510,464,096 bytes | -2,663,168 | | `query_cache` | 137,440,945 bytes | 136,982,566 bytes | -458,379 | | `cgu_instructions` | 3,586,307 bytes | 3,575,121 bytes | -11,186 | | `codegen_unit_size_estimate` | 2,084,846 bytes | 2,078,773 bytes | -6,073 | | `work_product_index` | 19,565 bytes | 19,565 bytes | 0 | ### Baseline hotspots before this change These are the top normalized obligation buckets from the shared baseline profile: | Obligation bucket | Samples | Duration | |---|---:|---:| | `outlives:tasks::review::ReviewTask` | 1,067 | 6.33s | | `outlives:tools::handlers::unified_exec::UnifiedExecHandler` | 896 | 5.63s | | `trait:T as tools::registry::ToolHandler` | 876 | 5.45s | | `outlives:tools::handlers::shell::ShellHandler` | 888 | 5.37s | | `outlives:tools::handlers::shell::ShellCommandHandler` | 870 | 5.29s | | `outlives:tools::runtimes::shell::unix_escalation::CoreShellActionProvider` | 637 | 3.73s | | `outlives:tools::handlers::mcp::McpHandler` | 695 | 3.61s | | `outlives:tasks::regular::RegularTask` | 726 | 3.57s | Top `items_of_instance` entries before this change were mostly concrete async handler/task impls: | Instance | Duration | |---|---:| | `tasks::regular::{impl#2}::run` | 3.79s | | `tools::handlers::mcp::{impl#0}::handle` | 3.27s | | `tools::runtimes::shell::unix_escalation::{impl#2}::determine_action` | 3.09s | | `tools::handlers::agent_jobs::{impl#11}::handle` | 3.07s | | `tools::handlers::multi_agents::spawn::{impl#1}::handle` | 2.84s | | `tasks::review::{impl#4}::run` | 2.82s | | `tools::handlers::multi_agents_v2::spawn::{impl#2}::handle` | 2.80s | | `tools::handlers::multi_agents::resume_agent::{impl#1}::handle` | 2.73s | | `tools::handlers::unified_exec::{impl#2}::handle` | 2.54s | | `tasks::compact::{impl#4}::run` | 2.45s | ## What changed Relevant pre-change registry shape: [`codex-rs/core/src/tools/registry.rs`](0bd31dc382/codex-rs/core/src/tools/registry.rs (L38-L219)) Current registry shape in this PR: [`codex-rs/core/src/tools/registry.rs`](41f7ac0ade/codex-rs/core/src/tools/registry.rs (L38-L203)) - `ToolHandler::{is_mutating, handle}` now return native `impl Future + Send` futures instead of using `#[async_trait]`. - `AnyToolHandler` remains the object-safe adapter and boxes those futures at the registry boundary with explicit lifetimes. - Concrete handlers and the registry test handler drop `#[async_trait]` but otherwise keep their async method bodies intact. - Representative examples: [`codex-rs/core/src/tools/handlers/shell.rs`](41f7ac0ade/codex-rs/core/src/tools/handlers/shell.rs (L223-L379)), [`codex-rs/core/src/tools/handlers/unified_exec.rs`](41f7ac0ade/codex-rs/core/src/tools/handlers/unified_exec.rs), [`codex-rs/core/src/tools/registry_tests.rs`](41f7ac0ade/codex-rs/core/src/tools/registry_tests.rs) ## Tradeoff This is intentionally less invasive than #16627: it does **not** move result boxing into every concrete handler and does **not** change `ToolHandler` into an object-safe trait. Instead, it keeps the existing registry-level type-erasure boundary and only removes the macro-generated async wrapper layer from concrete impls. So the runtime boxing story stays basically the same as before, while the compile-time savings are still large. ## Verification Existing verification for this branch still applies: - Ran `cargo test -p codex-core`; this change compiled and the suite reached the known unrelated `config::tests::*guardian*` failures, with no local diff under `codex-rs/core/src/config/`. Profiling commands used for the tables above: - `cargo clean -p codex-core` - `cargo +nightly build -p codex-core --lib -Z unstable-options --timings=json` - `cargo +nightly rustc -p codex-core --lib -- -Z time-passes -Z time-passes-format=json` - `cargo +nightly rustc -p codex-core --lib -- -Z self-profile=... -Z self-profile-events=default,query-keys,args,llvm,artifact-sizes` - `measureme summarize -p 0.5`
658 lines
22 KiB
Rust
658 lines
22 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use std::time::Instant;
|
|
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::hook_runtime::record_additional_contexts;
|
|
use crate::hook_runtime::run_post_tool_use_hooks;
|
|
use crate::hook_runtime::run_pre_tool_use_hooks;
|
|
use crate::memories::usage::emit_metric_for_tool_read;
|
|
use crate::sandbox_tags::sandbox_tag;
|
|
use crate::tools::context::FunctionToolOutput;
|
|
use crate::tools::context::ToolInvocation;
|
|
use crate::tools::context::ToolOutput;
|
|
use crate::tools::context::ToolPayload;
|
|
use codex_hooks::HookEvent;
|
|
use codex_hooks::HookEventAfterToolUse;
|
|
use codex_hooks::HookPayload;
|
|
use codex_hooks::HookResult;
|
|
use codex_hooks::HookToolInput;
|
|
use codex_hooks::HookToolInputLocalShell;
|
|
use codex_hooks::HookToolKind;
|
|
use codex_protocol::models::ResponseInputItem;
|
|
use codex_protocol::protocol::SandboxPolicy;
|
|
use codex_tools::ConfiguredToolSpec;
|
|
use codex_tools::ToolSpec;
|
|
use codex_utils_readiness::Readiness;
|
|
use futures::future::BoxFuture;
|
|
use serde_json::Value;
|
|
use tracing::warn;
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub enum ToolKind {
|
|
Function,
|
|
Mcp,
|
|
}
|
|
|
|
pub trait ToolHandler: Send + Sync {
|
|
type Output: ToolOutput + 'static;
|
|
|
|
fn kind(&self) -> ToolKind;
|
|
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
|
matches!(
|
|
(self.kind(), payload),
|
|
(ToolKind::Function, ToolPayload::Function { .. })
|
|
| (ToolKind::Function, ToolPayload::ToolSearch { .. })
|
|
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
|
)
|
|
}
|
|
|
|
/// Returns `true` if the [ToolInvocation] *might* mutate the environment of the
|
|
/// user (through file system, OS operations, ...).
|
|
/// This function must remains defensive and return `true` if a doubt exist on the
|
|
/// exact effect of a ToolInvocation.
|
|
fn is_mutating(
|
|
&self,
|
|
_invocation: &ToolInvocation,
|
|
) -> impl std::future::Future<Output = bool> + Send {
|
|
async { false }
|
|
}
|
|
|
|
fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
|
None
|
|
}
|
|
|
|
fn post_tool_use_payload(
|
|
&self,
|
|
_call_id: &str,
|
|
_payload: &ToolPayload,
|
|
_result: &dyn ToolOutput,
|
|
) -> Option<PostToolUsePayload> {
|
|
None
|
|
}
|
|
|
|
/// Perform the actual [ToolInvocation] and returns a [ToolOutput] containing
|
|
/// the final output to return to the model.
|
|
fn handle(
|
|
&self,
|
|
invocation: ToolInvocation,
|
|
) -> impl std::future::Future<Output = Result<Self::Output, FunctionCallError>> + Send;
|
|
}
|
|
|
|
pub(crate) struct AnyToolResult {
|
|
pub(crate) call_id: String,
|
|
pub(crate) payload: ToolPayload,
|
|
pub(crate) result: Box<dyn ToolOutput>,
|
|
}
|
|
|
|
impl AnyToolResult {
|
|
pub(crate) fn into_response(self) -> ResponseInputItem {
|
|
let Self {
|
|
call_id,
|
|
payload,
|
|
result,
|
|
..
|
|
} = self;
|
|
result.to_response_item(&call_id, &payload)
|
|
}
|
|
|
|
pub(crate) fn code_mode_result(self) -> serde_json::Value {
|
|
let Self {
|
|
payload, result, ..
|
|
} = self;
|
|
result.code_mode_result(&payload)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(crate) struct PreToolUsePayload {
|
|
pub(crate) command: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub(crate) struct PostToolUsePayload {
|
|
pub(crate) command: String,
|
|
pub(crate) tool_response: Value,
|
|
}
|
|
|
|
trait AnyToolHandler: Send + Sync {
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool;
|
|
|
|
fn is_mutating<'a>(&'a self, invocation: &'a ToolInvocation) -> BoxFuture<'a, bool>;
|
|
|
|
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload>;
|
|
|
|
fn post_tool_use_payload(
|
|
&self,
|
|
call_id: &str,
|
|
payload: &ToolPayload,
|
|
result: &dyn ToolOutput,
|
|
) -> Option<PostToolUsePayload>;
|
|
|
|
fn handle_any<'a>(
|
|
&'a self,
|
|
invocation: ToolInvocation,
|
|
) -> BoxFuture<'a, Result<AnyToolResult, FunctionCallError>>;
|
|
}
|
|
|
|
impl<T> AnyToolHandler for T
|
|
where
|
|
T: ToolHandler,
|
|
{
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
|
ToolHandler::matches_kind(self, payload)
|
|
}
|
|
|
|
fn is_mutating<'a>(&'a self, invocation: &'a ToolInvocation) -> BoxFuture<'a, bool> {
|
|
Box::pin(ToolHandler::is_mutating(self, invocation))
|
|
}
|
|
|
|
fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
|
|
ToolHandler::pre_tool_use_payload(self, invocation)
|
|
}
|
|
|
|
fn post_tool_use_payload(
|
|
&self,
|
|
call_id: &str,
|
|
payload: &ToolPayload,
|
|
result: &dyn ToolOutput,
|
|
) -> Option<PostToolUsePayload> {
|
|
ToolHandler::post_tool_use_payload(self, call_id, payload, result)
|
|
}
|
|
|
|
fn handle_any<'a>(
|
|
&'a self,
|
|
invocation: ToolInvocation,
|
|
) -> BoxFuture<'a, Result<AnyToolResult, FunctionCallError>> {
|
|
Box::pin(async move {
|
|
let call_id = invocation.call_id.clone();
|
|
let payload = invocation.payload.clone();
|
|
let output = self.handle(invocation).await?;
|
|
Ok(AnyToolResult {
|
|
call_id,
|
|
payload,
|
|
result: Box::new(output),
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
pub(crate) fn tool_handler_key(tool_name: &str, namespace: Option<&str>) -> String {
|
|
if let Some(namespace) = namespace {
|
|
format!("{namespace}:{tool_name}")
|
|
} else {
|
|
tool_name.to_string()
|
|
}
|
|
}
|
|
|
|
pub struct ToolRegistry {
|
|
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
|
|
}
|
|
|
|
impl ToolRegistry {
|
|
fn new(handlers: HashMap<String, Arc<dyn AnyToolHandler>>) -> Self {
|
|
Self { handlers }
|
|
}
|
|
|
|
fn handler(&self, name: &str, namespace: Option<&str>) -> Option<Arc<dyn AnyToolHandler>> {
|
|
self.handlers
|
|
.get(&tool_handler_key(name, namespace))
|
|
.map(Arc::clone)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub(crate) fn has_handler(&self, name: &str, namespace: Option<&str>) -> bool {
|
|
self.handler(name, namespace).is_some()
|
|
}
|
|
|
|
// TODO(jif) for dynamic tools.
|
|
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
|
// let name = name.into();
|
|
// if self.handlers.insert(name.clone(), handler).is_some() {
|
|
// warn!("overwriting handler for tool {name}");
|
|
// }
|
|
// }
|
|
|
|
pub(crate) async fn dispatch_any(
|
|
&self,
|
|
invocation: ToolInvocation,
|
|
) -> Result<AnyToolResult, FunctionCallError> {
|
|
let tool_name = invocation.tool_name.clone();
|
|
let tool_namespace = invocation.tool_namespace.clone();
|
|
let call_id_owned = invocation.call_id.clone();
|
|
let otel = invocation.turn.session_telemetry.clone();
|
|
let payload_for_response = invocation.payload.clone();
|
|
let log_payload = payload_for_response.log_payload();
|
|
let metric_tags = [
|
|
(
|
|
"sandbox",
|
|
sandbox_tag(
|
|
&invocation.turn.sandbox_policy,
|
|
invocation.turn.windows_sandbox_level,
|
|
),
|
|
),
|
|
(
|
|
"sandbox_policy",
|
|
sandbox_policy_tag(&invocation.turn.sandbox_policy),
|
|
),
|
|
];
|
|
let (mcp_server, mcp_server_origin) = match &invocation.payload {
|
|
ToolPayload::Mcp { server, .. } => {
|
|
let manager = invocation
|
|
.session
|
|
.services
|
|
.mcp_connection_manager
|
|
.read()
|
|
.await;
|
|
let origin = manager.server_origin(server).map(str::to_owned);
|
|
(Some(server.clone()), origin)
|
|
}
|
|
_ => (None, None),
|
|
};
|
|
let mcp_server_ref = mcp_server.as_deref();
|
|
let mcp_server_origin_ref = mcp_server_origin.as_deref();
|
|
|
|
{
|
|
let mut active = invocation.session.active_turn.lock().await;
|
|
if let Some(active_turn) = active.as_mut() {
|
|
let mut turn_state = active_turn.turn_state.lock().await;
|
|
turn_state.tool_calls = turn_state.tool_calls.saturating_add(1);
|
|
}
|
|
}
|
|
|
|
let handler = match self.handler(tool_name.as_ref(), tool_namespace.as_deref()) {
|
|
Some(handler) => handler,
|
|
None => {
|
|
let message = unsupported_tool_call_message(
|
|
&invocation.payload,
|
|
tool_name.as_ref(),
|
|
tool_namespace.as_deref(),
|
|
);
|
|
otel.tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
Duration::ZERO,
|
|
/*success*/ false,
|
|
&message,
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
);
|
|
return Err(FunctionCallError::RespondToModel(message));
|
|
}
|
|
};
|
|
|
|
if !handler.matches_kind(&invocation.payload) {
|
|
let message = format!("tool {tool_name} invoked with incompatible payload");
|
|
otel.tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
Duration::ZERO,
|
|
/*success*/ false,
|
|
&message,
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
);
|
|
return Err(FunctionCallError::Fatal(message));
|
|
}
|
|
|
|
if let Some(pre_tool_use_payload) = handler.pre_tool_use_payload(&invocation)
|
|
&& let Some(reason) = run_pre_tool_use_hooks(
|
|
&invocation.session,
|
|
&invocation.turn,
|
|
invocation.call_id.clone(),
|
|
pre_tool_use_payload.command.clone(),
|
|
)
|
|
.await
|
|
{
|
|
return Err(FunctionCallError::RespondToModel(format!(
|
|
"Command blocked by PreToolUse hook: {reason}. Command: {}",
|
|
pre_tool_use_payload.command
|
|
)));
|
|
}
|
|
|
|
let is_mutating = handler.is_mutating(&invocation).await;
|
|
let response_cell = tokio::sync::Mutex::new(None);
|
|
let invocation_for_tool = invocation.clone();
|
|
|
|
let started = Instant::now();
|
|
let result = otel
|
|
.log_tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
|| {
|
|
let handler = handler.clone();
|
|
let response_cell = &response_cell;
|
|
async move {
|
|
if is_mutating {
|
|
tracing::trace!("waiting for tool gate");
|
|
invocation_for_tool.turn.tool_call_gate.wait_ready().await;
|
|
tracing::trace!("tool gate released");
|
|
}
|
|
match handler.handle_any(invocation_for_tool).await {
|
|
Ok(result) => {
|
|
let preview = result.result.log_preview();
|
|
let success = result.result.success_for_logging();
|
|
let mut guard = response_cell.lock().await;
|
|
*guard = Some(result);
|
|
Ok((preview, success))
|
|
}
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
},
|
|
)
|
|
.await;
|
|
let duration = started.elapsed();
|
|
let (output_preview, success) = match &result {
|
|
Ok((preview, success)) => (preview.clone(), *success),
|
|
Err(err) => (err.to_string(), false),
|
|
};
|
|
emit_metric_for_tool_read(&invocation, success).await;
|
|
let post_tool_use_payload = if success {
|
|
let guard = response_cell.lock().await;
|
|
guard.as_ref().and_then(|result| {
|
|
handler.post_tool_use_payload(
|
|
&result.call_id,
|
|
&result.payload,
|
|
result.result.as_ref(),
|
|
)
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
let post_tool_use_outcome = if let Some(post_tool_use_payload) = post_tool_use_payload {
|
|
Some(
|
|
run_post_tool_use_hooks(
|
|
&invocation.session,
|
|
&invocation.turn,
|
|
invocation.call_id.clone(),
|
|
post_tool_use_payload.command,
|
|
post_tool_use_payload.tool_response,
|
|
)
|
|
.await,
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
// Deprecated: this is the legacy AfterToolUse hook. Prefer the new PostToolUse
|
|
let hook_abort_error = dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
|
|
invocation: &invocation,
|
|
output_preview,
|
|
success,
|
|
executed: true,
|
|
duration,
|
|
mutating: is_mutating,
|
|
})
|
|
.await;
|
|
|
|
if let Some(err) = hook_abort_error {
|
|
return Err(err);
|
|
}
|
|
|
|
if let Some(outcome) = &post_tool_use_outcome {
|
|
record_additional_contexts(
|
|
&invocation.session,
|
|
&invocation.turn,
|
|
outcome.additional_contexts.clone(),
|
|
)
|
|
.await;
|
|
|
|
let replacement_text = if outcome.should_stop {
|
|
Some(
|
|
outcome
|
|
.feedback_message
|
|
.clone()
|
|
.or_else(|| outcome.stop_reason.clone())
|
|
.unwrap_or_else(|| "PostToolUse hook stopped execution".to_string()),
|
|
)
|
|
} else {
|
|
outcome.feedback_message.clone()
|
|
};
|
|
if let Some(replacement_text) = replacement_text {
|
|
let mut guard = response_cell.lock().await;
|
|
if let Some(result) = guard.as_mut() {
|
|
result.result = Box::new(FunctionToolOutput::from_text(
|
|
replacement_text,
|
|
/*success*/ None,
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
let mut guard = response_cell.lock().await;
|
|
let result = guard.take().ok_or_else(|| {
|
|
FunctionCallError::Fatal("tool produced no output".to_string())
|
|
})?;
|
|
Ok(result)
|
|
}
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ToolRegistryBuilder {
|
|
handlers: HashMap<String, Arc<dyn AnyToolHandler>>,
|
|
specs: Vec<ConfiguredToolSpec>,
|
|
}
|
|
|
|
impl ToolRegistryBuilder {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
handlers: HashMap::new(),
|
|
specs: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn push_spec(&mut self, spec: ToolSpec) {
|
|
self.push_spec_with_parallel_support(spec, /*supports_parallel_tool_calls*/ false);
|
|
}
|
|
|
|
pub fn push_spec_with_parallel_support(
|
|
&mut self,
|
|
spec: ToolSpec,
|
|
supports_parallel_tool_calls: bool,
|
|
) {
|
|
self.specs
|
|
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
|
}
|
|
|
|
pub fn register_handler<H>(&mut self, name: impl Into<String>, handler: Arc<H>)
|
|
where
|
|
H: ToolHandler + 'static,
|
|
{
|
|
let name = name.into();
|
|
let handler: Arc<dyn AnyToolHandler> = handler;
|
|
if self
|
|
.handlers
|
|
.insert(name.clone(), handler.clone())
|
|
.is_some()
|
|
{
|
|
warn!("overwriting handler for tool {name}");
|
|
}
|
|
}
|
|
|
|
// TODO(jif) for dynamic tools.
|
|
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
|
// where
|
|
// I: IntoIterator,
|
|
// I::Item: Into<String>,
|
|
// {
|
|
// for name in names {
|
|
// let name = name.into();
|
|
// if self
|
|
// .handlers
|
|
// .insert(name.clone(), handler.clone())
|
|
// .is_some()
|
|
// {
|
|
// warn!("overwriting handler for tool {name}");
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
|
let registry = ToolRegistry::new(self.handlers);
|
|
(self.specs, registry)
|
|
}
|
|
}
|
|
|
|
fn unsupported_tool_call_message(
|
|
payload: &ToolPayload,
|
|
tool_name: &str,
|
|
namespace: Option<&str>,
|
|
) -> String {
|
|
let tool_name = tool_handler_key(tool_name, namespace);
|
|
match payload {
|
|
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
|
_ => format!("unsupported call: {tool_name}"),
|
|
}
|
|
}
|
|
|
|
fn sandbox_policy_tag(policy: &SandboxPolicy) -> &'static str {
|
|
match policy {
|
|
SandboxPolicy::ReadOnly { .. } => "read-only",
|
|
SandboxPolicy::WorkspaceWrite { .. } => "workspace-write",
|
|
SandboxPolicy::DangerFullAccess => "danger-full-access",
|
|
SandboxPolicy::ExternalSandbox { .. } => "external-sandbox",
|
|
}
|
|
}
|
|
|
|
// Hooks use a separate wire-facing input type so hook payload JSON stays stable
|
|
// and decoupled from core's internal tool runtime representation.
|
|
impl From<&ToolPayload> for HookToolInput {
|
|
fn from(payload: &ToolPayload) -> Self {
|
|
match payload {
|
|
ToolPayload::Function { arguments } => HookToolInput::Function {
|
|
arguments: arguments.clone(),
|
|
},
|
|
ToolPayload::ToolSearch { arguments } => HookToolInput::Function {
|
|
arguments: serde_json::json!({
|
|
"query": arguments.query,
|
|
"limit": arguments.limit,
|
|
})
|
|
.to_string(),
|
|
},
|
|
ToolPayload::Custom { input } => HookToolInput::Custom {
|
|
input: input.clone(),
|
|
},
|
|
ToolPayload::LocalShell { params } => HookToolInput::LocalShell {
|
|
params: HookToolInputLocalShell {
|
|
command: params.command.clone(),
|
|
workdir: params.workdir.clone(),
|
|
timeout_ms: params.timeout_ms,
|
|
sandbox_permissions: params.sandbox_permissions,
|
|
prefix_rule: params.prefix_rule.clone(),
|
|
justification: params.justification.clone(),
|
|
},
|
|
},
|
|
ToolPayload::Mcp {
|
|
server,
|
|
tool,
|
|
raw_arguments,
|
|
} => HookToolInput::Mcp {
|
|
server: server.clone(),
|
|
tool: tool.clone(),
|
|
arguments: raw_arguments.clone(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
fn hook_tool_kind(tool_input: &HookToolInput) -> HookToolKind {
|
|
match tool_input {
|
|
HookToolInput::Function { .. } => HookToolKind::Function,
|
|
HookToolInput::Custom { .. } => HookToolKind::Custom,
|
|
HookToolInput::LocalShell { .. } => HookToolKind::LocalShell,
|
|
HookToolInput::Mcp { .. } => HookToolKind::Mcp,
|
|
}
|
|
}
|
|
|
|
struct AfterToolUseHookDispatch<'a> {
|
|
invocation: &'a ToolInvocation,
|
|
output_preview: String,
|
|
success: bool,
|
|
executed: bool,
|
|
duration: Duration,
|
|
mutating: bool,
|
|
}
|
|
|
|
async fn dispatch_after_tool_use_hook(
|
|
dispatch: AfterToolUseHookDispatch<'_>,
|
|
) -> Option<FunctionCallError> {
|
|
let AfterToolUseHookDispatch { invocation, .. } = dispatch;
|
|
let session = invocation.session.as_ref();
|
|
let turn = invocation.turn.as_ref();
|
|
let tool_input = HookToolInput::from(&invocation.payload);
|
|
let hook_outcomes = session
|
|
.hooks()
|
|
.dispatch(HookPayload {
|
|
session_id: session.conversation_id,
|
|
cwd: turn.cwd.to_path_buf(),
|
|
client: turn.app_server_client_name.clone(),
|
|
triggered_at: chrono::Utc::now(),
|
|
hook_event: HookEvent::AfterToolUse {
|
|
event: HookEventAfterToolUse {
|
|
turn_id: turn.sub_id.clone(),
|
|
call_id: invocation.call_id.clone(),
|
|
tool_name: invocation.tool_name.clone(),
|
|
tool_kind: hook_tool_kind(&tool_input),
|
|
tool_input,
|
|
executed: dispatch.executed,
|
|
success: dispatch.success,
|
|
duration_ms: u64::try_from(dispatch.duration.as_millis()).unwrap_or(u64::MAX),
|
|
mutating: dispatch.mutating,
|
|
sandbox: sandbox_tag(&turn.sandbox_policy, turn.windows_sandbox_level)
|
|
.to_string(),
|
|
sandbox_policy: sandbox_policy_tag(&turn.sandbox_policy).to_string(),
|
|
output_preview: dispatch.output_preview.clone(),
|
|
},
|
|
},
|
|
})
|
|
.await;
|
|
|
|
for hook_outcome in hook_outcomes {
|
|
let hook_name = hook_outcome.hook_name;
|
|
match hook_outcome.result {
|
|
HookResult::Success => {}
|
|
HookResult::FailedContinue(error) => {
|
|
warn!(
|
|
call_id = %invocation.call_id,
|
|
tool_name = %invocation.tool_name,
|
|
hook_name = %hook_name,
|
|
error = %error,
|
|
"after_tool_use hook failed; continuing"
|
|
);
|
|
}
|
|
HookResult::FailedAbort(error) => {
|
|
warn!(
|
|
call_id = %invocation.call_id,
|
|
tool_name = %invocation.tool_name,
|
|
hook_name = %hook_name,
|
|
error = %error,
|
|
"after_tool_use hook failed; aborting operation"
|
|
);
|
|
return Some(FunctionCallError::Fatal(format!(
|
|
"after_tool_use hook '{hook_name}' failed and aborted operation: {error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[path = "registry_tests.rs"]
|
|
mod tests;
|