mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
feat: emit warning when sub agent is done
This commit is contained in:
@@ -14,7 +14,14 @@ pub(crate) fn spawn_collab_completion_warning_watcher(
|
||||
agent_id: ThreadId,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await {
|
||||
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
|
||||
&& !crate::agent::is_collab_wait_active(
|
||||
session.as_ref(),
|
||||
&turn_context.sub_id,
|
||||
agent_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let message = completion_warning_message(agent_id, &status);
|
||||
session.record_model_warning(message, &turn_context).await;
|
||||
}
|
||||
|
||||
32
codex-rs/core/src/agent/collab_wait_tracking.rs
Normal file
32
codex-rs/core/src/agent/collab_wait_tracking.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
use crate::codex::Session;
|
||||
|
||||
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let mut state = active_turn.turn_state.lock().await;
|
||||
state.begin_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let mut state = active_turn.turn_state.lock().await;
|
||||
state.end_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn is_collab_wait_active(
|
||||
session: &Session,
|
||||
turn_id: &str,
|
||||
agent_id: ThreadId,
|
||||
) -> bool {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let state = active_turn.turn_state.lock().await;
|
||||
return state.is_waiting_on(turn_id, agent_id);
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod collab_completion_warning;
|
||||
mod collab_wait_tracking;
|
||||
pub(crate) mod control;
|
||||
mod guards;
|
||||
pub(crate) mod role;
|
||||
@@ -6,6 +7,9 @@ pub(crate) mod status;
|
||||
|
||||
pub(crate) use codex_protocol::protocol::AgentStatus;
|
||||
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
|
||||
pub(crate) use collab_wait_tracking::begin_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::end_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::is_collab_wait_active;
|
||||
pub(crate) use control::AgentControl;
|
||||
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
|
||||
pub(crate) use guards::exceeds_thread_spawn_depth_limit;
|
||||
|
||||
@@ -16,6 +16,7 @@ use tokio::sync::oneshot;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
pub(crate) struct ActiveTurn {
|
||||
@@ -73,6 +74,7 @@ pub(crate) struct TurnState {
|
||||
pending_user_input: HashMap<String, oneshot::Sender<RequestUserInputResponse>>,
|
||||
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
@@ -96,6 +98,7 @@ impl TurnState {
|
||||
self.pending_user_input.clear();
|
||||
self.pending_dynamic_tools.clear();
|
||||
self.pending_input.clear();
|
||||
self.active_waits.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn insert_pending_user_input(
|
||||
@@ -145,6 +148,43 @@ impl TurnState {
|
||||
pub(crate) fn has_pending_input(&self) -> bool {
|
||||
!self.pending_input.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn begin_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let waits = self.active_waits.entry(turn_id.to_string()).or_default();
|
||||
for agent_id in agent_ids {
|
||||
*waits.entry(*agent_id).or_default() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn end_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let mut remove_turn = false;
|
||||
if let Some(waits) = self.active_waits.get_mut(turn_id) {
|
||||
for agent_id in agent_ids {
|
||||
if let Some(count) = waits.get_mut(agent_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
waits.remove(agent_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
remove_turn = waits.is_empty();
|
||||
}
|
||||
if remove_turn {
|
||||
self.active_waits.remove(turn_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_waiting_on(&self, turn_id: &str, agent_id: ThreadId) -> bool {
|
||||
self.active_waits
|
||||
.get(turn_id)
|
||||
.is_some_and(|waits| waits.contains_key(&agent_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
@@ -154,3 +194,35 @@ impl ActiveTurn {
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn thread_id(s: &str) -> ThreadId {
|
||||
ThreadId::from_string(s).expect("valid thread id")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wait_tracking_is_turn_scoped_and_reference_counted() {
|
||||
let mut state = TurnState::default();
|
||||
let turn_a = "turn-a";
|
||||
let turn_b = "turn-b";
|
||||
let agent = thread_id("00000000-0000-7000-0000-000000000001");
|
||||
|
||||
state.begin_wait(turn_a, &[agent]);
|
||||
state.begin_wait(turn_a, &[agent]);
|
||||
state.begin_wait(turn_b, &[agent]);
|
||||
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), true);
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), true);
|
||||
|
||||
state.end_wait(turn_a, &[agent]);
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), true);
|
||||
|
||||
state.end_wait(turn_a, &[agent]);
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), false);
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,114 +333,120 @@ mod wait {
|
||||
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
crate::agent::begin_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
|
||||
let result = async {
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(*id, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
// Wait for the first agent to reach a final status.
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
Err(collab_agent_error(*id, err))?;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
// Drain the unlikely last elements to prevent race.
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
// Wait for the first agent to reach a final status.
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
// Final event emission.
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
statuses: statuses_map,
|
||||
if !results.is_empty() {
|
||||
// Drain the unlikely last elements to prevent race.
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
results
|
||||
};
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: None,
|
||||
content_items: None,
|
||||
})
|
||||
// Final event emission.
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
statuses: statuses_map,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: None,
|
||||
content_items: None,
|
||||
})
|
||||
}
|
||||
.await;
|
||||
crate::agent::end_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
|
||||
result
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(
|
||||
|
||||
Reference in New Issue
Block a user