make it better

This commit is contained in:
jif-oai
2026-01-26 18:19:22 +00:00
parent ebddb1071d
commit ae13e569c6
5 changed files with 140 additions and 17 deletions

View File

@@ -15,7 +15,7 @@ pub(crate) fn spawn_collab_completion_warning_watcher(
) {
tokio::spawn(async move {
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
&& !crate::agent::is_collab_wait_active(
&& !crate::agent::is_collab_wait_suppressed(
session.as_ref(),
&turn_context.sub_id,
agent_id,

View File

@@ -1,32 +1,67 @@
use std::sync::Arc;
use codex_protocol::ThreadId;
use crate::codex::Session;
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let mut state = active_turn.turn_state.lock().await;
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.begin_wait(turn_id, agent_ids);
}
}
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let mut state = active_turn.turn_state.lock().await;
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.end_wait(turn_id, agent_ids);
}
}
pub(crate) async fn is_collab_wait_active(
pub(crate) async fn mark_collab_wait_collected(
session: &Session,
turn_id: &str,
agent_ids: &[ThreadId],
) {
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.mark_wait_collected(turn_id, agent_ids);
}
}
pub(crate) async fn is_collab_wait_suppressed(
session: &Session,
turn_id: &str,
agent_id: ThreadId,
) -> bool {
let active = session.active_turn.lock().await;
if let Some(active_turn) = active.as_ref() {
let state = active_turn.turn_state.lock().await;
return state.is_waiting_on(turn_id, agent_id);
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let state = turn_state.lock().await;
state.is_waiting_on(turn_id, agent_id) || state.is_wait_collected(turn_id, agent_id)
} else {
false
}
false
}

View File

@@ -9,7 +9,8 @@ pub(crate) use codex_protocol::protocol::AgentStatus;
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
pub(crate) use collab_wait_tracking::begin_collab_wait;
pub(crate) use collab_wait_tracking::end_collab_wait;
pub(crate) use collab_wait_tracking::is_collab_wait_active;
pub(crate) use collab_wait_tracking::is_collab_wait_suppressed;
pub(crate) use collab_wait_tracking::mark_collab_wait_collected;
pub(crate) use control::AgentControl;
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
pub(crate) use guards::exceeds_thread_spawn_depth_limit;

View File

@@ -2,6 +2,7 @@
use indexmap::IndexMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::Notify;
@@ -75,6 +76,7 @@ pub(crate) struct TurnState {
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pending_input: Vec<ResponseInputItem>,
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
collected_waits: HashMap<String, HashSet<ThreadId>>,
}
impl TurnState {
@@ -99,6 +101,7 @@ impl TurnState {
self.pending_dynamic_tools.clear();
self.pending_input.clear();
self.active_waits.clear();
self.collected_waits.clear();
}
pub(crate) fn insert_pending_user_input(
@@ -185,6 +188,20 @@ impl TurnState {
.get(turn_id)
.is_some_and(|waits| waits.contains_key(&agent_id))
}
pub(crate) fn mark_wait_collected(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let collected = self.collected_waits.entry(turn_id.to_string()).or_default();
collected.extend(agent_ids.iter().copied());
}
pub(crate) fn is_wait_collected(&self, turn_id: &str, agent_id: ThreadId) -> bool {
self.collected_waits
.get(turn_id)
.is_some_and(|collected| collected.contains(&agent_id))
}
}
impl ActiveTurn {
@@ -205,7 +222,7 @@ mod tests {
}
#[test]
fn wait_tracking_is_turn_scoped_and_reference_counted() {
fn wait_tracking_is_turn_scoped_and_reference_counted_and_collected() {
let mut state = TurnState::default();
let turn_a = "turn-a";
let turn_b = "turn-b";
@@ -224,5 +241,13 @@ mod tests {
state.end_wait(turn_a, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), false);
assert_eq!(state.is_waiting_on(turn_b, agent), true);
state.mark_wait_collected(turn_a, &[agent]);
assert_eq!(state.is_wait_collected(turn_a, agent), true);
assert_eq!(state.is_wait_collected(turn_b, agent), false);
state.clear_pending();
assert_eq!(state.is_waiting_on(turn_b, agent), false);
assert_eq!(state.is_wait_collected(turn_a, agent), false);
}
}

View File

@@ -414,11 +414,19 @@ mod wait {
results
};
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
crate::agent::mark_collab_wait_collected(
session.as_ref(),
&turn.sub_id,
&collected_ids,
)
.await;
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
timed_out: collected_ids.is_empty(),
};
// Final event emission.
@@ -997,6 +1005,60 @@ mod tests {
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_marks_missing_agents_as_collected_and_suppressed() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let id = ThreadId::new();
let session = Arc::new(session);
let turn = Arc::new(turn);
let turn_state = {
let mut active = session.active_turn.lock().await;
let active_turn = active.get_or_insert_with(crate::state::ActiveTurn::default);
Arc::clone(&active_turn.turn_state)
};
let invocation = invocation(
session.clone(),
turn.clone(),
"wait",
function_payload(json!({
"ids": [id.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([(id, AgentStatus::NotFound)]),
timed_out: false
}
);
assert_eq!(success, None);
{
let state = turn_state.lock().await;
assert_eq!(state.is_waiting_on(&turn.sub_id, id), false);
assert_eq!(state.is_wait_collected(&turn.sub_id, id), true);
}
let suppressed =
crate::agent::is_collab_wait_suppressed(session.as_ref(), &turn.sub_id, id).await;
assert_eq!(suppressed, true);
}
#[tokio::test]
async fn wait_times_out_when_status_is_not_final() {
let (mut session, turn) = make_session_and_context().await;