Files
codex/codex-rs/core/src/tools/handlers/collab.rs
Dylan Hurd 675f165c56 fix(core) Preserve base_instructions in SessionMeta (#9427)
## Summary
This PR consolidates base_instructions onto SessionMeta /
SessionConfiguration, so we ensure `base_instructions` is set once per
session and should be (mostly) immutable, unless:
- overridden by config on resume / fork
- sub-agent tasks, like review or collab


In a future PR, we should convert all references to `base_instructions`
to consistently used the typed struct, so it's less likely that we put
other strings there. See #9423. However, this PR is already quite
complex, so I'm deferring that to a follow-up.

## Testing
- [x] Added a resume test to assert that instructions are preserved. In
particular, `resume_switches_models_preserves_base_instructions` fails
against main.

Existing test coverage thats assert base instructions are preserved
across multiple requests in a session:
- Manual compact keeps baseline instructions:
core/tests/suite/compact.rs:199
- Auto-compact keeps baseline instructions:
core/tests/suite/compact.rs:1142
- Prompt caching reuses the same instructions across two requests:
core/tests/suite/prompt_caching.rs:150 and
core/tests/suite/prompt_caching.rs:157
- Prompt caching with explicit expected string across two requests:
core/tests/suite/prompt_caching.rs:213 and
core/tests/suite/prompt_caching.rs:222
- Resume with model switch keeps original instructions:
core/tests/suite/resume.rs:136
- Compact/resume/fork uses request 0 instructions for later expected
payloads: core/tests/suite/compact_resume_fork.rs:215
2026-01-19 21:59:36 -08:00

1111 lines
37 KiB
Rust

use crate::agent::AgentStatus;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::config::Config;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::handlers::parse_arguments;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use async_trait::async_trait;
use codex_protocol::ThreadId;
use codex_protocol::models::BaseInstructions;
use codex_protocol::protocol::CollabAgentInteractionBeginEvent;
use codex_protocol::protocol::CollabAgentInteractionEndEvent;
use codex_protocol::protocol::CollabAgentSpawnBeginEvent;
use codex_protocol::protocol::CollabAgentSpawnEndEvent;
use codex_protocol::protocol::CollabCloseBeginEvent;
use codex_protocol::protocol::CollabCloseEndEvent;
use codex_protocol::protocol::CollabWaitingBeginEvent;
use codex_protocol::protocol::CollabWaitingEndEvent;
use serde::Deserialize;
use serde::Serialize;
pub struct CollabHandler;
pub(crate) const DEFAULT_WAIT_TIMEOUT_MS: i64 = 30_000;
pub(crate) const MAX_WAIT_TIMEOUT_MS: i64 = 300_000;
#[derive(Debug, Deserialize)]
struct CloseAgentArgs {
id: String,
}
#[async_trait]
impl ToolHandler for CollabHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
tool_name,
payload,
call_id,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"collab handler received unsupported payload".to_string(),
));
}
};
match tool_name.as_str() {
"spawn_agent" => spawn::handle(session, turn, call_id, arguments).await,
"send_input" => send_input::handle(session, turn, call_id, arguments).await,
"wait" => wait::handle(session, turn, call_id, arguments).await,
"close_agent" => close_agent::handle(session, turn, call_id, arguments).await,
other => Err(FunctionCallError::RespondToModel(format!(
"unsupported collab tool {other}"
))),
}
}
}
mod spawn {
use super::*;
use crate::agent::AgentRole;
use std::sync::Arc;
#[derive(Debug, Deserialize)]
struct SpawnAgentArgs {
message: String,
agent_type: Option<AgentRole>,
}
#[derive(Debug, Serialize)]
struct SpawnAgentResult {
agent_id: String,
}
pub async fn handle(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: String,
) -> Result<ToolOutput, FunctionCallError> {
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
let agent_role = args.agent_type.unwrap_or(AgentRole::Default);
let prompt = args.message;
if prompt.trim().is_empty() {
return Err(FunctionCallError::RespondToModel(
"Empty message can't be sent to an agent".to_string(),
));
}
session
.send_event(
&turn,
CollabAgentSpawnBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
prompt: prompt.clone(),
}
.into(),
)
.await;
let mut config =
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
agent_role
.apply_to_config(&mut config)
.map_err(FunctionCallError::RespondToModel)?;
let result = session
.services
.agent_control
.spawn_agent(config, prompt.clone())
.await
.map_err(collab_spawn_error);
let (new_thread_id, status) = match &result {
Ok(thread_id) => (
Some(*thread_id),
session.services.agent_control.get_status(*thread_id).await,
),
Err(_) => (None, AgentStatus::NotFound),
};
session
.send_event(
&turn,
CollabAgentSpawnEndEvent {
call_id,
sender_thread_id: session.conversation_id,
new_thread_id,
prompt,
status,
}
.into(),
)
.await;
let new_thread_id = result?;
let content = serde_json::to_string(&SpawnAgentResult {
agent_id: new_thread_id.to_string(),
})
.map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize spawn_agent result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: Some(true),
content_items: None,
})
}
}
mod send_input {
use super::*;
use std::sync::Arc;
#[derive(Debug, Deserialize)]
struct SendInputArgs {
id: String,
message: String,
#[serde(default)]
interrupt: bool,
}
#[derive(Debug, Serialize)]
struct SendInputResult {
submission_id: String,
}
pub async fn handle(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: String,
) -> Result<ToolOutput, FunctionCallError> {
let args: SendInputArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
let prompt = args.message;
if prompt.trim().is_empty() {
return Err(FunctionCallError::RespondToModel(
"Empty message can't be sent to an agent".to_string(),
));
}
if args.interrupt {
session
.services
.agent_control
.interrupt_agent(receiver_thread_id)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
}
session
.send_event(
&turn,
CollabAgentInteractionBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id,
prompt: prompt.clone(),
}
.into(),
)
.await;
let result = session
.services
.agent_control
.send_prompt(receiver_thread_id, prompt.clone())
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err));
let status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
session
.send_event(
&turn,
CollabAgentInteractionEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id,
prompt,
status,
}
.into(),
)
.await;
let submission_id = result?;
let content = serde_json::to_string(&SendInputResult { submission_id }).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize send_input result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: Some(true),
content_items: None,
})
}
}
mod wait {
use super::*;
use crate::agent::status::is_final;
use futures::FutureExt;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch::Receiver;
use tokio::time::Instant;
use tokio::time::timeout_at;
#[derive(Debug, Deserialize)]
struct WaitArgs {
ids: Vec<String>,
timeout_ms: Option<i64>,
}
#[derive(Debug, Serialize)]
struct WaitResult {
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
}
pub async fn handle(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: String,
) -> Result<ToolOutput, FunctionCallError> {
let args: WaitArgs = parse_arguments(&arguments)?;
if args.ids.is_empty() {
return Err(FunctionCallError::RespondToModel(
"ids must be non-empty".to_owned(),
));
}
let receiver_thread_ids = args
.ids
.iter()
.map(|id| agent_id(id))
.collect::<Result<Vec<_>, _>>()?;
// Validate timeout.
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
let timeout_ms = match timeout_ms {
ms if ms <= 0 => {
return Err(FunctionCallError::RespondToModel(
"timeout_ms must be greater than zero".to_owned(),
));
}
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
};
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
// Wait for the first agent to reach a final status.
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
// Drain the unlikely last elements to prevent race.
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
results
};
// Convert payload.
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
// Final event emission.
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
statuses: statuses_map,
}
.into(),
)
.await;
let content = serde_json::to_string(&result).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: None,
content_items: None,
})
}
async fn wait_for_final_status(
session: Arc<Session>,
thread_id: ThreadId,
mut status_rx: Receiver<AgentStatus>,
) -> Option<(ThreadId, AgentStatus)> {
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(thread_id).await;
return is_final(&latest).then_some((thread_id, latest));
}
status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
}
}
}
pub mod close_agent {
use super::*;
use std::sync::Arc;
#[derive(Debug, Deserialize, Serialize)]
pub(super) struct CloseAgentResult {
pub(super) status: AgentStatus,
}
pub async fn handle(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: String,
) -> Result<ToolOutput, FunctionCallError> {
let args: CloseAgentArgs = parse_arguments(&arguments)?;
let agent_id = agent_id(&args.id)?;
session
.send_event(
&turn,
CollabCloseBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
}
.into(),
)
.await;
let status = match session
.services
.agent_control
.subscribe_status(agent_id)
.await
{
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
Err(err) => {
let status = session.services.agent_control.get_status(agent_id).await;
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
status,
}
.into(),
)
.await;
return Err(collab_agent_error(agent_id, err));
}
};
let result = if !matches!(status, AgentStatus::Shutdown) {
session
.services
.agent_control
.shutdown_agent(agent_id)
.await
.map_err(|err| collab_agent_error(agent_id, err))
.map(|_| ())
} else {
Ok(())
};
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
status: status.clone(),
}
.into(),
)
.await;
result?;
let content = serde_json::to_string(&CloseAgentResult { status }).map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize close_agent result: {err}"))
})?;
Ok(ToolOutput::Function {
content,
success: Some(true),
content_items: None,
})
}
}
fn agent_id(id: &str) -> Result<ThreadId, FunctionCallError> {
ThreadId::from_string(id)
.map_err(|e| FunctionCallError::RespondToModel(format!("invalid agent id {id}: {e:?}")))
}
fn collab_spawn_error(err: CodexErr) -> FunctionCallError {
match err {
CodexErr::UnsupportedOperation(_) => {
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
}
err => FunctionCallError::RespondToModel(format!("collab spawn failed: {err}")),
}
}
fn collab_agent_error(agent_id: ThreadId, err: CodexErr) -> FunctionCallError {
match err {
CodexErr::ThreadNotFound(id) => {
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
}
CodexErr::InternalAgentDied => {
FunctionCallError::RespondToModel(format!("agent with id {agent_id} is closed"))
}
CodexErr::UnsupportedOperation(_) => {
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
}
err => FunctionCallError::RespondToModel(format!("collab tool failed: {err}")),
}
}
fn build_agent_spawn_config(
base_instructions: &BaseInstructions,
turn: &TurnContext,
) -> Result<Config, FunctionCallError> {
let base_config = turn.client.config();
let mut config = (*base_config).clone();
config.base_instructions = Some(base_instructions.text.clone());
config.model = Some(turn.client.get_model());
config.model_provider = turn.client.get_provider();
config.model_reasoning_effort = turn.client.get_reasoning_effort();
config.model_reasoning_summary = turn.client.get_reasoning_summary();
config.developer_instructions = turn.developer_instructions.clone();
config.compact_prompt = turn.compact_prompt.clone();
config.user_instructions = turn.user_instructions.clone();
config.shell_environment_policy = turn.shell_environment_policy.clone();
config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
config.cwd = turn.cwd.clone();
config
.approval_policy
.set(turn.approval_policy)
.map_err(|err| {
FunctionCallError::RespondToModel(format!("approval_policy is invalid: {err}"))
})?;
config
.sandbox_policy
.set(turn.sandbox_policy.clone())
.map_err(|err| {
FunctionCallError::RespondToModel(format!("sandbox_policy is invalid: {err}"))
})?;
Ok(config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CodexAuth;
use crate::ThreadManager;
use crate::built_in_model_providers;
use crate::codex::make_session_and_context;
use crate::config::types::ShellEnvironmentPolicy;
use crate::function_tool::FunctionCallError;
use crate::protocol::AskForApproval;
use crate::protocol::Op;
use crate::protocol::SandboxPolicy;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::ThreadId;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::timeout;
fn invocation(
session: Arc<crate::codex::Session>,
turn: Arc<TurnContext>,
tool_name: &str,
payload: ToolPayload,
) -> ToolInvocation {
ToolInvocation {
session,
turn,
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
call_id: "call-1".to_string(),
tool_name: tool_name.to_string(),
payload,
}
}
fn function_payload(args: serde_json::Value) -> ToolPayload {
ToolPayload::Function {
arguments: args.to_string(),
}
}
fn thread_manager() -> ThreadManager {
ThreadManager::with_models_provider(
CodexAuth::from_api_key("dummy"),
built_in_model_providers()["openai"].clone(),
)
}
#[tokio::test]
async fn handler_rejects_non_function_payloads() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"spawn_agent",
ToolPayload::Custom {
input: "hello".to_string(),
},
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("payload should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel(
"collab handler received unsupported payload".to_string()
)
);
}
#[tokio::test]
async fn handler_rejects_unknown_tool() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"unknown_tool",
function_payload(json!({})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("tool should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel("unsupported collab tool unknown_tool".to_string())
);
}
#[tokio::test]
async fn spawn_agent_rejects_empty_message() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"spawn_agent",
function_payload(json!({"message": " "})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("empty message should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel(
"Empty message can't be sent to an agent".to_string()
)
);
}
#[tokio::test]
async fn spawn_agent_errors_when_manager_dropped() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"spawn_agent",
function_payload(json!({"message": "hello"})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("spawn should fail without a manager");
};
assert_eq!(
err,
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
);
}
#[tokio::test]
async fn send_input_rejects_empty_message() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"send_input",
function_payload(json!({"id": ThreadId::new().to_string(), "message": ""})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("empty message should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel(
"Empty message can't be sent to an agent".to_string()
)
);
}
#[tokio::test]
async fn send_input_rejects_invalid_id() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"send_input",
function_payload(json!({"id": "not-a-uuid", "message": "hi"})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("invalid id should be rejected");
};
let FunctionCallError::RespondToModel(msg) = err else {
panic!("expected respond-to-model error");
};
assert!(msg.starts_with("invalid agent id not-a-uuid:"));
}
#[tokio::test]
async fn send_input_reports_missing_agent() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let agent_id = ThreadId::new();
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"send_input",
function_payload(json!({"id": agent_id.to_string(), "message": "hi"})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("missing agent should be reported");
};
assert_eq!(
err,
FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found"))
);
}
#[tokio::test]
async fn send_input_interrupts_before_prompt() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"send_input",
function_payload(json!({
"id": agent_id.to_string(),
"message": "hi",
"interrupt": true
})),
);
CollabHandler
.handle(invocation)
.await
.expect("send_input should succeed");
let ops = manager.captured_ops();
let ops_for_agent: Vec<&Op> = ops
.iter()
.filter_map(|(id, op)| (*id == agent_id).then_some(op))
.collect();
assert_eq!(ops_for_agent.len(), 2);
assert!(matches!(ops_for_agent[0], Op::Interrupt));
assert!(matches!(ops_for_agent[1], Op::UserInput { .. }));
let _ = thread
.thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct WaitResult {
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
}
#[tokio::test]
async fn wait_rejects_non_positive_timeout() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [ThreadId::new().to_string()],
"timeout_ms": 0
})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("non-positive timeout should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel("timeout_ms must be greater than zero".to_string())
);
}
#[tokio::test]
async fn wait_rejects_invalid_id() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"ids": ["invalid"]})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("invalid id should be rejected");
};
let FunctionCallError::RespondToModel(msg) = err else {
panic!("expected respond-to-model error");
};
assert!(msg.starts_with("invalid agent id invalid:"));
}
#[tokio::test]
async fn wait_rejects_empty_ids() {
let (session, turn) = make_session_and_context().await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({"ids": []})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("empty ids should be rejected");
};
assert_eq!(
err,
FunctionCallError::RespondToModel("ids must be non-empty".to_string())
);
}
#[tokio::test]
async fn wait_returns_not_found_for_missing_agents() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let id_a = ThreadId::new();
let id_b = ThreadId::new();
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [id_a.to_string(), id_b.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([
(id_a, AgentStatus::NotFound),
(id_b, AgentStatus::NotFound),
]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_times_out_when_status_is_not_final() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [agent_id.to_string()],
"timeout_ms": 10
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::new(),
timed_out: true
}
);
assert_eq!(success, None);
let _ = thread
.thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
}
#[tokio::test]
async fn wait_returns_final_status_without_timeout() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let mut status_rx = manager
.agent_control()
.subscribe_status(agent_id)
.await
.expect("subscribe should succeed");
let _ = thread
.thread
.submit(Op::Shutdown {})
.await
.expect("shutdown should submit");
let _ = timeout(Duration::from_secs(1), status_rx.changed())
.await
.expect("shutdown status should arrive");
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [agent_id.to_string()],
"timeout_ms": 1000
})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("wait should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]
async fn close_agent_submits_shutdown_and_returns_status() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.client.config().as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let status_before = manager.agent_control().get_status(agent_id).await;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"close_agent",
function_payload(json!({"id": agent_id.to_string()})),
);
let output = CollabHandler
.handle(invocation)
.await
.expect("close_agent should succeed");
let ToolOutput::Function {
content, success, ..
} = output
else {
panic!("expected function output");
};
let result: close_agent::CloseAgentResult =
serde_json::from_str(&content).expect("close_agent result should be json");
assert_eq!(result.status, status_before);
assert_eq!(success, Some(true));
let ops = manager.captured_ops();
let submitted_shutdown = ops
.iter()
.any(|(id, op)| *id == agent_id && matches!(op, Op::Shutdown));
assert_eq!(submitted_shutdown, true);
let status_after = manager.agent_control().get_status(agent_id).await;
assert_eq!(status_after, AgentStatus::NotFound);
}
#[tokio::test]
async fn build_agent_spawn_config_uses_turn_context_values() {
let (_session, mut turn) = make_session_and_context().await;
let base_instructions = BaseInstructions {
text: "base".to_string(),
};
turn.developer_instructions = Some("dev".to_string());
turn.compact_prompt = Some("compact".to_string());
turn.user_instructions = Some("user".to_string());
turn.shell_environment_policy = ShellEnvironmentPolicy {
use_profile: true,
..ShellEnvironmentPolicy::default()
};
let temp_dir = tempfile::tempdir().expect("temp dir");
turn.cwd = temp_dir.path().to_path_buf();
turn.codex_linux_sandbox_exe = Some(PathBuf::from("/bin/echo"));
turn.approval_policy = AskForApproval::Never;
turn.sandbox_policy = SandboxPolicy::DangerFullAccess;
let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config");
let mut expected = (*turn.client.config()).clone();
expected.base_instructions = Some(base_instructions.text);
expected.model = Some(turn.client.get_model());
expected.model_provider = turn.client.get_provider();
expected.model_reasoning_effort = turn.client.get_reasoning_effort();
expected.model_reasoning_summary = turn.client.get_reasoning_summary();
expected.developer_instructions = turn.developer_instructions.clone();
expected.compact_prompt = turn.compact_prompt.clone();
expected.user_instructions = turn.user_instructions.clone();
expected.shell_environment_policy = turn.shell_environment_policy.clone();
expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
expected.cwd = turn.cwd.clone();
expected
.approval_policy
.set(turn.approval_policy)
.expect("approval policy set");
expected
.sandbox_policy
.set(turn.sandbox_policy)
.expect("sandbox policy set");
assert_eq!(config, expected);
}
}