feat: add wait tool implementation for collab (#9088)

Add implementation for the `wait` tool.

For this we consider all status different from `PendingInit` and
`Running` as terminal. The `wait` tool call will return either after a
given timeout or when the tool reaches a non-terminal status.

A few points to note:
* The usage of a channel is preferred to prevent some races (just
looping on `get_status()` could "miss" a terminal status)
* The order of operations is very important, we need to first subscribe
and then check the last known status to prevent race conditions
* If the channel gets dropped, we return an error on purpose
This commit is contained in:
jif-oai
2026-01-12 12:16:24 +00:00
committed by GitHub
parent 86f81ca010
commit 623707ab58
7 changed files with 207 additions and 43 deletions

View File

@@ -9,6 +9,7 @@ use codex_protocol::protocol::Op;
use codex_protocol::user_input::UserInput;
use std::sync::Arc;
use std::sync::Weak;
use tokio::sync::watch;
/// Control-plane handle for multi-agent operations.
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
@@ -80,6 +81,16 @@ impl AgentControl {
thread.agent_status().await
}
/// Subscribe to status updates for `agent_id`, yielding the latest value and changes.
pub(crate) async fn subscribe_status(
&self,
agent_id: ThreadId,
) -> CodexResult<watch::Receiver<AgentStatus>> {
let state = self.upgrade()?;
let thread = state.get_thread(agent_id).await?;
Ok(thread.subscribe_status())
}
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
self.manager
.upgrade()
@@ -275,6 +286,38 @@ mod tests {
assert_eq!(status, AgentStatus::PendingInit);
}
#[tokio::test]
async fn subscribe_status_errors_for_missing_thread() {
let harness = AgentControlHarness::new().await;
let thread_id = ThreadId::new();
let err = harness
.control
.subscribe_status(thread_id)
.await
.expect_err("subscribe_status should fail for missing thread");
assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id);
}
#[tokio::test]
async fn subscribe_status_updates_on_shutdown() {
let harness = AgentControlHarness::new().await;
let (thread_id, thread) = harness.start_thread().await;
let mut status_rx = harness
.control
.subscribe_status(thread_id)
.await
.expect("subscribe_status should succeed");
assert_eq!(status_rx.borrow().clone(), AgentStatus::PendingInit);
let _ = thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
let _ = status_rx.changed().await;
assert_eq!(status_rx.borrow().clone(), AgentStatus::Shutdown);
}
#[tokio::test]
async fn send_prompt_submits_user_message() {
let harness = AgentControlHarness::new().await;

View File

@@ -13,3 +13,7 @@ pub(crate) fn agent_status_from_event(msg: &EventMsg) -> Option<AgentStatus> {
_ => None,
}
}
pub(crate) fn is_final(status: &AgentStatus) -> bool {
!matches!(status, AgentStatus::PendingInit | AgentStatus::Running)
}

View File

