use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; use tokio_util::either::Either; use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use tracing::instrument; use tracing::trace_span; use crate::function_tool::FunctionCallError; use crate::session::session::Session; use crate::session::turn_context::TurnContext; use crate::tools::context::AbortedToolOutput; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::router::ToolCall; use crate::tools::router::ToolCallSource; use crate::tools::router::ToolRouter; use codex_protocol::error::CodexErr; use codex_protocol::models::ResponseInputItem; use codex_tools::ToolSpec; #[derive(Clone)] pub(crate) struct ToolCallRuntime { router: Arc, session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, parallel_execution: Arc>, } impl ToolCallRuntime { pub(crate) fn new( router: Arc, session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, ) -> Self { Self { router, session, turn_context, tracker, parallel_execution: Arc::new(RwLock::new(())), } } pub(crate) fn find_spec(&self, tool_name: &codex_tools::ToolName) -> Option { self.router.find_spec(tool_name) } pub(crate) fn create_diff_consumer( &self, tool_name: &codex_tools::ToolName, ) -> Option> { self.router.create_diff_consumer(tool_name) } #[instrument(level = "trace", skip_all)] pub(crate) fn handle_tool_call( self, call: ToolCall, cancellation_token: CancellationToken, ) -> impl std::future::Future> { let error_call = call.clone(); let future = self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token); async move { match future.await { Ok(response) => Ok(response.into_response()), Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), Err(other) => Ok(Self::failure_response(error_call, other)), } } .in_current_span() } #[instrument(level = "trace", skip_all)] pub(crate) fn handle_tool_call_with_source( self, call: ToolCall, source: ToolCallSource, cancellation_token: CancellationToken, ) -> impl std::future::Future> { let supports_parallel = self.router.tool_supports_parallel(&call); let router = Arc::clone(&self.router); let session = Arc::clone(&self.session); let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); let lock = Arc::clone(&self.parallel_execution); let invocation_cancellation_token = cancellation_token.clone(); let started = Instant::now(); let dispatch_span = trace_span!( "dispatch_tool_call_with_code_mode_result", otel.name = %call.tool_name, tool_name = %call.tool_name, call_id = call.call_id.as_str(), aborted = false, ); let handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { tokio::select! { _ = cancellation_token.cancelled() => { let secs = started.elapsed().as_secs_f32().max(0.1); dispatch_span.record("aborted", true); Ok(Self::aborted_response(&call, secs)) }, res = async { let _guard = if supports_parallel { Either::Left(lock.read().await) } else { Either::Right(lock.write().await) }; router .dispatch_tool_call_with_code_mode_result( session, turn, invocation_cancellation_token, tracker, call.clone(), source, ) .instrument(dispatch_span.clone()) .await } => res, } })); async move { handle.await.map_err(|err| { FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) })? } .in_current_span() } } impl ToolCallRuntime { fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem { let message = err.to_string(); match call.payload { ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput { call_id: call.call_id, status: "completed".to_string(), execution: "client".to_string(), tools: Vec::new(), }, ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { call_id: call.call_id, name: None, output: codex_protocol::models::FunctionCallOutputPayload { body: codex_protocol::models::FunctionCallOutputBody::Text(message), success: Some(false), }, }, _ => ResponseInputItem::FunctionCallOutput { call_id: call.call_id, output: codex_protocol::models::FunctionCallOutputPayload { body: codex_protocol::models::FunctionCallOutputBody::Text(message), success: Some(false), }, }, } } fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult { AnyToolResult { call_id: call.call_id.clone(), payload: call.payload.clone(), result: Box::new(AbortedToolOutput { message: Self::abort_message(call, secs), }), post_tool_use_payload: None, } } fn abort_message(call: &ToolCall, secs: f32) -> String { if call.tool_name.namespace.is_none() && matches!( call.tool_name.name.as_str(), "shell" | "container.exec" | "local_shell" | "shell_command" | "unified_exec" ) { format!("Wall time: {secs:.1} seconds\naborted by user") } else { format!("aborted by user after {secs:.1}s") } } }