mirror of
https://github.com/openai/codex.git
synced 2026-04-29 08:56:38 +00:00
## Summary This PR consolidates base_instructions onto SessionMeta / SessionConfiguration, so we ensure `base_instructions` is set once per session and should be (mostly) immutable, unless: - overridden by config on resume / fork - sub-agent tasks, like review or collab In a future PR, we should convert all references to `base_instructions` to consistently used the typed struct, so it's less likely that we put other strings there. See #9423. However, this PR is already quite complex, so I'm deferring that to a follow-up. ## Testing - [x] Added a resume test to assert that instructions are preserved. In particular, `resume_switches_models_preserves_base_instructions` fails against main. Existing test coverage thats assert base instructions are preserved across multiple requests in a session: - Manual compact keeps baseline instructions: core/tests/suite/compact.rs:199 - Auto-compact keeps baseline instructions: core/tests/suite/compact.rs:1142 - Prompt caching reuses the same instructions across two requests: core/tests/suite/prompt_caching.rs:150 and core/tests/suite/prompt_caching.rs:157 - Prompt caching with explicit expected string across two requests: core/tests/suite/prompt_caching.rs:213 and core/tests/suite/prompt_caching.rs:222 - Resume with model switch keeps original instructions: core/tests/suite/resume.rs:136 - Compact/resume/fork uses request 0 instructions for later expected payloads: core/tests/suite/compact_resume_fork.rs:215
1111 lines
37 KiB
Rust
1111 lines
37 KiB
Rust
use crate::agent::AgentStatus;
|
|
use crate::codex::Session;
|
|
use crate::codex::TurnContext;
|
|
use crate::config::Config;
|
|
use crate::error::CodexErr;
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::tools::context::ToolInvocation;
|
|
use crate::tools::context::ToolOutput;
|
|
use crate::tools::context::ToolPayload;
|
|
use crate::tools::handlers::parse_arguments;
|
|
use crate::tools::registry::ToolHandler;
|
|
use crate::tools::registry::ToolKind;
|
|
use async_trait::async_trait;
|
|
use codex_protocol::ThreadId;
|
|
use codex_protocol::models::BaseInstructions;
|
|
use codex_protocol::protocol::CollabAgentInteractionBeginEvent;
|
|
use codex_protocol::protocol::CollabAgentInteractionEndEvent;
|
|
use codex_protocol::protocol::CollabAgentSpawnBeginEvent;
|
|
use codex_protocol::protocol::CollabAgentSpawnEndEvent;
|
|
use codex_protocol::protocol::CollabCloseBeginEvent;
|
|
use codex_protocol::protocol::CollabCloseEndEvent;
|
|
use codex_protocol::protocol::CollabWaitingBeginEvent;
|
|
use codex_protocol::protocol::CollabWaitingEndEvent;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
|
|
pub struct CollabHandler;
|
|
|
|
pub(crate) const DEFAULT_WAIT_TIMEOUT_MS: i64 = 30_000;
|
|
pub(crate) const MAX_WAIT_TIMEOUT_MS: i64 = 300_000;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct CloseAgentArgs {
|
|
id: String,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ToolHandler for CollabHandler {
|
|
fn kind(&self) -> ToolKind {
|
|
ToolKind::Function
|
|
}
|
|
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
|
matches!(payload, ToolPayload::Function { .. })
|
|
}
|
|
|
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
|
let ToolInvocation {
|
|
session,
|
|
turn,
|
|
tool_name,
|
|
payload,
|
|
call_id,
|
|
..
|
|
} = invocation;
|
|
|
|
let arguments = match payload {
|
|
ToolPayload::Function { arguments } => arguments,
|
|
_ => {
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"collab handler received unsupported payload".to_string(),
|
|
));
|
|
}
|
|
};
|
|
|
|
match tool_name.as_str() {
|
|
"spawn_agent" => spawn::handle(session, turn, call_id, arguments).await,
|
|
"send_input" => send_input::handle(session, turn, call_id, arguments).await,
|
|
"wait" => wait::handle(session, turn, call_id, arguments).await,
|
|
"close_agent" => close_agent::handle(session, turn, call_id, arguments).await,
|
|
other => Err(FunctionCallError::RespondToModel(format!(
|
|
"unsupported collab tool {other}"
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
mod spawn {
|
|
use super::*;
|
|
use crate::agent::AgentRole;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct SpawnAgentArgs {
|
|
message: String,
|
|
agent_type: Option<AgentRole>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct SpawnAgentResult {
|
|
agent_id: String,
|
|
}
|
|
|
|
pub async fn handle(
|
|
session: Arc<Session>,
|
|
turn: Arc<TurnContext>,
|
|
call_id: String,
|
|
arguments: String,
|
|
) -> Result<ToolOutput, FunctionCallError> {
|
|
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
|
|
let agent_role = args.agent_type.unwrap_or(AgentRole::Default);
|
|
let prompt = args.message;
|
|
if prompt.trim().is_empty() {
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"Empty message can't be sent to an agent".to_string(),
|
|
));
|
|
}
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabAgentSpawnBeginEvent {
|
|
call_id: call_id.clone(),
|
|
sender_thread_id: session.conversation_id,
|
|
prompt: prompt.clone(),
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
let mut config =
|
|
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
|
|
agent_role
|
|
.apply_to_config(&mut config)
|
|
.map_err(FunctionCallError::RespondToModel)?;
|
|
|
|
let result = session
|
|
.services
|
|
.agent_control
|
|
.spawn_agent(config, prompt.clone())
|
|
.await
|
|
.map_err(collab_spawn_error);
|
|
let (new_thread_id, status) = match &result {
|
|
Ok(thread_id) => (
|
|
Some(*thread_id),
|
|
session.services.agent_control.get_status(*thread_id).await,
|
|
),
|
|
Err(_) => (None, AgentStatus::NotFound),
|
|
};
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabAgentSpawnEndEvent {
|
|
call_id,
|
|
sender_thread_id: session.conversation_id,
|
|
new_thread_id,
|
|
prompt,
|
|
status,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
let new_thread_id = result?;
|
|
|
|
let content = serde_json::to_string(&SpawnAgentResult {
|
|
agent_id: new_thread_id.to_string(),
|
|
})
|
|
.map_err(|err| {
|
|
FunctionCallError::Fatal(format!("failed to serialize spawn_agent result: {err}"))
|
|
})?;
|
|
|
|
Ok(ToolOutput::Function {
|
|
content,
|
|
success: Some(true),
|
|
content_items: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
mod send_input {
|
|
use super::*;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct SendInputArgs {
|
|
id: String,
|
|
message: String,
|
|
#[serde(default)]
|
|
interrupt: bool,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct SendInputResult {
|
|
submission_id: String,
|
|
}
|
|
|
|
pub async fn handle(
|
|
session: Arc<Session>,
|
|
turn: Arc<TurnContext>,
|
|
call_id: String,
|
|
arguments: String,
|
|
) -> Result<ToolOutput, FunctionCallError> {
|
|
let args: SendInputArgs = parse_arguments(&arguments)?;
|
|
let receiver_thread_id = agent_id(&args.id)?;
|
|
let prompt = args.message;
|
|
if prompt.trim().is_empty() {
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"Empty message can't be sent to an agent".to_string(),
|
|
));
|
|
}
|
|
if args.interrupt {
|
|
session
|
|
.services
|
|
.agent_control
|
|
.interrupt_agent(receiver_thread_id)
|
|
.await
|
|
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
|
|
}
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabAgentInteractionBeginEvent {
|
|
call_id: call_id.clone(),
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_id,
|
|
prompt: prompt.clone(),
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
let result = session
|
|
.services
|
|
.agent_control
|
|
.send_prompt(receiver_thread_id, prompt.clone())
|
|
.await
|
|
.map_err(|err| collab_agent_error(receiver_thread_id, err));
|
|
let status = session
|
|
.services
|
|
.agent_control
|
|
.get_status(receiver_thread_id)
|
|
.await;
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabAgentInteractionEndEvent {
|
|
call_id,
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_id,
|
|
prompt,
|
|
status,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
let submission_id = result?;
|
|
|
|
let content = serde_json::to_string(&SendInputResult { submission_id }).map_err(|err| {
|
|
FunctionCallError::Fatal(format!("failed to serialize send_input result: {err}"))
|
|
})?;
|
|
|
|
Ok(ToolOutput::Function {
|
|
content,
|
|
success: Some(true),
|
|
content_items: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
mod wait {
|
|
use super::*;
|
|
use crate::agent::status::is_final;
|
|
use futures::FutureExt;
|
|
use futures::StreamExt;
|
|
use futures::stream::FuturesUnordered;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::sync::watch::Receiver;
|
|
use tokio::time::Instant;
|
|
|
|
use tokio::time::timeout_at;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct WaitArgs {
|
|
ids: Vec<String>,
|
|
timeout_ms: Option<i64>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct WaitResult {
|
|
status: HashMap<ThreadId, AgentStatus>,
|
|
timed_out: bool,
|
|
}
|
|
|
|
pub async fn handle(
|
|
session: Arc<Session>,
|
|
turn: Arc<TurnContext>,
|
|
call_id: String,
|
|
arguments: String,
|
|
) -> Result<ToolOutput, FunctionCallError> {
|
|
let args: WaitArgs = parse_arguments(&arguments)?;
|
|
if args.ids.is_empty() {
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"ids must be non-empty".to_owned(),
|
|
));
|
|
}
|
|
let receiver_thread_ids = args
|
|
.ids
|
|
.iter()
|
|
.map(|id| agent_id(id))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
|
|
// Validate timeout.
|
|
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
|
|
let timeout_ms = match timeout_ms {
|
|
ms if ms <= 0 => {
|
|
return Err(FunctionCallError::RespondToModel(
|
|
"timeout_ms must be greater than zero".to_owned(),
|
|
));
|
|
}
|
|
ms => ms.min(MAX_WAIT_TIMEOUT_MS),
|
|
};
|
|
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabWaitingBeginEvent {
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_ids: receiver_thread_ids.clone(),
|
|
call_id: call_id.clone(),
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
|
|
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
|
let mut initial_final_statuses = Vec::new();
|
|
for id in &receiver_thread_ids {
|
|
match session.services.agent_control.subscribe_status(*id).await {
|
|
Ok(rx) => {
|
|
let status = rx.borrow().clone();
|
|
if is_final(&status) {
|
|
initial_final_statuses.push((*id, status));
|
|
}
|
|
status_rxs.push((*id, rx));
|
|
}
|
|
Err(CodexErr::ThreadNotFound(_)) => {
|
|
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
|
}
|
|
Err(err) => {
|
|
let mut statuses = HashMap::with_capacity(1);
|
|
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabWaitingEndEvent {
|
|
sender_thread_id: session.conversation_id,
|
|
call_id: call_id.clone(),
|
|
statuses,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
return Err(collab_agent_error(*id, err));
|
|
}
|
|
}
|
|
}
|
|
|
|
let statuses = if !initial_final_statuses.is_empty() {
|
|
initial_final_statuses
|
|
} else {
|
|
// Wait for the first agent to reach a final status.
|
|
let mut futures = FuturesUnordered::new();
|
|
for (id, rx) in status_rxs.into_iter() {
|
|
let session = session.clone();
|
|
futures.push(wait_for_final_status(session, id, rx));
|
|
}
|
|
let mut results = Vec::new();
|
|
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
|
loop {
|
|
match timeout_at(deadline, futures.next()).await {
|
|
Ok(Some(Some(result))) => {
|
|
results.push(result);
|
|
break;
|
|
}
|
|
Ok(Some(None)) => continue,
|
|
Ok(None) | Err(_) => break,
|
|
}
|
|
}
|
|
if !results.is_empty() {
|
|
// Drain the unlikely last elements to prevent race.
|
|
loop {
|
|
match futures.next().now_or_never() {
|
|
Some(Some(Some(result))) => results.push(result),
|
|
Some(Some(None)) => continue,
|
|
Some(None) | None => break,
|
|
}
|
|
}
|
|
}
|
|
results
|
|
};
|
|
|
|
// Convert payload.
|
|
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
|
let result = WaitResult {
|
|
status: statuses_map.clone(),
|
|
timed_out: statuses.is_empty(),
|
|
};
|
|
|
|
// Final event emission.
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabWaitingEndEvent {
|
|
sender_thread_id: session.conversation_id,
|
|
call_id,
|
|
statuses: statuses_map,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
|
|
let content = serde_json::to_string(&result).map_err(|err| {
|
|
FunctionCallError::Fatal(format!("failed to serialize wait result: {err}"))
|
|
})?;
|
|
|
|
Ok(ToolOutput::Function {
|
|
content,
|
|
success: None,
|
|
content_items: None,
|
|
})
|
|
}
|
|
|
|
async fn wait_for_final_status(
|
|
session: Arc<Session>,
|
|
thread_id: ThreadId,
|
|
mut status_rx: Receiver<AgentStatus>,
|
|
) -> Option<(ThreadId, AgentStatus)> {
|
|
let mut status = status_rx.borrow().clone();
|
|
if is_final(&status) {
|
|
return Some((thread_id, status));
|
|
}
|
|
|
|
loop {
|
|
if status_rx.changed().await.is_err() {
|
|
let latest = session.services.agent_control.get_status(thread_id).await;
|
|
return is_final(&latest).then_some((thread_id, latest));
|
|
}
|
|
status = status_rx.borrow().clone();
|
|
if is_final(&status) {
|
|
return Some((thread_id, status));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod close_agent {
|
|
use super::*;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
pub(super) struct CloseAgentResult {
|
|
pub(super) status: AgentStatus,
|
|
}
|
|
|
|
pub async fn handle(
|
|
session: Arc<Session>,
|
|
turn: Arc<TurnContext>,
|
|
call_id: String,
|
|
arguments: String,
|
|
) -> Result<ToolOutput, FunctionCallError> {
|
|
let args: CloseAgentArgs = parse_arguments(&arguments)?;
|
|
let agent_id = agent_id(&args.id)?;
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabCloseBeginEvent {
|
|
call_id: call_id.clone(),
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_id: agent_id,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
let status = match session
|
|
.services
|
|
.agent_control
|
|
.subscribe_status(agent_id)
|
|
.await
|
|
{
|
|
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
|
|
Err(err) => {
|
|
let status = session.services.agent_control.get_status(agent_id).await;
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabCloseEndEvent {
|
|
call_id: call_id.clone(),
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_id: agent_id,
|
|
status,
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
return Err(collab_agent_error(agent_id, err));
|
|
}
|
|
};
|
|
let result = if !matches!(status, AgentStatus::Shutdown) {
|
|
session
|
|
.services
|
|
.agent_control
|
|
.shutdown_agent(agent_id)
|
|
.await
|
|
.map_err(|err| collab_agent_error(agent_id, err))
|
|
.map(|_| ())
|
|
} else {
|
|
Ok(())
|
|
};
|
|
session
|
|
.send_event(
|
|
&turn,
|
|
CollabCloseEndEvent {
|
|
call_id,
|
|
sender_thread_id: session.conversation_id,
|
|
receiver_thread_id: agent_id,
|
|
status: status.clone(),
|
|
}
|
|
.into(),
|
|
)
|
|
.await;
|
|
result?;
|
|
|
|
let content = serde_json::to_string(&CloseAgentResult { status }).map_err(|err| {
|
|
FunctionCallError::Fatal(format!("failed to serialize close_agent result: {err}"))
|
|
})?;
|
|
|
|
Ok(ToolOutput::Function {
|
|
content,
|
|
success: Some(true),
|
|
content_items: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
fn agent_id(id: &str) -> Result<ThreadId, FunctionCallError> {
|
|
ThreadId::from_string(id)
|
|
.map_err(|e| FunctionCallError::RespondToModel(format!("invalid agent id {id}: {e:?}")))
|
|
}
|
|
|
|
fn collab_spawn_error(err: CodexErr) -> FunctionCallError {
|
|
match err {
|
|
CodexErr::UnsupportedOperation(_) => {
|
|
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
|
|
}
|
|
err => FunctionCallError::RespondToModel(format!("collab spawn failed: {err}")),
|
|
}
|
|
}
|
|
|
|
fn collab_agent_error(agent_id: ThreadId, err: CodexErr) -> FunctionCallError {
|
|
match err {
|
|
CodexErr::ThreadNotFound(id) => {
|
|
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
|
|
}
|
|
CodexErr::InternalAgentDied => {
|
|
FunctionCallError::RespondToModel(format!("agent with id {agent_id} is closed"))
|
|
}
|
|
CodexErr::UnsupportedOperation(_) => {
|
|
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
|
|
}
|
|
err => FunctionCallError::RespondToModel(format!("collab tool failed: {err}")),
|
|
}
|
|
}
|
|
|
|
fn build_agent_spawn_config(
|
|
base_instructions: &BaseInstructions,
|
|
turn: &TurnContext,
|
|
) -> Result<Config, FunctionCallError> {
|
|
let base_config = turn.client.config();
|
|
let mut config = (*base_config).clone();
|
|
config.base_instructions = Some(base_instructions.text.clone());
|
|
config.model = Some(turn.client.get_model());
|
|
config.model_provider = turn.client.get_provider();
|
|
config.model_reasoning_effort = turn.client.get_reasoning_effort();
|
|
config.model_reasoning_summary = turn.client.get_reasoning_summary();
|
|
config.developer_instructions = turn.developer_instructions.clone();
|
|
config.compact_prompt = turn.compact_prompt.clone();
|
|
config.user_instructions = turn.user_instructions.clone();
|
|
config.shell_environment_policy = turn.shell_environment_policy.clone();
|
|
config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
|
|
config.cwd = turn.cwd.clone();
|
|
config
|
|
.approval_policy
|
|
.set(turn.approval_policy)
|
|
.map_err(|err| {
|
|
FunctionCallError::RespondToModel(format!("approval_policy is invalid: {err}"))
|
|
})?;
|
|
config
|
|
.sandbox_policy
|
|
.set(turn.sandbox_policy.clone())
|
|
.map_err(|err| {
|
|
FunctionCallError::RespondToModel(format!("sandbox_policy is invalid: {err}"))
|
|
})?;
|
|
Ok(config)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::CodexAuth;
|
|
use crate::ThreadManager;
|
|
use crate::built_in_model_providers;
|
|
use crate::codex::make_session_and_context;
|
|
use crate::config::types::ShellEnvironmentPolicy;
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::protocol::AskForApproval;
|
|
use crate::protocol::Op;
|
|
use crate::protocol::SandboxPolicy;
|
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
|
use codex_protocol::ThreadId;
|
|
use pretty_assertions::assert_eq;
|
|
use serde::Deserialize;
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::sync::Mutex;
|
|
use tokio::time::timeout;
|
|
|
|
fn invocation(
|
|
session: Arc<crate::codex::Session>,
|
|
turn: Arc<TurnContext>,
|
|
tool_name: &str,
|
|
payload: ToolPayload,
|
|
) -> ToolInvocation {
|
|
ToolInvocation {
|
|
session,
|
|
turn,
|
|
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
|
|
call_id: "call-1".to_string(),
|
|
tool_name: tool_name.to_string(),
|
|
payload,
|
|
}
|
|
}
|
|
|
|
fn function_payload(args: serde_json::Value) -> ToolPayload {
|
|
ToolPayload::Function {
|
|
arguments: args.to_string(),
|
|
}
|
|
}
|
|
|
|
fn thread_manager() -> ThreadManager {
|
|
ThreadManager::with_models_provider(
|
|
CodexAuth::from_api_key("dummy"),
|
|
built_in_model_providers()["openai"].clone(),
|
|
)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn handler_rejects_non_function_payloads() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"spawn_agent",
|
|
ToolPayload::Custom {
|
|
input: "hello".to_string(),
|
|
},
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("payload should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel(
|
|
"collab handler received unsupported payload".to_string()
|
|
)
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn handler_rejects_unknown_tool() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"unknown_tool",
|
|
function_payload(json!({})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("tool should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel("unsupported collab tool unknown_tool".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn spawn_agent_rejects_empty_message() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"spawn_agent",
|
|
function_payload(json!({"message": " "})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("empty message should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel(
|
|
"Empty message can't be sent to an agent".to_string()
|
|
)
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn spawn_agent_errors_when_manager_dropped() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"spawn_agent",
|
|
function_payload(json!({"message": "hello"})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("spawn should fail without a manager");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel("collab manager unavailable".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_input_rejects_empty_message() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"send_input",
|
|
function_payload(json!({"id": ThreadId::new().to_string(), "message": ""})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("empty message should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel(
|
|
"Empty message can't be sent to an agent".to_string()
|
|
)
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_input_rejects_invalid_id() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"send_input",
|
|
function_payload(json!({"id": "not-a-uuid", "message": "hi"})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("invalid id should be rejected");
|
|
};
|
|
let FunctionCallError::RespondToModel(msg) = err else {
|
|
panic!("expected respond-to-model error");
|
|
};
|
|
assert!(msg.starts_with("invalid agent id not-a-uuid:"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_input_reports_missing_agent() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let agent_id = ThreadId::new();
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"send_input",
|
|
function_payload(json!({"id": agent_id.to_string(), "message": "hi"})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("missing agent should be reported");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found"))
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_input_interrupts_before_prompt() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let config = turn.client.config().as_ref().clone();
|
|
let thread = manager.start_thread(config).await.expect("start thread");
|
|
let agent_id = thread.thread_id;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"send_input",
|
|
function_payload(json!({
|
|
"id": agent_id.to_string(),
|
|
"message": "hi",
|
|
"interrupt": true
|
|
})),
|
|
);
|
|
CollabHandler
|
|
.handle(invocation)
|
|
.await
|
|
.expect("send_input should succeed");
|
|
|
|
let ops = manager.captured_ops();
|
|
let ops_for_agent: Vec<&Op> = ops
|
|
.iter()
|
|
.filter_map(|(id, op)| (*id == agent_id).then_some(op))
|
|
.collect();
|
|
assert_eq!(ops_for_agent.len(), 2);
|
|
assert!(matches!(ops_for_agent[0], Op::Interrupt));
|
|
assert!(matches!(ops_for_agent[1], Op::UserInput { .. }));
|
|
|
|
let _ = thread
|
|
.thread
|
|
.submit(Op::Shutdown {})
|
|
.await
|
|
.expect("shutdown should submit");
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, PartialEq, Eq)]
|
|
struct WaitResult {
|
|
status: HashMap<ThreadId, AgentStatus>,
|
|
timed_out: bool,
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_rejects_non_positive_timeout() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({
|
|
"ids": [ThreadId::new().to_string()],
|
|
"timeout_ms": 0
|
|
})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("non-positive timeout should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel("timeout_ms must be greater than zero".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_rejects_invalid_id() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({"ids": ["invalid"]})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("invalid id should be rejected");
|
|
};
|
|
let FunctionCallError::RespondToModel(msg) = err else {
|
|
panic!("expected respond-to-model error");
|
|
};
|
|
assert!(msg.starts_with("invalid agent id invalid:"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_rejects_empty_ids() {
|
|
let (session, turn) = make_session_and_context().await;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({"ids": []})),
|
|
);
|
|
let Err(err) = CollabHandler.handle(invocation).await else {
|
|
panic!("empty ids should be rejected");
|
|
};
|
|
assert_eq!(
|
|
err,
|
|
FunctionCallError::RespondToModel("ids must be non-empty".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_returns_not_found_for_missing_agents() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let id_a = ThreadId::new();
|
|
let id_b = ThreadId::new();
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({
|
|
"ids": [id_a.to_string(), id_b.to_string()],
|
|
"timeout_ms": 1000
|
|
})),
|
|
);
|
|
let output = CollabHandler
|
|
.handle(invocation)
|
|
.await
|
|
.expect("wait should succeed");
|
|
let ToolOutput::Function {
|
|
content, success, ..
|
|
} = output
|
|
else {
|
|
panic!("expected function output");
|
|
};
|
|
let result: WaitResult =
|
|
serde_json::from_str(&content).expect("wait result should be json");
|
|
assert_eq!(
|
|
result,
|
|
WaitResult {
|
|
status: HashMap::from([
|
|
(id_a, AgentStatus::NotFound),
|
|
(id_b, AgentStatus::NotFound),
|
|
]),
|
|
timed_out: false
|
|
}
|
|
);
|
|
assert_eq!(success, None);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_times_out_when_status_is_not_final() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let config = turn.client.config().as_ref().clone();
|
|
let thread = manager.start_thread(config).await.expect("start thread");
|
|
let agent_id = thread.thread_id;
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({
|
|
"ids": [agent_id.to_string()],
|
|
"timeout_ms": 10
|
|
})),
|
|
);
|
|
let output = CollabHandler
|
|
.handle(invocation)
|
|
.await
|
|
.expect("wait should succeed");
|
|
let ToolOutput::Function {
|
|
content, success, ..
|
|
} = output
|
|
else {
|
|
panic!("expected function output");
|
|
};
|
|
let result: WaitResult =
|
|
serde_json::from_str(&content).expect("wait result should be json");
|
|
assert_eq!(
|
|
result,
|
|
WaitResult {
|
|
status: HashMap::new(),
|
|
timed_out: true
|
|
}
|
|
);
|
|
assert_eq!(success, None);
|
|
|
|
let _ = thread
|
|
.thread
|
|
.submit(Op::Shutdown {})
|
|
.await
|
|
.expect("shutdown should submit");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wait_returns_final_status_without_timeout() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let config = turn.client.config().as_ref().clone();
|
|
let thread = manager.start_thread(config).await.expect("start thread");
|
|
let agent_id = thread.thread_id;
|
|
let mut status_rx = manager
|
|
.agent_control()
|
|
.subscribe_status(agent_id)
|
|
.await
|
|
.expect("subscribe should succeed");
|
|
|
|
let _ = thread
|
|
.thread
|
|
.submit(Op::Shutdown {})
|
|
.await
|
|
.expect("shutdown should submit");
|
|
let _ = timeout(Duration::from_secs(1), status_rx.changed())
|
|
.await
|
|
.expect("shutdown status should arrive");
|
|
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"wait",
|
|
function_payload(json!({
|
|
"ids": [agent_id.to_string()],
|
|
"timeout_ms": 1000
|
|
})),
|
|
);
|
|
let output = CollabHandler
|
|
.handle(invocation)
|
|
.await
|
|
.expect("wait should succeed");
|
|
let ToolOutput::Function {
|
|
content, success, ..
|
|
} = output
|
|
else {
|
|
panic!("expected function output");
|
|
};
|
|
let result: WaitResult =
|
|
serde_json::from_str(&content).expect("wait result should be json");
|
|
assert_eq!(
|
|
result,
|
|
WaitResult {
|
|
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
|
|
timed_out: false
|
|
}
|
|
);
|
|
assert_eq!(success, None);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn close_agent_submits_shutdown_and_returns_status() {
|
|
let (mut session, turn) = make_session_and_context().await;
|
|
let manager = thread_manager();
|
|
session.services.agent_control = manager.agent_control();
|
|
let config = turn.client.config().as_ref().clone();
|
|
let thread = manager.start_thread(config).await.expect("start thread");
|
|
let agent_id = thread.thread_id;
|
|
let status_before = manager.agent_control().get_status(agent_id).await;
|
|
|
|
let invocation = invocation(
|
|
Arc::new(session),
|
|
Arc::new(turn),
|
|
"close_agent",
|
|
function_payload(json!({"id": agent_id.to_string()})),
|
|
);
|
|
let output = CollabHandler
|
|
.handle(invocation)
|
|
.await
|
|
.expect("close_agent should succeed");
|
|
let ToolOutput::Function {
|
|
content, success, ..
|
|
} = output
|
|
else {
|
|
panic!("expected function output");
|
|
};
|
|
let result: close_agent::CloseAgentResult =
|
|
serde_json::from_str(&content).expect("close_agent result should be json");
|
|
assert_eq!(result.status, status_before);
|
|
assert_eq!(success, Some(true));
|
|
|
|
let ops = manager.captured_ops();
|
|
let submitted_shutdown = ops
|
|
.iter()
|
|
.any(|(id, op)| *id == agent_id && matches!(op, Op::Shutdown));
|
|
assert_eq!(submitted_shutdown, true);
|
|
|
|
let status_after = manager.agent_control().get_status(agent_id).await;
|
|
assert_eq!(status_after, AgentStatus::NotFound);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn build_agent_spawn_config_uses_turn_context_values() {
|
|
let (_session, mut turn) = make_session_and_context().await;
|
|
let base_instructions = BaseInstructions {
|
|
text: "base".to_string(),
|
|
};
|
|
turn.developer_instructions = Some("dev".to_string());
|
|
turn.compact_prompt = Some("compact".to_string());
|
|
turn.user_instructions = Some("user".to_string());
|
|
turn.shell_environment_policy = ShellEnvironmentPolicy {
|
|
use_profile: true,
|
|
..ShellEnvironmentPolicy::default()
|
|
};
|
|
let temp_dir = tempfile::tempdir().expect("temp dir");
|
|
turn.cwd = temp_dir.path().to_path_buf();
|
|
turn.codex_linux_sandbox_exe = Some(PathBuf::from("/bin/echo"));
|
|
turn.approval_policy = AskForApproval::Never;
|
|
turn.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
|
|
|
let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config");
|
|
let mut expected = (*turn.client.config()).clone();
|
|
expected.base_instructions = Some(base_instructions.text);
|
|
expected.model = Some(turn.client.get_model());
|
|
expected.model_provider = turn.client.get_provider();
|
|
expected.model_reasoning_effort = turn.client.get_reasoning_effort();
|
|
expected.model_reasoning_summary = turn.client.get_reasoning_summary();
|
|
expected.developer_instructions = turn.developer_instructions.clone();
|
|
expected.compact_prompt = turn.compact_prompt.clone();
|
|
expected.user_instructions = turn.user_instructions.clone();
|
|
expected.shell_environment_policy = turn.shell_environment_policy.clone();
|
|
expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
|
|
expected.cwd = turn.cwd.clone();
|
|
expected
|
|
.approval_policy
|
|
.set(turn.approval_policy)
|
|
.expect("approval policy set");
|
|
expected
|
|
.sandbox_policy
|
|
.set(turn.sandbox_policy)
|
|
.expect("sandbox policy set");
|
|
assert_eq!(config, expected);
|
|
}
|
|
}
|