Compare commits

...

7 Commits

Author SHA1 Message Date
jif-oai
4acffebebc agent control design 2026-01-26 22:07:02 +00:00
jif-oai
cf0c52504e switch to tool call 2026-01-26 19:21:52 +00:00
jif-oai
49ee35b86c better message 2026-01-26 18:47:51 +00:00
jif-oai
4dd80700cb Merge remote-tracking branch 'origin/main' into jif/warning-wait-done 2026-01-26 18:28:37 +00:00
jif-oai
ae13e569c6 make it better 2026-01-26 18:19:22 +00:00
jif-oai
ebddb1071d feat: emit warning when sub agent is done 2026-01-26 18:07:28 +00:00
jif-oai
ec2551ba36 feat: emit warning when sub agent is done 2026-01-26 16:55:47 +00:00
4 changed files with 381 additions and 100 deletions

View File

@@ -1,14 +1,86 @@
use crate::agent::AgentStatus;
use crate::agent::guards::Guards;
use crate::agent::status::is_final;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::thread_manager::ThreadManagerState;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::Op;
use codex_protocol::user_input::UserInput;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::Weak;
use tokio::sync::watch;
use uuid::Uuid;
const SYNTHETIC_WAIT_TIMEOUT_MS: i64 = 300_000;
#[derive(Default)]
struct TurnCollabLedger {
spawned: HashSet<ThreadId>,
acknowledged: HashSet<ThreadId>,
in_flight_waits: HashMap<ThreadId, usize>,
}
impl TurnCollabLedger {
fn is_idle(&self) -> bool {
self.in_flight_waits.is_empty()
&& self
.spawned
.iter()
.all(|agent_id| self.acknowledged.contains(agent_id))
}
}
#[derive(Default)]
struct CollabLedger {
turns: HashMap<String, TurnCollabLedger>,
}
/// RAII guard that marks a set of agents as being explicitly waited on.
pub(crate) struct WaitGuard {
collab: Arc<Mutex<CollabLedger>>,
turn_id: String,
agent_ids: Vec<ThreadId>,
}
impl WaitGuard {
fn new(collab: Arc<Mutex<CollabLedger>>, turn_id: &str, agent_ids: &[ThreadId]) -> Self {
Self {
collab,
turn_id: turn_id.to_string(),
agent_ids: agent_ids.to_vec(),
}
}
}
impl Drop for WaitGuard {
fn drop(&mut self) {
let mut collab = self
.collab
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(turn) = collab.turns.get_mut(&self.turn_id) {
for agent_id in &self.agent_ids {
if let Some(count) = turn.in_flight_waits.get_mut(agent_id) {
*count = count.saturating_sub(1);
if *count == 0 {
turn.in_flight_waits.remove(agent_id);
}
}
}
if turn.is_idle() {
collab.turns.remove(&self.turn_id);
}
}
}
}
/// Control-plane handle for multi-agent operations.
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
@@ -23,6 +95,7 @@ pub(crate) struct AgentControl {
/// `ThreadManagerState -> CodexThread -> Session -> SessionServices -> ThreadManagerState`.
manager: Weak<ThreadManagerState>,
state: Arc<Guards>,
collab: Arc<Mutex<CollabLedger>>,
}
impl AgentControl {
@@ -129,11 +202,148 @@ impl AgentControl {
Ok(thread.subscribe_status())
}
pub(crate) fn register_spawn(&self, turn_id: &str, agent_id: ThreadId) {
let mut collab = self.lock_collab();
let turn = collab.turns.entry(turn_id.to_string()).or_default();
turn.spawned.insert(agent_id);
turn.acknowledged.remove(&agent_id);
}
pub(crate) fn wait_guard(&self, turn_id: &str, agent_ids: &[ThreadId]) -> WaitGuard {
let mut collab = self.lock_collab();
let turn = collab.turns.entry(turn_id.to_string()).or_default();
for agent_id in agent_ids {
*turn.in_flight_waits.entry(*agent_id).or_default() += 1;
}
WaitGuard::new(Arc::clone(&self.collab), turn_id, agent_ids)
}
pub(crate) fn acknowledge(&self, turn_id: &str, agent_ids: &[ThreadId]) {
let mut collab = self.lock_collab();
if let Some(turn) = collab.turns.get_mut(turn_id) {
turn.acknowledged.extend(agent_ids.iter().copied());
if turn.is_idle() {
collab.turns.remove(turn_id);
}
}
}
pub(crate) fn clear_turn(&self, turn_id: &str) {
let mut collab = self.lock_collab();
collab.turns.remove(turn_id);
}
pub(crate) async fn collect_unacknowledged_finals(
&self,
turn_id: &str,
) -> Vec<(ThreadId, AgentStatus)> {
let mut candidates = {
let collab = self.lock_collab();
collab
.turns
.get(turn_id)
.map(|turn| {
turn.spawned
.iter()
.copied()
.filter(|agent_id| {
!turn.acknowledged.contains(agent_id)
&& !turn.in_flight_waits.contains_key(agent_id)
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
};
if candidates.is_empty() {
return Vec::new();
}
candidates.sort_by_key(std::string::ToString::to_string);
let mut finals = Vec::new();
for agent_id in candidates {
let status = self.get_status(agent_id).await;
if is_final(&status) {
finals.push((agent_id, status));
}
}
if finals.is_empty() {
return Vec::new();
}
let mut ready = Vec::new();
let mut collab = self.lock_collab();
if let Some(turn) = collab.turns.get_mut(turn_id) {
for (agent_id, status) in finals {
if turn.acknowledged.contains(&agent_id)
|| turn.in_flight_waits.contains_key(&agent_id)
{
continue;
}
turn.acknowledged.insert(agent_id);
ready.push((agent_id, status));
}
if turn.is_idle() {
collab.turns.remove(turn_id);
}
}
ready
}
pub(crate) fn synthetic_wait_items(statuses: &[(ThreadId, AgentStatus)]) -> Vec<ResponseItem> {
if statuses.is_empty() {
return Vec::new();
}
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
let mut ids = statuses
.iter()
.map(|(agent_id, _)| agent_id.to_string())
.collect::<Vec<_>>();
ids.sort();
let arguments = json!({
"ids": ids,
"timeout_ms": SYNTHETIC_WAIT_TIMEOUT_MS,
})
.to_string();
let status_map = statuses
.iter()
.map(|(agent_id, status)| (agent_id.to_string(), status.clone()))
.collect::<BTreeMap<_, _>>();
let output = json!({
"status": status_map,
"timed_out": false,
})
.to_string();
let call = ResponseItem::FunctionCall {
id: None,
name: "wait".to_string(),
arguments,
call_id: call_id.clone(),
};
let output = ResponseItem::FunctionCallOutput {
call_id,
output: FunctionCallOutputPayload {
content: output,
..Default::default()
},
};
vec![call, output]
}
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
self.manager
.upgrade()
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
}
fn lock_collab(&self) -> std::sync::MutexGuard<'_, CollabLedger> {
self.collab
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[cfg(test)]

View File

@@ -3006,11 +3006,20 @@ pub(crate) async fn run_turn(
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
// may support this, the model might not.
let pending_input = sess
.get_pending_input()
.await
let collab_ready = sess
.services
.agent_control
.collect_unacknowledged_finals(&turn_context.sub_id)
.await;
let collab_ready_items = AgentControl::synthetic_wait_items(&collab_ready);
let pending_input = collab_ready_items
.into_iter()
.map(ResponseItem::from)
.chain(
sess.get_pending_input()
.await
.into_iter()
.map(ResponseItem::from),
)
.collect::<Vec<ResponseItem>>();
// Construct the input that we will send to the model.

View File

@@ -189,6 +189,7 @@ impl Session {
false
};
drop(active);
self.services.agent_control.clear_turn(&turn_context.sub_id);
if should_close_processes {
self.close_unified_exec_processes().await;
}
@@ -207,7 +208,11 @@ impl Session {
let mut active = self.active_turn.lock().await;
match active.take() {
Some(mut at) => {
let turn_ids = at.tasks.keys().cloned().collect::<Vec<_>>();
at.clear_pending().await;
for turn_id in turn_ids {
self.services.agent_control.clear_turn(&turn_id);
}
at.drain_tasks()
}

View File

@@ -175,6 +175,10 @@ mod spawn {
)
.await;
let new_thread_id = result?;
session
.services
.agent_control
.register_spawn(&turn.sub_id, new_thread_id);
let content = serde_json::to_string(&SpawnAgentResult {
agent_id: new_thread_id.to_string(),
@@ -337,114 +341,128 @@ mod wait {
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
};
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let _wait_guard = session
.services
.agent_control
.wait_guard(&turn.sub_id, &receiver_thread_ids);
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
async {
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
.into(),
)
.await;
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
Err(collab_agent_error(*id, err))?;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
}
results
};
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
statuses: statuses_map,
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
.into(),
)
.await;
results
};
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
session
.services
.agent_control
.acknowledge(&turn.sub_id, &collected_ids);
Ok(ToolOutput::Function {
content,
success: None,
content_items: None,
})
// Convert payload.
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: collected_ids.is_empty(),
};
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
statuses: statuses_map,
}
.into(),
)
.await;
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: None,
content_items: None,
})
}
.await
}
async fn wait_for_final_status(
@@ -1002,6 +1020,45 @@ mod tests {
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_handles_missing_agents() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let id = ThreadId::new();
let session = Arc::new(session);
let turn = Arc::new(turn);
let invocation = invocation(
session.clone(),
turn.clone(),
"wait",
function_payload(json!({
"ids": [id.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([(id, AgentStatus::NotFound)]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_times_out_when_status_is_not_final() {
let (mut session, turn) = make_session_and_context().await;