Compare commits

...

14 Commits

Author SHA1 Message Date
Albin Cassirer
a08e9c3d9e Hacked it so that we can run benchmarks on parallel toolcalls impact.
Note that you have to first build a release binary and reference that with --codex-bin for this to work.
2025-10-04 18:18:57 -07:00
jimmyfraiture
5273aaba56 Clippy and FMT 2025-10-04 18:03:49 +01:00
jimmyfraiture
58567a7d7c Tool for testing 2025-10-04 18:00:45 +01:00
jimmyfraiture
bcacf3338c Comments 2 2025-10-04 17:42:39 +01:00
jimmyfraiture
a5aee84e22 Comments 1 2025-10-04 17:37:14 +01:00
jif-oai
50e2f12c9a Merge branch 'main' into jif/tools-2 2025-10-04 17:11:41 +01:00
jimmyfraiture
e98d80c7cf Fix merge 2025-10-03 18:14:30 +01:00
jimmyfraiture
c2c66f9c83 Merge remote-tracking branch 'origin/main' into jif/tools-2
# Conflicts:
#	codex-rs/core/src/model_family.rs
#	codex-rs/core/src/tools/spec.rs
2025-10-03 18:11:39 +01:00
jimmyfraiture
56470e9e09 Fix more tests 2025-10-03 17:03:56 +01:00
jimmyfraiture
1352b524cb v5 2025-10-03 16:36:56 +01:00
jimmyfraiture
bdd22c8cb4 v4 2025-10-03 16:02:49 +01:00
jimmyfraiture
4232edb818 v3 2025-10-03 15:36:46 +01:00
jimmyfraiture
d866bd167e v2 2025-10-03 15:24:34 +01:00
jimmyfraiture
2fff283df2 v1 2025-10-03 15:00:05 +01:00
33 changed files with 1580 additions and 247 deletions

View File

@@ -1351,6 +1351,7 @@ async fn derive_config_from_params(
include_view_image_tool: None,
show_raw_agent_reasoning: None,
tools_web_search_request: None,
parallel_tool_calls: None,
};
let cli_overrides = cli_overrides

View File

