Files
codex/codex-rs/core/src/tools/parallel.rs
Anton Panasenko 77b0c75267 feat: search_tool migrate to bring you own tool of Responses API (#14274)
## Why

to support a new bring your own search tool in Responses
API(https://developers.openai.com/api/docs/guides/tools-tool-search#client-executed-tool-search)
we migrating our bm25 search tool to use official way to execute search
on client and communicate additional tools to the model.

## What
- replace the legacy `search_tool_bm25` flow with client-executed
`tool_search`
- add protocol, SSE, history, and normalization support for
`tool_search_call` and `tool_search_output`
- return namespaced Codex Apps search results and wire namespaced
follow-up tool calls back into MCP dispatch
2026-03-11 17:51:51 -07:00

156 lines
5.5 KiB
Rust

use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio_util::either::Either;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing::Instrument;
use tracing::instrument;
use tracing::trace_span;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
#[derive(Clone)]
pub(crate) struct ToolCallRuntime {
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
parallel_execution: Arc<RwLock<()>>,
}
impl ToolCallRuntime {
pub(crate) fn new(
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
) -> Self {
Self {
router,
session,
turn_context,
tracker,
parallel_execution: Arc::new(RwLock::new(())),
}
}
#[instrument(level = "trace", skip_all)]
pub(crate) fn handle_tool_call(
self,
call: ToolCall,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution);
let started = Instant::now();
let dispatch_span = trace_span!(
"dispatch_tool_call",
otel.name = call.tool_name.as_str(),
tool_name = call.tool_name.as_str(),
call_id = call.call_id.as_str(),
aborted = false,
);
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {
let secs = started.elapsed().as_secs_f32().max(0.1);
dispatch_span.record("aborted", true);
Ok(Self::aborted_response(&call, secs))
},
res = async {
let _guard = if supports_parallel {
Either::Left(lock.read().await)
} else {
Either::Right(lock.write().await)
};
router
.dispatch_tool_call(
session,
turn,
tracker,
call.clone(),
crate::tools::router::ToolCallSource::Direct,
)
.instrument(dispatch_span.clone())
.await
} => res,
}
}));
async move {
match handle.await {
Ok(Ok(response)) => Ok(response),
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
Err(err) => Err(CodexErr::Fatal(format!(
"tool task failed to receive: {err:?}"
))),
}
}
.in_current_span()
}
}
impl ToolCallRuntime {
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
match &call.payload {
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
..Default::default()
},
},
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call.call_id.clone(),
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call.call_id.clone(),
output: codex_protocol::mcp::CallToolResult::from_error_text(Self::abort_message(
call, secs,
)),
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
..Default::default()
},
},
}
}
fn abort_message(call: &ToolCall, secs: f32) -> String {
match call.tool_name.as_str() {
"shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec" => {
format!("Wall time: {secs:.1} seconds\naborted by user")
}
_ => format!("aborted by user after {secs:.1}s"),
}
}
}