mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
make it better
This commit is contained in:
@@ -15,7 +15,7 @@ pub(crate) fn spawn_collab_completion_warning_watcher(
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
|
||||
&& !crate::agent::is_collab_wait_active(
|
||||
&& !crate::agent::is_collab_wait_suppressed(
|
||||
session.as_ref(),
|
||||
&turn_context.sub_id,
|
||||
agent_id,
|
||||
|
||||
@@ -1,32 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
use crate::codex::Session;
|
||||
|
||||
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let mut state = active_turn.turn_state.lock().await;
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.begin_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let mut state = active_turn.turn_state.lock().await;
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.end_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn is_collab_wait_active(
|
||||
pub(crate) async fn mark_collab_wait_collected(
|
||||
session: &Session,
|
||||
turn_id: &str,
|
||||
agent_ids: &[ThreadId],
|
||||
) {
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.mark_wait_collected(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn is_collab_wait_suppressed(
|
||||
session: &Session,
|
||||
turn_id: &str,
|
||||
agent_id: ThreadId,
|
||||
) -> bool {
|
||||
let active = session.active_turn.lock().await;
|
||||
if let Some(active_turn) = active.as_ref() {
|
||||
let state = active_turn.turn_state.lock().await;
|
||||
return state.is_waiting_on(turn_id, agent_id);
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let state = turn_state.lock().await;
|
||||
state.is_waiting_on(turn_id, agent_id) || state.is_wait_collected(turn_id, agent_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ pub(crate) use codex_protocol::protocol::AgentStatus;
|
||||
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
|
||||
pub(crate) use collab_wait_tracking::begin_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::end_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::is_collab_wait_active;
|
||||
pub(crate) use collab_wait_tracking::is_collab_wait_suppressed;
|
||||
pub(crate) use collab_wait_tracking::mark_collab_wait_collected;
|
||||
pub(crate) use control::AgentControl;
|
||||
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
|
||||
pub(crate) use guards::exceeds_thread_spawn_depth_limit;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
@@ -75,6 +76,7 @@ pub(crate) struct TurnState {
|
||||
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
|
||||
collected_waits: HashMap<String, HashSet<ThreadId>>,
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
@@ -99,6 +101,7 @@ impl TurnState {
|
||||
self.pending_dynamic_tools.clear();
|
||||
self.pending_input.clear();
|
||||
self.active_waits.clear();
|
||||
self.collected_waits.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn insert_pending_user_input(
|
||||
@@ -185,6 +188,20 @@ impl TurnState {
|
||||
.get(turn_id)
|
||||
.is_some_and(|waits| waits.contains_key(&agent_id))
|
||||
}
|
||||
|
||||
pub(crate) fn mark_wait_collected(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let collected = self.collected_waits.entry(turn_id.to_string()).or_default();
|
||||
collected.extend(agent_ids.iter().copied());
|
||||
}
|
||||
|
||||
pub(crate) fn is_wait_collected(&self, turn_id: &str, agent_id: ThreadId) -> bool {
|
||||
self.collected_waits
|
||||
.get(turn_id)
|
||||
.is_some_and(|collected| collected.contains(&agent_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
@@ -205,7 +222,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wait_tracking_is_turn_scoped_and_reference_counted() {
|
||||
fn wait_tracking_is_turn_scoped_and_reference_counted_and_collected() {
|
||||
let mut state = TurnState::default();
|
||||
let turn_a = "turn-a";
|
||||
let turn_b = "turn-b";
|
||||
@@ -224,5 +241,13 @@ mod tests {
|
||||
state.end_wait(turn_a, &[agent]);
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), false);
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), true);
|
||||
|
||||
state.mark_wait_collected(turn_a, &[agent]);
|
||||
assert_eq!(state.is_wait_collected(turn_a, agent), true);
|
||||
assert_eq!(state.is_wait_collected(turn_b, agent), false);
|
||||
|
||||
state.clear_pending();
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), false);
|
||||
assert_eq!(state.is_wait_collected(turn_a, agent), false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,11 +414,19 @@ mod wait {
|
||||
results
|
||||
};
|
||||
|
||||
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
|
||||
crate::agent::mark_collab_wait_collected(
|
||||
session.as_ref(),
|
||||
&turn.sub_id,
|
||||
&collected_ids,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
timed_out: collected_ids.is_empty(),
|
||||
};
|
||||
|
||||
// Final event emission.
|
||||
@@ -997,6 +1005,60 @@ mod tests {
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_marks_missing_agents_as_collected_and_suppressed() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let id = ThreadId::new();
|
||||
let session = Arc::new(session);
|
||||
let turn = Arc::new(turn);
|
||||
let turn_state = {
|
||||
let mut active = session.active_turn.lock().await;
|
||||
let active_turn = active.get_or_insert_with(crate::state::ActiveTurn::default);
|
||||
Arc::clone(&active_turn.turn_state)
|
||||
};
|
||||
let invocation = invocation(
|
||||
session.clone(),
|
||||
turn.clone(),
|
||||
"wait",
|
||||
function_payload(json!({
|
||||
"ids": [id.to_string()],
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait should succeed");
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = output
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
status: HashMap::from([(id, AgentStatus::NotFound)]),
|
||||
timed_out: false
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
|
||||
{
|
||||
let state = turn_state.lock().await;
|
||||
assert_eq!(state.is_waiting_on(&turn.sub_id, id), false);
|
||||
assert_eq!(state.is_wait_collected(&turn.sub_id, id), true);
|
||||
}
|
||||
|
||||
let suppressed =
|
||||
crate::agent::is_collab_wait_suppressed(session.as_ref(), &turn.sub_id, id).await;
|
||||
assert_eq!(suppressed, true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_times_out_when_status_is_not_final() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
|
||||
Reference in New Issue
Block a user