@@ -228,7 +228,7 @@ impl ModelClient {
input: &input_with_instructions,
tools: &tools_json,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: prompt.parallel_tool_calls,
reasoning,
store: azure_workaround,
stream: true,

View File

@@ -31,6 +31,9 @@ pub struct Prompt {
/// external MCP servers.
pub(crate) tools: Vec<ToolSpec>,
/// Whether parallel tool calls are permitted for this prompt.
pub(crate) parallel_tool_calls: bool,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,
@@ -182,6 +185,17 @@ pub(crate) mod tools {
Freeform(FreeformTool),
}
impl ToolSpec {
pub(crate) fn name(&self) -> &str {
match self {
ToolSpec::Function(tool) => tool.name.as_str(),
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(tool) => tool.name.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FreeformTool {
pub(crate) name: String,
@@ -327,7 +341,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
@@ -368,7 +382,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
@@ -404,7 +418,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,

View File

@@ -100,7 +100,9 @@ use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::format_exec_output_str;
use crate::tools::parallel::ToolCallRuntime;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
@@ -258,6 +260,7 @@ pub(crate) struct TurnContext {
pub(crate) tools_config: ToolsConfig,
pub(crate) is_review_mode: bool,
pub(crate) final_output_json_schema: Option<Value>,
pub(crate) parallel_tool_calls_override: Option<bool>,
}
impl TurnContext {
@@ -455,6 +458,7 @@ impl Session {
cwd,
is_review_mode: false,
final_output_json_schema: None,
parallel_tool_calls_override: config.force_parallel_tool_calls,
};
let services = SessionServices {
mcp_connection_manager,
@@ -807,7 +811,7 @@ impl Session {
async fn on_exec_command_begin(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
exec_command_context: ExecCommandContext,
) {
let ExecCommandContext {
@@ -823,7 +827,10 @@ impl Session {
user_explicitly_approved_this_action,
changes,
}) => {
turn_diff_tracker.on_patch_begin(&changes);
{
let mut tracker = turn_diff_tracker.lock().await;
tracker.on_patch_begin(&changes);
}
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
@@ -850,7 +857,7 @@ impl Session {
async fn on_exec_command_end(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str,
call_id: &str,
output: &ExecToolCallOutput,
@@ -898,7 +905,10 @@ impl Session {
// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff();
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
@@ -915,7 +925,7 @@ impl Session {
/// Returns the output of the exec tool call.
pub(crate) async fn run_exec_with_events(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
prepared: PreparedExec,
approval_policy: AskForApproval,
) -> Result<ExecToolCallOutput, ExecError> {
@@ -924,7 +934,7 @@ impl Session {
let sub_id = context.sub_id.clone();
let call_id = context.call_id.clone();
self.on_exec_command_begin(turn_diff_tracker, context.clone())
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
.await;
let result = self
@@ -1191,6 +1201,7 @@ async fn submission_loop(
cwd: new_cwd.clone(),
is_review_mode: false,
final_output_json_schema: None,
parallel_tool_calls_override: config.force_parallel_tool_calls,
};
// Install the new persistent context for subsequent tasks/turns.
@@ -1251,6 +1262,7 @@ async fn submission_loop(
if let Some(model_info) = get_model_info(&model_family) {
per_turn_config.model_context_window = Some(model_info.context_window);
}
let per_turn_parallel_override = per_turn_config.force_parallel_tool_calls;
let otel_event_manager =
turn_context.client.get_otel_event_manager().with_model(
@@ -1291,6 +1303,7 @@ async fn submission_loop(
cwd,
is_review_mode: false,
final_output_json_schema,
parallel_tool_calls_override: per_turn_parallel_override,
};
// if the environment context has changed, record it in the conversation history
@@ -1539,6 +1552,7 @@ async fn spawn_review_thread(
per_turn_config.model_family.slug.as_str(),
);
let per_turn_parallel_override = per_turn_config.force_parallel_tool_calls;
let per_turn_config = Arc::new(per_turn_config);
let client = ModelClient::new(
per_turn_config.clone(),
@@ -1561,6 +1575,7 @@ async fn spawn_review_thread(
cwd: parent_turn_context.cwd.clone(),
is_review_mode: true,
final_output_json_schema: None,
parallel_tool_calls_override: per_turn_parallel_override,
};
// Seed the child task with the review prompt as the initial user message.
@@ -1633,7 +1648,7 @@ pub(crate) async fn run_task(
let mut last_agent_message: Option<String> = None;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new();
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut auto_compact_recently_attempted = false;
loop {
@@ -1681,9 +1696,9 @@ pub(crate) async fn run_task(
})
.collect();
match run_turn(
&sess,
turn_context.as_ref(),
&mut turn_diff_tracker,
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.clone(),
turn_input,
)
@@ -1906,18 +1921,29 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
}
async fn run_turn(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(mcp_tools),
));
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
let parallel_tool_calls = turn_context
.parallel_tool_calls_override
.unwrap_or(model_supports_parallel);
let prompt = Prompt {
input,
tools: router.specs().to_vec(),
tools: router.specs(),
parallel_tool_calls,
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
};
@@ -1925,10 +1951,10 @@ async fn run_turn(
let mut retries = 0;
loop {
match try_run_turn(
&router,
sess,
turn_context,
turn_diff_tracker,
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
&sub_id,
&prompt,
)
@@ -1984,9 +2010,9 @@ async fn run_turn(
/// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn.
#[derive(Debug)]
struct ProcessedResponseItem {
item: ResponseItem,
response: Option<ResponseInputItem>,
pub(crate) struct ProcessedResponseItem {
pub(crate) item: ResponseItem,
pub(crate) response: Option<ResponseInputItem>,
}
#[derive(Debug)]
@@ -1996,10 +2022,10 @@ struct TurnRunResult {
}
async fn try_run_turn(
router: &crate::tools::ToolRouter,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
router: Arc<ToolRouter>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<TurnRunResult> {
@@ -2070,24 +2096,34 @@ async fn try_run_turn(
let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut output = Vec::new();
let mut tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.to_string(),
);
loop {
// Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let Some(event) = event else {
// Channel closed without yielding a final Completed event or explicit error.
// Treat as a disconnected stream so the caller can retry.
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
let event = match event {
Some(event) => event,
None => {
tool_runtime.abort_all();
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
}
};
let event = match event {
Ok(ev) => ev,
Err(e) => {
tool_runtime.abort_all();
// Propagate the underlying stream error to the caller (run_turn), which
// will apply the configured `stream_max_retries` policy.
return Err(e);
@@ -2097,16 +2133,66 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item(
router,
sess,
turn_context,
turn_diff_tracker,
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
Ok(Some(call)) => {
let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
let index = output.len();
output.push(ProcessedResponseItem {
item,
response: None,
});
tool_runtime
.handle_tool_call(call, index, output.as_mut_slice())
.await?;
}
Ok(None) => {
let response = handle_non_tool_response_item(
Arc::clone(&sess),
Arc::clone(&turn_context),
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::RespondToModel(message)) => {
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: message,
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::Fatal(message)) => {
return Err(CodexErr::Fatal(message));
}
}
}
ResponseEvent::WebSearchCallBegin { call_id } => {
let _ = sess
@@ -2126,10 +2212,15 @@ async fn try_run_turn(
response_id: _,
token_usage,
} => {
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await;
let unified_diff = turn_diff_tracker.get_unified_diff();
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
@@ -2188,88 +2279,40 @@ async fn try_run_turn(
}
}
async fn handle_response_item(
router: &crate::tools::ToolRouter,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
async fn handle_non_tool_response_item(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: &str,
item: ResponseItem,
) -> CodexResult<Option<ResponseInputItem>> {
debug!(?item, "Output item");
match ToolRouter::build_tool_call(sess, item.clone()) {
Ok(Some(call)) => {
let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
match router
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
.await
{
Ok(response) => Ok(Some(response)),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
match &item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => {
let msgs = match &item {
ResponseItem::Message { .. } if turn_context.is_review_mode => {
trace!("suppressing assistant Message in review mode");
Vec::new()
}
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
};
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
}
}
Ok(None) => {
match &item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => {
let msgs = match &item {
ResponseItem::Message { .. } if turn_context.is_review_mode => {
trace!("suppressing assistant Message in review mode");
Vec::new()
}
_ => map_response_item_to_event_messages(
&item,
sess.show_raw_agent_reasoning(),
),
};
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
}
}
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::CustomToolCallOutput { .. } => {
debug!("unexpected tool output from stream");
}
_ => {}
}
Ok(None)
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
debug!("unexpected tool output from stream");
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
}))
}
Err(FunctionCallError::RespondToModel(msg)) => {
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg,
success: None,
},
}))
}
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
_ => {}
}
Ok(None)
}
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
@@ -2698,6 +2741,7 @@ mod tests {
tools_config,
is_review_mode: false,
final_output_json_schema: None,
parallel_tool_calls_override: config.force_parallel_tool_calls,
};
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
@@ -2771,6 +2815,7 @@ mod tests {
tools_config,
is_review_mode: false,
final_output_json_schema: None,
parallel_tool_calls_override: config.force_parallel_tool_calls,
});
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
@@ -2901,13 +2946,10 @@ mod tests {
#[tokio::test]
async fn fatal_tool_error_stops_turn_and_reports_error() {
let (session, turn_context, _rx) = make_session_and_context_with_rx();
let session_ref = session.as_ref();
let turn_context_ref = turn_context.as_ref();
let router = ToolRouter::from_config(
&turn_context_ref.tools_config,
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
&turn_context.tools_config,
Some(session.services.mcp_connection_manager.list_all_tools()),
);
let mut tracker = TurnDiffTracker::new();
let item = ResponseItem::CustomToolCall {
id: None,
status: None,
@@ -2916,22 +2958,26 @@ mod tests {
input: "{}".to_string(),
};
let err = handle_response_item(
&router,
session_ref,
turn_context_ref,
&mut tracker,
"sub-id",
item,
)
.await
.expect_err("expected fatal error");
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
.expect("build tool call")
.expect("tool call present");
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let err = router
.dispatch_tool_call(
Arc::clone(&session),
Arc::clone(&turn_context),
tracker,
"sub-id".to_string(),
call,
)
.await
.expect_err("expected fatal error");
match err {
CodexErr::Fatal(message) => {
FunctionCallError::Fatal(message) => {
assert_eq!(message, "tool shell invoked with incompatible payload");
}
other => panic!("expected CodexErr::Fatal, got {other:?}"),
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
}
}
@@ -3045,9 +3091,11 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker;
use std::collections::HashMap;
let (session, mut turn_context) = make_session_and_context();
let (session, mut turn_context_raw) = make_session_and_context();
// Ensure policy is NOT OnRequest so the early rejection path triggers
turn_context.approval_policy = AskForApproval::OnFailure;
turn_context_raw.approval_policy = AskForApproval::OnFailure;
let session = Arc::new(session);
let mut turn_context = Arc::new(turn_context_raw);
let params = ExecParams {
command: if cfg!(windows) {
@@ -3075,7 +3123,7 @@ mod tests {
..params.clone()
};
let mut turn_diff_tracker = TurnDiffTracker::new();
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let tool_name = "shell";
let sub_id = "test-sub".to_string();
@@ -3084,9 +3132,9 @@ mod tests {
let resp = handle_container_exec_with_params(
tool_name,
params,
&session,
&turn_context,
&mut turn_diff_tracker,
Arc::clone(&session),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id,
call_id,
)
@@ -3105,14 +3153,16 @@ mod tests {
// Now retry the same command WITHOUT escalated permissions; should succeed.
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
Arc::get_mut(&mut turn_context)
.expect("unique turn context Arc")
.sandbox_policy = SandboxPolicy::DangerFullAccess;
let resp2 = handle_container_exec_with_params(
tool_name,
params2,
&session,
&turn_context,
&mut turn_diff_tracker,
Arc::clone(&session),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
"test-sub".to_string(),
"test-call-2".to_string(),
)

View File

@@ -194,6 +194,9 @@ pub struct Config {
pub tools_web_search_request: bool,
/// Override for whether tool calls run in parallel. When `None`, falls back to model defaults.
pub force_parallel_tool_calls: Option<bool>,
pub use_experimental_streamable_shell_tool: bool,
/// If set to `true`, used only the experimental unified exec tool.
@@ -877,6 +880,7 @@ pub struct ConfigOverrides {
pub include_view_image_tool: Option<bool>,
pub show_raw_agent_reasoning: Option<bool>,
pub tools_web_search_request: Option<bool>,
pub parallel_tool_calls: Option<bool>,
}
impl Config {
@@ -905,6 +909,7 @@ impl Config {
include_view_image_tool,
show_raw_agent_reasoning,
tools_web_search_request: override_tools_web_search_request,
parallel_tool_calls,
} = overrides;
let active_profile_name = config_profile_key
@@ -1084,6 +1089,7 @@ impl Config {
include_plan_tool: include_plan_tool.unwrap_or(false),
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
tools_web_search_request,
force_parallel_tool_calls: parallel_tool_calls,
use_experimental_streamable_shell_tool: cfg
.experimental_use_exec_command_tool
.unwrap_or(false),
@@ -1880,6 +1886,7 @@ model_verbosity = "high"
include_plan_tool: false,
include_apply_patch_tool: false,
tools_web_search_request: false,
force_parallel_tool_calls: None,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
use_experimental_use_rmcp_client: false,
@@ -1941,6 +1948,7 @@ model_verbosity = "high"
include_plan_tool: false,
include_apply_patch_tool: false,
tools_web_search_request: false,
force_parallel_tool_calls: None,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
use_experimental_use_rmcp_client: false,
@@ -2017,6 +2025,7 @@ model_verbosity = "high"
include_plan_tool: false,
include_apply_patch_tool: false,
tools_web_search_request: false,
force_parallel_tool_calls: None,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
use_experimental_use_rmcp_client: false,
@@ -2079,6 +2088,7 @@ model_verbosity = "high"
include_plan_tool: false,
include_apply_patch_tool: false,
tools_web_search_request: false,
force_parallel_tool_calls: None,
use_experimental_streamable_shell_tool: false,
use_experimental_unified_exec_tool: false,
use_experimental_use_rmcp_client: false,

View File

@@ -35,6 +35,10 @@ pub struct ModelFamily {
// See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool,
/// Whether this model supports parallel tool calls when using the
/// Responses API.
pub supports_parallel_tool_calls: bool,
/// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
@@ -58,6 +62,7 @@ macro_rules! model_family {
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),
@@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("gpt-3.5") {
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("test-gpt-5-codex") {
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
experimental_supported_tools: vec![
"read_file".to_string(),
"test_sync_tool".to_string()
],
supports_parallel_tool_calls: true,
)
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
model_family!(
slug, slug,
@@ -110,12 +127,14 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
experimental_supported_tools: vec!["read_file".to_string()],
supports_parallel_tool_calls: true,
)
} else if slug.starts_with("gpt-5") {
model_family!(
slug, "gpt-5",
supports_reasoning_summaries: true,
needs_special_apply_patch_instructions: true,
supports_parallel_tool_calls: true,
)
} else {
None
@@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),

View File

@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct ToolInvocation<'a> {
pub session: &'a Session,
pub turn: &'a TurnContext,
pub tracker: &'a mut TurnDiffTracker,
pub sub_id: &'a str,
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
#[derive(Clone)]
pub struct ToolInvocation {
pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub tracker: SharedTurnDiffTracker,
pub sub_id: String,
pub call_id: String,
pub tool_name: String,
pub payload: ToolPayload,

View File

@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::FreeformTool;
use crate::client_common::tools::FreeformToolFormat;
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
tool_name,

View File

@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
ToolKind::Mcp
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
let arguments_str = raw_arguments;
let response = handle_mcp_tool_call(
session,
sub_id,
session.as_ref(),
&sub_id,
call_id.clone(),
server,
tool,

View File

@@ -4,6 +4,7 @@ mod mcp;
mod plan;
mod read_file;
mod shell;
mod test_sync;
mod unified_exec;
mod view_image;
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
pub use plan::PlanHandler;
pub use read_file::ReadFileHandler;
pub use shell::ShellHandler;
pub use test_sync::TestSyncHandler;
pub use unified_exec::UnifiedExecHandler;
pub use view_image::ViewImageHandler;

View File

@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
}
};
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
let content =
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
Ok(ToolOutput::Function {
content,

View File

@@ -7,6 +7,7 @@ use serde::Deserialize;
use tokio::fs::File;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
@@ -19,6 +20,11 @@ pub struct ReadFileHandler;
const MAX_LINE_LENGTH: usize = 500;
#[path = "read_file_test_support.rs"]
mod test_support;
use test_support::test_delay_for_path;
fn default_offset() -> usize {
1
}
@@ -42,10 +48,7 @@ impl ToolHandler for ReadFileHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
@@ -88,6 +91,10 @@ impl ToolHandler for ReadFileHandler {
));
}
if let Some(delay) = test_delay_for_path(&path) {
sleep(delay).await;
}
let collected = read_file_slice(&path, offset, limit).await?;
Ok(ToolOutput::Function {
content: collected.join("\n"),

View File

@@ -0,0 +1,32 @@
use std::path::Path;
use std::time::Duration;
pub(crate) fn test_delay_for_path(path: &Path) -> Option<Duration> {
let Ok(config) = std::env::var("CODEX_TEST_READ_FILE_DELAYS") else {
return None;
};
if config.is_empty() {
return None;
}
let target = path.to_string_lossy();
for entry in config.split(';') {
if entry.is_empty() {
continue;
}
let Some((candidate, delay_ms)) = entry.split_once('=') else {
continue;
};
if candidate != target {
continue;
}
if let Ok(ms) = delay_ms.parse::<u64>()
&& ms > 0
{
return Some(Duration::from_millis(ms));
}
}
None
}

View File

@@ -1,5 +1,6 @@
use async_trait::async_trait;
use codex_protocol::models::ShellToolCallParams;
use std::sync::Arc;
use crate::codex::TurnContext;
use crate::exec::ExecParams;
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
"failed to parse function arguments: {e:?}"
))
})?;
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
})
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::sync::Barrier;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct TestSyncHandler;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
struct BarrierState {
barrier: Arc<Barrier>,
participants: usize,
}
#[derive(Debug, Deserialize)]
struct BarrierArgs {
id: String,
participants: usize,
#[serde(default = "default_timeout_ms")]
timeout_ms: u64,
}
#[derive(Debug, Deserialize)]
struct TestSyncArgs {
#[serde(default)]
sleep_before_ms: Option<u64>,
#[serde(default)]
sleep_after_ms: Option<u64>,
#[serde(default)]
barrier: Option<BarrierArgs>,
}
fn default_timeout_ms() -> u64 {
DEFAULT_TIMEOUT_MS
}
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
}
#[async_trait]
impl ToolHandler for TestSyncHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"test_sync_tool handler received unsupported payload".to_string(),
));
}
};
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
if let Some(delay) = args.sleep_before_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
if let Some(barrier) = args.barrier {
wait_on_barrier(barrier).await?;
}
if let Some(delay) = args.sleep_after_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
Ok(ToolOutput::Function {
content: "ok".to_string(),
success: Some(true),
})
}
}
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
if args.participants == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier participants must be greater than zero".to_string(),
));
}
if args.timeout_ms == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier timeout must be greater than zero".to_string(),
));
}
let barrier_id = args.id.clone();
let barrier = {
let mut map = barrier_map().lock().await;
match map.entry(barrier_id.clone()) {
Entry::Occupied(entry) => {
let state = entry.get();
if state.participants != args.participants {
let existing = state.participants;
return Err(FunctionCallError::RespondToModel(format!(
"barrier {barrier_id} already registered with {existing} participants"
)));
}
state.barrier.clone()
}
Entry::Vacant(entry) => {
let barrier = Arc::new(Barrier::new(args.participants));
entry.insert(BarrierState {
barrier: barrier.clone(),
participants: args.participants,
});
barrier
}
}
};
let timeout = Duration::from_millis(args.timeout_ms);
let wait_result = tokio::time::timeout(timeout, barrier.wait())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
})?;
if wait_result.is_leader() {
let mut map = barrier_map().lock().await;
if let Some(state) = map.get(&barrier_id)
&& Arc::ptr_eq(&state.barrier, &barrier)
{
map.remove(&barrier_id);
}
}
Ok(())
}

View File

@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session, payload, ..
} = invocation;

View File

@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,

View File

@@ -1,5 +1,6 @@
pub mod context;
pub(crate) mod handlers;
pub mod parallel;
pub mod registry;
pub mod router;
pub mod spec;
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ApplyPatchCommandContext;
use crate::tools::context::ExecCommandContext;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::tools::context::SharedTurnDiffTracker;
use codex_apply_patch::MaybeApplyPatchVerified;
use codex_apply_patch::maybe_parse_apply_patch_verified;
use codex_protocol::protocol::AskForApproval;
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
use codex_utils_string::take_last_bytes_at_char_boundary;
pub use router::ToolRouter;
use serde::Serialize;
use std::sync::Arc;
use tracing::trace;
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
pub(crate) async fn handle_container_exec_with_params(
tool_name: &str,
params: ExecParams,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String,
call_id: String,
) -> Result<String, FunctionCallError> {
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
// check if this was a patch, and apply it if so
let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) {
MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
match apply_patch::apply_patch(
sess.as_ref(),
turn_context.as_ref(),
&sub_id,
&call_id,
changes,
)
.await
{
InternalApplyPatchInvocation::Output(item) => return item,
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
Some(apply_patch_exec)
@@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params(
let output_result = sess
.run_exec_with_events(
turn_diff_tracker,
turn_diff_tracker.clone(),
prepared_exec,
turn_context.approval_policy,
)

View File

@@ -0,0 +1,137 @@
use std::sync::Arc;
use tokio::task::JoinHandle;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::ResponseInputItem;
use crate::codex::ProcessedResponseItem;
struct PendingToolCall {
index: usize,
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
}
pub(crate) struct ToolCallRuntime {
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
pending_calls: Vec<PendingToolCall>,
}
impl ToolCallRuntime {
pub(crate) fn new(
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
) -> Self {
Self {
router,
session,
turn_context,
tracker,
sub_id,
pending_calls: Vec::new(),
}
}
pub(crate) async fn handle_tool_call(
&mut self,
call: ToolCall,
output_index: usize,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
if supports_parallel {
self.spawn_parallel(call, output_index);
} else {
self.resolve_pending(output).await?;
let response = self.dispatch_serial(call).await?;
let slot = output.get_mut(output_index).ok_or_else(|| {
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
})?;
slot.response = Some(response);
}
Ok(())
}
pub(crate) fn abort_all(&mut self) {
while let Some(pending) = self.pending_calls.pop() {
pending.handle.abort();
}
}
pub(crate) async fn resolve_pending(
&mut self,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
match handle.await {
Ok(Ok(response)) => {
if let Some(slot) = output.get_mut(index) {
slot.response = Some(response);
}
}
Ok(Err(FunctionCallError::Fatal(message))) => {
self.abort_all();
return Err(CodexErr::Fatal(message));
}
Ok(Err(other)) => {
self.abort_all();
return Err(CodexErr::Fatal(other.to_string()));
}
Err(join_err) => {
self.abort_all();
return Err(CodexErr::Fatal(format!(
"tool task failed to join: {join_err}"
)));
}
}
}
Ok(())
}
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let sub_id = self.sub_id.clone();
let handle = tokio::spawn(async move {
router
.dispatch_tool_call(session, turn, tracker, sub_id, call)
.await
});
self.pending_calls.push(PendingToolCall { index, handle });
}
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
match self
.router
.dispatch_tool_call(
Arc::clone(&self.session),
Arc::clone(&self.turn_context),
Arc::clone(&self.tracker),
self.sub_id.clone(),
call,
)
.await
{
Ok(response) => Ok(response),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => Err(CodexErr::Fatal(other.to_string())),
}
}
}

View File

@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
)
}
async fn handle(&self, invocation: ToolInvocation<'_>)
-> Result<ToolOutput, FunctionCallError>;
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
}
pub struct ToolRegistry {
@@ -57,9 +56,9 @@ impl ToolRegistry {
// }
// }
pub async fn dispatch<'a>(
pub async fn dispatch(
&self,
invocation: ToolInvocation<'a>,
invocation: ToolInvocation,
) -> Result<ResponseInputItem, FunctionCallError> {
let tool_name = invocation.tool_name.clone();
let call_id_owned = invocation.call_id.clone();
@@ -137,9 +136,24 @@ impl ToolRegistry {
}
}
#[derive(Debug, Clone)]
pub struct ConfiguredToolSpec {
pub spec: ToolSpec,
pub supports_parallel_tool_calls: bool,
}
impl ConfiguredToolSpec {
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
Self {
spec,
supports_parallel_tool_calls,
}
}
}
pub struct ToolRegistryBuilder {
handlers: HashMap<String, Arc<dyn ToolHandler>>,
specs: Vec<ToolSpec>,
specs: Vec<ConfiguredToolSpec>,
}
impl ToolRegistryBuilder {
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
}
pub fn push_spec(&mut self, spec: ToolSpec) {
self.specs.push(spec);
self.push_spec_with_parallel_support(spec, false);
}
pub fn push_spec_with_parallel_support(
&mut self,
spec: ToolSpec,
supports_parallel_tool_calls: bool,
) {
self.specs
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
}
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
// }
// }
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
let registry = ToolRegistry::new(self.handlers);
(self.specs, registry)
}

View File

@@ -1,15 +1,17 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::ToolSpec;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ConfiguredToolSpec;
use crate::tools::registry::ToolRegistry;
use crate::tools::spec::ToolsConfig;
use crate::tools::spec::build_specs;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::models::LocalShellAction;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
@@ -24,7 +26,7 @@ pub struct ToolCall {
pub struct ToolRouter {
registry: ToolRegistry,
specs: Vec<ToolSpec>,
specs: Vec<ConfiguredToolSpec>,
}
impl ToolRouter {
@@ -34,11 +36,22 @@ impl ToolRouter {
) -> Self {
let builder = build_specs(config, mcp_tools);
let (specs, registry) = builder.build();
Self { registry, specs }
}
pub fn specs(&self) -> &[ToolSpec] {
&self.specs
pub fn specs(&self) -> Vec<ToolSpec> {
self.specs
.iter()
.map(|config| config.spec.clone())
.collect()
}
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
self.specs
.iter()
.filter(|config| config.supports_parallel_tool_calls)
.any(|config| config.spec.name() == tool_name)
}
pub fn build_tool_call(
@@ -118,10 +131,10 @@ impl ToolRouter {
pub async fn dispatch_tool_call(
&self,
session: &Session,
turn: &TurnContext,
tracker: &mut TurnDiffTracker,
sub_id: &str,
session: Arc<Session>,
turn: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
call: ToolCall,
) -> Result<ResponseInputItem, FunctionCallError> {
let ToolCall {

View File

@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
})
}
fn create_test_sync_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"sleep_before_ms".to_string(),
JsonSchema::Number {
description: Some("Optional delay in milliseconds before any other action".to_string()),
},
);
properties.insert(
"sleep_after_ms".to_string(),
JsonSchema::Number {
description: Some(
"Optional delay in milliseconds after completing the barrier".to_string(),
),
},
);
let mut barrier_properties = BTreeMap::new();
barrier_properties.insert(
"id".to_string(),
JsonSchema::String {
description: Some(
"Identifier shared by concurrent calls that should rendezvous".to_string(),
),
},
);
barrier_properties.insert(
"participants".to_string(),
JsonSchema::Number {
description: Some(
"Number of tool calls that must arrive before the barrier opens".to_string(),
),
},
);
barrier_properties.insert(
"timeout_ms".to_string(),
JsonSchema::Number {
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
},
);
properties.insert(
"barrier".to_string(),
JsonSchema::Object {
properties: barrier_properties,
required: Some(vec!["id".to_string(), "participants".to_string()]),
additional_properties: Some(false.into()),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "test_sync_tool".to_string(),
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: None,
additional_properties: Some(false.into()),
},
})
}
fn create_read_file_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
@@ -507,6 +569,7 @@ pub(crate) fn build_specs(
use crate::tools::handlers::PlanHandler;
use crate::tools::handlers::ReadFileHandler;
use crate::tools::handlers::ShellHandler;
use crate::tools::handlers::TestSyncHandler;
use crate::tools::handlers::UnifiedExecHandler;
use crate::tools::handlers::ViewImageHandler;
use std::sync::Arc;
@@ -573,16 +636,26 @@ pub(crate) fn build_specs(
.any(|tool| tool == "read_file")
{
let read_file_handler = Arc::new(ReadFileHandler);
builder.push_spec(create_read_file_tool());
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
builder.register_handler("read_file", read_file_handler);
}
if config
.experimental_supported_tools
.iter()
.any(|tool| tool == "test_sync_tool")
{
let test_sync_handler = Arc::new(TestSyncHandler);
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
builder.register_handler("test_sync_tool", test_sync_handler);
}
if config.web_search_request {
builder.push_spec(ToolSpec::WebSearch {});
}
if config.include_view_image_tool {
builder.push_spec(create_view_image_tool());
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
builder.register_handler("view_image", view_image_handler);
}
@@ -610,20 +683,25 @@ pub(crate) fn build_specs(
mod tests {
use crate::client_common::tools::FreeformTool;
use crate::model_family::find_family_for_model;
use crate::tools::registry::ConfiguredToolSpec;
use mcp_types::ToolInputSchema;
use pretty_assertions::assert_eq;
use super::*;
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
fn tool_name(tool: &ToolSpec) -> &str {
match tool {
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
}
}
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
let tool_names = tools
.iter()
.map(|tool| match tool {
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
})
.map(|tool| tool_name(&tool.spec))
.collect::<Vec<_>>();
assert_eq!(
@@ -639,6 +717,16 @@ mod tests {
}
}
fn find_tool<'a>(
tools: &'a [ConfiguredToolSpec],
expected_name: &str,
) -> &'a ConfiguredToolSpec {
tools
.iter()
.find(|tool| tool_name(&tool.spec) == expected_name)
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
}
#[test]
fn test_build_specs() {
let model_family = find_family_for_model("codex-mini-latest")
@@ -698,6 +786,52 @@ mod tests {
assert_eq_tool_names(&tools, &["unified_exec", "read_file"]);
}
#[test]
fn test_parallel_support_flags() {
let model_family = find_family_for_model("gpt-5-codex")
.expect("codex-mini-latest should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: true,
});
let (tools, _) = build_specs(&config, None).build();
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
}
#[test]
fn test_test_model_family_includes_sync_tool() {
let model_family = find_family_for_model("test-gpt-5-codex")
.expect("test-gpt-5-codex should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: false,
});
let (tools, _) = build_specs(&config, None).build();
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
);
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "read_file")
);
}
#[test]
fn test_build_specs_mcp_tools() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
@@ -760,7 +894,7 @@ mod tests {
);
assert_eq!(
tools[3],
tools[3].spec,
ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {
@@ -929,7 +1063,7 @@ mod tests {
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/search".to_string(),
parameters: JsonSchema::Object {
@@ -995,7 +1129,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/paginate".to_string(),
parameters: JsonSchema::Object {
@@ -1059,7 +1193,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/tags".to_string(),
parameters: JsonSchema::Object {
@@ -1126,7 +1260,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/value".to_string(),
parameters: JsonSchema::Object {
@@ -1231,7 +1365,7 @@ mod tests {
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {

View File

@@ -3,14 +3,14 @@ use std::time::Duration;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event_with_timeout;
use serde_json::json;
use wiremock::matchers::body_string_contains;
/// Integration test: spawn a longrunning shell tool via a mocked Responses SSE
/// function call, then interrupt the session and expect TurnAborted.
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
"timeout_ms": 60_000
})
.to_string();
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
let body = sse(vec![
ev_function_call("call_sleep", "shell", &args),
ev_completed("done"),
]);
let server = start_mock_server().await;
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await;
mount_sse_once(&server, body).await;
let codex = test_codex().build(&server).await.unwrap().codex;

View File

@@ -171,7 +171,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
],
"tools": tool_calls,
"tool_choice": "auto",
"parallel_tool_calls": false,
"parallel_tool_calls": true,
"reasoning": {
"summary": "auto"
},
@@ -305,7 +305,7 @@ SUMMARY_ONLY_CONTEXT"
],
"tools": tool_calls,
"tool_choice": "auto",
"parallel_tool_calls": false,
"parallel_tool_calls": true,
"reasoning": {
"summary": "auto"
},
@@ -390,7 +390,7 @@ SUMMARY_ONLY_CONTEXT"
],
"tools": tool_calls,
"tool_choice": "auto",
"parallel_tool_calls": false,
"parallel_tool_calls": true,
"reasoning": {
"summary": "auto"
},
@@ -475,7 +475,7 @@ SUMMARY_ONLY_CONTEXT"
],
"tools": tool_calls,
"tool_choice": "auto",
"parallel_tool_calls": false,
"parallel_tool_calls": true,
"reasoning": {
"summary": "auto"
},

View File

@@ -23,6 +23,7 @@ mod seatbelt;
mod stream_error_allows_next_turn;
mod stream_no_completed;
mod tool_harness;
mod tool_parallelism;
mod tools;
mod unified_exec;
mod user_notification;

View File

@@ -0,0 +1,182 @@
#![cfg(not(target_os = "windows"))]
#![allow(clippy::unwrap_used)]
use std::time::Duration;
use std::time::Instant;
use codex_core::model_family::find_family_for_model;
use codex_core::protocol::AskForApproval;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use serde_json::json;
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
let session_model = test.session_configured.model.clone();
test.codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: prompt.into(),
}],
final_output_json_schema: None,
cwd: test.cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: session_model,
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
loop {
let event = test.codex.next_event().await?;
if matches!(event.msg, EventMsg::TaskComplete(_)) {
break;
}
}
Ok(())
}
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
let start = Instant::now();
run_turn(test, prompt).await?;
Ok(start.elapsed())
}
#[allow(clippy::expect_used)]
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
let mut builder = test_codex().with_config(|config| {
config.model = "test-gpt-5-codex".to_string();
config.model_family =
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
});
builder.build(server).await
}
fn assert_parallel_duration(actual: Duration) {
assert!(
actual < Duration::from_millis(500),
"expected parallel execution to finish quickly, got {actual:?}"
);
}
fn assert_serial_duration(actual: Duration) {
assert!(
actual >= Duration::from_millis(500),
"expected serial execution to take longer, got {actual:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = build_codex_with_test_tool(&server).await?;
let parallel_args = json!({
"sleep_after_ms": 300,
"barrier": {
"id": "parallel-test-sync",
"participants": 2,
"timeout_ms": 1_000,
}
})
.to_string();
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "test_sync_tool", &parallel_args),
ev_function_call("call-2", "test_sync_tool", &parallel_args),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
assert_parallel_duration(duration);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = test_codex().build(&server).await?;
let shell_args = json!({
"command": ["/bin/sh", "-c", "sleep 0.3"],
"timeout_ms": 1_000,
});
let args_one = serde_json::to_string(&shell_args)?;
let args_two = serde_json::to_string(&shell_args)?;
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "shell", &args_one),
ev_function_call("call-2", "shell", &args_two),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "run shell twice").await?;
assert_serial_duration(duration);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = build_codex_with_test_tool(&server).await?;
let sync_args = json!({
"sleep_after_ms": 300
})
.to_string();
let shell_args = serde_json::to_string(&json!({
"command": ["/bin/sh", "-c", "sleep 0.3"],
"timeout_ms": 1_000,
}))?;
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "test_sync_tool", &sync_args),
ev_function_call("call-2", "shell", &shell_args),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "mix tools").await?;
assert_serial_duration(duration);
Ok(())
}

View File

@@ -71,6 +71,10 @@ pub struct Cli {
#[arg(long = "include-plan-tool", default_value_t = false)]
pub include_plan_tool: bool,
/// Override parallel tool call behaviour. Defaults to model capability when unset.
#[arg(long = "parallel-tool-calls", value_enum)]
pub parallel_tool_calls: Option<ParallelToolCalls>,
/// Specifies file where the last message from the agent should be written.
#[arg(long = "output-last-message", short = 'o', value_name = "FILE")]
pub last_message_file: Option<PathBuf>,
@@ -111,3 +115,11 @@ pub enum Color {
#[default]
Auto,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[value(rename_all = "kebab-case")]
pub enum ParallelToolCalls {
Auto,
On,
Off,
}

View File

@@ -42,6 +42,7 @@ use tracing_subscriber::EnvFilter;
use tracing_subscriber::prelude::*;
use crate::cli::Command as ExecCommand;
use crate::cli::ParallelToolCalls;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
use codex_core::default_client::set_default_originator;
@@ -69,6 +70,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
prompt,
output_schema: output_schema_path,
include_plan_tool,
parallel_tool_calls,
config_overrides,
} = cli;
@@ -164,6 +166,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
};
// Load configuration and determine approval policy
let parallel_tool_calls_override = match parallel_tool_calls {
Some(ParallelToolCalls::On) => Some(true),
Some(ParallelToolCalls::Off) => Some(false),
_ => None,
};
let overrides = ConfigOverrides {
model,
review_model: None,
@@ -181,6 +189,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
include_view_image_tool: None,
show_raw_agent_reasoning: oss.then_some(true),
tools_web_search_request: None,
parallel_tool_calls: parallel_tool_calls_override,
};
// Parse `-c` overrides.
let cli_kv_overrides = match config_overrides.parse_overrides() {

View File

@@ -164,6 +164,7 @@ impl CodexToolCallParam {
include_view_image_tool: None,
show_raw_agent_reasoning: None,
tools_web_search_request: None,
parallel_tool_calls: None,
};
let cli_overrides = cli_overrides

View File

@@ -0,0 +1,526 @@
#!/usr/bin/env python3
"""Benchmark codex exec runs with and without parallel tool calls."""
from __future__ import annotations
import argparse
import json
import math
import os
import shlex
import shutil
import statistics
import subprocess
import sys
import tempfile
import threading
import time
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Iterable, Sequence
DEFAULT_MODEL = "gpt-5-codex"
@dataclass
class ModeConfig:
label: str
model: str
extra_args: tuple[str, ...]
env_pairs: tuple[tuple[str, str], ...]
parallel_flag: str | None
enabled: bool = True
@dataclass
class RunResult:
index: int
duration_s: float
returncode: int
stdout_path: Path
stderr_path: Path
metadata_path: Path
@dataclass
class ModeResult:
config: ModeConfig
outputs_dir: Path
runs: list[RunResult]
@property
def durations(self) -> list[float]:
return [run.duration_s for run in self.runs]
class BenchmarkError(RuntimeError):
pass
@dataclass
class ProgressTracker:
total_runs: int
completed: int = 0
lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def advance(self, mode_label: str, run_index: int) -> None:
with self.lock:
self.completed += 1
percentage = (self.completed / self.total_runs) * 100 if self.total_runs else 100.0
print(
f"[{self.completed:>3}/{self.total_runs:<3} | {percentage:5.1f}%] "
f"mode={mode_label} run={run_index:03d}",
flush=True,
)
def parse_key_value_pairs(pairs: Iterable[str]) -> tuple[tuple[str, str], ...]:
parsed: list[tuple[str, str]] = []
for pair in pairs:
if "=" not in pair:
raise BenchmarkError(f"Expected KEY=VALUE format, got: {pair}")
key, value = pair.split("=", 1)
parsed.append((key, value))
return tuple(parsed)
def parse_args(argv: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run codex exec repeatedly for parallel vs serial tool call models.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("prompt", help="Prompt passed to codex exec. Use quotes to preserve spaces.")
parser.add_argument(
"-n",
"--runs",
type=int,
default=5,
help="Number of executions per mode (parallel and serial).",
)
parser.add_argument(
"--codex-bin",
default="codex",
help="Path to codex binary. If relative, resolved against the working directory.",
)
parser.add_argument(
"--workdir",
default=str(Path(__file__).resolve().parents[1]),
help="Working directory passed to codex exec commands.",
)
parser.add_argument(
"--model",
default=DEFAULT_MODEL,
help="Model slug shared by both modes when explicit overrides are not provided.",
)
parser.add_argument(
"--parallel-model",
default=None,
help="Model slug used only for parallel runs; defaults to --model when omitted.",
)
parser.add_argument(
"--serial-model",
default=None,
help="Model slug used only for serial runs; defaults to --model when omitted.",
)
parser.add_argument(
"--parallel-extra",
default="",
help="Additional CLI args passed only to parallel runs (quoted string).",
)
parser.add_argument(
"--serial-extra",
default="",
help="Additional CLI args passed only to serial runs (quoted string).",
)
parser.add_argument(
"--parallel-env",
action="append",
default=[],
help="Environment overrides KEY=VALUE applied to parallel runs (repeatable).",
)
parser.add_argument(
"--serial-env",
action="append",
default=[],
help="Environment overrides KEY=VALUE applied to serial runs (repeatable).",
)
parser.add_argument(
"--output-root",
default=str(Path(tempfile.gettempdir()) / "codex_parallel_benchmark"),
help="Directory under which experiment outputs and plots are stored.",
)
parser.add_argument(
"--label",
default=datetime.now().strftime("%Y%m%d-%H%M%S"),
help="Label used to create a unique run directory under output-root.",
)
parser.add_argument(
"--json",
action="store_true",
help="Print summary JSON in addition to the human-readable report.",
)
parser.add_argument(
"--skip-parallel",
action="store_true",
help="Skip runs flagged as parallel (only serial runs execute).",
)
parser.add_argument(
"--skip-serial",
action="store_true",
help="Skip runs flagged as serial (only parallel runs execute).",
)
parser.add_argument(
"--parallel-runs",
action="store_true",
help="Execute all codex exec runs concurrently instead of sequentially.",
)
parser.add_argument(
"--max-workers",
type=int,
default=None,
help="Maximum number of in-flight codex exec runs when --parallel-runs is set.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the commands that would run without executing them.",
)
return parser.parse_args(argv[1:])
def ensure_binary(path: str) -> str:
candidate = Path(path)
if candidate.is_file():
return str(candidate.resolve())
resolved = shutil.which(path)
if not resolved:
raise BenchmarkError(f"Unable to locate codex binary: {path}")
return resolved
def expand_args(arg_string: str) -> tuple[str, ...]:
if not arg_string.strip():
return tuple()
return tuple(shlex.split(arg_string))
def build_mode_configs(args: argparse.Namespace) -> list[ModeConfig]:
parallel_model = args.parallel_model or args.model
serial_model = args.serial_model or args.model
modes = [
ModeConfig(
label="parallel_on",
model=parallel_model,
extra_args=expand_args(args.parallel_extra),
env_pairs=parse_key_value_pairs(args.parallel_env),
parallel_flag="on",
enabled=not args.skip_parallel,
),
ModeConfig(
label="parallel_off",
model=serial_model,
extra_args=expand_args(args.serial_extra),
env_pairs=parse_key_value_pairs(args.serial_env),
parallel_flag="off",
enabled=not args.skip_serial,
),
]
enabled_modes = [mode for mode in modes if mode.enabled]
if not enabled_modes:
raise BenchmarkError("All modes skipped; enable at least one mode to run the benchmark.")
return enabled_modes
def run_command(
codex_bin: str,
workdir: Path,
prompt: str,
mode: ModeConfig,
run_index: int,
output_dir: Path,
dry_run: bool,
) -> RunResult:
workdir.mkdir(parents=True, exist_ok=True)
mode_dir = output_dir / mode.label
run_dir = mode_dir / f"run_{run_index:03d}"
run_dir.mkdir(parents=True, exist_ok=True)
command = [codex_bin, "exec", "--model", mode.model]
if mode.parallel_flag:
command.extend(["--parallel-tool-calls", mode.parallel_flag])
command.extend((*mode.extra_args, prompt))
env = os.environ.copy()
for key, value in mode.env_pairs:
env[key] = value
stdout_path = run_dir / "stdout.txt"
stderr_path = run_dir / "stderr.txt"
metadata_path = run_dir / "metadata.json"
start_dt = datetime.now()
if dry_run:
duration_s = float("nan")
returncode = 0
stdout = ""
stderr = ""
else:
start = time.perf_counter()
result = subprocess.run(
command,
cwd=str(workdir),
capture_output=True,
text=True,
env=env,
check=False,
)
duration_s = time.perf_counter() - start
returncode = result.returncode
stdout = result.stdout
stderr = result.stderr
stdout_path.write_text(stdout)
stderr_path.write_text(stderr)
metadata = {
"command": command,
"env_overrides": {key: value for key, value in mode.env_pairs},
"model": mode.model,
"label": mode.label,
"prompt": prompt,
"run_index": run_index,
"duration_seconds": duration_s,
"returncode": returncode,
"started_at": start_dt.isoformat(),
}
metadata_path.write_text(json.dumps(metadata, indent=2))
if dry_run:
command_str = " ".join(shlex.quote(element) for element in command)
print(f"[DRY-RUN] {command_str}")
return RunResult(
index=run_index,
duration_s=duration_s,
returncode=returncode,
stdout_path=stdout_path,
stderr_path=stderr_path,
metadata_path=metadata_path,
)
def execute_runs(
*,
codex_bin: str,
workdir: Path,
prompt: str,
modes: Sequence[ModeConfig],
runs_per_mode: int,
output_dir: Path,
dry_run: bool,
progress: ProgressTracker,
parallel_runs: bool,
max_workers: int | None,
) -> list[ModeResult]:
if not modes:
return []
if parallel_runs:
total_runs = runs_per_mode * len(modes)
worker_count = max_workers or total_runs
if worker_count < 1:
raise BenchmarkError("max workers must be a positive integer")
runs_by_mode: dict[str, list[RunResult]] = {mode.label: [] for mode in modes}
future_to_mode: dict[Future[RunResult], tuple[ModeConfig, int]] = {}
with ThreadPoolExecutor(max_workers=worker_count) as executor:
for mode in modes:
for idx in range(1, runs_per_mode + 1):
future = executor.submit(
run_command,
codex_bin,
workdir,
prompt,
mode,
idx,
output_dir,
dry_run,
)
future_to_mode[future] = (mode, idx)
for future in as_completed(future_to_mode):
mode, _ = future_to_mode[future]
result = future.result()
runs_by_mode[mode.label].append(result)
progress.advance(mode.label, result.index)
mode_results: list[ModeResult] = []
for mode in modes:
runs = sorted(runs_by_mode[mode.label], key=lambda run: run.index)
mode_results.append(ModeResult(config=mode, outputs_dir=output_dir / mode.label, runs=runs))
return mode_results
mode_results = []
for mode in modes:
runs: list[RunResult] = []
for idx in range(1, runs_per_mode + 1):
result = run_command(
codex_bin=codex_bin,
workdir=workdir,
prompt=prompt,
mode=mode,
run_index=idx,
output_dir=output_dir,
dry_run=dry_run,
)
runs.append(result)
progress.advance(mode.label, idx)
mode_results.append(ModeResult(config=mode, outputs_dir=output_dir / mode.label, runs=runs))
return mode_results
def compute_stats(values: Sequence[float]) -> dict[str, float | int]:
clean_values = [value for value in values if math.isfinite(value)]
if not clean_values:
return {"count": 0}
stats: dict[str, float | int] = {
"count": len(clean_values),
"min": min(clean_values),
"max": max(clean_values),
"mean": statistics.mean(clean_values),
"median": statistics.median(clean_values),
}
if len(clean_values) > 1:
stats["stdev"] = statistics.stdev(clean_values)
return stats
def summarize(mode_results: list[ModeResult]) -> dict[str, dict[str, float | int]]:
summary: dict[str, dict[str, float | int]] = {}
for result in mode_results:
summary[result.config.label] = compute_stats(result.durations)
return summary
def write_summary(
output_dir: Path,
summary: dict[str, dict[str, float | int]],
mode_results: list[ModeResult],
) -> Path:
payload = {
"output_dir": str(output_dir),
"summary": summary,
"runs": {
result.config.label: [
{
"index": run.index,
"duration_seconds": run.duration_s,
"returncode": run.returncode,
"stdout_path": str(run.stdout_path),
"stderr_path": str(run.stderr_path),
}
for run in result.runs
]
for result in mode_results
},
}
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(payload, indent=2))
return summary_path
def attempt_plot(output_dir: Path, mode_results: list[ModeResult]) -> Path | None:
has_finite = any(
math.isfinite(duration)
for result in mode_results
for duration in result.durations
)
if not has_finite:
return None
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except Exception as exc: # pragma: no cover - plotting is optional
print(f"[WARN] Unable to create plot ({exc}); continue without chart.")
return None
fig, ax = plt.subplots(figsize=(8, 4))
labels = [result.config.label for result in mode_results]
data = [
[value for value in result.durations if math.isfinite(value)]
for result in mode_results
]
ax.boxplot(data, labels=labels, showmeans=True)
ax.set_ylabel("Duration (seconds)")
ax.set_title("codex exec durations by mode")
ax.grid(True, axis="y", linestyle="--", alpha=0.4)
plot_path = output_dir / "duration_boxplot.png"
fig.tight_layout()
fig.savefig(plot_path)
plt.close(fig)
return plot_path
def format_report(summary: dict[str, dict[str, float | int]], output_dir: Path, plot_path: Path | None) -> str:
lines = ["Benchmark summary:"]
for label, stats in summary.items():
lines.append(f" {label}:")
for key in sorted(stats):
value = stats[key]
if isinstance(value, float):
lines.append(f" {key}: {value:.4f}")
else:
lines.append(f" {key}: {value}")
lines.append(f"Outputs stored in: {output_dir}")
if plot_path:
lines.append(f"Plot saved to: {plot_path}")
return "\n".join(lines)
def main(argv: Sequence[str]) -> int:
args = parse_args(argv)
try:
codex_bin = ensure_binary(args.codex_bin)
workdir = Path(args.workdir).resolve()
output_root = Path(args.output_root).resolve()
run_dir = output_root / args.label
run_dir.mkdir(parents=True, exist_ok=True)
modes = build_mode_configs(args)
total_runs = len(modes) * args.runs
if args.max_workers is not None and args.max_workers < 1:
raise BenchmarkError("--max-workers must be a positive integer")
progress = ProgressTracker(total_runs=total_runs)
parallel_runs = args.parallel_runs
mode_results = execute_runs(
codex_bin=codex_bin,
workdir=workdir,
prompt=args.prompt,
modes=modes,
runs_per_mode=args.runs,
output_dir=run_dir,
dry_run=args.dry_run,
progress=progress,
parallel_runs=parallel_runs,
max_workers=args.max_workers,
)
summary = summarize(mode_results)
summary_path = write_summary(run_dir, summary, mode_results)
plot_path = attempt_plot(run_dir, mode_results)
report = format_report(summary, run_dir, plot_path)
print(report)
if args.json:
payload = {
"summary": summary,
"output_dir": str(run_dir),
"plot_path": str(plot_path) if plot_path else None,
"summary_path": str(summary_path),
}
print(json.dumps(payload, indent=2))
return 0
except BenchmarkError as error:
print(f"[ERROR] {error}", file=sys.stderr)
return 1
if __name__ == "__main__":
raise SystemExit(main(sys.argv))

View File

@@ -146,6 +146,7 @@ pub async fn run_main(
include_view_image_tool: None,
show_raw_agent_reasoning: cli.oss.then_some(true),
tools_web_search_request: cli.web_search.then_some(true),
parallel_tool_calls: None,
};
let raw_overrides = cli.config_overrides.raw_overrides.clone();
let overrides_cli = codex_common::CliConfigOverrides { raw_overrides };