mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
agent control design
This commit is contained in:
@@ -1,151 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent::AgentStatus;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
|
||||
/// Subscribe to a spawned sub-agent and warn the model once it reaches a final status.
|
||||
pub(crate) fn spawn_collab_completion_warning_watcher(
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
agent_id: ThreadId,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Some(status) = wait_for_final_status(session.as_ref(), agent_id).await
|
||||
&& !crate::agent::is_collab_wait_suppressed(
|
||||
session.as_ref(),
|
||||
&turn_context.sub_id,
|
||||
agent_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let items = synthetic_wait_items(agent_id, status);
|
||||
session
|
||||
.record_conversation_items(&turn_context, &items)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(session: &Session, agent_id: ThreadId) -> Option<AgentStatus> {
|
||||
let mut status_rx = match session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(agent_id)
|
||||
.await
|
||||
{
|
||||
Ok(rx) => rx,
|
||||
Err(_) => {
|
||||
let status = session.services.agent_control.get_status(agent_id).await;
|
||||
return is_final(&status).then_some(status);
|
||||
}
|
||||
};
|
||||
|
||||
let mut status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some(status);
|
||||
}
|
||||
|
||||
loop {
|
||||
if status_rx.changed().await.is_err() {
|
||||
let latest = session.services.agent_control.get_status(agent_id).await;
|
||||
return is_final(&latest).then_some(latest);
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn synthetic_wait_items(agent_id: ThreadId, status: AgentStatus) -> Vec<ResponseItem> {
|
||||
tracing::info!("synthetic_wait_items: agent_id: {}, status: {:?}", agent_id, status);
|
||||
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
|
||||
let agent_id_str = agent_id.to_string();
|
||||
let arguments = json!({
|
||||
"ids": [agent_id_str.clone()],
|
||||
"timeout_ms": 300_000,
|
||||
})
|
||||
.to_string();
|
||||
let output = json!({
|
||||
"status": { agent_id_str: status },
|
||||
"timed_out": false,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let call = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "wait".to_string(),
|
||||
arguments,
|
||||
call_id: call_id.clone(),
|
||||
};
|
||||
let output = ResponseItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: output,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
vec![call, output]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::synthetic_wait_items;
|
||||
use crate::agent::AgentStatus;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
#[test]
|
||||
fn synthetic_wait_items_look_like_a_real_wait_result() {
|
||||
let agent_id =
|
||||
ThreadId::from_string("00000000-0000-7000-0000-000000000001").expect("valid id");
|
||||
let status = AgentStatus::Completed(Some("done".to_string()));
|
||||
|
||||
let items = synthetic_wait_items(agent_id, status.clone());
|
||||
assert_eq!(items.len(), 2);
|
||||
|
||||
let (call_id, arguments_json) = match &items[0] {
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
call_id,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(name, "wait");
|
||||
(call_id.clone(), arguments.clone())
|
||||
}
|
||||
other => panic!("expected function call, got {other:?}"),
|
||||
};
|
||||
|
||||
let args: Value = serde_json::from_str(&arguments_json).expect("arguments should be json");
|
||||
let agent_id_string = agent_id.to_string();
|
||||
assert_eq!(args["ids"][0].as_str(), Some(agent_id_string.as_str()));
|
||||
|
||||
match &items[1] {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: out_id,
|
||||
output,
|
||||
} => {
|
||||
assert_eq!(out_id, &call_id);
|
||||
let out: Value =
|
||||
serde_json::from_str(&output.content).expect("output should be json");
|
||||
assert_eq!(out["timed_out"].as_bool(), Some(false));
|
||||
let expected_status =
|
||||
serde_json::to_value(status).expect("status should serialize");
|
||||
assert_eq!(out["status"][agent_id_string.as_str()], expected_status);
|
||||
}
|
||||
other => panic!("expected function call output, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
use crate::codex::Session;
|
||||
|
||||
pub(crate) async fn begin_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.begin_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn end_collab_wait(session: &Session, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.end_wait(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn mark_collab_wait_collected(
|
||||
session: &Session,
|
||||
turn_id: &str,
|
||||
agent_ids: &[ThreadId],
|
||||
) {
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let mut state = turn_state.lock().await;
|
||||
state.mark_wait_collected(turn_id, agent_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn is_collab_wait_suppressed(
|
||||
session: &Session,
|
||||
turn_id: &str,
|
||||
agent_id: ThreadId,
|
||||
) -> bool {
|
||||
let turn_state = {
|
||||
let active = session.active_turn.lock().await;
|
||||
active
|
||||
.as_ref()
|
||||
.map(|active_turn| Arc::clone(&active_turn.turn_state))
|
||||
};
|
||||
if let Some(turn_state) = turn_state {
|
||||
let state = turn_state.lock().await;
|
||||
state.is_waiting_on(turn_id, agent_id) || state.is_wait_collected(turn_id, agent_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,86 @@
|
||||
use crate::agent::AgentStatus;
|
||||
use crate::agent::guards::Guards;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::thread_manager::ThreadManagerState;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::watch;
|
||||
use uuid::Uuid;
|
||||
|
||||
const SYNTHETIC_WAIT_TIMEOUT_MS: i64 = 300_000;
|
||||
|
||||
#[derive(Default)]
|
||||
struct TurnCollabLedger {
|
||||
spawned: HashSet<ThreadId>,
|
||||
acknowledged: HashSet<ThreadId>,
|
||||
in_flight_waits: HashMap<ThreadId, usize>,
|
||||
}
|
||||
|
||||
impl TurnCollabLedger {
|
||||
fn is_idle(&self) -> bool {
|
||||
self.in_flight_waits.is_empty()
|
||||
&& self
|
||||
.spawned
|
||||
.iter()
|
||||
.all(|agent_id| self.acknowledged.contains(agent_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CollabLedger {
|
||||
turns: HashMap<String, TurnCollabLedger>,
|
||||
}
|
||||
|
||||
/// RAII guard that marks a set of agents as being explicitly waited on.
|
||||
pub(crate) struct WaitGuard {
|
||||
collab: Arc<Mutex<CollabLedger>>,
|
||||
turn_id: String,
|
||||
agent_ids: Vec<ThreadId>,
|
||||
}
|
||||
|
||||
impl WaitGuard {
|
||||
fn new(collab: Arc<Mutex<CollabLedger>>, turn_id: &str, agent_ids: &[ThreadId]) -> Self {
|
||||
Self {
|
||||
collab,
|
||||
turn_id: turn_id.to_string(),
|
||||
agent_ids: agent_ids.to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WaitGuard {
|
||||
fn drop(&mut self) {
|
||||
let mut collab = self
|
||||
.collab
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if let Some(turn) = collab.turns.get_mut(&self.turn_id) {
|
||||
for agent_id in &self.agent_ids {
|
||||
if let Some(count) = turn.in_flight_waits.get_mut(agent_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
turn.in_flight_waits.remove(agent_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(&self.turn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Control-plane handle for multi-agent operations.
|
||||
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
|
||||
@@ -23,6 +95,7 @@ pub(crate) struct AgentControl {
|
||||
/// `ThreadManagerState -> CodexThread -> Session -> SessionServices -> ThreadManagerState`.
|
||||
manager: Weak<ThreadManagerState>,
|
||||
state: Arc<Guards>,
|
||||
collab: Arc<Mutex<CollabLedger>>,
|
||||
}
|
||||
|
||||
impl AgentControl {
|
||||
@@ -129,11 +202,148 @@ impl AgentControl {
|
||||
Ok(thread.subscribe_status())
|
||||
}
|
||||
|
||||
pub(crate) fn register_spawn(&self, turn_id: &str, agent_id: ThreadId) {
|
||||
let mut collab = self.lock_collab();
|
||||
let turn = collab.turns.entry(turn_id.to_string()).or_default();
|
||||
turn.spawned.insert(agent_id);
|
||||
turn.acknowledged.remove(&agent_id);
|
||||
}
|
||||
|
||||
pub(crate) fn wait_guard(&self, turn_id: &str, agent_ids: &[ThreadId]) -> WaitGuard {
|
||||
let mut collab = self.lock_collab();
|
||||
let turn = collab.turns.entry(turn_id.to_string()).or_default();
|
||||
for agent_id in agent_ids {
|
||||
*turn.in_flight_waits.entry(*agent_id).or_default() += 1;
|
||||
}
|
||||
WaitGuard::new(Arc::clone(&self.collab), turn_id, agent_ids)
|
||||
}
|
||||
|
||||
pub(crate) fn acknowledge(&self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let mut collab = self.lock_collab();
|
||||
if let Some(turn) = collab.turns.get_mut(turn_id) {
|
||||
turn.acknowledged.extend(agent_ids.iter().copied());
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clear_turn(&self, turn_id: &str) {
|
||||
let mut collab = self.lock_collab();
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
|
||||
pub(crate) async fn collect_unacknowledged_finals(
|
||||
&self,
|
||||
turn_id: &str,
|
||||
) -> Vec<(ThreadId, AgentStatus)> {
|
||||
let mut candidates = {
|
||||
let collab = self.lock_collab();
|
||||
collab
|
||||
.turns
|
||||
.get(turn_id)
|
||||
.map(|turn| {
|
||||
turn.spawned
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|agent_id| {
|
||||
!turn.acknowledged.contains(agent_id)
|
||||
&& !turn.in_flight_waits.contains_key(agent_id)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
};
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
candidates.sort_by_key(std::string::ToString::to_string);
|
||||
|
||||
let mut finals = Vec::new();
|
||||
for agent_id in candidates {
|
||||
let status = self.get_status(agent_id).await;
|
||||
if is_final(&status) {
|
||||
finals.push((agent_id, status));
|
||||
}
|
||||
}
|
||||
if finals.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut ready = Vec::new();
|
||||
let mut collab = self.lock_collab();
|
||||
if let Some(turn) = collab.turns.get_mut(turn_id) {
|
||||
for (agent_id, status) in finals {
|
||||
if turn.acknowledged.contains(&agent_id)
|
||||
|| turn.in_flight_waits.contains_key(&agent_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
turn.acknowledged.insert(agent_id);
|
||||
ready.push((agent_id, status));
|
||||
}
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
}
|
||||
ready
|
||||
}
|
||||
|
||||
pub(crate) fn synthetic_wait_items(statuses: &[(ThreadId, AgentStatus)]) -> Vec<ResponseItem> {
|
||||
if statuses.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
|
||||
|
||||
let mut ids = statuses
|
||||
.iter()
|
||||
.map(|(agent_id, _)| agent_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
ids.sort();
|
||||
let arguments = json!({
|
||||
"ids": ids,
|
||||
"timeout_ms": SYNTHETIC_WAIT_TIMEOUT_MS,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let status_map = statuses
|
||||
.iter()
|
||||
.map(|(agent_id, status)| (agent_id.to_string(), status.clone()))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let output = json!({
|
||||
"status": status_map,
|
||||
"timed_out": false,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let call = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "wait".to_string(),
|
||||
arguments,
|
||||
call_id: call_id.clone(),
|
||||
};
|
||||
let output = ResponseItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: output,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
vec![call, output]
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
|
||||
}
|
||||
|
||||
fn lock_collab(&self) -> std::sync::MutexGuard<'_, CollabLedger> {
|
||||
self.collab
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
mod collab_completion_warning;
|
||||
mod collab_wait_tracking;
|
||||
pub(crate) mod control;
|
||||
mod guards;
|
||||
pub(crate) mod role;
|
||||
pub(crate) mod status;
|
||||
|
||||
pub(crate) use codex_protocol::protocol::AgentStatus;
|
||||
pub(crate) use collab_completion_warning::spawn_collab_completion_warning_watcher;
|
||||
pub(crate) use collab_wait_tracking::begin_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::end_collab_wait;
|
||||
pub(crate) use collab_wait_tracking::is_collab_wait_suppressed;
|
||||
pub(crate) use collab_wait_tracking::mark_collab_wait_collected;
|
||||
pub(crate) use control::AgentControl;
|
||||
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
|
||||
pub(crate) use guards::exceeds_thread_spawn_depth_limit;
|
||||
|
||||
@@ -3006,11 +3006,20 @@ pub(crate) async fn run_turn(
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
// may support this, the model might not.
|
||||
let pending_input = sess
|
||||
.get_pending_input()
|
||||
.await
|
||||
let collab_ready = sess
|
||||
.services
|
||||
.agent_control
|
||||
.collect_unacknowledged_finals(&turn_context.sub_id)
|
||||
.await;
|
||||
let collab_ready_items = AgentControl::synthetic_wait_items(&collab_ready);
|
||||
let pending_input = collab_ready_items
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.chain(
|
||||
sess.get_pending_input()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(ResponseItem::from),
|
||||
)
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
|
||||
// Construct the input that we will send to the model.
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
@@ -17,7 +16,6 @@ use tokio::sync::oneshot;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
pub(crate) struct ActiveTurn {
|
||||
@@ -75,8 +73,6 @@ pub(crate) struct TurnState {
|
||||
pending_user_input: HashMap<String, oneshot::Sender<RequestUserInputResponse>>,
|
||||
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
active_waits: HashMap<String, HashMap<ThreadId, usize>>,
|
||||
collected_waits: HashMap<String, HashSet<ThreadId>>,
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
@@ -100,8 +96,6 @@ impl TurnState {
|
||||
self.pending_user_input.clear();
|
||||
self.pending_dynamic_tools.clear();
|
||||
self.pending_input.clear();
|
||||
self.active_waits.clear();
|
||||
self.collected_waits.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn insert_pending_user_input(
|
||||
@@ -151,57 +145,6 @@ impl TurnState {
|
||||
pub(crate) fn has_pending_input(&self) -> bool {
|
||||
!self.pending_input.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn begin_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let waits = self.active_waits.entry(turn_id.to_string()).or_default();
|
||||
for agent_id in agent_ids {
|
||||
*waits.entry(*agent_id).or_default() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn end_wait(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let mut remove_turn = false;
|
||||
if let Some(waits) = self.active_waits.get_mut(turn_id) {
|
||||
for agent_id in agent_ids {
|
||||
if let Some(count) = waits.get_mut(agent_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
waits.remove(agent_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
remove_turn = waits.is_empty();
|
||||
}
|
||||
if remove_turn {
|
||||
self.active_waits.remove(turn_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_waiting_on(&self, turn_id: &str, agent_id: ThreadId) -> bool {
|
||||
self.active_waits
|
||||
.get(turn_id)
|
||||
.is_some_and(|waits| waits.contains_key(&agent_id))
|
||||
}
|
||||
|
||||
pub(crate) fn mark_wait_collected(&mut self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
if agent_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
let collected = self.collected_waits.entry(turn_id.to_string()).or_default();
|
||||
collected.extend(agent_ids.iter().copied());
|
||||
}
|
||||
|
||||
pub(crate) fn is_wait_collected(&self, turn_id: &str, agent_id: ThreadId) -> bool {
|
||||
self.collected_waits
|
||||
.get(turn_id)
|
||||
.is_some_and(|collected| collected.contains(&agent_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
@@ -211,43 +154,3 @@ impl ActiveTurn {
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn thread_id(s: &str) -> ThreadId {
|
||||
ThreadId::from_string(s).expect("valid thread id")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wait_tracking_is_turn_scoped_and_reference_counted_and_collected() {
|
||||
let mut state = TurnState::default();
|
||||
let turn_a = "turn-a";
|
||||
let turn_b = "turn-b";
|
||||
let agent = thread_id("00000000-0000-7000-0000-000000000001");
|
||||
|
||||
state.begin_wait(turn_a, &[agent]);
|
||||
state.begin_wait(turn_a, &[agent]);
|
||||
state.begin_wait(turn_b, &[agent]);
|
||||
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), true);
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), true);
|
||||
|
||||
state.end_wait(turn_a, &[agent]);
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), true);
|
||||
|
||||
state.end_wait(turn_a, &[agent]);
|
||||
assert_eq!(state.is_waiting_on(turn_a, agent), false);
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), true);
|
||||
|
||||
state.mark_wait_collected(turn_a, &[agent]);
|
||||
assert_eq!(state.is_wait_collected(turn_a, agent), true);
|
||||
assert_eq!(state.is_wait_collected(turn_b, agent), false);
|
||||
|
||||
state.clear_pending();
|
||||
assert_eq!(state.is_waiting_on(turn_b, agent), false);
|
||||
assert_eq!(state.is_wait_collected(turn_a, agent), false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +189,7 @@ impl Session {
|
||||
false
|
||||
};
|
||||
drop(active);
|
||||
self.services.agent_control.clear_turn(&turn_context.sub_id);
|
||||
if should_close_processes {
|
||||
self.close_unified_exec_processes().await;
|
||||
}
|
||||
@@ -207,7 +208,11 @@ impl Session {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
let turn_ids = at.tasks.keys().cloned().collect::<Vec<_>>();
|
||||
at.clear_pending().await;
|
||||
for turn_id in turn_ids {
|
||||
self.services.agent_control.clear_turn(&turn_id);
|
||||
}
|
||||
|
||||
at.drain_tasks()
|
||||
}
|
||||
|
||||
@@ -175,11 +175,10 @@ mod spawn {
|
||||
)
|
||||
.await;
|
||||
let new_thread_id = result?;
|
||||
crate::agent::spawn_collab_completion_warning_watcher(
|
||||
session.clone(),
|
||||
turn.clone(),
|
||||
new_thread_id,
|
||||
);
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.register_spawn(&turn.sub_id, new_thread_id);
|
||||
|
||||
let content = serde_json::to_string(&SpawnAgentResult {
|
||||
agent_id: new_thread_id.to_string(),
|
||||
@@ -342,8 +341,12 @@ mod wait {
|
||||
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
crate::agent::begin_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
|
||||
let result = async {
|
||||
let _wait_guard = session
|
||||
.services
|
||||
.agent_control
|
||||
.wait_guard(&turn.sub_id, &receiver_thread_ids);
|
||||
|
||||
async {
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
@@ -424,12 +427,10 @@ mod wait {
|
||||
};
|
||||
|
||||
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
|
||||
crate::agent::mark_collab_wait_collected(
|
||||
session.as_ref(),
|
||||
&turn.sub_id,
|
||||
&collected_ids,
|
||||
)
|
||||
.await;
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.acknowledge(&turn.sub_id, &collected_ids);
|
||||
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
|
||||
@@ -461,9 +462,7 @@ mod wait {
|
||||
content_items: None,
|
||||
})
|
||||
}
|
||||
.await;
|
||||
crate::agent::end_collab_wait(session.as_ref(), &turn.sub_id, &receiver_thread_ids).await;
|
||||
result
|
||||
.await
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(
|
||||
@@ -1022,18 +1021,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_marks_missing_agents_as_collected_and_suppressed() {
|
||||
async fn wait_handles_missing_agents() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let id = ThreadId::new();
|
||||
let session = Arc::new(session);
|
||||
let turn = Arc::new(turn);
|
||||
let turn_state = {
|
||||
let mut active = session.active_turn.lock().await;
|
||||
let active_turn = active.get_or_insert_with(crate::state::ActiveTurn::default);
|
||||
Arc::clone(&active_turn.turn_state)
|
||||
};
|
||||
let invocation = invocation(
|
||||
session.clone(),
|
||||
turn.clone(),
|
||||
@@ -1063,16 +1057,6 @@ mod tests {
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
|
||||
{
|
||||
let state = turn_state.lock().await;
|
||||
assert_eq!(state.is_waiting_on(&turn.sub_id, id), false);
|
||||
assert_eq!(state.is_wait_collected(&turn.sub_id, id), true);
|
||||
}
|
||||
|
||||
let suppressed =
|
||||
crate::agent::is_collab_wait_suppressed(session.as_ref(), &turn.sub_id, id).await;
|
||||
assert_eq!(suppressed, true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user