Keep tool dispatch payload tracing out of handlers

This commit is contained in:
pakrym-oai
2026-05-08 14:09:17 -07:00
parent c22ab62634
commit 9198ee14f3
2 changed files with 28 additions and 49 deletions

View File

@@ -21,7 +21,6 @@ use crate::tools::tool_dispatch_trace::ToolDispatchTrace;
use crate::util::error_or_panic;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::protocol::EventMsg;
use codex_rollout_trace::ToolDispatchPayload;
use codex_tools::ConfiguredToolSpec;
use codex_tools::ToolName;
use codex_tools::ToolSpec;
@@ -60,10 +59,6 @@ pub trait ToolHandler: Send + Sync {
async { Vec::new() }
}
fn dispatch_payload(&self, invocation: &ToolInvocation) -> ToolDispatchPayload {
tool_dispatch_payload(&invocation.payload)
}
/// Returns `true` if the [ToolInvocation] *might* mutate the environment of the
/// user (through file system, OS operations, ...).
/// This function must remains defensive and return `true` if a doubt exist on the
@@ -182,8 +177,6 @@ trait AnyToolHandler: Send + Sync {
invocation: &'a ToolInvocation,
) -> BoxFuture<'a, ToolTelemetryTags>;
fn dispatch_payload(&self, invocation: &ToolInvocation) -> ToolDispatchPayload;
fn create_diff_consumer(&self) -> Option<Box<dyn ToolArgumentDiffConsumer>>;
fn handle_any<'a>(
&'a self,
@@ -218,10 +211,6 @@ where
Box::pin(ToolHandler::telemetry_tags(self, invocation))
}
fn dispatch_payload(&self, invocation: &ToolInvocation) -> ToolDispatchPayload {
ToolHandler::dispatch_payload(self, invocation)
}
fn create_diff_consumer(&self) -> Option<Box<dyn ToolArgumentDiffConsumer>> {
ToolHandler::create_diff_consumer(self)
}
@@ -328,12 +317,10 @@ impl ToolRegistry {
}
}
let dispatch_trace = ToolDispatchTrace::start(&invocation);
let handler = match self.handler(&tool_name) {
Some(handler) => handler,
None => {
let dispatch_trace = ToolDispatchTrace::start(&invocation, || {
tool_dispatch_payload(&invocation.payload)
});
let message = unsupported_tool_call_message(&invocation.payload, &tool_name);
otel.tool_result_with_tags(
tool_name_flat.as_ref(),
@@ -359,8 +346,6 @@ impl ToolRegistry {
.iter()
.map(|(key, value)| (*key, value.as_str())),
);
let dispatch_trace =
ToolDispatchTrace::start(&invocation, || handler.dispatch_payload(&invocation));
if !handler.matches_kind(&invocation.payload) {
let message = format!("tool {tool_name} invoked with incompatible payload");
@@ -579,29 +564,6 @@ fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &ToolName) ->
}
}
fn tool_dispatch_payload(payload: &ToolPayload) -> ToolDispatchPayload {
match payload {
ToolPayload::Function { arguments } => ToolDispatchPayload::Function {
arguments: arguments.clone(),
},
ToolPayload::ToolSearch { arguments } => ToolDispatchPayload::ToolSearch {
arguments: arguments.clone(),
},
ToolPayload::Custom { input } => ToolDispatchPayload::Custom {
input: input.clone(),
},
ToolPayload::LocalShell { params } => ToolDispatchPayload::LocalShell {
command: params.command.clone(),
workdir: params.workdir.clone(),
timeout_ms: params.timeout_ms,
sandbox_permissions: params.sandbox_permissions,
prefix_rule: params.prefix_rule.clone(),
additional_permissions: params.additional_permissions.clone(),
justification: params.justification.clone(),
},
}
}
#[cfg(test)]
#[path = "registry_tests.rs"]
mod tests;

View File

@@ -22,15 +22,12 @@ pub(crate) struct ToolDispatchTrace {
}
impl ToolDispatchTrace {
pub(crate) fn start(
invocation: &ToolInvocation,
payload: impl FnOnce() -> ToolDispatchPayload,
) -> Self {
pub(crate) fn start(invocation: &ToolInvocation) -> Self {
let context = invocation
.session
.services
.rollout_thread_trace
.start_tool_dispatch_trace(|| tool_dispatch_invocation(invocation, payload()));
.start_tool_dispatch_trace(|| tool_dispatch_invocation(invocation));
Self { context }
}
@@ -62,10 +59,7 @@ impl ToolDispatchTrace {
}
}
fn tool_dispatch_invocation(
invocation: &ToolInvocation,
payload: ToolDispatchPayload,
) -> Option<ToolDispatchInvocation> {
fn tool_dispatch_invocation(invocation: &ToolInvocation) -> Option<ToolDispatchInvocation> {
let requester = match &invocation.source {
ToolCallSource::Direct => ToolDispatchRequester::Model {
model_visible_call_id: invocation.call_id.clone(),
@@ -86,7 +80,7 @@ fn tool_dispatch_invocation(
tool_name: invocation.tool_name.name.clone(),
tool_namespace: invocation.tool_name.namespace.clone(),
requester,
payload,
payload: tool_dispatch_payload(&invocation.payload),
})
}
@@ -106,6 +100,29 @@ fn tool_dispatch_result(
}
}
fn tool_dispatch_payload(payload: &ToolPayload) -> ToolDispatchPayload {
match payload {
ToolPayload::Function { arguments } => ToolDispatchPayload::Function {
arguments: arguments.clone(),
},
ToolPayload::ToolSearch { arguments } => ToolDispatchPayload::ToolSearch {
arguments: arguments.clone(),
},
ToolPayload::Custom { input } => ToolDispatchPayload::Custom {
input: input.clone(),
},
ToolPayload::LocalShell { params } => ToolDispatchPayload::LocalShell {
command: params.command.clone(),
workdir: params.workdir.clone(),
timeout_ms: params.timeout_ms,
sandbox_permissions: params.sandbox_permissions,
prefix_rule: params.prefix_rule.clone(),
additional_permissions: params.additional_permissions.clone(),
justification: params.justification.clone(),
},
}
}
#[cfg(test)]
#[path = "tool_dispatch_trace_tests.rs"]
mod tests;