agent control design

This commit is contained in:
jif-oai
2026-01-26 22:07:02 +00:00
parent cf0c52504e
commit 4acffebebc
8 changed files with 244 additions and 358 deletions

View File

@@ -1,151 +0,0 @@
use std::sync::Arc;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use serde_json::json;
use uuid::Uuid;
use crate::agent::AgentStatus;
use crate::agent::status::is_final;
use crate::codex::Session;
use crate::codex::TurnContext;
/// Subscribe to a spawned sub-agent and warn the model once it reaches a final status.
pub(crate) fn spawn_collab_completion_warning_watcher(
session: Arc<Session>,
turn_context: Arc<TurnContext>,
agent_id: ThreadId,
) {
tokio::spawn(async move {
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
&& !crate::agent::is_collab_wait_suppressed(
session.as_ref(),
&turn_context.sub_id,
agent_id,
)
.await
{
let items = synthetic_wait_items(agent_id, status);
session
.record_conversation_items(&turn_context, &items)
.await;
}
});
}
async fn wait_for_final_status(session: &Session, agent_id: ThreadId) -> Option<AgentStatus> {
let mut status_rx = match session
.services
.agent_control
.subscribe_status(agent_id)
.await
{
Ok(rx) => rx,
Err(_) => {
let status = session.services.agent_control.get_status(agent_id).await;
return is_final(&status).then_some(status);
}
};
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some(status);
}
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(agent_id).await;
return is_final(&latest).then_some(latest);
}
status = status_rx.borrow().clone();
if is_final(&status) {
return Some(status);
}
}
}
fn synthetic_wait_items(agent_id: ThreadId, status: AgentStatus) -> Vec<ResponseItem> {
tracing::info!("synthetic_wait_items: agent_id: {}, status: {:?}", agent_id, status);
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
let agent_id_str = agent_id.to_string();
let arguments = json!({
"ids": [agent_id_str.clone()],
"timeout_ms": 300_000,
})
.to_string();
let output = json!({
"status": { agent_id_str: status },
"timed_out": false,
})
.to_string();
let call = ResponseItem::FunctionCall {
id: None,
name: "wait".to_string(),
arguments,
call_id: call_id.clone(),
};
let output = ResponseItem::FunctionCallOutput {
call_id,
output: FunctionCallOutputPayload {
content: output,
..Default::default()
},
};
vec![call, output]
}
#[cfg(test)]
mod tests {
use super::synthetic_wait_items;
use crate::agent::AgentStatus;
use codex_protocol::ThreadId;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use serde_json::Value;
#[test]
fn synthetic_wait_items_look_like_a_real_wait_result() {
let agent_id =
ThreadId::from_string("00000000-0000-7000-0000-000000000001").expect("valid id");
let status = AgentStatus::Completed(Some("done".to_string()));
let items = synthetic_wait_items(agent_id, status.clone());
assert_eq!(items.len(), 2);
let (call_id, arguments_json) = match &items[0] {
ResponseItem::FunctionCall {
name,
call_id,
arguments,
..
} => {
assert_eq!(name, "wait");
(call_id.clone(), arguments.clone())
}
other => panic!("expected function call, got {other:?}"),
};
let args: Value = serde_json::from_str(&arguments_json).expect("arguments should be json");
let agent_id_string = agent_id.to_string();
assert_eq!(args["ids"][0].as_str(), Some(agent_id_string.as_str()));
match &items[1] {
ResponseItem::FunctionCallOutput {
call_id: out_id,
output,
} => {
assert_eq!(out_id, &call_id);
let out: Value =
serde_json::from_str(&output.content).expect("output should be json");
assert_eq!(out["timed_out"].as_bool(), Some(false));
let expected_status =
serde_json::to_value(status).expect("status should serialize");
assert_eq!(out["status"][agent_id_string.as_str()], expected_status);
}
other => panic!("expected function call output, got {other:?}"),
}
}
}

View File

