mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
7 Commits
main
...
jif/warnin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4acffebebc | ||
|
|
cf0c52504e | ||
|
|
49ee35b86c | ||
|
|
4dd80700cb | ||
|
|
ae13e569c6 | ||
|
|
ebddb1071d | ||
|
|
ec2551ba36 |
@@ -1,14 +1,86 @@
|
||||
use crate::agent::AgentStatus;
|
||||
use crate::agent::guards::Guards;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::thread_manager::ThreadManagerState;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::watch;
|
||||
use uuid::Uuid;
|
||||
|
||||
const SYNTHETIC_WAIT_TIMEOUT_MS: i64 = 300_000;
|
||||
|
||||
#[derive(Default)]
|
||||
struct TurnCollabLedger {
|
||||
spawned: HashSet<ThreadId>,
|
||||
acknowledged: HashSet<ThreadId>,
|
||||
in_flight_waits: HashMap<ThreadId, usize>,
|
||||
}
|
||||
|
||||
impl TurnCollabLedger {
|
||||
fn is_idle(&self) -> bool {
|
||||
self.in_flight_waits.is_empty()
|
||||
&& self
|
||||
.spawned
|
||||
.iter()
|
||||
.all(|agent_id| self.acknowledged.contains(agent_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CollabLedger {
|
||||
turns: HashMap<String, TurnCollabLedger>,
|
||||
}
|
||||
|
||||
/// RAII guard that marks a set of agents as being explicitly waited on.
|
||||
pub(crate) struct WaitGuard {
|
||||
collab: Arc<Mutex<CollabLedger>>,
|
||||
turn_id: String,
|
||||
agent_ids: Vec<ThreadId>,
|
||||
}
|
||||
|
||||
impl WaitGuard {
|
||||
fn new(collab: Arc<Mutex<CollabLedger>>, turn_id: &str, agent_ids: &[ThreadId]) -> Self {
|
||||
Self {
|
||||
collab,
|
||||
turn_id: turn_id.to_string(),
|
||||
agent_ids: agent_ids.to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WaitGuard {
|
||||
fn drop(&mut self) {
|
||||
let mut collab = self
|
||||
.collab
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if let Some(turn) = collab.turns.get_mut(&self.turn_id) {
|
||||
for agent_id in &self.agent_ids {
|
||||
if let Some(count) = turn.in_flight_waits.get_mut(agent_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
turn.in_flight_waits.remove(agent_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(&self.turn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Control-plane handle for multi-agent operations.
|
||||
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
|
||||
@@ -23,6 +95,7 @@ pub(crate) struct AgentControl {
|
||||
/// `ThreadManagerState -> CodexThread -> Session -> SessionServices -> ThreadManagerState`.
|
||||
manager: Weak<ThreadManagerState>,
|
||||
state: Arc<Guards>,
|
||||
collab: Arc<Mutex<CollabLedger>>,
|
||||
}
|
||||
|
||||
impl AgentControl {
|
||||
@@ -129,11 +202,148 @@ impl AgentControl {
|
||||
Ok(thread.subscribe_status())
|
||||
}
|
||||
|
||||
pub(crate) fn register_spawn(&self, turn_id: &str, agent_id: ThreadId) {
|
||||
let mut collab = self.lock_collab();
|
||||
let turn = collab.turns.entry(turn_id.to_string()).or_default();
|
||||
turn.spawned.insert(agent_id);
|
||||
turn.acknowledged.remove(&agent_id);
|
||||
}
|
||||
|
||||
pub(crate) fn wait_guard(&self, turn_id: &str, agent_ids: &[ThreadId]) -> WaitGuard {
|
||||
let mut collab = self.lock_collab();
|
||||
let turn = collab.turns.entry(turn_id.to_string()).or_default();
|
||||
for agent_id in agent_ids {
|
||||
*turn.in_flight_waits.entry(*agent_id).or_default() += 1;
|
||||
}
|
||||
WaitGuard::new(Arc::clone(&self.collab), turn_id, agent_ids)
|
||||
}
|
||||
|
||||
pub(crate) fn acknowledge(&self, turn_id: &str, agent_ids: &[ThreadId]) {
|
||||
let mut collab = self.lock_collab();
|
||||
if let Some(turn) = collab.turns.get_mut(turn_id) {
|
||||
turn.acknowledged.extend(agent_ids.iter().copied());
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clear_turn(&self, turn_id: &str) {
|
||||
let mut collab = self.lock_collab();
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
|
||||
pub(crate) async fn collect_unacknowledged_finals(
|
||||
&self,
|
||||
turn_id: &str,
|
||||
) -> Vec<(ThreadId, AgentStatus)> {
|
||||
let mut candidates = {
|
||||
let collab = self.lock_collab();
|
||||
collab
|
||||
.turns
|
||||
.get(turn_id)
|
||||
.map(|turn| {
|
||||
turn.spawned
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|agent_id| {
|
||||
!turn.acknowledged.contains(agent_id)
|
||||
&& !turn.in_flight_waits.contains_key(agent_id)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
};
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
candidates.sort_by_key(std::string::ToString::to_string);
|
||||
|
||||
let mut finals = Vec::new();
|
||||
for agent_id in candidates {
|
||||
let status = self.get_status(agent_id).await;
|
||||
if is_final(&status) {
|
||||
finals.push((agent_id, status));
|
||||
}
|
||||
}
|
||||
if finals.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut ready = Vec::new();
|
||||
let mut collab = self.lock_collab();
|
||||
if let Some(turn) = collab.turns.get_mut(turn_id) {
|
||||
for (agent_id, status) in finals {
|
||||
if turn.acknowledged.contains(&agent_id)
|
||||
|| turn.in_flight_waits.contains_key(&agent_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
turn.acknowledged.insert(agent_id);
|
||||
ready.push((agent_id, status));
|
||||
}
|
||||
if turn.is_idle() {
|
||||
collab.turns.remove(turn_id);
|
||||
}
|
||||
}
|
||||
ready
|
||||
}
|
||||
|
||||
pub(crate) fn synthetic_wait_items(statuses: &[(ThreadId, AgentStatus)]) -> Vec<ResponseItem> {
|
||||
if statuses.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let call_id = format!("synthetic-wait-{}", Uuid::new_v4());
|
||||
|
||||
let mut ids = statuses
|
||||
.iter()
|
||||
.map(|(agent_id, _)| agent_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
ids.sort();
|
||||
let arguments = json!({
|
||||
"ids": ids,
|
||||
"timeout_ms": SYNTHETIC_WAIT_TIMEOUT_MS,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let status_map = statuses
|
||||
.iter()
|
||||
.map(|(agent_id, status)| (agent_id.to_string(), status.clone()))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let output = json!({
|
||||
"status": status_map,
|
||||
"timed_out": false,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let call = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "wait".to_string(),
|
||||
arguments,
|
||||
call_id: call_id.clone(),
|
||||
};
|
||||
let output = ResponseItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: output,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
vec![call, output]
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
|
||||
}
|
||||
|
||||
fn lock_collab(&self) -> std::sync::MutexGuard<'_, CollabLedger> {
|
||||
self.collab
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -3006,11 +3006,20 @@ pub(crate) async fn run_turn(
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
// may support this, the model might not.
|
||||
let pending_input = sess
|
||||
.get_pending_input()
|
||||
.await
|
||||
let collab_ready = sess
|
||||
.services
|
||||
.agent_control
|
||||
.collect_unacknowledged_finals(&turn_context.sub_id)
|
||||
.await;
|
||||
let collab_ready_items = AgentControl::synthetic_wait_items(&collab_ready);
|
||||
let pending_input = collab_ready_items
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.chain(
|
||||
sess.get_pending_input()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(ResponseItem::from),
|
||||
)
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
|
||||
// Construct the input that we will send to the model.
|
||||
|
||||
@@ -189,6 +189,7 @@ impl Session {
|
||||
false
|
||||
};
|
||||
drop(active);
|
||||
self.services.agent_control.clear_turn(&turn_context.sub_id);
|
||||
if should_close_processes {
|
||||
self.close_unified_exec_processes().await;
|
||||
}
|
||||
@@ -207,7 +208,11 @@ impl Session {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
let turn_ids = at.tasks.keys().cloned().collect::<Vec<_>>();
|
||||
at.clear_pending().await;
|
||||
for turn_id in turn_ids {
|
||||
self.services.agent_control.clear_turn(&turn_id);
|
||||
}
|
||||
|
||||
at.drain_tasks()
|
||||
}
|
||||
|
||||
@@ -175,6 +175,10 @@ mod spawn {
|
||||
)
|
||||
.await;
|
||||
let new_thread_id = result?;
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.register_spawn(&turn.sub_id, new_thread_id);
|
||||
|
||||
let content = serde_json::to_string(&SpawnAgentResult {
|
||||
agent_id: new_thread_id.to_string(),
|
||||
@@ -337,114 +341,128 @@ mod wait {
|
||||
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let _wait_guard = session
|
||||
.services
|
||||
.agent_control
|
||||
.wait_guard(&turn.sub_id, &receiver_thread_ids);
|
||||
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
async {
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(*id, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
// Wait for the first agent to reach a final status.
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
Err(collab_agent_error(*id, err))?;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
// Drain the unlikely last elements to prevent race.
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
// Wait for the first agent to reach a final status.
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
// Final event emission.
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
statuses: statuses_map,
|
||||
if !results.is_empty() {
|
||||
// Drain the unlikely last elements to prevent race.
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
results
|
||||
};
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
let collected_ids = statuses.iter().map(|(id, _)| *id).collect::<Vec<_>>();
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.acknowledge(&turn.sub_id, &collected_ids);
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: None,
|
||||
content_items: None,
|
||||
})
|
||||
// Convert payload.
|
||||
let statuses_map = statuses.into_iter().collect::<HashMap<_, _>>();
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: collected_ids.is_empty(),
|
||||
};
|
||||
|
||||
// Final event emission.
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
statuses: statuses_map,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: None,
|
||||
content_items: None,
|
||||
})
|
||||
}
|
||||
.await
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(
|
||||
@@ -1002,6 +1020,45 @@ mod tests {
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_handles_missing_agents() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let id = ThreadId::new();
|
||||
let session = Arc::new(session);
|
||||
let turn = Arc::new(turn);
|
||||
let invocation = invocation(
|
||||
session.clone(),
|
||||
turn.clone(),
|
||||
"wait",
|
||||
function_payload(json!({
|
||||
"ids": [id.to_string()],
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait should succeed");
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = output
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
status: HashMap::from([(id, AgentStatus::NotFound)]),
|
||||
timed_out: false
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_times_out_when_status_is_not_final() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
|
||||
Reference in New Issue
Block a user