feat: emit warning when sub agent is done

This commit is contained in:
jif-oai
2026-01-26 18:07:28 +00:00
parent ec2551ba36
commit ebddb1071d
5 changed files with 219 additions and 98 deletions

View File

@@ -14,7 +14,14 @@ pub(crate) fn spawn_collab_completion_warning_watcher(
agent_id: ThreadId,
) {
tokio::spawn(async move {
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await {
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
&& !crate::agent::is_collab_wait_active(
session.as_ref(),
&turn_context.sub_id,
agent_id,
)
.await
{
let message = completion_warning_message(agent_id, &status);
session.record_model_warning(message, &turn_context).await;
}

View File

@@ -0,0 +1,32 @@
use codex_protocol::ThreadId;
use crate::codex::Session;
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let mut state = active_turn.turn_state.lock().await;
state.begin_wait(turn_id, agent_ids);
}
}
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let mut state = active_turn.turn_state.lock().await;
state.end_wait(turn_id, agent_ids);
}
}
pub(crate) async fn is_collab_wait_active(
session: &Session,
turn_id: &str,
agent_id: ThreadId,
) -> bool {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let state = active_turn.turn_state.lock().await;
return state.is_waiting_on(turn_id, agent_id);
}
false
}

View File

@@ -1,4 +1,5 @@
mod collab_completion_warning;
mod collab_wait_tracking;
pub(crate) mod control;
mod guards;
pub(crate) mod role;
@@ -6,6 +7,9 @@ pub(crate) mod status;
pub(crate) use codex_protocol::protocol::AgentStatus;
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
pub(crate) use collab_wait_tracking::begin_collab_wait;
pub(crate) use collab_wait_tracking::end_collab_wait;
pub(crate) use collab_wait_tracking::is_collab_wait_active;
pub(crate) use control::AgentControl;
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
pub(crate) use guards::exceeds_thread_spawn_depth_limit;

View File

@@ -16,6 +16,7 @@ use tokio::sync::oneshot;
use crate::codex::TurnContext;
use crate::protocol::ReviewDecision;
use crate::tasks::SessionTask;
use codex_protocol::ThreadId;
/// Metadata about the currently running turn.
pub(crate) struct ActiveTurn {
@@ -73,6 +74,7 @@ pub(crate) struct TurnState {
pending_user_input: HashMap<String, oneshot::Sender<RequestUserInputResponse>>,
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pending_input: Vec<ResponseInputItem>,
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
}
impl TurnState {
@@ -96,6 +98,7 @@ impl TurnState {
self.pending_user_input.clear();
self.pending_dynamic_tools.clear();
self.pending_input.clear();
self.active_waits.clear();
}
pub(crate) fn insert_pending_user_input(
@@ -145,6 +148,43 @@ impl TurnState {
pub(crate) fn has_pending_input(&self) -> bool {
!self.pending_input.is_empty()
}
pub(crate) fn begin_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let waits = self.active_waits.entry(turn_id.to_string()).or_default();
for agent_id in agent_ids {
*waits.entry(*agent_id).or_default() += 1;
}
}
pub(crate) fn end_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let mut remove_turn = false;
if let Some(waits) = self.active_waits.get_mut(turn_id) {
for agent_id in agent_ids {
if let Some(count) = waits.get_mut(agent_id) {
*count = count.saturating_sub(1);
if *count == 0 {
waits.remove(agent_id);
}
}
}
remove_turn = waits.is_empty();
}
if remove_turn {
self.active_waits.remove(turn_id);
}
}
pub(crate) fn is_waiting_on(&self, turn_id: &str, agent_id: ThreadId) -> bool {
self.active_waits
.get(turn_id)
.is_some_and(|waits| waits.contains_key(&agent_id))
}
}
impl ActiveTurn {
@@ -154,3 +194,35 @@ impl ActiveTurn {
ts.clear_pending();
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
fn thread_id(s: &str) -> ThreadId {
ThreadId::from_string(s).expect("valid thread id")
}
#[test]
fn wait_tracking_is_turn_scoped_and_reference_counted() {
let mut state = TurnState::default();
let turn_a = "turn-a";
let turn_b = "turn-b";
let agent = thread_id("00000000-0000-7000-0000-000000000001");
state.begin_wait(turn_a, &[agent]);
state.begin_wait(turn_a, &[agent]);
state.begin_wait(turn_b, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), true);
assert_eq!(state.is_waiting_on(turn_b, agent), true);
state.end_wait(turn_a, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), true);
state.end_wait(turn_a, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), false);
assert_eq!(state.is_waiting_on(turn_b, agent), true);
}
}

View File

@@ -333,114 +333,120 @@ mod wait {
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
};
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
crate::agent::begin_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
let result = async {
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
.into(),
)
.await;
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
Err(collab_agent_error(*id, err))?;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
}
results
};
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
statuses: statuses_map,
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
.into(),
)
.await;
results
};
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
Ok(ToolOutput::Function {
content,
success: None,
content_items: None,
})
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
statuses: statuses_map,
}
.into(),
)
.await;
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: None,
content_items: None,
})
}
.await;
crate::agent::end_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
result
}
async fn wait_for_final_status(