fix parallel tool calls (#7956)

This commit is contained in:
Ahmed Ibrahim
2025-12-15 17:28:27 -08:00
committed by GitHub
parent b093565bfb
commit d802b18716
9 changed files with 866 additions and 23 deletions

View File

@@ -1,6 +1,7 @@
#![cfg(not(target_os = "windows"))]
#![allow(clippy::unwrap_used)]
use std::fs;
use std::time::Duration;
use std::time::Instant;
@@ -13,16 +14,22 @@ use codex_protocol::user_input::UserInput;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_shell_command_call_with_args;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::streaming_sse::StreamingSseChunk;
use core_test_support::streaming_sse::start_streaming_sse_server;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use tokio::sync::oneshot;
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
let session_model = test.session_configured.model.clone();
@@ -280,3 +287,123 @@ async fn tool_results_grouped() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn shell_tools_start_before_response_completed_when_stream_delayed() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let output_file = tempfile::NamedTempFile::new()?;
let output_path = output_file.path();
let first_response_id = "resp-1";
let second_response_id = "resp-2";
let command = format!(
"perl -MTime::HiRes -e 'print int(Time::HiRes::time()*1000), \"\\n\"' >> \"{}\"",
output_path.display()
);
let args = json!({
"command": command,
"timeout_ms": 1_000,
});
let first_chunk = sse(vec![
ev_response_created(first_response_id),
ev_shell_command_call_with_args("call-1", &args),
ev_shell_command_call_with_args("call-2", &args),
ev_shell_command_call_with_args("call-3", &args),
ev_shell_command_call_with_args("call-4", &args),
]);
let second_chunk = sse(vec![ev_completed(first_response_id)]);
let follow_up = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed(second_response_id),
]);
let (first_gate_tx, first_gate_rx) = oneshot::channel();
let (completion_gate_tx, completion_gate_rx) = oneshot::channel();
let (follow_up_gate_tx, follow_up_gate_rx) = oneshot::channel();
let (streaming_server, completion_receivers) = start_streaming_sse_server(vec![
vec![
StreamingSseChunk {
gate: Some(first_gate_rx),
body: first_chunk,
},
StreamingSseChunk {
gate: Some(completion_gate_rx),
body: second_chunk,
},
],
vec![StreamingSseChunk {
gate: Some(follow_up_gate_rx),
body: follow_up,
}],
])
.await;
let mut builder = test_codex().with_model("gpt-5.1");
let test = builder
.build_with_streaming_server(&streaming_server)
.await?;
let session_model = test.session_configured.model.clone();
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "stream delayed completion".into(),
}],
final_output_json_schema: None,
cwd: test.cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: session_model,
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
let _ = first_gate_tx.send(());
let _ = follow_up_gate_tx.send(());
let timestamps = tokio::time::timeout(Duration::from_secs(1), async {
loop {
let contents = fs::read_to_string(output_path)?;
let timestamps = contents
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| {
line.trim()
.parse::<i64>()
.map_err(|err| anyhow::anyhow!("invalid timestamp {line:?}: {err}"))
})
.collect::<Result<Vec<_>, _>>()?;
if timestamps.len() == 4 {
return Ok::<_, anyhow::Error>(timestamps);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await??;
let _ = completion_gate_tx.send(());
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let mut completion_iter = completion_receivers.into_iter();
let completed_at = completion_iter
.next()
.expect("completion receiver missing")
.await
.expect("completion timestamp missing");
let count = i64::try_from(timestamps.len()).expect("timestamp count fits in i64");
assert_eq!(count, 4);
for timestamp in timestamps {
assert!(
timestamp < completed_at,
"timestamp {timestamp} should be before completed {completed_at}"
);
}
streaming_server.shutdown().await;
Ok(())
}