Reuse tool runtime for code mode worker (#14496)

## Summary
- create the turn-scoped `ToolCallRuntime` before starting the code mode
worker so the worker reuses the same runtime and router
- thread the shared runtime through the code mode service/worker path
and use it for nested tool calls
- model aborted tool calls as a concrete `ToolOutput` so aborted
responses still produce valid tool output shapes

## Testing
- `just fmt`
- `cargo test -p codex-core` (still running locally)
This commit is contained in:
pakrym-oai
2026-03-12 12:48:32 -07:00
committed by GitHub
parent d3e6680531
commit 09ba6b47ae
8 changed files with 112 additions and 81 deletions

View File

@@ -5551,11 +5551,6 @@ pub(crate) async fn run_turn(
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let _code_mode_worker = sess
.services
.code_mode_service
.start_turn_worker(&sess, &turn_context, &turn_diff_tracker)
.await;
let mut server_model_warning_emitted_for_turn = false;
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
@@ -6161,10 +6156,26 @@ async fn run_sampling_request(
turn_context.as_ref(),
base_instructions,
);
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
);
let _code_mode_worker = sess
.services
.code_mode_service
.start_turn_worker(
&sess,
&turn_context,
Arc::clone(&router),
Arc::clone(&turn_diff_tracker),
)
.await;
let mut retries = 0;
loop {
let err = match try_run_sampling_request(
Arc::clone(&router),
tool_runtime.clone(),
Arc::clone(&sess),
Arc::clone(&turn_context),
client_session,
@@ -6919,7 +6930,7 @@ async fn drain_in_flight(
)
)]
async fn try_run_sampling_request(
router: Arc<ToolRouter>,
tool_runtime: ToolCallRuntime,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
client_session: &mut ModelClientSession,
@@ -6950,13 +6961,6 @@ async fn try_run_sampling_request(
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)
.await??;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
);
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
FuturesOrdered::new();
let mut needs_follow_up = false;

View File

@@ -4,7 +4,6 @@ use crate::codex::Session;
use crate::codex::TurnContext;
use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
@@ -25,14 +24,9 @@ impl CodeModeExecuteHandler {
&self,
session: std::sync::Arc<Session>,
turn: std::sync::Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
code: String,
) -> Result<FunctionToolOutput, FunctionCallError> {
let exec = ExecContext {
session,
turn,
tracker,
};
let exec = ExecContext { session, turn };
let enabled_tools = build_enabled_tools(&exec).await;
let service = &exec.session.services.code_mode_service;
let stored_values = service.stored_values().await;
@@ -94,7 +88,6 @@ impl ToolHandler for CodeModeExecuteHandler {
let ToolInvocation {
session,
turn,
tracker,
tool_name,
payload,
..
@@ -102,7 +95,7 @@ impl ToolHandler for CodeModeExecuteHandler {
match payload {
ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => {
self.execute(session, turn, tracker, input).await
self.execute(session, turn, input).await
}
_ => Err(FunctionCallError::RespondToModel(format!(
"{PUBLIC_TOOL_NAME} expects raw JavaScript source text"

View File

@@ -18,8 +18,8 @@ use crate::tools::ToolRouter;
use crate::tools::code_mode_description::augment_tool_spec_for_code_mode;
use crate::tools::code_mode_description::code_mode_tool_reference;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolCallSource;
use crate::tools::router::ToolRouterParams;
@@ -42,7 +42,6 @@ pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000;
pub(super) struct ExecContext {
pub(super) session: Arc<Session>,
pub(super) turn: Arc<TurnContext>,
pub(super) tracker: SharedTurnDiffTracker,
}
pub(crate) use execute_handler::CodeModeExecuteHandler;
@@ -270,8 +269,10 @@ async fn build_nested_router(exec: &ExecContext) -> ToolRouter {
async fn call_nested_tool(
exec: ExecContext,
tool_runtime: ToolCallRuntime,
tool_name: String,
input: Option<JsonValue>,
cancellation_token: tokio_util::sync::CancellationToken,
) -> JsonValue {
if tool_name == PUBLIC_TOOL_NAME {
return JsonValue::String(format!("{PUBLIC_TOOL_NAME} cannot invoke itself"));
@@ -302,14 +303,8 @@ async fn call_nested_tool(
tool_namespace: None,
payload,
};
let result = router
.dispatch_tool_call_with_code_mode_result(
exec.session.clone(),
exec.turn.clone(),
exec.tracker.clone(),
call,
ToolCallSource::CodeMode,
)
let result = tool_runtime
.handle_tool_call_with_source(call, ToolCallSource::CodeMode, cancellation_token)
.await;
match result {

View File

@@ -9,8 +9,10 @@ use tracing::warn;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::features::Feature;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::js_repl::resolve_compatible_node;
use crate::tools::parallel::ToolCallRuntime;
use super::ExecContext;
use super::PUBLIC_TOOL_NAME;
@@ -65,7 +67,8 @@ impl CodeModeService {
&self,
session: &Arc<Session>,
turn: &Arc<TurnContext>,
tracker: &SharedTurnDiffTracker,
router: Arc<ToolRouter>,
tracker: SharedTurnDiffTracker,
) -> Option<CodeModeWorker> {
if !turn.features.enabled(Feature::CodeMode) {
return None;
@@ -73,8 +76,9 @@ impl CodeModeService {
let exec = ExecContext {
session: Arc::clone(session),
turn: Arc::clone(turn),
tracker: Arc::clone(tracker),
};
let tool_runtime =
ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker);
let mut process_slot = match self.ensure_started().await {
Ok(process_slot) => process_slot,
Err(err) => {
@@ -88,7 +92,7 @@ impl CodeModeService {
);
return None;
};
Some(process.worker(exec))
Some(process.worker(exec, tool_runtime))
}
pub(crate) async fn allocate_session_id(&self) -> i32 {

View File

@@ -54,7 +54,6 @@ impl ToolHandler for CodeModeWaitHandler {
let ToolInvocation {
session,
turn,
tracker,
tool_name,
payload,
..
@@ -63,11 +62,7 @@ impl ToolHandler for CodeModeWaitHandler {
match payload {
ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => {
let args: ExecWaitArgs = parse_arguments(&arguments)?;
let exec = ExecContext {
session,
turn,
tracker,
};
let exec = ExecContext { session, turn };
let request_id = exec
.session
.services

View File

@@ -1,4 +1,5 @@
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use super::ExecContext;
@@ -7,6 +8,7 @@ use super::call_nested_tool;
use super::process::CodeModeProcess;
use super::process::write_message;
use super::protocol::HostToNodeMessage;
use crate::tools::parallel::ToolCallRuntime;
pub(crate) struct CodeModeWorker {
shutdown_tx: Option<oneshot::Sender<()>>,
}
@@ -20,7 +22,11 @@ impl Drop for CodeModeWorker {
}
impl CodeModeProcess {
pub(super) fn worker(&self, exec: ExecContext) -> CodeModeWorker {
pub(super) fn worker(
&self,
exec: ExecContext,
tool_runtime: ToolCallRuntime,
) -> CodeModeWorker {
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let stdin = self.stdin.clone();
let tool_call_rx = self.tool_call_rx.clone();
@@ -37,13 +43,20 @@ impl CodeModeProcess {
break;
};
let exec = exec.clone();
let tool_runtime = tool_runtime.clone();
let stdin = stdin.clone();
tokio::spawn(async move {
let response = HostToNodeMessage::Response {
request_id: tool_call.request_id,
id: tool_call.id,
code_mode_result: call_nested_tool(exec, tool_call.name, tool_call.input)
.await,
code_mode_result: call_nested_tool(
exec,
tool_runtime,
tool_call.name,
tool_call.input,
CancellationToken::new(),
)
.await,
};
if let Err(err) = write_message(&stdin, &response).await {
warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}");

View File

@@ -199,6 +199,43 @@ impl ToolOutput for FunctionToolOutput {
}
}
pub struct AbortedToolOutput {
pub message: String,
}
impl ToolOutput for AbortedToolOutput {
fn log_preview(&self) -> String {
telemetry_preview(&self.message)
}
fn success_for_logging(&self) -> bool {
false
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
match payload {
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call_id.to_string(),
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call_id.to_string(),
output: CallToolResult::from_error_text(self.message.clone()),
},
_ => function_tool_response(
call_id,
payload,
vec![FunctionCallOutputContentItem::InputText {
text: self.message.clone(),
}],
None,
),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExecCommandToolOutput {
pub event_call_id: String,
@@ -299,7 +336,7 @@ impl ExecCommandToolOutput {
}
}
fn response_input_to_code_mode_result(response: ResponseInputItem) -> JsonValue {
pub(crate) fn response_input_to_code_mode_result(response: ResponseInputItem) -> JsonValue {
match response {
ResponseInputItem::Message { content, .. } => content_items_to_code_mode_result(
&content

View File

@@ -13,12 +13,12 @@ use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::AbortedToolOutput;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::registry::AnyToolResult;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolCallSource;
use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
#[derive(Clone)]
@@ -52,8 +52,19 @@ impl ToolCallRuntime {
call: ToolCall,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
let future =
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
async move { future.await.map(AnyToolResult::into_response) }.in_current_span()
}
#[instrument(level = "trace", skip_all)]
pub(crate) fn handle_tool_call_with_source(
self,
call: ToolCall,
source: ToolCallSource,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<AnyToolResult, CodexErr>> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
@@ -69,7 +80,7 @@ impl ToolCallRuntime {
aborted = false,
);
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
let handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {
@@ -85,12 +96,12 @@ impl ToolCallRuntime {
};
router
.dispatch_tool_call(
.dispatch_tool_call_with_code_mode_result(
session,
turn,
tracker,
call.clone(),
crate::tools::router::ToolCallSource::Direct,
source,
)
.instrument(dispatch_span.clone())
.await
@@ -113,34 +124,13 @@ impl ToolCallRuntime {
}
impl ToolCallRuntime {
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
match &call.payload {
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
..Default::default()
},
},
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call.call_id.clone(),
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call.call_id.clone(),
output: codex_protocol::mcp::CallToolResult::from_error_text(Self::abort_message(
call, secs,
)),
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
..Default::default()
},
},
fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult {
AnyToolResult {
call_id: call.call_id.clone(),
payload: call.payload.clone(),
result: Box::new(AbortedToolOutput {
message: Self::abort_message(call, secs),
}),
}
}