@@ -1,67 +0,0 @@
use std::sync::Arc;
use codex_protocol::ThreadId;
use crate::codex::Session;
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.begin_wait(turn_id, agent_ids);
}
}
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.end_wait(turn_id, agent_ids);
}
}
pub(crate) async fn mark_collab_wait_collected(
session: &Session,
turn_id: &str,
agent_ids: &[ThreadId],
) {
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let mut state = turn_state.lock().await;
state.mark_wait_collected(turn_id, agent_ids);
}
}
pub(crate) async fn is_collab_wait_suppressed(
session: &Session,
turn_id: &str,
agent_id: ThreadId,
) -> bool {
let turn_state = {
let active = session.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
if let Some(turn_state) = turn_state {
let state = turn_state.lock().await;
state.is_waiting_on(turn_id, agent_id) || state.is_wait_collected(turn_id, agent_id)
} else {
false
}
}

View File

@@ -1,14 +1,86 @@
use crate::agent::AgentStatus;
use crate::agent::guards::Guards;
use crate::agent::status::is_final;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::thread_manager::ThreadManagerState;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::Op;
use codex_protocol::user_input::UserInput;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::Weak;
use tokio::sync::watch;
use uuid::Uuid;
const SYNTHETIC_WAIT_TIMEOUT_MS: i64 = 300_000;
#[derive(Default)]
struct TurnCollabLedger {
spawned: HashSet<ThreadId>,
acknowledged: HashSet<ThreadId>,
in_flight_waits: HashMap<ThreadId, usize>,
}
impl TurnCollabLedger {
fn is_idle(&self) -> bool {
self.in_flight_waits.is_empty()
&& self
.spawned
.iter()
.all(|agent_id| self.acknowledged.contains(agent_id))
}
}
#[derive(Default)]
struct CollabLedger {
turns: HashMap<String, TurnCollabLedger>,
}
/// RAII guard that marks a set of agents as being explicitly waited on.
pub(crate) struct WaitGuard {
collab: Arc<Mutex<CollabLedger>>,
turn_id: String,
agent_ids: Vec<ThreadId>,
}
impl WaitGuard {
fn new(collab: Arc<Mutex<CollabLedger>>, turn_id: &str, agent_ids: &[ThreadId]) -> Self {
Self {
collab,
turn_id: turn_id.to_string(),
agent_ids: agent_ids.to_vec(),
}
}
}
impl Drop for WaitGuard {
fn drop(&mut self) {
let mut collab = self
.collab
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(turn) = collab.turns.get_mut(&self.turn_id) {
for agent_id in &self.agent_ids {
if let Some(count) = turn.in_flight_waits.get_mut(agent_id) {
*count = count.saturating_sub(1);
if *count == 0 {
turn.in_flight_waits.remove(agent_id);
}
}
}
if turn.is_idle() {
collab.turns.remove(&self.turn_id);
}
}
}
}
/// Control-plane handle for multi-agent operations.
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
@@ -23,6 +95,7 @@ pub(crate) struct AgentControl {
/// `ThreadManagerState -> CodexThread -> Session -> SessionServices -> ThreadManagerState`.
manager: Weak<ThreadManagerState>,
state: Arc<Guards>,
collab: Arc<Mutex<CollabLedger>>,
}
impl AgentControl {
@@ -129,11 +202,148 @@ impl AgentControl {
Ok(thread.subscribe_status())
}
pub(crate) fn register_spawn(&self, turn_id: &str, agent_id: ThreadId) {
let mut collab = self.lock_collab();
let turn = collab.turns.entry(turn_id.to_string()).or_default();
turn.spawned.insert(agent_id);
turn.acknowledged.remove(&agent_id);
}
pub(crate) fn wait_guard(&self, turn_id: &str, agent_ids: &[ThreadId]) -> WaitGuard {
let mut collab = self.lock_collab();
let turn = collab.turns.entry(turn_id.to_string()).or_default();
for agent_id in agent_ids {
*turn.in_flight_waits.entry(*agent_id).or_default() += 1;
}
WaitGuard::new(Arc::clone(&self.collab), turn_id, agent_ids)
}
pub(crate) fn acknowledge(&self, turn_id: &str, agent_ids: &[ThreadId]) {
let mut collab = self.lock_collab();
if let Some(turn) = collab.turns.get_mut(turn_id) {
turn.acknowledged.extend(agent_ids.iter().copied());
if turn.is_idle() {
collab.turns.remove(turn_id);
}
}
}
pub(crate) fn clear_turn(&self, turn_id: &str) {
let mut collab = self.lock_collab();
collab.turns.remove(turn_id);
}
pub(crate) async fn collect_unacknowledged_finals(
&self,
turn_id: &str,
) -> Vec<(ThreadId, AgentStatus)> {
let mut candidates = {
let collab = self.lock_collab();
collab
.turns
.get(turn_id)
.map(|turn| {
turn.spawned
.iter()
.copied()
.filter(|agent_id| {
!turn.acknowledged.contains(agent_id)
&& !turn.in_flight_waits.contains_key(agent_id)
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
};
if candidates.is_empty() {
return Vec::new();
}
candidates.sort_by_key(std::string::ToString::to_string);
let mut finals = Vec::new();
for agent_id in candidates {
let status = self.get_status(agent_id).await;
if is_final(&status) {
finals.push((agent_id, status));
}
}
if finals.is_empty() {
return Vec::new();
}
let mut ready = Vec::new();
let mut collab = self.lock_collab();
if let Some(turn) = collab.turns.get_mut(turn_id) {
for (agent_id, status) in finals {
if turn.acknowledged.contains(&agent_id)
|| turn.in_flight_waits.contains_key(&agent_id)
{
continue;
}
turn.acknowledged.insert(agent_id);
ready.push((agent_id, status));
}
if turn.is_idle() {
collab.turns.remove(turn_id);
}
}
ready
}
pub(crate) fn synthetic_wait_items(statuses: &[(ThreadId, AgentStatus)]) -> Vec<ResponseItem> {
if statuses.is_empty() {
return Vec::new();
}
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
let mut ids = statuses
.iter()
.map(|(agent_id, _)| agent_id.to_string())
.collect::<Vec<_>>();
ids.sort();
let arguments = json!({
"ids": ids,
"timeout_ms": SYNTHETIC_WAIT_TIMEOUT_MS,
})
.to_string();
let status_map = statuses
.iter()
.map(|(agent_id, status)| (agent_id.to_string(), status.clone()))
.collect::<BTreeMap<_, _>>();
let output = json!({
"status": status_map,
"timed_out": false,
})
.to_string();
let call = ResponseItem::FunctionCall {
id: None,
name: "wait".to_string(),
arguments,
call_id: call_id.clone(),
};
let output = ResponseItem::FunctionCallOutput {
call_id,
output: FunctionCallOutputPayload {
content: output,
..Default::default()
},
};
vec![call, output]
}
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
self.manager
.upgrade()
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
}
fn lock_collab(&self) -> std::sync::MutexGuard<'_, CollabLedger> {
self.collab
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[cfg(test)]

View File

@@ -1,16 +1,9 @@
mod collab_completion_warning;
mod collab_wait_tracking;
pub(crate) mod control;
mod guards;
pub(crate) mod role;
pub(crate) mod status;
pub(crate) use codex_protocol::protocol::AgentStatus;
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
pub(crate) use collab_wait_tracking::begin_collab_wait;
pub(crate) use collab_wait_tracking::end_collab_wait;
pub(crate) use collab_wait_tracking::is_collab_wait_suppressed;
pub(crate) use collab_wait_tracking::mark_collab_wait_collected;
pub(crate) use control::AgentControl;
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
pub(crate) use guards::exceeds_thread_spawn_depth_limit;

View File

@@ -3006,11 +3006,20 @@ pub(crate) async fn run_turn(
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
// may support this, the model might not.
let pending_input = sess
.get_pending_input()
.await
let collab_ready = sess
.services
.agent_control
.collect_unacknowledged_finals(&turn_context.sub_id)
.await;
let collab_ready_items = AgentControl::synthetic_wait_items(&collab_ready);
let pending_input = collab_ready_items
.into_iter()
.map(ResponseItem::from)
.chain(
sess.get_pending_input()
.await
.into_iter()
.map(ResponseItem::from),
)
.collect::<Vec<ResponseItem>>();
// Construct the input that we will send to the model.

View File

@@ -2,7 +2,6 @@
use indexmap::IndexMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::Notify;
@@ -17,7 +16,6 @@ use tokio::sync::oneshot;
use crate::codex::TurnContext;
use crate::protocol::ReviewDecision;
use crate::tasks::SessionTask;
use codex_protocol::ThreadId;
/// Metadata about the currently running turn.
pub(crate) struct ActiveTurn {
@@ -75,8 +73,6 @@ pub(crate) struct TurnState {
pending_user_input: HashMap<String, oneshot::Sender<RequestUserInputResponse>>,
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pending_input: Vec<ResponseInputItem>,
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
collected_waits: HashMap<String, HashSet<ThreadId>>,
}
impl TurnState {
@@ -100,8 +96,6 @@ impl TurnState {
self.pending_user_input.clear();
self.pending_dynamic_tools.clear();
self.pending_input.clear();
self.active_waits.clear();
self.collected_waits.clear();
}
pub(crate) fn insert_pending_user_input(
@@ -151,57 +145,6 @@ impl TurnState {
pub(crate) fn has_pending_input(&self) -> bool {
!self.pending_input.is_empty()
}
pub(crate) fn begin_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let waits = self.active_waits.entry(turn_id.to_string()).or_default();
for agent_id in agent_ids {
*waits.entry(*agent_id).or_default() += 1;
}
}
pub(crate) fn end_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let mut remove_turn = false;
if let Some(waits) = self.active_waits.get_mut(turn_id) {
for agent_id in agent_ids {
if let Some(count) = waits.get_mut(agent_id) {
*count = count.saturating_sub(1);
if *count == 0 {
waits.remove(agent_id);
}
}
}
remove_turn = waits.is_empty();
}
if remove_turn {
self.active_waits.remove(turn_id);
}
}
pub(crate) fn is_waiting_on(&self, turn_id: &str, agent_id: ThreadId) -> bool {
self.active_waits
.get(turn_id)
.is_some_and(|waits| waits.contains_key(&agent_id))
}
pub(crate) fn mark_wait_collected(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
if agent_ids.is_empty() {
return;
}
let collected = self.collected_waits.entry(turn_id.to_string()).or_default();
collected.extend(agent_ids.iter().copied());
}
pub(crate) fn is_wait_collected(&self, turn_id: &str, agent_id: ThreadId) -> bool {
self.collected_waits
.get(turn_id)
.is_some_and(|collected| collected.contains(&agent_id))
}
}
impl ActiveTurn {
@@ -211,43 +154,3 @@ impl ActiveTurn {
ts.clear_pending();
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
fn thread_id(s: &str) -> ThreadId {
ThreadId::from_string(s).expect("valid thread id")
}
#[test]
fn wait_tracking_is_turn_scoped_and_reference_counted_and_collected() {
let mut state = TurnState::default();
let turn_a = "turn-a";
let turn_b = "turn-b";
let agent = thread_id("00000000-0000-7000-0000-000000000001");
state.begin_wait(turn_a, &[agent]);
state.begin_wait(turn_a, &[agent]);
state.begin_wait(turn_b, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), true);
assert_eq!(state.is_waiting_on(turn_b, agent), true);
state.end_wait(turn_a, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), true);
state.end_wait(turn_a, &[agent]);
assert_eq!(state.is_waiting_on(turn_a, agent), false);
assert_eq!(state.is_waiting_on(turn_b, agent), true);
state.mark_wait_collected(turn_a, &[agent]);
assert_eq!(state.is_wait_collected(turn_a, agent), true);
assert_eq!(state.is_wait_collected(turn_b, agent), false);
state.clear_pending();
assert_eq!(state.is_waiting_on(turn_b, agent), false);
assert_eq!(state.is_wait_collected(turn_a, agent), false);
}
}

View File

@@ -189,6 +189,7 @@ impl Session {
false
};
drop(active);
self.services.agent_control.clear_turn(&turn_context.sub_id);
if should_close_processes {
self.close_unified_exec_processes().await;
}
@@ -207,7 +208,11 @@ impl Session {
let mut active = self.active_turn.lock().await;
match active.take() {
Some(mut at) => {
let turn_ids = at.tasks.keys().cloned().collect::<Vec<_>>();
at.clear_pending().await;
for turn_id in turn_ids {
self.services.agent_control.clear_turn(&turn_id);
}
at.drain_tasks()
}

View File

@@ -175,11 +175,10 @@ mod spawn {
)
.await;
let new_thread_id = result?;
crate::agent::spawn_collab_completion_warning_watcher(
session.clone(),
turn.clone(),
new_thread_id,
);
session
.services
.agent_control
.register_spawn(&turn.sub_id, new_thread_id);
let content = serde_json::to_string(&SpawnAgentResult {
agent_id: new_thread_id.to_string(),
@@ -342,8 +341,12 @@ mod wait {
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
};
crate::agent::begin_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
let result = async {
let _wait_guard = session
.services
.agent_control
.wait_guard(&turn.sub_id, &receiver_thread_ids);
async {
session
.send_event(
&turn,
@@ -424,12 +427,10 @@ mod wait {
};
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
crate::agent::mark_collab_wait_collected(
session.as_ref(),
&turn.sub_id,
&collected_ids,
)
.await;
session
.services
.agent_control
.acknowledge(&turn.sub_id, &collected_ids);
// Convert payload.
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
@@ -461,9 +462,7 @@ mod wait {
content_items: None,
})
}
.await;
crate::agent::end_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
result
.await
}
async fn wait_for_final_status(
@@ -1022,18 +1021,13 @@ mod tests {
}
#[tokio::test]
async fn wait_marks_missing_agents_as_collected_and_suppressed() {
async fn wait_handles_missing_agents() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let id = ThreadId::new();
let session = Arc::new(session);
let turn = Arc::new(turn);
let turn_state = {
let mut active = session.active_turn.lock().await;
let active_turn = active.get_or_insert_with(crate::state::ActiveTurn::default);
Arc::clone(&active_turn.turn_state)
};
let invocation = invocation(
session.clone(),
turn.clone(),
@@ -1063,16 +1057,6 @@ mod tests {
}
);
assert_eq!(success, None);
{
let state = turn_state.lock().await;
assert_eq!(state.is_waiting_on(&turn.sub_id, id), false);
assert_eq!(state.is_wait_collected(&turn.sub_id, id), true);
}
let suppressed =
crate::agent::is_collab_wait_suppressed(session.as_ref(), &turn.sub_id, id).await;
assert_eq!(suppressed, true);
}
#[tokio::test]