mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
Reuse tool runtime for code mode worker (#14496)
## Summary - create the turn-scoped `ToolCallRuntime` before starting the code mode worker so the worker reuses the same runtime and router - thread the shared runtime through the code mode service/worker path and use it for nested tool calls - model aborted tool calls as a concrete `ToolOutput` so aborted responses still produce valid tool output shapes ## Testing - `just fmt` - `cargo test -p codex-core` (still running locally)
This commit is contained in:
@@ -5551,11 +5551,6 @@ pub(crate) async fn run_turn(
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let _code_mode_worker = sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(&sess, &turn_context, &turn_diff_tracker)
|
||||
.await;
|
||||
let mut server_model_warning_emitted_for_turn = false;
|
||||
|
||||
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
|
||||
@@ -6161,10 +6156,26 @@ async fn run_sampling_request(
|
||||
turn_context.as_ref(),
|
||||
base_instructions,
|
||||
);
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let _code_mode_worker = sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
)
|
||||
.await;
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let err = match try_run_sampling_request(
|
||||
Arc::clone(&router),
|
||||
tool_runtime.clone(),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
client_session,
|
||||
@@ -6919,7 +6930,7 @@ async fn drain_in_flight(
|
||||
)
|
||||
)]
|
||||
async fn try_run_sampling_request(
|
||||
router: Arc<ToolRouter>,
|
||||
tool_runtime: ToolCallRuntime,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
client_session: &mut ModelClientSession,
|
||||
@@ -6950,13 +6961,6 @@ async fn try_run_sampling_request(
|
||||
.instrument(trace_span!("stream_request"))
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
|
||||
FuturesOrdered::new();
|
||||
let mut needs_follow_up = false;
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::FunctionToolOutput;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
@@ -25,14 +24,9 @@ impl CodeModeExecuteHandler {
|
||||
&self,
|
||||
session: std::sync::Arc<Session>,
|
||||
turn: std::sync::Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
code: String,
|
||||
) -> Result<FunctionToolOutput, FunctionCallError> {
|
||||
let exec = ExecContext {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
};
|
||||
let exec = ExecContext { session, turn };
|
||||
let enabled_tools = build_enabled_tools(&exec).await;
|
||||
let service = &exec.session.services.code_mode_service;
|
||||
let stored_values = service.stored_values().await;
|
||||
@@ -94,7 +88,6 @@ impl ToolHandler for CodeModeExecuteHandler {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
@@ -102,7 +95,7 @@ impl ToolHandler for CodeModeExecuteHandler {
|
||||
|
||||
match payload {
|
||||
ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => {
|
||||
self.execute(session, turn, tracker, input).await
|
||||
self.execute(session, turn, input).await
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||
"{PUBLIC_TOOL_NAME} expects raw JavaScript source text"
|
||||
|
||||
@@ -18,8 +18,8 @@ use crate::tools::ToolRouter;
|
||||
use crate::tools::code_mode_description::augment_tool_spec_for_code_mode;
|
||||
use crate::tools::code_mode_description::code_mode_tool_reference;
|
||||
use crate::tools::context::FunctionToolOutput;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolCallSource;
|
||||
use crate::tools::router::ToolRouterParams;
|
||||
@@ -42,7 +42,6 @@ pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000;
|
||||
pub(super) struct ExecContext {
|
||||
pub(super) session: Arc<Session>,
|
||||
pub(super) turn: Arc<TurnContext>,
|
||||
pub(super) tracker: SharedTurnDiffTracker,
|
||||
}
|
||||
|
||||
pub(crate) use execute_handler::CodeModeExecuteHandler;
|
||||
@@ -270,8 +269,10 @@ async fn build_nested_router(exec: &ExecContext) -> ToolRouter {
|
||||
|
||||
async fn call_nested_tool(
|
||||
exec: ExecContext,
|
||||
tool_runtime: ToolCallRuntime,
|
||||
tool_name: String,
|
||||
input: Option<JsonValue>,
|
||||
cancellation_token: tokio_util::sync::CancellationToken,
|
||||
) -> JsonValue {
|
||||
if tool_name == PUBLIC_TOOL_NAME {
|
||||
return JsonValue::String(format!("{PUBLIC_TOOL_NAME} cannot invoke itself"));
|
||||
@@ -302,14 +303,8 @@ async fn call_nested_tool(
|
||||
tool_namespace: None,
|
||||
payload,
|
||||
};
|
||||
let result = router
|
||||
.dispatch_tool_call_with_code_mode_result(
|
||||
exec.session.clone(),
|
||||
exec.turn.clone(),
|
||||
exec.tracker.clone(),
|
||||
call,
|
||||
ToolCallSource::CodeMode,
|
||||
)
|
||||
let result = tool_runtime
|
||||
.handle_tool_call_with_source(call, ToolCallSource::CodeMode, cancellation_token)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
||||
@@ -9,8 +9,10 @@ use tracing::warn;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::features::Feature;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::js_repl::resolve_compatible_node;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
|
||||
use super::ExecContext;
|
||||
use super::PUBLIC_TOOL_NAME;
|
||||
@@ -65,7 +67,8 @@ impl CodeModeService {
|
||||
&self,
|
||||
session: &Arc<Session>,
|
||||
turn: &Arc<TurnContext>,
|
||||
tracker: &SharedTurnDiffTracker,
|
||||
router: Arc<ToolRouter>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
) -> Option<CodeModeWorker> {
|
||||
if !turn.features.enabled(Feature::CodeMode) {
|
||||
return None;
|
||||
@@ -73,8 +76,9 @@ impl CodeModeService {
|
||||
let exec = ExecContext {
|
||||
session: Arc::clone(session),
|
||||
turn: Arc::clone(turn),
|
||||
tracker: Arc::clone(tracker),
|
||||
};
|
||||
let tool_runtime =
|
||||
ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker);
|
||||
let mut process_slot = match self.ensure_started().await {
|
||||
Ok(process_slot) => process_slot,
|
||||
Err(err) => {
|
||||
@@ -88,7 +92,7 @@ impl CodeModeService {
|
||||
);
|
||||
return None;
|
||||
};
|
||||
Some(process.worker(exec))
|
||||
Some(process.worker(exec, tool_runtime))
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate_session_id(&self) -> i32 {
|
||||
|
||||
@@ -54,7 +54,6 @@ impl ToolHandler for CodeModeWaitHandler {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
@@ -63,11 +62,7 @@ impl ToolHandler for CodeModeWaitHandler {
|
||||
match payload {
|
||||
ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => {
|
||||
let args: ExecWaitArgs = parse_arguments(&arguments)?;
|
||||
let exec = ExecContext {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
};
|
||||
let exec = ExecContext { session, turn };
|
||||
let request_id = exec
|
||||
.session
|
||||
.services
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
use super::ExecContext;
|
||||
@@ -7,6 +8,7 @@ use super::call_nested_tool;
|
||||
use super::process::CodeModeProcess;
|
||||
use super::process::write_message;
|
||||
use super::protocol::HostToNodeMessage;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
pub(crate) struct CodeModeWorker {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
@@ -20,7 +22,11 @@ impl Drop for CodeModeWorker {
|
||||
}
|
||||
|
||||
impl CodeModeProcess {
|
||||
pub(super) fn worker(&self, exec: ExecContext) -> CodeModeWorker {
|
||||
pub(super) fn worker(
|
||||
&self,
|
||||
exec: ExecContext,
|
||||
tool_runtime: ToolCallRuntime,
|
||||
) -> CodeModeWorker {
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
let stdin = self.stdin.clone();
|
||||
let tool_call_rx = self.tool_call_rx.clone();
|
||||
@@ -37,13 +43,20 @@ impl CodeModeProcess {
|
||||
break;
|
||||
};
|
||||
let exec = exec.clone();
|
||||
let tool_runtime = tool_runtime.clone();
|
||||
let stdin = stdin.clone();
|
||||
tokio::spawn(async move {
|
||||
let response = HostToNodeMessage::Response {
|
||||
request_id: tool_call.request_id,
|
||||
id: tool_call.id,
|
||||
code_mode_result: call_nested_tool(exec, tool_call.name, tool_call.input)
|
||||
.await,
|
||||
code_mode_result: call_nested_tool(
|
||||
exec,
|
||||
tool_runtime,
|
||||
tool_call.name,
|
||||
tool_call.input,
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await,
|
||||
};
|
||||
if let Err(err) = write_message(&stdin, &response).await {
|
||||
warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}");
|
||||
|
||||
@@ -199,6 +199,43 @@ impl ToolOutput for FunctionToolOutput {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AbortedToolOutput {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ToolOutput for AbortedToolOutput {
|
||||
fn log_preview(&self) -> String {
|
||||
telemetry_preview(&self.message)
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
match payload {
|
||||
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
|
||||
call_id: call_id.to_string(),
|
||||
status: "completed".to_string(),
|
||||
execution: "client".to_string(),
|
||||
tools: Vec::new(),
|
||||
},
|
||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: CallToolResult::from_error_text(self.message.clone()),
|
||||
},
|
||||
_ => function_tool_response(
|
||||
call_id,
|
||||
payload,
|
||||
vec![FunctionCallOutputContentItem::InputText {
|
||||
text: self.message.clone(),
|
||||
}],
|
||||
None,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ExecCommandToolOutput {
|
||||
pub event_call_id: String,
|
||||
@@ -299,7 +336,7 @@ impl ExecCommandToolOutput {
|
||||
}
|
||||
}
|
||||
|
||||
fn response_input_to_code_mode_result(response: ResponseInputItem) -> JsonValue {
|
||||
pub(crate) fn response_input_to_code_mode_result(response: ResponseInputItem) -> JsonValue {
|
||||
match response {
|
||||
ResponseInputItem::Message { content, .. } => content_items_to_code_mode_result(
|
||||
&content
|
||||
|
||||
@@ -13,12 +13,12 @@ use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::AbortedToolOutput;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::AnyToolResult;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolCallSource;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -52,8 +52,19 @@ impl ToolCallRuntime {
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
let future =
|
||||
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
|
||||
async move { future.await.map(AnyToolResult::into_response) }.in_current_span()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub(crate) fn handle_tool_call_with_source(
|
||||
self,
|
||||
call: ToolCall,
|
||||
source: ToolCallSource,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<AnyToolResult, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
@@ -69,7 +80,7 @@ impl ToolCallRuntime {
|
||||
aborted = false,
|
||||
);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
let handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
@@ -85,12 +96,12 @@ impl ToolCallRuntime {
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(
|
||||
.dispatch_tool_call_with_code_mode_result(
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call.clone(),
|
||||
crate::tools::router::ToolCallSource::Direct,
|
||||
source,
|
||||
)
|
||||
.instrument(dispatch_span.clone())
|
||||
.await
|
||||
@@ -113,34 +124,13 @@ impl ToolCallRuntime {
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
status: "completed".to_string(),
|
||||
execution: "client".to_string(),
|
||||
tools: Vec::new(),
|
||||
},
|
||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: codex_protocol::mcp::CallToolResult::from_error_text(Self::abort_message(
|
||||
call, secs,
|
||||
)),
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult {
|
||||
AnyToolResult {
|
||||
call_id: call.call_id.clone(),
|
||||
payload: call.payload.clone(),
|
||||
result: Box::new(AbortedToolOutput {
|
||||
message: Self::abort_message(call, secs),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user