feat: collab wait multiple IDs (#9294)

This commit is contained in:
jif-oai
2026-01-16 12:05:04 +01:00
committed by GitHub
parent c1ac5223e1
commit c576756c81
7 changed files with 302 additions and 132 deletions

View File

@@ -1771,12 +1771,12 @@ pub enum ThreadItem {
/// Thread ID of the agent issuing the collab request.
sender_thread_id: String,
/// Thread ID of the receiving agent, when applicable. In case of spawn operation,
/// this correspond to the newly spawned agent.
receiver_thread_id: Option<String>,
/// this corresponds to the newly spawned agent.
receiver_thread_ids: Vec<String>,
/// Prompt text sent as part of the collab tool call, when available.
prompt: Option<String>,
/// Last known status of the target agent, when available.
agent_state: Option<CollabAgentState>,
/// Last known status of the target agents, when available.
agents_states: HashMap<String, CollabAgentState>,
},
#[serde(rename_all = "camelCase")]
#[ts(rename_all = "camelCase")]

View File

@@ -287,9 +287,9 @@ pub(crate) async fn apply_bespoke_event_handling(
tool: CollabAgentTool::SpawnAgent,
status: V2CollabToolCallStatus::InProgress,
sender_thread_id: begin_event.sender_thread_id.to_string(),
receiver_thread_id: None,
receiver_thread_ids: Vec::new(),
prompt: Some(begin_event.prompt),
agent_state: None,
agents_states: HashMap::new(),
};
let notification = ItemStartedNotification {
thread_id: conversation_id.to_string(),
@@ -301,19 +301,32 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabAgentSpawnEnd(end_event) => {
let status = if end_event.new_thread_id.is_some() {
V2CollabToolCallStatus::Completed
} else {
V2CollabToolCallStatus::Failed
let has_receiver = end_event.new_thread_id.is_some();
let status = match &end_event.status {
codex_protocol::protocol::AgentStatus::Errored(_)
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
_ if has_receiver => V2CollabToolCallStatus::Completed,
_ => V2CollabToolCallStatus::Failed,
};
let (receiver_thread_ids, agents_states) = match end_event.new_thread_id {
Some(id) => {
let receiver_id = id.to_string();
let received_status = V2CollabAgentStatus::from(end_event.status.clone());
(
vec![receiver_id.clone()],
[(receiver_id, received_status)].into_iter().collect(),
)
}
None => (Vec::new(), HashMap::new()),
};
let item = ThreadItem::CollabAgentToolCall {
id: end_event.call_id,
tool: CollabAgentTool::SpawnAgent,
status,
sender_thread_id: end_event.sender_thread_id.to_string(),
receiver_thread_id: end_event.new_thread_id.map(|id| id.to_string()),
receiver_thread_ids,
prompt: Some(end_event.prompt),
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
agents_states,
};
let notification = ItemCompletedNotification {
thread_id: conversation_id.to_string(),
@@ -325,14 +338,15 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabAgentInteractionBegin(begin_event) => {
let receiver_thread_ids = vec![begin_event.receiver_thread_id.to_string()];
let item = ThreadItem::CollabAgentToolCall {
id: begin_event.call_id,
tool: CollabAgentTool::SendInput,
status: V2CollabToolCallStatus::InProgress,
sender_thread_id: begin_event.sender_thread_id.to_string(),
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
receiver_thread_ids,
prompt: Some(begin_event.prompt),
agent_state: None,
agents_states: HashMap::new(),
};
let notification = ItemStartedNotification {
thread_id: conversation_id.to_string(),
@@ -344,19 +358,21 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabAgentInteractionEnd(end_event) => {
let status = match end_event.status {
let status = match &end_event.status {
codex_protocol::protocol::AgentStatus::Errored(_)
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
_ => V2CollabToolCallStatus::Completed,
};
let receiver_id = end_event.receiver_thread_id.to_string();
let received_status = V2CollabAgentStatus::from(end_event.status);
let item = ThreadItem::CollabAgentToolCall {
id: end_event.call_id,
tool: CollabAgentTool::SendInput,
status,
sender_thread_id: end_event.sender_thread_id.to_string(),
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
receiver_thread_ids: vec![receiver_id.clone()],
prompt: Some(end_event.prompt),
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
agents_states: [(receiver_id, received_status)].into_iter().collect(),
};
let notification = ItemCompletedNotification {
thread_id: conversation_id.to_string(),
@@ -368,14 +384,19 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabWaitingBegin(begin_event) => {
let receiver_thread_ids = begin_event
.receiver_thread_ids
.iter()
.map(ToString::to_string)
.collect();
let item = ThreadItem::CollabAgentToolCall {
id: begin_event.call_id,
tool: CollabAgentTool::Wait,
status: V2CollabToolCallStatus::InProgress,
sender_thread_id: begin_event.sender_thread_id.to_string(),
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
receiver_thread_ids,
prompt: None,
agent_state: None,
agents_states: HashMap::new(),
};
let notification = ItemStartedNotification {
thread_id: conversation_id.to_string(),
@@ -387,19 +408,31 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabWaitingEnd(end_event) => {
let status = match end_event.status {
codex_protocol::protocol::AgentStatus::Errored(_)
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
_ => V2CollabToolCallStatus::Completed,
let status = if end_event.statuses.values().any(|status| {
matches!(
status,
codex_protocol::protocol::AgentStatus::Errored(_)
| codex_protocol::protocol::AgentStatus::NotFound
)
}) {
V2CollabToolCallStatus::Failed
} else {
V2CollabToolCallStatus::Completed
};
let receiver_thread_ids = end_event.statuses.keys().map(ToString::to_string).collect();
let agents_states = end_event
.statuses
.iter()
.map(|(id, status)| (id.to_string(), V2CollabAgentStatus::from(status.clone())))
.collect();
let item = ThreadItem::CollabAgentToolCall {
id: end_event.call_id,
tool: CollabAgentTool::Wait,
status,
sender_thread_id: end_event.sender_thread_id.to_string(),
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
receiver_thread_ids,
prompt: None,
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
agents_states,
};
let notification = ItemCompletedNotification {
thread_id: conversation_id.to_string(),
@@ -416,9 +449,9 @@ pub(crate) async fn apply_bespoke_event_handling(
tool: CollabAgentTool::CloseAgent,
status: V2CollabToolCallStatus::InProgress,
sender_thread_id: begin_event.sender_thread_id.to_string(),
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
receiver_thread_ids: vec![begin_event.receiver_thread_id.to_string()],
prompt: None,
agent_state: None,
agents_states: HashMap::new(),
};
let notification = ItemStartedNotification {
thread_id: conversation_id.to_string(),
@@ -430,19 +463,26 @@ pub(crate) async fn apply_bespoke_event_handling(
.await;
}
EventMsg::CollabCloseEnd(end_event) => {
let status = match end_event.status {
let status = match &end_event.status {
codex_protocol::protocol::AgentStatus::Errored(_)
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
_ => V2CollabToolCallStatus::Completed,
};
let receiver_id = end_event.receiver_thread_id.to_string();
let agents_states = [(
receiver_id.clone(),
V2CollabAgentStatus::from(end_event.status),
)]
.into_iter()
.collect();
let item = ThreadItem::CollabAgentToolCall {
id: end_event.call_id,
tool: CollabAgentTool::CloseAgent,
status,
sender_thread_id: end_event.sender_thread_id.to_string(),
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
receiver_thread_ids: vec![receiver_id],
prompt: None,
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
agents_states,
};
let notification = ItemCompletedNotification {
thread_id: conversation_id.to_string(),

View File

@@ -254,20 +254,26 @@ mod send_input {
mod wait {
use super::*;
use crate::agent::status::is_final;
use futures::FutureExt;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch::Receiver;
use tokio::time::Instant;
use tokio::time::timeout_at;
#[derive(Debug, Deserialize)]
struct WaitArgs {
id: String,
ids: Vec<String>,
timeout_ms: Option<i64>,
}
#[derive(Debug, Serialize)]
struct WaitResult {
status: AgentStatus,
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
}
@@ -278,7 +284,16 @@ mod wait {
arguments: String,
) -> Result<ToolOutput, FunctionCallError> {
let args: WaitArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
if args.ids.is_empty() {
return Err(FunctionCallError::RespondToModel(
"ids must be non-empty".to_owned(),
));
}
let receiver_thread_ids = args
.ids
.iter()
.map(|id| agent_id(id))
.collect::<Result<Vec<_>, _>>()?;
// Validate timeout.
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
@@ -296,105 +311,131 @@ mod wait {
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let status_rx = match session
.services
.agent_control
.subscribe_status(receiver_thread_id)
.await
{
Ok(status_rx) => status_rx,
Err(err) => {
let status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
receiver_thread_id,
call_id: call_id.clone(),
status,
}
.into(),
)
.await;
return Err(collab_agent_error(receiver_thread_id, err));
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
results
};
let result =
wait_for_status(session.as_ref(), receiver_thread_id, timeout_ms, status_rx).await;
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
receiver_thread_id,
call_id,
status: result.status.clone(),
statuses: statuses_map,
}
.into(),
)
.await;
if matches!(result.status, AgentStatus::NotFound) {
return Err(FunctionCallError::RespondToModel(format!(
"agent with id {receiver_thread_id} not found"
)));
}
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
let success = !result.timed_out && !matches!(result.status, AgentStatus::Errored(_));
Ok(ToolOutput::Function {
content,
success: Some(success),
success: None,
content_items: None,
})
}
async fn wait_for_status(
session: &Session,
agent_id: ThreadId,
timeout_ms: i64,
mut status_rx: tokio::sync::watch::Receiver<AgentStatus>,
) -> WaitResult {
// Get last known status.
let mut status = status_rx.borrow_and_update().clone();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
async fn wait_for_final_status(
session: Arc<Session>,
thread_id: ThreadId,
mut status_rx: Receiver<AgentStatus>,
) -> Option<(ThreadId, AgentStatus)> {
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
let timed_out = loop {
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(thread_id).await;
return is_final(&latest).then_some((thread_id, latest));
}
status = status_rx.borrow().clone();
if is_final(&status) {
break false;
return Some((thread_id, status));
}
match timeout_at(deadline, status_rx.changed()).await {
Ok(Ok(())) => status = status_rx.borrow().clone(),
Ok(Err(_)) => {
let last_status = session.services.agent_control.get_status(agent_id).await;
if last_status != AgentStatus::NotFound {
// On-purpose we keep the last known status if the agent gets dropped. This
// event is not supposed to happen.
status = last_status;
}
break false;
}
Err(_) => break true,
}
};
WaitResult { status, timed_out }
}
}
}
@@ -560,7 +601,9 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::ThreadId;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -772,6 +815,12 @@ mod tests {
.expect("shutdown should submit");
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct WaitResult {
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
}
#[tokio::test]
async fn wait_rejects_non_positive_timeout() {
let (session, turn) = make_session_and_context().await;
@@ -779,7 +828,10 @@ mod tests {
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": ThreadId::new().to_string(), "timeout_ms": 0})),
function_payload(json!({
"ids": [ThreadId::new().to_string()],
"timeout_ms": 0
})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("non-positive timeout should be rejected");
@@ -797,7 +849,7 @@ mod tests {
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": "invalid"})),
function_payload(json!({"ids": ["invalid"]})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("invalid id should be rejected");
@@ -808,6 +860,65 @@ mod tests {
assert!(msg.starts_with("invalid agent id invalid:"));
}
#[tokio::test]
async fn wait_rejects_empty_ids() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"ids": []})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("empty ids should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel("ids must be non-empty".to_string())
);
}
#[tokio::test]
async fn wait_returns_not_found_for_missing_agents() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let id_a = ThreadId::new();
let id_b = ThreadId::new();
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [id_a.to_string(), id_b.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([
(id_a, AgentStatus::NotFound),
(id_b, AgentStatus::NotFound),
]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_times_out_when_status_is_not_final() {
let (mut session, turn) = make_session_and_context().await;
@@ -820,7 +931,10 @@ mod tests {
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 10})),
function_payload(json!({
"ids": [agent_id.to_string()],
"timeout_ms": 10
})),
);
let output = CollabHandler
.handle(invocation)
@@ -832,8 +946,16 @@ mod tests {
else {
panic!("expected function output");
};
assert_eq!(content, r#"{"status":"pending_init","timed_out":true}"#);
assert_eq!(success, Some(false));
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::new(),
timed_out: true
}
);
assert_eq!(success, None);
let _ = thread
.thread
@@ -869,7 +991,10 @@ mod tests {
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 1000})),
function_payload(json!({
"ids": [agent_id.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
@@ -881,8 +1006,16 @@ mod tests {
else {
panic!("expected function output");
};
assert_eq!(content, r#"{"status":"shutdown","timed_out":false}"#);
assert_eq!(success, Some(true));
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]

View File

@@ -503,9 +503,10 @@ fn create_send_input_tool() -> ToolSpec {
fn create_wait_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"id".to_string(),
JsonSchema::String {
description: Some("Identifier of the agent to wait on.".to_string()),
"ids".to_string(),
JsonSchema::Array {
items: Box::new(JsonSchema::String { description: None }),
description: Some("Identifiers of the agents to wait on.".to_string()),
},
);
properties.insert(
@@ -519,11 +520,13 @@ fn create_wait_tool() -> ToolSpec {
ToolSpec::Function(ResponsesApiTool {
name: "wait".to_string(),
description: "Wait for an agent and return its status.".to_string(),
description:
"Wait for agents and return their statuses. If no agent is done, no status get returned."
.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["id".to_string()]),
required: Some(vec!["ids".to_string()]),
additional_properties: Some(false.into()),
},
})

View File

@@ -2177,8 +2177,8 @@ pub struct CollabAgentInteractionEndEvent {
pub struct CollabWaitingBeginEvent {
/// Thread ID of the sender.
pub sender_thread_id: ThreadId,
/// Thread ID of the receiver.
pub receiver_thread_id: ThreadId,
/// Thread ID of the receivers.
pub receiver_thread_ids: Vec<ThreadId>,
/// ID of the waiting call.
pub call_id: String,
}
@@ -2187,12 +2187,10 @@ pub struct CollabWaitingBeginEvent {
pub struct CollabWaitingEndEvent {
/// Thread ID of the sender.
pub sender_thread_id: ThreadId,
/// Thread ID of the receiver.
pub receiver_thread_id: ThreadId,
/// ID of the waiting call.
pub call_id: String,
/// Last known status of the receiver agent reported to the sender agent.
pub status: AgentStatus,
/// Last known status of the receiver agents reported to the sender agent.
pub statuses: HashMap<ThreadId, AgentStatus>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema, TS)]

View File

@@ -59,12 +59,12 @@ pub(crate) fn waiting_begin(ev: CollabWaitingBeginEvent) -> PlainHistoryCell {
let CollabWaitingBeginEvent {
call_id,
sender_thread_id,
receiver_thread_id,
receiver_thread_ids,
} = ev;
let details = vec![
detail_line("call", call_id),
detail_line("sender", sender_thread_id),
detail_line("receiver", receiver_thread_id),
detail_line("receiver", format!("{receiver_thread_ids:?}")),
];
collab_event("Collab wait begin", details)
}
@@ -73,14 +73,12 @@ pub(crate) fn waiting_end(ev: CollabWaitingEndEvent) -> PlainHistoryCell {
let CollabWaitingEndEvent {
call_id,
sender_thread_id,
receiver_thread_id,
status,
statuses,
} = ev;
let details = vec![
detail_line("call", call_id),
detail_line("sender", sender_thread_id),
detail_line("receiver", receiver_thread_id),
status_line(&status),
detail_line("statuses", format!("{statuses:#?}")),
];
collab_event("Collab wait end", details)
}

View File

@@ -59,12 +59,12 @@ pub(crate) fn waiting_begin(ev: CollabWaitingBeginEvent) -> PlainHistoryCell {
let CollabWaitingBeginEvent {
call_id,
sender_thread_id,
receiver_thread_id,
receiver_thread_ids,
} = ev;
let details = vec![
detail_line("call", call_id),
detail_line("sender", sender_thread_id),
detail_line("receiver", receiver_thread_id),
detail_line("receiver", format!("{receiver_thread_ids:?}")),
];
collab_event("Collab wait begin", details)
}
@@ -73,14 +73,12 @@ pub(crate) fn waiting_end(ev: CollabWaitingEndEvent) -> PlainHistoryCell {
let CollabWaitingEndEvent {
call_id,
sender_thread_id,
receiver_thread_id,
status,
statuses,
} = ev;
let details = vec![
detail_line("call", call_id),
detail_line("sender", sender_thread_id),
detail_line("receiver", receiver_thread_id),
status_line(&status),
detail_line("statuses", format!("{statuses:#?}")),
];
collab_event("Collab wait end", details)
}