mirror of
https://github.com/openai/codex.git
synced 2026-03-03 05:03:20 +00:00
Compare commits
44 Commits
fix/notify
...
jif/parall
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
601f5715c6 | ||
|
|
39cfc8465c | ||
|
|
e0850b71f8 | ||
|
|
0d05d03b94 | ||
|
|
a300bb3cd7 | ||
|
|
cec903f257 | ||
|
|
7384625722 | ||
|
|
6508ce79c0 | ||
|
|
60de27031e | ||
|
|
ad29d2bf39 | ||
|
|
a6d3e2a334 | ||
|
|
1a58229438 | ||
|
|
4c335d5e42 | ||
|
|
495f94cbd5 | ||
|
|
0a24fed30b | ||
|
|
9921a061ed | ||
|
|
0a75a69ae6 | ||
|
|
c3d2c83d0b | ||
|
|
e23cbcaaf6 | ||
|
|
0718b00b80 | ||
|
|
d3687a7a65 | ||
|
|
5305f247aa | ||
|
|
fcbf2b6c0a | ||
|
|
46d53a2430 | ||
|
|
ee5f5e85cd | ||
|
|
94a66e7d8b | ||
|
|
9e04b908e1 | ||
|
|
150765dbe3 | ||
|
|
dccf499850 | ||
|
|
3def127178 | ||
|
|
5c00e1596a | ||
|
|
9c194dc0f9 | ||
|
|
4533dceafa | ||
|
|
43c0abb31e | ||
|
|
8c09db17c3 | ||
|
|
1d87628d41 | ||
|
|
4656160e31 | ||
|
|
2dd226891a | ||
|
|
ed45f85209 | ||
|
|
5b74f10a7b | ||
|
|
7b6d8b60c9 | ||
|
|
caab5a19ee | ||
|
|
a29380cdff | ||
|
|
805de19381 |
@@ -515,6 +515,8 @@ pub struct Tools {
|
||||
pub web_search: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub view_image: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_read_only: Option<bool>,
|
||||
}
|
||||
|
||||
/// MCP representation of a [`codex_core::config_types::SandboxWorkspaceWrite`].
|
||||
|
||||
@@ -99,6 +99,7 @@ async fn get_config_toml_parses_all_fields() {
|
||||
tools: Some(Tools {
|
||||
web_search: Some(false),
|
||||
view_image: Some(true),
|
||||
parallel_read_only: Some(false),
|
||||
}),
|
||||
profile: Some("test".to_string()),
|
||||
profiles: HashMap::from([(
|
||||
|
||||
@@ -120,6 +120,8 @@ impl ModelClient {
|
||||
})
|
||||
}
|
||||
|
||||
// Parallel read-only scheduling is controlled internally; no provider hint needed.
|
||||
|
||||
/// Dispatches to either the Responses or Chat implementation depending on
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
@@ -228,7 +230,6 @@ impl ModelClient {
|
||||
input: &input_with_instructions,
|
||||
tools: &tools_json,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
@@ -1051,15 +1052,11 @@ mod tests {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
@@ -1114,15 +1111,11 @@ mod tests {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
@@ -1150,15 +1143,11 @@ mod tests {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
@@ -1257,15 +1246,11 @@ mod tests {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
@@ -149,7 +149,6 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
@@ -327,7 +326,6 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -368,7 +366,6 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -404,7 +401,6 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -416,4 +412,6 @@ mod tests {
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert!(v.get("text").is_none());
|
||||
}
|
||||
|
||||
// parallel_tool_calls flag removed: scheduling is internal.
|
||||
}
|
||||
|
||||
@@ -101,6 +101,8 @@ use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::format_exec_output_str;
|
||||
use crate::tools::executor::ProcessedResponseItem;
|
||||
use crate::tools::executor::ToolCallExecutor;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -235,7 +237,7 @@ use crate::state::SessionState;
|
||||
/// A session has at most 1 running task at a time, and can be interrupted by user input.
|
||||
pub(crate) struct Session {
|
||||
conversation_id: ConversationId,
|
||||
tx_event: Sender<Event>,
|
||||
pub(crate) tx_event: Sender<Event>,
|
||||
state: Mutex<SessionState>,
|
||||
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
|
||||
pub(crate) services: SessionServices,
|
||||
@@ -446,6 +448,7 @@ impl Session {
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
}),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -1088,7 +1091,7 @@ impl Session {
|
||||
&self.services.user_shell
|
||||
}
|
||||
|
||||
fn show_raw_agent_reasoning(&self) -> bool {
|
||||
pub(crate) fn show_raw_agent_reasoning(&self) -> bool {
|
||||
self.services.show_raw_agent_reasoning
|
||||
}
|
||||
}
|
||||
@@ -1178,6 +1181,7 @@ async fn submission_loop(
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
});
|
||||
|
||||
let new_turn_context = TurnContext {
|
||||
@@ -1282,6 +1286,7 @@ async fn submission_loop(
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config
|
||||
.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
}),
|
||||
user_instructions: turn_context.user_instructions.clone(),
|
||||
base_instructions: turn_context.base_instructions.clone(),
|
||||
@@ -1513,6 +1518,7 @@ async fn spawn_review_thread(
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
});
|
||||
|
||||
let base_instructions = REVIEW_PROMPT.to_string();
|
||||
@@ -1681,8 +1687,8 @@ pub(crate) async fn run_task(
|
||||
})
|
||||
.collect();
|
||||
match run_turn(
|
||||
&sess,
|
||||
turn_context.as_ref(),
|
||||
sess.clone(),
|
||||
turn_context.clone(),
|
||||
&mut turn_diff_tracker,
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
@@ -1906,14 +1912,17 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(mcp_tools),
|
||||
));
|
||||
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
@@ -1924,10 +1933,12 @@ async fn run_turn(
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let tool_executor =
|
||||
ToolCallExecutor::new(router.clone(), sess.clone(), turn_context.clone());
|
||||
match try_run_turn(
|
||||
&router,
|
||||
sess,
|
||||
turn_context,
|
||||
tool_executor,
|
||||
sess.clone(),
|
||||
turn_context.clone(),
|
||||
turn_diff_tracker,
|
||||
&sub_id,
|
||||
&prompt,
|
||||
@@ -1979,31 +1990,22 @@ async fn run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
/// When the model is prompted, it returns a stream of events. Some of these
|
||||
/// events map to a `ResponseItem`. A `ResponseItem` may need to be
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TurnRunResult {
|
||||
processed_items: Vec<ProcessedResponseItem>,
|
||||
total_token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
async fn try_run_turn(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
mut tool_executor: ToolCallExecutor,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
// call_ids that are part of this response.
|
||||
let sess_ref = sess.as_ref();
|
||||
let turn_context_ref = turn_context.as_ref();
|
||||
|
||||
let completed_call_ids = prompt
|
||||
.input
|
||||
.iter()
|
||||
@@ -2018,9 +2020,6 @@ async fn try_run_turn(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// call_ids that were pending but are not part of this response.
|
||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||
// and then the user sent a follow-up message.
|
||||
let missing_calls = {
|
||||
prompt
|
||||
.input
|
||||
@@ -2059,22 +2058,20 @@ async fn try_run_turn(
|
||||
};
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
cwd: turn_context_ref.cwd.clone(),
|
||||
approval_policy: turn_context_ref.approval_policy,
|
||||
sandbox_policy: turn_context_ref.sandbox_policy.clone(),
|
||||
model: turn_context_ref.client.get_model(),
|
||||
effort: turn_context_ref.client.get_reasoning_effort(),
|
||||
summary: turn_context_ref.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
sess_ref.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut stream = turn_context_ref.client.clone().stream(&prompt).await?;
|
||||
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
tool_executor.drain_ready()?;
|
||||
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
@@ -2097,19 +2094,12 @@ async fn try_run_turn(
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(
|
||||
router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
tool_executor
|
||||
.handle_output_item(item, turn_diff_tracker, sub_id)
|
||||
.await?;
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
let _ = sess
|
||||
let _ = sess_ref
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
@@ -2118,15 +2108,16 @@ async fn try_run_turn(
|
||||
.await;
|
||||
}
|
||||
ResponseEvent::RateLimits(snapshot) => {
|
||||
// Update internal state with latest rate limits, but defer sending until
|
||||
// token usage is available to avoid duplicate TokenCount events.
|
||||
sess.update_rate_limits(sub_id, snapshot).await;
|
||||
sess_ref.update_rate_limits(sub_id, snapshot).await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
tool_executor.flush().await?;
|
||||
|
||||
sess_ref
|
||||
.update_token_usage_info(sub_id, turn_context_ref, token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
@@ -2136,11 +2127,11 @@ async fn try_run_turn(
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess_ref.send_event(event).await;
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
processed_items: output,
|
||||
processed_items: tool_executor.take_processed_items()?,
|
||||
total_token_usage: token_usage.clone(),
|
||||
};
|
||||
|
||||
@@ -2154,7 +2145,7 @@ async fn try_run_turn(
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess_ref.send_event(event).await;
|
||||
} else {
|
||||
trace!("suppressing OutputTextDelta in review mode");
|
||||
}
|
||||
@@ -2164,114 +2155,30 @@ async fn try_run_turn(
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess_ref.send_event(event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryPartAdded => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess_ref.send_event(event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningContentDelta(delta) => {
|
||||
if sess.show_raw_agent_reasoning() {
|
||||
if sess_ref.show_raw_agent_reasoning() {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningRawContentDelta(
|
||||
AgentReasoningRawContentDeltaEvent { delta },
|
||||
),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess_ref.send_event(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
|
||||
match ToolRouter::build_tool_call(sess, item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
match router
|
||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(Some(response)),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
&item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg,
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
responses.iter().rev().find_map(|item| {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
@@ -2686,6 +2593,7 @@ mod tests {
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
});
|
||||
let turn_context = TurnContext {
|
||||
client,
|
||||
@@ -2759,6 +2667,7 @@ mod tests {
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: config.enable_parallel_read_only_tools,
|
||||
});
|
||||
let turn_context = Arc::new(TurnContext {
|
||||
client,
|
||||
@@ -2935,6 +2844,86 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
router: &ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
|
||||
match router.build_tool_call(sess, item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
match router
|
||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(Some(response)),
|
||||
Err(err) => Err(match err {
|
||||
FunctionCallError::Fatal(message) => CodexErr::Fatal(message),
|
||||
other => CodexErr::Fatal(other.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
&item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(err) => Err(match err {
|
||||
FunctionCallError::Fatal(message) => CodexErr::Fatal(message),
|
||||
other => CodexErr::Fatal(other.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_rollout(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
|
||||
@@ -209,6 +209,9 @@ pub struct Config {
|
||||
|
||||
/// OTEL configuration (exporter type, endpoint, headers, etc.).
|
||||
pub otel: crate::config_types::OtelConfig,
|
||||
|
||||
/// Enable read-only tools to run in parallel.
|
||||
pub enable_parallel_read_only_tools: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -773,6 +776,9 @@ pub struct ToolsToml {
|
||||
/// Enable the `view_image` tool that lets the agent attach local images.
|
||||
#[serde(default)]
|
||||
pub view_image: Option<bool>,
|
||||
|
||||
#[serde(default)]
|
||||
pub parallel_read_only: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ToolsToml> for Tools {
|
||||
@@ -780,6 +786,7 @@ impl From<ToolsToml> for Tools {
|
||||
Self {
|
||||
web_search: tools_toml.web_search,
|
||||
view_image: tools_toml.view_image,
|
||||
parallel_read_only: tools_toml.parallel_read_only,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -976,6 +983,12 @@ impl Config {
|
||||
.or(cfg.tools.as_ref().and_then(|t| t.view_image))
|
||||
.unwrap_or(true);
|
||||
|
||||
let enable_parallel_read_only_tools = cfg
|
||||
.tools
|
||||
.as_ref()
|
||||
.and_then(|t| t.parallel_read_only)
|
||||
.unwrap_or(false);
|
||||
|
||||
let model = model
|
||||
.or(config_profile.model)
|
||||
.or(cfg.model)
|
||||
@@ -1090,6 +1103,7 @@ impl Config {
|
||||
.unwrap_or(false),
|
||||
use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
enable_parallel_read_only_tools,
|
||||
active_profile: active_profile_name,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
@@ -1677,9 +1691,8 @@ model = "gpt-5-codex"
|
||||
cwd: TempDir,
|
||||
codex_home: TempDir,
|
||||
cfg: ConfigToml,
|
||||
model_provider_map: HashMap<String, ModelProviderInfo>,
|
||||
openai_provider: ModelProviderInfo,
|
||||
openai_chat_completions_provider: ModelProviderInfo,
|
||||
model_provider_map: HashMap<String, ModelProviderInfo>,
|
||||
}
|
||||
|
||||
impl PrecedenceTestFixture {
|
||||
@@ -1752,20 +1765,16 @@ model_verbosity = "high"
|
||||
base_url: Some("https://api.openai.com/v1".to_string()),
|
||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||
wire_api: crate::WireApi::Chat,
|
||||
env_key_instructions: None,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(4),
|
||||
stream_max_retries: Some(10),
|
||||
stream_idle_timeout_ms: Some(300_000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
let model_provider_map = {
|
||||
let mut model_provider_map = built_in_model_providers();
|
||||
model_provider_map.insert(
|
||||
"openai-chat-completions".to_string(),
|
||||
openai_chat_completions_provider.clone(),
|
||||
openai_chat_completions_provider,
|
||||
);
|
||||
model_provider_map
|
||||
};
|
||||
@@ -1779,9 +1788,8 @@ model_verbosity = "high"
|
||||
cwd: cwd_temp_dir,
|
||||
codex_home: codex_home_temp_dir,
|
||||
cfg,
|
||||
model_provider_map,
|
||||
openai_provider,
|
||||
openai_chat_completions_provider,
|
||||
model_provider_map,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1849,6 +1857,7 @@ model_verbosity = "high"
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
enable_parallel_read_only_tools: false,
|
||||
active_profile: Some("o3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
@@ -1881,7 +1890,7 @@ model_verbosity = "high"
|
||||
model_max_output_tokens: Some(4_096),
|
||||
model_auto_compact_token_limit: None,
|
||||
model_provider_id: "openai-chat-completions".to_string(),
|
||||
model_provider: fixture.openai_chat_completions_provider.clone(),
|
||||
model_provider: gpt3_profile_config.model_provider.clone(),
|
||||
approval_policy: AskForApproval::UnlessTrusted,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
@@ -1889,7 +1898,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
model_providers: gpt3_profile_config.model_providers.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
codex_home: fixture.codex_home(),
|
||||
@@ -1910,6 +1919,7 @@ model_verbosity = "high"
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
enable_parallel_read_only_tools: false,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
@@ -1965,7 +1975,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
model_providers: zdr_profile_config.model_providers.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
codex_home: fixture.codex_home(),
|
||||
@@ -1986,6 +1996,7 @@ model_verbosity = "high"
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
enable_parallel_read_only_tools: false,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
@@ -2027,7 +2038,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
model_providers: gpt5_profile_config.model_providers.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
codex_home: fixture.codex_home(),
|
||||
@@ -2048,6 +2059,7 @@ model_verbosity = "high"
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
enable_parallel_read_only_tools: false,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
|
||||
@@ -41,6 +41,9 @@ pub struct ModelFamily {
|
||||
|
||||
// Instructions to use for querying the model
|
||||
pub base_instructions: String,
|
||||
|
||||
/// If the model supports parallel tool calls for read-only tools.
|
||||
pub supports_parallel_read_only_tools: bool,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
@@ -57,6 +60,7 @@ macro_rules! model_family {
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
supports_parallel_read_only_tools: false,
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
@@ -105,12 +109,14 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
supports_parallel_read_only_tools: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
slug, "gpt-5",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
supports_parallel_read_only_tools: true,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -127,5 +133,6 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
supports_parallel_read_only_tools: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ pub enum WireApi {
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
|
||||
pub struct ModelProviderInfo {
|
||||
/// Friendly display name.
|
||||
pub name: String,
|
||||
@@ -86,6 +86,10 @@ pub struct ModelProviderInfo {
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
|
||||
/// Does the model support parallel tool calls.
|
||||
#[serde(default)]
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
@@ -297,6 +301,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
supports_parallel_tool_calls: true,
|
||||
},
|
||||
),
|
||||
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||
@@ -341,6 +346,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,16 +376,8 @@ base_url = "http://localhost:11434/v1"
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: Some("http://localhost:11434/v1".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -398,17 +396,11 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -428,19 +420,14 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
name: "Example".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: Some("API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Header".to_string() => "example-value".to_string(),
|
||||
}),
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -453,16 +440,8 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -485,16 +464,8 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
let named_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(named_provider.is_azure_responses_endpoint());
|
||||
|
||||
|
||||
323
codex-rs/core/src/tools/executor.rs
Normal file
323
codex-rs/core/src/tools/executor.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use std::panic::AssertUnwindSafe;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::event_mapping::map_response_item_to_event_messages;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ProcessedResponseItem {
|
||||
pub item: ResponseItem,
|
||||
pub response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolCallExecutor {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
allow_parallel_read_only: bool,
|
||||
read_only_tasks: JoinSet<(usize, Result<ResponseInputItem, FunctionCallError>)>,
|
||||
processed_items: Vec<ProcessedResponseItem>,
|
||||
}
|
||||
|
||||
impl ToolCallExecutor {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) -> Self {
|
||||
let allow_parallel_read_only =
|
||||
router.has_read_only_tools() && turn_context.tools_config.enable_parallel_read_only;
|
||||
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
allow_parallel_read_only,
|
||||
read_only_tasks: JoinSet::new(),
|
||||
processed_items: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn drain_ready(&mut self) -> CodexResult<()> {
|
||||
while let Some(res) = self.read_only_tasks.try_join_next() {
|
||||
match res {
|
||||
Ok((idx, response)) => self.assign_result(idx, response)?,
|
||||
Err(join_err) => {
|
||||
warn!(
|
||||
?join_err,
|
||||
"parallel read-only task aborted before completion"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn flush(&mut self) -> CodexResult<()> {
|
||||
while let Some(res) = self.read_only_tasks.join_next().await {
|
||||
match res {
|
||||
Ok((idx, response)) => self.assign_result(idx, response)?,
|
||||
Err(join_err) => {
|
||||
warn!(
|
||||
?join_err,
|
||||
"parallel read-only task aborted before completion"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_output_item(
|
||||
&mut self,
|
||||
item: ResponseItem,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
) -> CodexResult<()> {
|
||||
match self
|
||||
.router
|
||||
.build_tool_call(self.session.as_ref(), item.clone())
|
||||
{
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
|
||||
let idx = self.processed_items.len();
|
||||
self.processed_items.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: None,
|
||||
});
|
||||
|
||||
if self.allow_parallel_read_only && call.capabilities.read_only {
|
||||
self.schedule_parallel_task(idx, call, sub_id);
|
||||
} else {
|
||||
self.flush().await?;
|
||||
let response = self
|
||||
.router
|
||||
.dispatch_tool_call(
|
||||
self.session.as_ref(),
|
||||
self.turn_context.as_ref(),
|
||||
turn_diff_tracker,
|
||||
sub_id,
|
||||
call,
|
||||
)
|
||||
.await
|
||||
.map_err(Self::map_dispatch_error)?;
|
||||
if let Some(slot) = self.processed_items.get_mut(idx) {
|
||||
slot.response = Some(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
self.emit_response_item_events(sub_id, &item).await?;
|
||||
self.processed_items.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: None,
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
||||
if msg == "LocalShellCall without call_id or id" {
|
||||
self.turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", &msg);
|
||||
error!(msg);
|
||||
}
|
||||
|
||||
self.flush().await?;
|
||||
self.processed_items.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg,
|
||||
success: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
self.turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
self.flush().await?;
|
||||
self.processed_items.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
self.flush().await?;
|
||||
return Err(Self::map_dispatch_error(err));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assign_result(
|
||||
&mut self,
|
||||
idx: usize,
|
||||
response: Result<ResponseInputItem, FunctionCallError>,
|
||||
) -> CodexResult<()> {
|
||||
match response {
|
||||
Ok(response) => {
|
||||
if let Some(slot) = self.processed_items.get_mut(idx) {
|
||||
slot.response = Some(response);
|
||||
} else {
|
||||
warn!(idx, "parallel tool completion missing output slot");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => Err(Self::map_dispatch_error(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn take_processed_items(mut self) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
self.drain_ready()?;
|
||||
Ok(self.processed_items)
|
||||
}
|
||||
|
||||
fn schedule_parallel_task(&mut self, idx: usize, call: ToolCall, sub_id: &str) {
|
||||
let router_for_task = self.router.clone();
|
||||
let session_for_task = self.session.clone();
|
||||
let turn_context_for_task = self.turn_context.clone();
|
||||
let sub_id_for_task = sub_id.to_string();
|
||||
|
||||
self.read_only_tasks.spawn(async move {
|
||||
let mut tracker = TurnDiffTracker::new();
|
||||
let payload_for_fallback = call.payload.clone();
|
||||
let call_id_for_fallback = call.call_id.clone();
|
||||
let tool_name_for_msg = call.tool_name.clone();
|
||||
let fut = async {
|
||||
router_for_task
|
||||
.dispatch_tool_call(
|
||||
session_for_task.as_ref(),
|
||||
turn_context_for_task.as_ref(),
|
||||
&mut tracker,
|
||||
&sub_id_for_task,
|
||||
call,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
let response = match AssertUnwindSafe(fut).catch_unwind().await {
|
||||
Ok(resp) => resp,
|
||||
Err(panic) => {
|
||||
let msg = Self::panic_to_message(panic);
|
||||
let message = format!("{tool_name_for_msg} failed: {msg}");
|
||||
Ok(Self::fallback_response(
|
||||
call_id_for_fallback,
|
||||
payload_for_fallback,
|
||||
message,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
(idx, response)
|
||||
});
|
||||
}
|
||||
|
||||
async fn emit_response_item_events(
|
||||
&self,
|
||||
sub_id: &str,
|
||||
item: &ResponseItem,
|
||||
) -> CodexResult<()> {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match item {
|
||||
ResponseItem::Message { .. } if self.turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
item,
|
||||
self.session.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
self.session.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fallback_response(
|
||||
call_id: String,
|
||||
payload: ToolPayload,
|
||||
message: String,
|
||||
) -> ResponseInputItem {
|
||||
match payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: message,
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn panic_to_message(payload: Box<dyn std::any::Any + Send>) -> String {
|
||||
if let Some(s) = payload.downcast_ref::<&str>() {
|
||||
(*s).to_string()
|
||||
} else if let Some(s) = payload.downcast_ref::<String>() {
|
||||
s.clone()
|
||||
} else {
|
||||
"panic without message".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn map_dispatch_error(err: FunctionCallError) -> CodexErr {
|
||||
match err {
|
||||
FunctionCallError::Fatal(message) => CodexErr::Fatal(message),
|
||||
_ => CodexErr::Fatal(err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod context;
|
||||
pub mod executor;
|
||||
pub(crate) mod handlers;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
|
||||
@@ -19,6 +19,27 @@ pub enum ToolKind {
|
||||
Mcp,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct ToolCapabilities {
|
||||
pub read_only: bool,
|
||||
}
|
||||
|
||||
impl ToolCapabilities {
|
||||
pub const fn mutating() -> Self {
|
||||
Self { read_only: false }
|
||||
}
|
||||
|
||||
pub const fn read_only() -> Self {
|
||||
Self { read_only: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ToolEntry {
|
||||
handler: Arc<dyn ToolHandler>,
|
||||
capabilities: ToolCapabilities,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ToolHandler: Send + Sync {
|
||||
fn kind(&self) -> ToolKind;
|
||||
@@ -36,17 +57,24 @@ pub trait ToolHandler: Send + Sync {
|
||||
-> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolRegistry {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
handlers: HashMap<String, ToolEntry>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||||
fn new(handlers: HashMap<String, ToolEntry>) -> Self {
|
||||
Self { handlers }
|
||||
}
|
||||
|
||||
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||||
self.handlers.get(name).map(Arc::clone)
|
||||
pub fn capabilities(&self, name: &str) -> Option<ToolCapabilities> {
|
||||
self.handlers.get(name).map(|entry| entry.capabilities)
|
||||
}
|
||||
|
||||
pub fn has_read_only_tools(&self) -> bool {
|
||||
self.handlers
|
||||
.values()
|
||||
.any(|entry| entry.capabilities.read_only)
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
@@ -67,8 +95,8 @@ impl ToolRegistry {
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
let log_payload = payload_for_response.log_payload();
|
||||
|
||||
let handler = match self.handler(tool_name.as_ref()) {
|
||||
Some(handler) => handler,
|
||||
let entry = match self.handlers.get(tool_name.as_str()) {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
let message =
|
||||
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||||
@@ -84,6 +112,8 @@ impl ToolRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
let handler = Arc::clone(&entry.handler);
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||
otel.tool_result(
|
||||
@@ -128,7 +158,7 @@ impl ToolRegistry {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
let output = guard.take().ok_or_else(|| {
|
||||
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||
FunctionCallError::RespondToModel("tool produced no output".to_string())
|
||||
})?;
|
||||
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||
}
|
||||
@@ -138,7 +168,7 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
handlers: HashMap<String, ToolEntry>,
|
||||
specs: Vec<ToolSpec>,
|
||||
}
|
||||
|
||||
@@ -155,10 +185,33 @@ impl ToolRegistryBuilder {
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
self.register_with_capabilities(name, handler, ToolCapabilities::mutating());
|
||||
}
|
||||
|
||||
pub fn register_read_only_handler(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
handler: Arc<dyn ToolHandler>,
|
||||
) {
|
||||
self.register_with_capabilities(name, handler, ToolCapabilities::read_only());
|
||||
}
|
||||
|
||||
pub fn register_with_capabilities(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
handler: Arc<dyn ToolHandler>,
|
||||
capabilities: ToolCapabilities,
|
||||
) {
|
||||
let name = name.into();
|
||||
if self
|
||||
.handlers
|
||||
.insert(name.clone(), handler.clone())
|
||||
.insert(
|
||||
name.clone(),
|
||||
ToolEntry {
|
||||
handler: handler.clone(),
|
||||
capabilities,
|
||||
},
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
warn!("overwriting handler for tool {name}");
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolCapabilities;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
@@ -20,8 +21,10 @@ pub struct ToolCall {
|
||||
pub tool_name: String,
|
||||
pub call_id: String,
|
||||
pub payload: ToolPayload,
|
||||
pub capabilities: ToolCapabilities,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ToolSpec>,
|
||||
@@ -41,7 +44,12 @@ impl ToolRouter {
|
||||
&self.specs
|
||||
}
|
||||
|
||||
pub fn has_read_only_tools(&self) -> bool {
|
||||
self.registry.has_read_only_tools()
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
&self,
|
||||
session: &Session,
|
||||
item: ResponseItem,
|
||||
) -> Result<Option<ToolCall>, FunctionCallError> {
|
||||
@@ -53,7 +61,7 @@ impl ToolRouter {
|
||||
..
|
||||
} => {
|
||||
if let Some((server, tool)) = session.parse_mcp_tool_name(&name) {
|
||||
Ok(Some(ToolCall {
|
||||
Ok(Some(self.attach_capabilities(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Mcp {
|
||||
@@ -61,18 +69,20 @@ impl ToolRouter {
|
||||
tool,
|
||||
raw_arguments: arguments,
|
||||
},
|
||||
}))
|
||||
capabilities: ToolCapabilities::mutating(),
|
||||
})))
|
||||
} else {
|
||||
let payload = if name == "unified_exec" {
|
||||
ToolPayload::UnifiedExec { arguments }
|
||||
} else {
|
||||
ToolPayload::Function { arguments }
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
Ok(Some(self.attach_capabilities(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload,
|
||||
}))
|
||||
capabilities: ToolCapabilities::mutating(),
|
||||
})))
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
@@ -80,11 +90,12 @@ impl ToolRouter {
|
||||
input,
|
||||
call_id,
|
||||
..
|
||||
} => Ok(Some(ToolCall {
|
||||
} => Ok(Some(self.attach_capabilities(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Custom { input },
|
||||
})),
|
||||
capabilities: ToolCapabilities::mutating(),
|
||||
}))),
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id,
|
||||
@@ -104,11 +115,12 @@ impl ToolRouter {
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
Ok(Some(self.attach_capabilities(ToolCall {
|
||||
tool_name: "local_shell".to_string(),
|
||||
call_id,
|
||||
payload: ToolPayload::LocalShell { params },
|
||||
}))
|
||||
capabilities: ToolCapabilities::mutating(),
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,6 +128,13 @@ impl ToolRouter {
|
||||
}
|
||||
}
|
||||
|
||||
fn attach_capabilities(&self, mut call: ToolCall) -> ToolCall {
|
||||
if let Some(capabilities) = self.registry.capabilities(call.tool_name.as_str()) {
|
||||
call.capabilities = capabilities;
|
||||
}
|
||||
call
|
||||
}
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: &Session,
|
||||
@@ -128,6 +147,7 @@ impl ToolRouter {
|
||||
tool_name,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
} = call;
|
||||
let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. });
|
||||
let failure_call_id = call_id.clone();
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::tools::handlers::PLAN_TOOL;
|
||||
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
|
||||
use crate::tools::handlers::apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tools::handlers::apply_patch::create_apply_patch_json_tool;
|
||||
use crate::tools::registry::ToolCapabilities;
|
||||
use crate::tools::registry::ToolRegistryBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -28,6 +29,7 @@ pub(crate) struct ToolsConfig {
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
pub enable_parallel_read_only: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
@@ -38,6 +40,7 @@ pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
pub(crate) enable_parallel_read_only: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
@@ -50,6 +53,7 @@ impl ToolsConfig {
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
enable_parallel_read_only,
|
||||
} = params;
|
||||
let shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::Streamable
|
||||
@@ -78,6 +82,7 @@ impl ToolsConfig {
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
enable_parallel_read_only: *enable_parallel_read_only,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -567,7 +572,7 @@ pub(crate) fn build_specs(
|
||||
}
|
||||
|
||||
builder.push_spec(create_read_file_tool());
|
||||
builder.register_handler("read_file", read_file_handler);
|
||||
builder.register_read_only_handler("read_file", read_file_handler);
|
||||
|
||||
if config.web_search_request {
|
||||
builder.push_spec(ToolSpec::WebSearch {});
|
||||
@@ -575,7 +580,7 @@ pub(crate) fn build_specs(
|
||||
|
||||
if config.include_view_image_tool {
|
||||
builder.push_spec(create_view_image_tool());
|
||||
builder.register_handler("view_image", view_image_handler);
|
||||
builder.register_read_only_handler("view_image", view_image_handler);
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
@@ -583,10 +588,21 @@ pub(crate) fn build_specs(
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
for (name, tool) in entries.into_iter() {
|
||||
let capabilities = if tool
|
||||
.annotations
|
||||
.as_ref()
|
||||
.and_then(|ann| ann.read_only_hint)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
ToolCapabilities::read_only()
|
||||
} else {
|
||||
ToolCapabilities::mutating()
|
||||
};
|
||||
|
||||
match mcp_tool_to_openai_tool(name.clone(), tool.clone()) {
|
||||
Ok(converted_tool) => {
|
||||
builder.push_spec(ToolSpec::Function(converted_tool));
|
||||
builder.register_handler(name, mcp_handler.clone());
|
||||
builder.register_with_capabilities(name, mcp_handler.clone(), capabilities);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert {name:?} MCP tool to OpenAI tool: {e:?}");
|
||||
@@ -643,6 +659,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
|
||||
|
||||
@@ -669,6 +686,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
|
||||
|
||||
@@ -695,6 +713,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
let (tools, _) = build_specs(
|
||||
&config,
|
||||
@@ -801,6 +820,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -878,6 +898,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -946,6 +967,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1009,6 +1031,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1075,6 +1098,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1153,6 +1177,7 @@ mod tests {
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
enable_parallel_read_only: false,
|
||||
});
|
||||
let (tools, _) = build_specs(
|
||||
&config,
|
||||
|
||||
@@ -48,16 +48,12 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let codex_home = match TempDir::new() {
|
||||
|
||||
@@ -46,16 +46,11 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let codex_home = match TempDir::new() {
|
||||
|
||||
@@ -14,6 +14,7 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -26,6 +27,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use futures::StreamExt;
|
||||
@@ -647,16 +649,11 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "azure".into(),
|
||||
base_url: Some(format!("{}/openai", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
@@ -1036,17 +1033,12 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
"api-version".to_string(),
|
||||
"2025-04-01-preview".to_string(),
|
||||
)])),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
http_headers: Some(std::collections::HashMap::from([(
|
||||
"Custom-Header".to_string(),
|
||||
"Value".to_string(),
|
||||
)])),
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Init session
|
||||
@@ -1113,17 +1105,12 @@ async fn env_var_overrides_loaded_auth() {
|
||||
"api-version".to_string(),
|
||||
"2025-04-01-preview".to_string(),
|
||||
)])),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
http_headers: Some(std::collections::HashMap::from([(
|
||||
"Custom-Header".to_string(),
|
||||
"Value".to_string(),
|
||||
)])),
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Init session
|
||||
@@ -1291,3 +1278,51 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
"request 3 tail mismatch",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_tool_calls_field_not_sent() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp_parallel"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
provider.supports_parallel_tool_calls = true; // no longer affects payload
|
||||
|
||||
let provider_clone = provider.clone();
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("model family");
|
||||
config.enable_parallel_read_only_tools = true;
|
||||
config.model_provider = provider_clone;
|
||||
config.model_provider_id = "openai".to_string();
|
||||
})
|
||||
.build(&server)
|
||||
.await
|
||||
.expect("build codex");
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.expect("requests")[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
// Field removed: executor handles scheduling internally
|
||||
assert!(request_body.get("parallel_tool_calls").is_none());
|
||||
}
|
||||
|
||||
@@ -169,7 +169,6 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
],
|
||||
"tools": tool_calls,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": false,
|
||||
"reasoning": {
|
||||
"summary": "auto"
|
||||
},
|
||||
@@ -238,7 +237,6 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
],
|
||||
"tools": [],
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": false,
|
||||
"reasoning": {
|
||||
"summary": "auto"
|
||||
},
|
||||
@@ -303,7 +301,6 @@ SUMMARY_ONLY_CONTEXT"
|
||||
],
|
||||
"tools": tool_calls,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": false,
|
||||
"reasoning": {
|
||||
"summary": "auto"
|
||||
},
|
||||
@@ -388,7 +385,6 @@ SUMMARY_ONLY_CONTEXT"
|
||||
],
|
||||
"tools": tool_calls,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": false,
|
||||
"reasoning": {
|
||||
"summary": "auto"
|
||||
},
|
||||
@@ -473,7 +469,6 @@ SUMMARY_ONLY_CONTEXT"
|
||||
],
|
||||
"tools": tool_calls,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": false,
|
||||
"reasoning": {
|
||||
"summary": "auto"
|
||||
},
|
||||
|
||||
@@ -14,6 +14,7 @@ mod live_cli;
|
||||
mod model_overrides;
|
||||
mod model_tools;
|
||||
mod otel;
|
||||
mod parallel_read_only;
|
||||
mod prompt_caching;
|
||||
mod read_file;
|
||||
mod review;
|
||||
|
||||
301
codex-rs/core/tests/suite/parallel_read_only.rs
Normal file
301
codex-rs/core/tests/suite/parallel_read_only.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::ffi::CString;
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use tokio::fs::OpenOptions;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn read_only_tools_execute_before_apply_patch() -> anyhow::Result<()> {
|
||||
// Bail out early if the sandbox does not allow network traffic, because the
|
||||
// mocked Codex server still communicates over HTTP.
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
// Stand up a mock Codex backend that will stream tool calls and responses.
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.enable_parallel_read_only_tools = true;
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
// Create two FIFOs that the mocked read-only tools will try to read from in
|
||||
// order to simulate long running, blocking I/O.
|
||||
let fifo_one = cwd.path().join("parallel_fifo_one");
|
||||
let fifo_two = cwd.path().join("parallel_fifo_two");
|
||||
create_fifo(&fifo_one)?;
|
||||
create_fifo(&fifo_two)?;
|
||||
|
||||
let read_call_one = "read-file-1";
|
||||
let read_call_two = "read-file-2";
|
||||
let patch_call = "apply-patch";
|
||||
|
||||
let read_args_one = serde_json::json!({
|
||||
"file_path": fifo_one.to_string_lossy(),
|
||||
"offset": 1,
|
||||
"limit": 1,
|
||||
})
|
||||
.to_string();
|
||||
let read_args_two = serde_json::json!({
|
||||
"file_path": fifo_two.to_string_lossy(),
|
||||
"offset": 1,
|
||||
"limit": 1,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let patch_path = "parallel_patch_output.txt";
|
||||
let patch_content = format!(
|
||||
"*** Begin Patch\n*** Add File: {patch_path}\n+parallel apply_patch executed\n*** End Patch"
|
||||
);
|
||||
|
||||
// Queue the first SSE response that drives the session: fire two read-only
|
||||
// tool calls, then schedule the apply-patch call.
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-parallel"}
|
||||
}),
|
||||
ev_function_call(read_call_one, "read_file", &read_args_one),
|
||||
ev_function_call(read_call_two, "read_file", &read_args_two),
|
||||
ev_apply_patch_function_call(patch_call, &patch_content),
|
||||
ev_completed("resp-parallel"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
// Queue a follow-up response so the session can complete once all tools run.
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "all done"),
|
||||
ev_completed("resp-final"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
// Start timers that resolve when each FIFO gets a writer, helping us measure
|
||||
// when the corresponding read-only tool begins execution.
|
||||
let start = Instant::now();
|
||||
let wait_one = tokio::spawn(wait_for_writer(fifo_one.clone(), start));
|
||||
let wait_two = tokio::spawn(wait_for_writer(fifo_two.clone(), start));
|
||||
|
||||
// Capture all Codex events so we can verify tool ordering once execution finishes.
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
let events_task = events.clone();
|
||||
let codex_for_events = codex.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let event = codex_for_events.next_event().await.expect("event");
|
||||
let msg = event.msg;
|
||||
let is_done = matches!(msg, EventMsg::TaskComplete(_));
|
||||
{
|
||||
let mut log = events_task.lock().await;
|
||||
log.push(msg);
|
||||
}
|
||||
if is_done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = done_tx.send(());
|
||||
});
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
// Trigger the user turn that causes Codex to invoke the two read-only tools
|
||||
// and subsequently the apply-patch tool.
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please process the tools in parallel".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let wait_timeout = Duration::from_secs(5);
|
||||
|
||||
let (mut writer_one, elapsed_one) = tokio::time::timeout(wait_timeout, async {
|
||||
wait_one.await.expect("wait fifo one task panicked")
|
||||
})
|
||||
.await
|
||||
.expect("timeout waiting for first read-only tool")?;
|
||||
|
||||
let (mut writer_two, elapsed_two) = tokio::time::timeout(wait_timeout, async {
|
||||
wait_two.await.expect("wait fifo two task panicked")
|
||||
})
|
||||
.await
|
||||
.expect("timeout waiting for second read-only tool")?;
|
||||
|
||||
// Ensure the two read-only tools started within 200ms of each other so that
|
||||
// they can be considered parallel.
|
||||
let delta = elapsed_one.abs_diff(elapsed_two);
|
||||
assert!(
|
||||
delta < Duration::from_millis(200),
|
||||
"expected read-only tools to start in parallel (delta {delta:?})"
|
||||
);
|
||||
|
||||
writer_one.write_all(b"fifo one line\n").await?;
|
||||
writer_one.shutdown().await?;
|
||||
drop(writer_one);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
{
|
||||
// Confirm that apply_patch has not started while the second read-only
|
||||
// tool is still blocked waiting for input.
|
||||
let log = events.lock().await;
|
||||
assert!(
|
||||
!log.iter()
|
||||
.any(|msg| matches!(msg, EventMsg::PatchApplyBegin(_))),
|
||||
"apply_patch began before the second read-only tool completed"
|
||||
);
|
||||
}
|
||||
|
||||
writer_two.write_all(b"fifo two line\n").await?;
|
||||
writer_two.shutdown().await?;
|
||||
drop(writer_two);
|
||||
|
||||
// Wait for the event collector to observe task completion so we can inspect
|
||||
// the final event log.
|
||||
done_rx.await.expect("event collector finished");
|
||||
|
||||
let events_log = events.lock().await;
|
||||
let patch_begin_index = events_log
|
||||
.iter()
|
||||
.position(|msg| match msg {
|
||||
EventMsg::PatchApplyBegin(begin) => begin.call_id == patch_call,
|
||||
_ => false,
|
||||
})
|
||||
.expect("expected PatchApplyBegin event");
|
||||
let patch_end_index = events_log
|
||||
.iter()
|
||||
.position(|msg| match msg {
|
||||
EventMsg::PatchApplyEnd(end) => end.call_id == patch_call,
|
||||
_ => false,
|
||||
})
|
||||
.expect("expected PatchApplyEnd event");
|
||||
assert!(
|
||||
patch_begin_index < patch_end_index,
|
||||
"PatchApplyEnd occurred before PatchApplyBegin"
|
||||
);
|
||||
|
||||
// Record whether apply_patch succeeded so the assertions below can verify
|
||||
// either the patched file or the reported stderr output.
|
||||
let patch_end_success = events_log.iter().find_map(|msg| match msg {
|
||||
EventMsg::PatchApplyEnd(end) if end.call_id == patch_call => {
|
||||
Some((end.success, end.stderr.clone()))
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
let (patch_success, patch_stderr) = patch_end_success.expect("expected PatchApplyEnd details");
|
||||
drop(events_log);
|
||||
|
||||
if patch_success {
|
||||
let patched_file = cwd.path().join(patch_path);
|
||||
let patched_contents = std::fs::read_to_string(&patched_file)?;
|
||||
assert!(
|
||||
patched_contents.contains("parallel apply_patch executed"),
|
||||
"unexpected patch contents: {patched_contents:?}"
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
patch_stderr.contains("codex-run-as-apply-patch"),
|
||||
"unexpected apply_patch stderr: {patch_stderr:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// Check that the mock server observed outputs from every tool invocation.
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
!requests.is_empty(),
|
||||
"expected at least one request recorded"
|
||||
);
|
||||
|
||||
let mut seen_outputs = std::collections::HashSet::new();
|
||||
for request in requests {
|
||||
let body = request
|
||||
.body_json::<serde_json::Value>()
|
||||
.expect("request json");
|
||||
if let Some(items) = body.get("input").and_then(|v| v.as_array()) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(|v| v.as_str()) == Some("function_call_output")
|
||||
&& let Some(call_id) = item.get("call_id").and_then(|v| v.as_str())
|
||||
{
|
||||
seen_outputs.insert(call_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
seen_outputs.contains(read_call_one),
|
||||
"missing read-only tool output for {read_call_one}"
|
||||
);
|
||||
assert!(
|
||||
seen_outputs.contains(read_call_two),
|
||||
"missing read-only tool output for {read_call_two}"
|
||||
);
|
||||
assert!(
|
||||
seen_outputs.contains(patch_call),
|
||||
"missing apply_patch tool output"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_fifo(path: &Path) -> anyhow::Result<()> {
|
||||
let c_path =
|
||||
CString::new(path.as_os_str().as_bytes()).context("fifo path contained null byte")?;
|
||||
let res = unsafe { libc::mkfifo(c_path.as_ptr(), 0o600) };
|
||||
if res != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_writer(
|
||||
path: PathBuf,
|
||||
origin: Instant,
|
||||
) -> anyhow::Result<(tokio::fs::File, Duration)> {
|
||||
let file = OpenOptions::new()
|
||||
.write(true)
|
||||
.open(&path)
|
||||
.await
|
||||
.with_context(|| format!("open fifo {path:?} for writing"))?;
|
||||
Ok((file, origin.elapsed()))
|
||||
}
|
||||
@@ -65,15 +65,11 @@ async fn continue_after_stream_error() {
|
||||
name: "mock-openai".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(1),
|
||||
stream_max_retries: Some(1),
|
||||
stream_idle_timeout_ms: Some(2_000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
|
||||
@@ -72,16 +72,12 @@ async fn retries_on_early_close() {
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// exercise retry path: first attempt yields incomplete stream, so allow 1 retry
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(1),
|
||||
stream_idle_timeout_ms: Some(2000),
|
||||
requires_openai_auth: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
|
||||
Reference in New Issue
Block a user