@@ -164,6 +164,7 @@ use codex_protocol::protocol::InitialHistory;
use codex_protocol::user_input::UserInput;
use codex_utils_readiness::Readiness;
use codex_utils_readiness::ReadinessFlag;
use tokio::sync::watch;
/// The high-level interface to the Codex system.
/// It operates as a queue pair where you send submissions and receive events.
@@ -172,7 +173,7 @@ pub struct Codex {
pub(crate) tx_sub: Sender<Submission>,
pub(crate) rx_event: Receiver<Event>,
// Last known status of the agent.
pub(crate) agent_status: Arc<RwLock<AgentStatus>>,
pub(crate) agent_status: watch::Receiver<AgentStatus>,
}
/// Wrapper returned by [`Codex::spawn`] containing the spawned [`Codex`],
@@ -275,7 +276,7 @@ impl Codex {
// Generate a unique ID for the lifetime of this Codex session.
let session_source_clone = session_configuration.session_source.clone();
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
let (agent_status_tx, agent_status_rx) = watch::channel(AgentStatus::PendingInit);
let session = Session::new(
session_configuration,
@@ -284,7 +285,7 @@ impl Codex {
models_manager.clone(),
exec_policy,
tx_event.clone(),
Arc::clone(&agent_status),
agent_status_tx.clone(),
conversation_history,
session_source_clone,
skills_manager,
@@ -303,7 +304,7 @@ impl Codex {
next_id: AtomicU64::new(0),
tx_sub,
rx_event,
agent_status,
agent_status: agent_status_rx,
};
#[allow(deprecated)]
@@ -345,8 +346,7 @@ impl Codex {
}
pub(crate) async fn agent_status(&self) -> AgentStatus {
let status = self.agent_status.read().await;
status.clone()
self.agent_status.borrow().clone()
}
}
@@ -356,7 +356,7 @@ impl Codex {
pub(crate) struct Session {
conversation_id: ThreadId,
tx_event: Sender<Event>,
agent_status: Arc<RwLock<AgentStatus>>,
agent_status: watch::Sender<AgentStatus>,
state: Mutex<SessionState>,
/// The set of enabled features should be invariant for the lifetime of the
/// session.
@@ -557,7 +557,7 @@ impl Session {
models_manager: Arc<ModelsManager>,
exec_policy: ExecPolicyManager,
tx_event: Sender<Event>,
agent_status: Arc<RwLock<AgentStatus>>,
agent_status: watch::Sender<AgentStatus>,
initial_history: InitialHistory,
session_source: SessionSource,
skills_manager: Arc<SkillsManager>,
@@ -703,7 +703,7 @@ impl Session {
let sess = Arc::new(Session {
conversation_id,
tx_event: tx_event.clone(),
agent_status: Arc::clone(&agent_status),
agent_status,
state: Mutex::new(state),
features: config.features.clone(),
active_turn: Mutex::new(None),
@@ -1026,8 +1026,7 @@ impl Session {
pub(crate) async fn send_event_raw(&self, event: Event) {
// Record the last known agent status.
if let Some(status) = agent_status_from_event(&event.msg) {
let mut guard = self.agent_status.write().await;
*guard = status;
self.agent_status.send_replace(status);
}
// Persist the event into rollout (recorder filters as needed)
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
@@ -1045,8 +1044,7 @@ impl Session {
pub(crate) async fn send_event_raw_flushed(&self, event: Event) {
// Record the last known agent status.
if let Some(status) = agent_status_from_event(&event.msg) {
let mut guard = self.agent_status.write().await;
*guard = status;
self.agent_status.send_replace(status);
}
self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())])
.await;
@@ -3494,7 +3492,7 @@ mod tests {
));
let agent_control = AgentControl::default();
let exec_policy = ExecPolicyManager::default();
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit);
let model = ModelsManager::get_model_offline(config.model.as_deref());
let session_configuration = SessionConfiguration {
provider: config.model_provider.clone(),
@@ -3557,7 +3555,7 @@ mod tests {
let session = Session {
conversation_id,
tx_event,
agent_status: Arc::clone(&agent_status),
agent_status: agent_status_tx,
state: Mutex::new(state),
features: config.features.clone(),
active_turn: Mutex::new(None),
@@ -3588,7 +3586,7 @@ mod tests {
));
let agent_control = AgentControl::default();
let exec_policy = ExecPolicyManager::default();
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit);
let model = ModelsManager::get_model_offline(config.model.as_deref());
let session_configuration = SessionConfiguration {
provider: config.model_provider.clone(),
@@ -3651,7 +3649,7 @@ mod tests {
let session = Arc::new(Session {
conversation_id,
tx_event,
agent_status: Arc::clone(&agent_status),
agent_status: agent_status_tx,
state: Mutex::new(state),
features: config.features.clone(),
active_turn: Mutex::new(None),

View File

@@ -87,7 +87,7 @@ pub(crate) async fn run_codex_thread_interactive(
next_id: AtomicU64::new(0),
tx_sub: tx_ops,
rx_event: rx_sub,
agent_status: Arc::clone(&codex.agent_status),
agent_status: codex.agent_status.clone(),
})
}
@@ -129,7 +129,7 @@ pub(crate) async fn run_codex_thread_one_shot(
// Bridge events so we can observe completion and shut down automatically.
let (tx_bridge, rx_bridge) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let ops_tx = io.tx_sub.clone();
let agent_status = Arc::clone(&io.agent_status);
let agent_status = io.agent_status.clone();
let io_for_bridge = io;
tokio::spawn(async move {
while let Ok(event) = io_for_bridge.next_event().await {
@@ -363,20 +363,23 @@ mod tests {
use super::*;
use async_channel::bounded;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::AgentStatus;
use codex_protocol::protocol::RawResponseItemEvent;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use pretty_assertions::assert_eq;
use tokio::sync::watch;
#[tokio::test]
async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() {
let (tx_events, rx_events) = bounded(1);
let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);
let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit);
let codex = Arc::new(Codex {
next_id: AtomicU64::new(0),
tx_sub,
rx_event: rx_events,
agent_status: Default::default(),
agent_status,
});
let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await;

View File

@@ -5,6 +5,7 @@ use crate::protocol::Event;
use crate::protocol::Op;
use crate::protocol::Submission;
use std::path::PathBuf;
use tokio::sync::watch;
pub struct CodexThread {
codex: Codex,
@@ -38,6 +39,10 @@ impl CodexThread {
self.codex.agent_status().await
}
pub(crate) fn subscribe_status(&self) -> watch::Receiver<AgentStatus> {
self.codex.agent_status.clone()
}
pub fn rollout_path(&self) -> PathBuf {
self.rollout_path.clone()
}

View File

@@ -143,11 +143,14 @@ mod send_input {
}
}
#[allow(unused_variables)]
mod wait {
use super::*;
use crate::agent::status::is_final;
use crate::codex::Session;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tokio::time::timeout_at;
#[derive(Debug, Deserialize)]
struct WaitArgs {
@@ -168,40 +171,68 @@ mod wait {
let args: WaitArgs = parse_arguments(&arguments)?;
let agent_id = agent_id(&args.id)?;
// Validate timeout.
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
if timeout_ms <= 0 {
return Err(FunctionCallError::RespondToModel(
"timeout_ms must be greater than zero".to_string(),
));
}
let timeout_ms = timeout_ms.min(MAX_WAIT_TIMEOUT_MS);
// TODO(jif) actual implementation
let outcome = WaitResult {
status: Default::default(),
timed_out: false,
let timeout_ms = match timeout_ms {
ms if ms <= 0 => {
return Err(FunctionCallError::RespondToModel(
"timeout_ms must be greater than zero".to_owned(),
));
}
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
};
if matches!(outcome.status, AgentStatus::NotFound) {
let mut status_rx = session
.services
.agent_control
.subscribe_status(agent_id)
.await
.map_err(|err| match err {
CodexErr::ThreadNotFound(id) => {
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
}
err => FunctionCallError::Fatal(err.to_string()),
})?;
// Get last known status.
let mut status = status_rx.borrow_and_update().clone();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
let timed_out = loop {
if is_final(&status) {
break false;
}
match timeout_at(deadline, status_rx.changed()).await {
Ok(Ok(())) => status = status_rx.borrow().clone(),
Ok(Err(_)) => {
let last_status = session.services.agent_control.get_status(agent_id).await;
if last_status != AgentStatus::NotFound {
// On-purpose we keep the last known status if the agent gets dropped. This
// event is not supposed to happen.
status = last_status;
}
break false;
}
Err(_) => break true,
}
};
if matches!(status, AgentStatus::NotFound) {
return Err(FunctionCallError::RespondToModel(format!(
"agent with id {agent_id} not found"
)));
}
let message = outcome.timed_out.then(|| {
format!(
"Timed out after {timeout_ms}ms waiting for agent {agent_id}. The agent may still be running."
)
});
let result = WaitResult {
status: outcome.status,
timed_out: outcome.timed_out,
};
let result = WaitResult { status, timed_out };
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: Some(!outcome.timed_out),
success: Some(!result.timed_out),
content_items: None,
})
}
@@ -264,6 +295,7 @@ mod tests {
use crate::config::types::ShellEnvironmentPolicy;
use crate::function_tool::FunctionCallError;
use crate::protocol::AskForApproval;
use crate::protocol::Op;
use crate::protocol::SandboxPolicy;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::ThreadId;
@@ -271,7 +303,9 @@ mod tests {
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::timeout;
fn invocation(
session: Arc<crate::codex::Session>,
@@ -475,6 +509,83 @@ mod tests {
assert!(msg.starts_with("invalid agent id invalid:"));
}
#[tokio::test]
async fn wait_times_out_when_status_is_not_final() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 10})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
assert_eq!(content, r#"{"status":"pending_init","timed_out":true}"#);
assert_eq!(success, Some(false));
let _ = thread
.thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
}
#[tokio::test]
async fn wait_returns_final_status_without_timeout() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let mut status_rx = manager
.agent_control()
.subscribe_status(agent_id)
.await
.expect("subscribe should succeed");
let _ = thread
.thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
let _ = timeout(Duration::from_secs(1), status_rx.changed())
.await
.expect("shutdown status should arrive");
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 1000})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
assert_eq!(content, r#"{"status":"shutdown","timed_out":false}"#);
assert_eq!(success, Some(true));
}
#[tokio::test]
async fn close_agent_reports_not_implemented() {
let (session, turn) = make_session_and_context().await;

View File

@@ -699,7 +699,7 @@ pub enum AgentStatus {
Completed(Option<String>),
/// Agent encountered an error.
Errored(String),
/// Agent has been shutdowned.
/// Agent has been shutdown.
Shutdown,
/// Agent is not found.
NotFound,