mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
feat: add wait tool implementation for collab (#9088)
Add implementation for the `wait` tool. For this we consider all status different from `PendingInit` and `Running` as terminal. The `wait` tool call will return either after a given timeout or when the tool reaches a non-terminal status. A few points to note: * The usage of a channel is preferred to prevent some races (just looping on `get_status()` could "miss" a terminal status) * The order of operations is very important, we need to first subscribe and then check the last known status to prevent race conditions * If the channel gets dropped, we return an error on purpose
This commit is contained in:
@@ -9,6 +9,7 @@ use codex_protocol::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// Control-plane handle for multi-agent operations.
|
||||
/// `AgentControl` is held by each session (via `SessionServices`). It provides capability to
|
||||
@@ -80,6 +81,16 @@ impl AgentControl {
|
||||
thread.agent_status().await
|
||||
}
|
||||
|
||||
/// Subscribe to status updates for `agent_id`, yielding the latest value and changes.
|
||||
pub(crate) async fn subscribe_status(
|
||||
&self,
|
||||
agent_id: ThreadId,
|
||||
) -> CodexResult<watch::Receiver<AgentStatus>> {
|
||||
let state = self.upgrade()?;
|
||||
let thread = state.get_thread(agent_id).await?;
|
||||
Ok(thread.subscribe_status())
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
@@ -275,6 +286,38 @@ mod tests {
|
||||
assert_eq!(status, AgentStatus::PendingInit);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_status_errors_for_missing_thread() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
let thread_id = ThreadId::new();
|
||||
let err = harness
|
||||
.control
|
||||
.subscribe_status(thread_id)
|
||||
.await
|
||||
.expect_err("subscribe_status should fail for missing thread");
|
||||
assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_status_updates_on_shutdown() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
let (thread_id, thread) = harness.start_thread().await;
|
||||
let mut status_rx = harness
|
||||
.control
|
||||
.subscribe_status(thread_id)
|
||||
.await
|
||||
.expect("subscribe_status should succeed");
|
||||
assert_eq!(status_rx.borrow().clone(), AgentStatus::PendingInit);
|
||||
|
||||
let _ = thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("shutdown should submit");
|
||||
|
||||
let _ = status_rx.changed().await;
|
||||
assert_eq!(status_rx.borrow().clone(), AgentStatus::Shutdown);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_prompt_submits_user_message() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
|
||||
@@ -13,3 +13,7 @@ pub(crate) fn agent_status_from_event(msg: &EventMsg) -> Option<AgentStatus> {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_final(status: &AgentStatus) -> bool {
|
||||
!matches!(status, AgentStatus::PendingInit | AgentStatus::Running)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,7 @@ use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::ReadinessFlag;
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// The high-level interface to the Codex system.
|
||||
/// It operates as a queue pair where you send submissions and receive events.
|
||||
@@ -172,7 +173,7 @@ pub struct Codex {
|
||||
pub(crate) tx_sub: Sender<Submission>,
|
||||
pub(crate) rx_event: Receiver<Event>,
|
||||
// Last known status of the agent.
|
||||
pub(crate) agent_status: Arc<RwLock<AgentStatus>>,
|
||||
pub(crate) agent_status: watch::Receiver<AgentStatus>,
|
||||
}
|
||||
|
||||
/// Wrapper returned by [`Codex::spawn`] containing the spawned [`Codex`],
|
||||
@@ -275,7 +276,7 @@ impl Codex {
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
let session_source_clone = session_configuration.session_source.clone();
|
||||
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
|
||||
let (agent_status_tx, agent_status_rx) = watch::channel(AgentStatus::PendingInit);
|
||||
|
||||
let session = Session::new(
|
||||
session_configuration,
|
||||
@@ -284,7 +285,7 @@ impl Codex {
|
||||
models_manager.clone(),
|
||||
exec_policy,
|
||||
tx_event.clone(),
|
||||
Arc::clone(&agent_status),
|
||||
agent_status_tx.clone(),
|
||||
conversation_history,
|
||||
session_source_clone,
|
||||
skills_manager,
|
||||
@@ -303,7 +304,7 @@ impl Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
rx_event,
|
||||
agent_status,
|
||||
agent_status: agent_status_rx,
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
@@ -345,8 +346,7 @@ impl Codex {
|
||||
}
|
||||
|
||||
pub(crate) async fn agent_status(&self) -> AgentStatus {
|
||||
let status = self.agent_status.read().await;
|
||||
status.clone()
|
||||
self.agent_status.borrow().clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ impl Codex {
|
||||
pub(crate) struct Session {
|
||||
conversation_id: ThreadId,
|
||||
tx_event: Sender<Event>,
|
||||
agent_status: Arc<RwLock<AgentStatus>>,
|
||||
agent_status: watch::Sender<AgentStatus>,
|
||||
state: Mutex<SessionState>,
|
||||
/// The set of enabled features should be invariant for the lifetime of the
|
||||
/// session.
|
||||
@@ -557,7 +557,7 @@ impl Session {
|
||||
models_manager: Arc<ModelsManager>,
|
||||
exec_policy: ExecPolicyManager,
|
||||
tx_event: Sender<Event>,
|
||||
agent_status: Arc<RwLock<AgentStatus>>,
|
||||
agent_status: watch::Sender<AgentStatus>,
|
||||
initial_history: InitialHistory,
|
||||
session_source: SessionSource,
|
||||
skills_manager: Arc<SkillsManager>,
|
||||
@@ -703,7 +703,7 @@ impl Session {
|
||||
let sess = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event: tx_event.clone(),
|
||||
agent_status: Arc::clone(&agent_status),
|
||||
agent_status,
|
||||
state: Mutex::new(state),
|
||||
features: config.features.clone(),
|
||||
active_turn: Mutex::new(None),
|
||||
@@ -1026,8 +1026,7 @@ impl Session {
|
||||
pub(crate) async fn send_event_raw(&self, event: Event) {
|
||||
// Record the last known agent status.
|
||||
if let Some(status) = agent_status_from_event(&event.msg) {
|
||||
let mut guard = self.agent_status.write().await;
|
||||
*guard = status;
|
||||
self.agent_status.send_replace(status);
|
||||
}
|
||||
// Persist the event into rollout (recorder filters as needed)
|
||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||
@@ -1045,8 +1044,7 @@ impl Session {
|
||||
pub(crate) async fn send_event_raw_flushed(&self, event: Event) {
|
||||
// Record the last known agent status.
|
||||
if let Some(status) = agent_status_from_event(&event.msg) {
|
||||
let mut guard = self.agent_status.write().await;
|
||||
*guard = status;
|
||||
self.agent_status.send_replace(status);
|
||||
}
|
||||
self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())])
|
||||
.await;
|
||||
@@ -3494,7 +3492,7 @@ mod tests {
|
||||
));
|
||||
let agent_control = AgentControl::default();
|
||||
let exec_policy = ExecPolicyManager::default();
|
||||
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
|
||||
let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit);
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
@@ -3557,7 +3555,7 @@ mod tests {
|
||||
let session = Session {
|
||||
conversation_id,
|
||||
tx_event,
|
||||
agent_status: Arc::clone(&agent_status),
|
||||
agent_status: agent_status_tx,
|
||||
state: Mutex::new(state),
|
||||
features: config.features.clone(),
|
||||
active_turn: Mutex::new(None),
|
||||
@@ -3588,7 +3586,7 @@ mod tests {
|
||||
));
|
||||
let agent_control = AgentControl::default();
|
||||
let exec_policy = ExecPolicyManager::default();
|
||||
let agent_status = Arc::new(RwLock::new(AgentStatus::PendingInit));
|
||||
let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit);
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
@@ -3651,7 +3649,7 @@ mod tests {
|
||||
let session = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event,
|
||||
agent_status: Arc::clone(&agent_status),
|
||||
agent_status: agent_status_tx,
|
||||
state: Mutex::new(state),
|
||||
features: config.features.clone(),
|
||||
active_turn: Mutex::new(None),
|
||||
|
||||
@@ -87,7 +87,7 @@ pub(crate) async fn run_codex_thread_interactive(
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub: tx_ops,
|
||||
rx_event: rx_sub,
|
||||
agent_status: Arc::clone(&codex.agent_status),
|
||||
agent_status: codex.agent_status.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ pub(crate) async fn run_codex_thread_one_shot(
|
||||
// Bridge events so we can observe completion and shut down automatically.
|
||||
let (tx_bridge, rx_bridge) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let ops_tx = io.tx_sub.clone();
|
||||
let agent_status = Arc::clone(&io.agent_status);
|
||||
let agent_status = io.agent_status.clone();
|
||||
let io_for_bridge = io;
|
||||
tokio::spawn(async move {
|
||||
while let Ok(event) = io_for_bridge.next_event().await {
|
||||
@@ -363,20 +363,23 @@ mod tests {
|
||||
use super::*;
|
||||
use async_channel::bounded;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::AgentStatus;
|
||||
use codex_protocol::protocol::RawResponseItemEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::watch;
|
||||
|
||||
#[tokio::test]
|
||||
async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() {
|
||||
let (tx_events, rx_events) = bounded(1);
|
||||
let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit);
|
||||
let codex = Arc::new(Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
rx_event: rx_events,
|
||||
agent_status: Default::default(),
|
||||
agent_status,
|
||||
});
|
||||
|
||||
let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await;
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::protocol::Event;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::Submission;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub struct CodexThread {
|
||||
codex: Codex,
|
||||
@@ -38,6 +39,10 @@ impl CodexThread {
|
||||
self.codex.agent_status().await
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_status(&self) -> watch::Receiver<AgentStatus> {
|
||||
self.codex.agent_status.clone()
|
||||
}
|
||||
|
||||
pub fn rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
||||
@@ -143,11 +143,14 @@ mod send_input {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
mod wait {
|
||||
use super::*;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::codex::Session;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout_at;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WaitArgs {
|
||||
@@ -168,40 +171,68 @@ mod wait {
|
||||
let args: WaitArgs = parse_arguments(&arguments)?;
|
||||
let agent_id = agent_id(&args.id)?;
|
||||
|
||||
// Validate timeout.
|
||||
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
|
||||
if timeout_ms <= 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"timeout_ms must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
let timeout_ms = timeout_ms.min(MAX_WAIT_TIMEOUT_MS);
|
||||
// TODO(jif) actual implementation
|
||||
let outcome = WaitResult {
|
||||
status: Default::default(),
|
||||
timed_out: false,
|
||||
let timeout_ms = match timeout_ms {
|
||||
ms if ms <= 0 => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"timeout_ms must be greater than zero".to_owned(),
|
||||
));
|
||||
}
|
||||
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
if matches!(outcome.status, AgentStatus::NotFound) {
|
||||
let mut status_rx = session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(agent_id)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
CodexErr::ThreadNotFound(id) => {
|
||||
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
|
||||
}
|
||||
err => FunctionCallError::Fatal(err.to_string()),
|
||||
})?;
|
||||
|
||||
// Get last known status.
|
||||
let mut status = status_rx.borrow_and_update().clone();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
|
||||
let timed_out = loop {
|
||||
if is_final(&status) {
|
||||
break false;
|
||||
}
|
||||
|
||||
match timeout_at(deadline, status_rx.changed()).await {
|
||||
Ok(Ok(())) => status = status_rx.borrow().clone(),
|
||||
Ok(Err(_)) => {
|
||||
let last_status = session.services.agent_control.get_status(agent_id).await;
|
||||
if last_status != AgentStatus::NotFound {
|
||||
// On-purpose we keep the last known status if the agent gets dropped. This
|
||||
// event is not supposed to happen.
|
||||
status = last_status;
|
||||
}
|
||||
break false;
|
||||
}
|
||||
Err(_) => break true,
|
||||
}
|
||||
};
|
||||
|
||||
if matches!(status, AgentStatus::NotFound) {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"agent with id {agent_id} not found"
|
||||
)));
|
||||
}
|
||||
|
||||
let message = outcome.timed_out.then(|| {
|
||||
format!(
|
||||
"Timed out after {timeout_ms}ms waiting for agent {agent_id}. The agent may still be running."
|
||||
)
|
||||
});
|
||||
let result = WaitResult {
|
||||
status: outcome.status,
|
||||
timed_out: outcome.timed_out,
|
||||
};
|
||||
let result = WaitResult { status, timed_out };
|
||||
|
||||
let content = serde_json::to_string(&result).map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(!outcome.timed_out),
|
||||
success: Some(!result.timed_out),
|
||||
content_items: None,
|
||||
})
|
||||
}
|
||||
@@ -264,6 +295,7 @@ mod tests {
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::ThreadId;
|
||||
@@ -271,7 +303,9 @@ mod tests {
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn invocation(
|
||||
session: Arc<crate::codex::Session>,
|
||||
@@ -475,6 +509,83 @@ mod tests {
|
||||
assert!(msg.starts_with("invalid agent id invalid:"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_times_out_when_status_is_not_final() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let config = turn.client.config().as_ref().clone();
|
||||
let thread = manager.start_thread(config).await.expect("start thread");
|
||||
let agent_id = thread.thread_id;
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 10})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait should succeed");
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = output
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
assert_eq!(content, r#"{"status":"pending_init","timed_out":true}"#);
|
||||
assert_eq!(success, Some(false));
|
||||
|
||||
let _ = thread
|
||||
.thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("shutdown should submit");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_returns_final_status_without_timeout() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let config = turn.client.config().as_ref().clone();
|
||||
let thread = manager.start_thread(config).await.expect("start thread");
|
||||
let agent_id = thread.thread_id;
|
||||
let mut status_rx = manager
|
||||
.agent_control()
|
||||
.subscribe_status(agent_id)
|
||||
.await
|
||||
.expect("subscribe should succeed");
|
||||
|
||||
let _ = thread
|
||||
.thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("shutdown should submit");
|
||||
let _ = timeout(Duration::from_secs(1), status_rx.changed())
|
||||
.await
|
||||
.expect("shutdown status should arrive");
|
||||
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait",
|
||||
function_payload(json!({"id": agent_id.to_string(), "timeout_ms": 1000})),
|
||||
);
|
||||
let output = CollabHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait should succeed");
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = output
|
||||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
assert_eq!(content, r#"{"status":"shutdown","timed_out":false}"#);
|
||||
assert_eq!(success, Some(true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_agent_reports_not_implemented() {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
|
||||
@@ -699,7 +699,7 @@ pub enum AgentStatus {
|
||||
Completed(Option<String>),
|
||||
/// Agent encountered an error.
|
||||
Errored(String),
|
||||
/// Agent has been shutdowned.
|
||||
/// Agent has been shutdown.
|
||||
Shutdown,
|
||||
/// Agent is not found.
|
||||
NotFound,
|
||||
|
||||
Reference in New Issue
Block a user