mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
feat: collab wait multiple IDs (#9294)
This commit is contained in:
@@ -1771,12 +1771,12 @@ pub enum ThreadItem {
|
||||
/// Thread ID of the agent issuing the collab request.
|
||||
sender_thread_id: String,
|
||||
/// Thread ID of the receiving agent, when applicable. In case of spawn operation,
|
||||
/// this correspond to the newly spawned agent.
|
||||
receiver_thread_id: Option<String>,
|
||||
/// this corresponds to the newly spawned agent.
|
||||
receiver_thread_ids: Vec<String>,
|
||||
/// Prompt text sent as part of the collab tool call, when available.
|
||||
prompt: Option<String>,
|
||||
/// Last known status of the target agent, when available.
|
||||
agent_state: Option<CollabAgentState>,
|
||||
/// Last known status of the target agents, when available.
|
||||
agents_states: HashMap<String, CollabAgentState>,
|
||||
},
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(rename_all = "camelCase")]
|
||||
|
||||
@@ -287,9 +287,9 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
tool: CollabAgentTool::SpawnAgent,
|
||||
status: V2CollabToolCallStatus::InProgress,
|
||||
sender_thread_id: begin_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: None,
|
||||
receiver_thread_ids: Vec::new(),
|
||||
prompt: Some(begin_event.prompt),
|
||||
agent_state: None,
|
||||
agents_states: HashMap::new(),
|
||||
};
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -301,19 +301,32 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabAgentSpawnEnd(end_event) => {
|
||||
let status = if end_event.new_thread_id.is_some() {
|
||||
V2CollabToolCallStatus::Completed
|
||||
} else {
|
||||
V2CollabToolCallStatus::Failed
|
||||
let has_receiver = end_event.new_thread_id.is_some();
|
||||
let status = match &end_event.status {
|
||||
codex_protocol::protocol::AgentStatus::Errored(_)
|
||||
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
|
||||
_ if has_receiver => V2CollabToolCallStatus::Completed,
|
||||
_ => V2CollabToolCallStatus::Failed,
|
||||
};
|
||||
let (receiver_thread_ids, agents_states) = match end_event.new_thread_id {
|
||||
Some(id) => {
|
||||
let receiver_id = id.to_string();
|
||||
let received_status = V2CollabAgentStatus::from(end_event.status.clone());
|
||||
(
|
||||
vec![receiver_id.clone()],
|
||||
[(receiver_id, received_status)].into_iter().collect(),
|
||||
)
|
||||
}
|
||||
None => (Vec::new(), HashMap::new()),
|
||||
};
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: end_event.call_id,
|
||||
tool: CollabAgentTool::SpawnAgent,
|
||||
status,
|
||||
sender_thread_id: end_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: end_event.new_thread_id.map(|id| id.to_string()),
|
||||
receiver_thread_ids,
|
||||
prompt: Some(end_event.prompt),
|
||||
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
|
||||
agents_states,
|
||||
};
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -325,14 +338,15 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabAgentInteractionBegin(begin_event) => {
|
||||
let receiver_thread_ids = vec![begin_event.receiver_thread_id.to_string()];
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: begin_event.call_id,
|
||||
tool: CollabAgentTool::SendInput,
|
||||
status: V2CollabToolCallStatus::InProgress,
|
||||
sender_thread_id: begin_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids,
|
||||
prompt: Some(begin_event.prompt),
|
||||
agent_state: None,
|
||||
agents_states: HashMap::new(),
|
||||
};
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -344,19 +358,21 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabAgentInteractionEnd(end_event) => {
|
||||
let status = match end_event.status {
|
||||
let status = match &end_event.status {
|
||||
codex_protocol::protocol::AgentStatus::Errored(_)
|
||||
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
|
||||
_ => V2CollabToolCallStatus::Completed,
|
||||
};
|
||||
let receiver_id = end_event.receiver_thread_id.to_string();
|
||||
let received_status = V2CollabAgentStatus::from(end_event.status);
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: end_event.call_id,
|
||||
tool: CollabAgentTool::SendInput,
|
||||
status,
|
||||
sender_thread_id: end_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids: vec![receiver_id.clone()],
|
||||
prompt: Some(end_event.prompt),
|
||||
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
|
||||
agents_states: [(receiver_id, received_status)].into_iter().collect(),
|
||||
};
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -368,14 +384,19 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabWaitingBegin(begin_event) => {
|
||||
let receiver_thread_ids = begin_event
|
||||
.receiver_thread_ids
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: begin_event.call_id,
|
||||
tool: CollabAgentTool::Wait,
|
||||
status: V2CollabToolCallStatus::InProgress,
|
||||
sender_thread_id: begin_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids,
|
||||
prompt: None,
|
||||
agent_state: None,
|
||||
agents_states: HashMap::new(),
|
||||
};
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -387,19 +408,31 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabWaitingEnd(end_event) => {
|
||||
let status = match end_event.status {
|
||||
codex_protocol::protocol::AgentStatus::Errored(_)
|
||||
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
|
||||
_ => V2CollabToolCallStatus::Completed,
|
||||
let status = if end_event.statuses.values().any(|status| {
|
||||
matches!(
|
||||
status,
|
||||
codex_protocol::protocol::AgentStatus::Errored(_)
|
||||
| codex_protocol::protocol::AgentStatus::NotFound
|
||||
)
|
||||
}) {
|
||||
V2CollabToolCallStatus::Failed
|
||||
} else {
|
||||
V2CollabToolCallStatus::Completed
|
||||
};
|
||||
let receiver_thread_ids = end_event.statuses.keys().map(ToString::to_string).collect();
|
||||
let agents_states = end_event
|
||||
.statuses
|
||||
.iter()
|
||||
.map(|(id, status)| (id.to_string(), V2CollabAgentStatus::from(status.clone())))
|
||||
.collect();
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: end_event.call_id,
|
||||
tool: CollabAgentTool::Wait,
|
||||
status,
|
||||
sender_thread_id: end_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids,
|
||||
prompt: None,
|
||||
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
|
||||
agents_states,
|
||||
};
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -416,9 +449,9 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
tool: CollabAgentTool::CloseAgent,
|
||||
status: V2CollabToolCallStatus::InProgress,
|
||||
sender_thread_id: begin_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(begin_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids: vec![begin_event.receiver_thread_id.to_string()],
|
||||
prompt: None,
|
||||
agent_state: None,
|
||||
agents_states: HashMap::new(),
|
||||
};
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -430,19 +463,26 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabCloseEnd(end_event) => {
|
||||
let status = match end_event.status {
|
||||
let status = match &end_event.status {
|
||||
codex_protocol::protocol::AgentStatus::Errored(_)
|
||||
| codex_protocol::protocol::AgentStatus::NotFound => V2CollabToolCallStatus::Failed,
|
||||
_ => V2CollabToolCallStatus::Completed,
|
||||
};
|
||||
let receiver_id = end_event.receiver_thread_id.to_string();
|
||||
let agents_states = [(
|
||||
receiver_id.clone(),
|
||||
V2CollabAgentStatus::from(end_event.status),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: end_event.call_id,
|
||||
tool: CollabAgentTool::CloseAgent,
|
||||
status,
|
||||
sender_thread_id: end_event.sender_thread_id.to_string(),
|
||||
receiver_thread_id: Some(end_event.receiver_thread_id.to_string()),
|
||||
receiver_thread_ids: vec![receiver_id],
|
||||
prompt: None,
|
||||
agent_state: Some(V2CollabAgentStatus::from(end_event.status)),
|
||||
agents_states,
|
||||
};
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
|
||||
@@ -254,20 +254,26 @@ mod send_input {
|
||||
mod wait {
|
||||
use super::*;
|
||||
use crate::agent::status::is_final;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use tokio::time::timeout_at;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WaitArgs {
|
||||
id: String,
|
||||
ids: Vec<String>,
|
||||
timeout_ms: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct WaitResult {
|
||||
status: AgentStatus,
|
||||
status: HashMap<ThreadId, AgentStatus>,
|
||||
timed_out: bool,
|
||||
}
|
||||
|
||||
@@ -278,7 +284,16 @@ mod wait {
|
||||
arguments: String,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: WaitArgs = parse_arguments(&arguments)?;
|
||||
let receiver_thread_id = agent_id(&args.id)?;
|
||||
if args.ids.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"ids must be non-empty".to_owned(),
|
||||
));
|
||||
}
|
||||
let receiver_thread_ids = args
|
||||
.ids
|
||||
.iter()
|
||||
.map(|id| agent_id(id))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Validate timeout.
|
||||
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
|
||||
@@ -296,105 +311,131 @@ mod wait {
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let status_rx = match session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(receiver_thread_id)
|
||||
.await
|
||||
{
|
||||
Ok(status_rx) => status_rx,
|
||||
Err(err) => {
|
||||
let status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
call_id: call_id.clone(),
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(receiver_thread_id, err));
|
||||
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(*id, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
// Wait for the first agent to reach a final status.
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
// Drain the unlikely last elements to prevent race.
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
let result =
|
||||
wait_for_status(session.as_ref(), receiver_thread_id, timeout_ms, status_rx).await;
|
||||
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
// Final event emission.
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
call_id,
|
||||
status: result.status.clone(),
|
||||
statuses: statuses_map,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if matches!(result.status, AgentStatus::NotFound) {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"agent with id {receiver_thread_id} not found"
|
||||
)));
|
||||
}
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
|
||||
let success = !result.timed_out && !matches!(result.status, AgentStatus::Errored(_));
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(success),
|
||||
success: None,
|
||||
content_items: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn wait_for_status(
|
||||
session: &Session,
|
||||
agent_id: ThreadId,
|
||||
timeout_ms: i64,
|
||||
mut status_rx: tokio::sync::watch::Receiver<AgentStatus>,
|
||||
) -> WaitResult {
|
||||
// Get last known status.
|
||||
let mut status = status_rx.borrow_and_update().clone();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
async fn wait_for_final_status(
|
||||
session: Arc<Session>,
|
||||
thread_id: ThreadId,
|
||||
mut status_rx: Receiver<AgentStatus>,
|
||||
) -> Option<(ThreadId, AgentStatus)> {
|
||||
let mut status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
|
||||
let timed_out = loop {
|
||||
loop {
|
||||
if status_rx.changed().await.is_err() {
|
||||
let latest = session.services.agent_control.get_status(thread_id).await;
|
||||
return is_final(&latest).then_some((thread_id, latest));
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
break false;
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
|
||||
match timeout_at(deadline, status_rx.changed()).await {
|
||||
Ok(Ok(())) => status = status_rx.borrow().clone(),
|
||||
Ok(Err(_)) => {
|
||||
let last_status = session.services.agent_control.get_status(agent_id).await;
|
||||
if last_status != AgentStatus::NotFound {
|
||||
// On-purpose we keep the last known status if the agent gets dropped. This
|
||||
// event is not supposed to happen.
|
||||
status = last_status;
|
||||
}
|
||||
break false;
|
||||
}
|
||||
Err(_) => break true,
|
||||
}
|
||||
};
|
||||
|
||||
WaitResult { status, timed_out }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -560,7 +601,9 @@ mod tests {
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -772,6 +815,12 @@ mod tests {
|
||||
.expect("shutdown should submit");
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq, Eq)]
|
||||
struct WaitResult {
|
||||
status: HashMap<ThreadId, AgentStatus>,
|
||||
timed_out: bool,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_rejects_non_positive_timeout() {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
@@ -779,7 +828,10 @@ mod tests {
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": ThreadId::new().to_string(), "timeout_ms": 0})),
|
||||
function_payload(json!({
|
||||
"ids": [ThreadId::new().to_string()],
|
||||
"timeout_ms": 0
|
||||
})),
|
||||
);
|
||||
let Err(err) = CollabHandler.handle(invocation).await else {
|
||||
panic!("non-positive timeout should be rejected");
|
||||
@@ -797,7 +849,7 @@ mod tests {
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": "invalid"})),
|
||||
function_payload(json!({"ids": ["invalid"]})),
|
||||
);
|
||||
let Err(err) = CollabHandler.handle(invocation).await else {
|
||||
panic!("invalid id should be rejected");
|
||||
@@ -808,6 +860,65 @@ mod tests {
|
||||
assert!(msg.starts_with("invalid agent id invalid:"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_rejects_empty_ids() {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"ids": []})),
|
||||
);
|
||||
let Err(err) = CollabHandler.handle(invocation).await else {
|
||||
panic!("empty ids should be rejected");
|
||||
};
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("ids must be non-empty".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_returns_not_found_for_missing_agents() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let id_a = ThreadId::new();
|
||||
let id_b = ThreadId::new();
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({
|
||||
"ids": [id_a.to_string(), id_b.to_string()],
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait should succeed");
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = output
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
status: HashMap::from([
|
||||
(id_a, AgentStatus::NotFound),
|
||||
(id_b, AgentStatus::NotFound),
|
||||
]),
|
||||
timed_out: false
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_times_out_when_status_is_not_final() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
@@ -820,7 +931,10 @@ mod tests {
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 10})),
|
||||
function_payload(json!({
|
||||
"ids": [agent_id.to_string()],
|
||||
"timeout_ms": 10
|
||||
})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
@@ -832,8 +946,16 @@ mod tests {
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
assert_eq!(content, r#"{"status":"pending_init","timed_out":true}"#);
|
||||
assert_eq!(success, Some(false));
|
||||
let result: WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
status: HashMap::new(),
|
||||
timed_out: true
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
|
||||
let _ = thread
|
||||
.thread
|
||||
@@ -869,7 +991,10 @@ mod tests {
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 1000})),
|
||||
function_payload(json!({
|
||||
"ids": [agent_id.to_string()],
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
@@ -881,8 +1006,16 @@ mod tests {
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
assert_eq!(content, r#"{"status":"shutdown","timed_out":false}"#);
|
||||
assert_eq!(success, Some(true));
|
||||
let result: WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
|
||||
timed_out: false
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -503,9 +503,10 @@ fn create_send_input_tool() -> ToolSpec {
|
||||
fn create_wait_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Identifier of the agent to wait on.".to_string()),
|
||||
"ids".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some("Identifiers of the agents to wait on.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
@@ -519,11 +520,13 @@ fn create_wait_tool() -> ToolSpec {
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "wait".to_string(),
|
||||
description: "Wait for an agent and return its status.".to_string(),
|
||||
description:
|
||||
"Wait for agents and return their statuses. If no agent is done, no status get returned."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["id".to_string()]),
|
||||
required: Some(vec!["ids".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
|
||||
@@ -2177,8 +2177,8 @@ pub struct CollabAgentInteractionEndEvent {
|
||||
pub struct CollabWaitingBeginEvent {
|
||||
/// Thread ID of the sender.
|
||||
pub sender_thread_id: ThreadId,
|
||||
/// Thread ID of the receiver.
|
||||
pub receiver_thread_id: ThreadId,
|
||||
/// Thread ID of the receivers.
|
||||
pub receiver_thread_ids: Vec<ThreadId>,
|
||||
/// ID of the waiting call.
|
||||
pub call_id: String,
|
||||
}
|
||||
@@ -2187,12 +2187,10 @@ pub struct CollabWaitingBeginEvent {
|
||||
pub struct CollabWaitingEndEvent {
|
||||
/// Thread ID of the sender.
|
||||
pub sender_thread_id: ThreadId,
|
||||
/// Thread ID of the receiver.
|
||||
pub receiver_thread_id: ThreadId,
|
||||
/// ID of the waiting call.
|
||||
pub call_id: String,
|
||||
/// Last known status of the receiver agent reported to the sender agent.
|
||||
pub status: AgentStatus,
|
||||
/// Last known status of the receiver agents reported to the sender agent.
|
||||
pub statuses: HashMap<ThreadId, AgentStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema, TS)]
|
||||
|
||||
@@ -59,12 +59,12 @@ pub(crate) fn waiting_begin(ev: CollabWaitingBeginEvent) -> PlainHistoryCell {
|
||||
let CollabWaitingBeginEvent {
|
||||
call_id,
|
||||
sender_thread_id,
|
||||
receiver_thread_id,
|
||||
receiver_thread_ids,
|
||||
} = ev;
|
||||
let details = vec![
|
||||
detail_line("call", call_id),
|
||||
detail_line("sender", sender_thread_id),
|
||||
detail_line("receiver", receiver_thread_id),
|
||||
detail_line("receiver", format!("{receiver_thread_ids:?}")),
|
||||
];
|
||||
collab_event("Collab wait begin", details)
|
||||
}
|
||||
@@ -73,14 +73,12 @@ pub(crate) fn waiting_end(ev: CollabWaitingEndEvent) -> PlainHistoryCell {
|
||||
let CollabWaitingEndEvent {
|
||||
call_id,
|
||||
sender_thread_id,
|
||||
receiver_thread_id,
|
||||
status,
|
||||
statuses,
|
||||
} = ev;
|
||||
let details = vec![
|
||||
detail_line("call", call_id),
|
||||
detail_line("sender", sender_thread_id),
|
||||
detail_line("receiver", receiver_thread_id),
|
||||
status_line(&status),
|
||||
detail_line("statuses", format!("{statuses:#?}")),
|
||||
];
|
||||
collab_event("Collab wait end", details)
|
||||
}
|
||||
|
||||
@@ -59,12 +59,12 @@ pub(crate) fn waiting_begin(ev: CollabWaitingBeginEvent) -> PlainHistoryCell {
|
||||
let CollabWaitingBeginEvent {
|
||||
call_id,
|
||||
sender_thread_id,
|
||||
receiver_thread_id,
|
||||
receiver_thread_ids,
|
||||
} = ev;
|
||||
let details = vec![
|
||||
detail_line("call", call_id),
|
||||
detail_line("sender", sender_thread_id),
|
||||
detail_line("receiver", receiver_thread_id),
|
||||
detail_line("receiver", format!("{receiver_thread_ids:?}")),
|
||||
];
|
||||
collab_event("Collab wait begin", details)
|
||||
}
|
||||
@@ -73,14 +73,12 @@ pub(crate) fn waiting_end(ev: CollabWaitingEndEvent) -> PlainHistoryCell {
|
||||
let CollabWaitingEndEvent {
|
||||
call_id,
|
||||
sender_thread_id,
|
||||
receiver_thread_id,
|
||||
status,
|
||||
statuses,
|
||||
} = ev;
|
||||
let details = vec![
|
||||
detail_line("call", call_id),
|
||||
detail_line("sender", sender_thread_id),
|
||||
detail_line("receiver", receiver_thread_id),
|
||||
status_line(&status),
|
||||
detail_line("statuses", format!("{statuses:#?}")),
|
||||
];
|
||||
collab_event("Collab wait end", details)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user