mirror of
https://github.com/openai/codex.git
synced 2026-03-13 18:23:49 +00:00
Compare commits
1 Commits
pr13432
...
dev/cc/mul
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
584baeb550 |
@@ -11,6 +11,7 @@ use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::CommandAction;
|
||||
use codex_app_server_protocol::CommandExecutionApprovalDecision;
|
||||
@@ -34,11 +35,9 @@ use codex_core::features::Feature;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[cfg(windows)]
|
||||
@@ -63,14 +62,19 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
|
||||
};
|
||||
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
|
||||
|
||||
// Keep the exec command in flight until we interrupt it. A fast command
|
||||
// Keep the shell command in flight until we interrupt it. A fast command
|
||||
// like `echo hi` can finish before the interrupt arrives on faster runners,
|
||||
// which turns this into a test for post-command follow-up behavior instead
|
||||
// of interrupting an active zsh-fork command.
|
||||
let release_marker_escaped = release_marker.to_string_lossy().replace('\'', r#"'\''"#);
|
||||
let wait_for_interrupt =
|
||||
format!("while [ ! -f '{release_marker_escaped}' ]; do sleep 0.01; done");
|
||||
let response = create_zsh_fork_exec_command_sse_response(&wait_for_interrupt, "call-zsh-fork")?;
|
||||
let response = create_shell_command_sse_response(
|
||||
vec!["/bin/sh".to_string(), "-c".to_string(), wait_for_interrupt],
|
||||
None,
|
||||
Some(5000),
|
||||
"call-zsh-fork",
|
||||
)?;
|
||||
let no_op_response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_completed("resp-2"),
|
||||
@@ -87,7 +91,7 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
|
||||
"never",
|
||||
&BTreeMap::from([
|
||||
(Feature::ShellZshFork, true),
|
||||
(Feature::UnifiedExec, true),
|
||||
(Feature::UnifiedExec, false),
|
||||
(Feature::ShellSnapshot, false),
|
||||
]),
|
||||
&zsh_path,
|
||||
@@ -159,7 +163,7 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
|
||||
assert_eq!(id, "call-zsh-fork");
|
||||
assert_eq!(status, CommandExecutionStatus::InProgress);
|
||||
assert!(command.starts_with(&zsh_path.display().to_string()));
|
||||
assert!(command.contains(" -lc "));
|
||||
assert!(command.contains("/bin/sh -c"));
|
||||
assert!(command.contains("sleep 0.01"));
|
||||
assert!(command.contains(&release_marker.display().to_string()));
|
||||
assert_eq!(cwd, workspace);
|
||||
@@ -187,8 +191,14 @@ async fn turn_start_shell_zsh_fork_exec_approval_decline_v2() -> Result<()> {
|
||||
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
|
||||
|
||||
let responses = vec![
|
||||
create_zsh_fork_exec_command_sse_response(
|
||||
"python3 -c 'print(42)'",
|
||||
create_shell_command_sse_response(
|
||||
vec![
|
||||
"python3".to_string(),
|
||||
"-c".to_string(),
|
||||
"print(42)".to_string(),
|
||||
],
|
||||
None,
|
||||
Some(5000),
|
||||
"call-zsh-fork-decline",
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done")?,
|
||||
@@ -200,7 +210,7 @@ async fn turn_start_shell_zsh_fork_exec_approval_decline_v2() -> Result<()> {
|
||||
"untrusted",
|
||||
&BTreeMap::from([
|
||||
(Feature::ShellZshFork, true),
|
||||
(Feature::UnifiedExec, true),
|
||||
(Feature::UnifiedExec, false),
|
||||
(Feature::ShellSnapshot, false),
|
||||
]),
|
||||
&zsh_path,
|
||||
@@ -316,8 +326,14 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
|
||||
};
|
||||
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
|
||||
|
||||
let responses = vec![create_zsh_fork_exec_command_sse_response(
|
||||
"python3 -c 'print(42)'",
|
||||
let responses = vec![create_shell_command_sse_response(
|
||||
vec![
|
||||
"python3".to_string(),
|
||||
"-c".to_string(),
|
||||
"print(42)".to_string(),
|
||||
],
|
||||
None,
|
||||
Some(5000),
|
||||
"call-zsh-fork-cancel",
|
||||
)?];
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
@@ -327,7 +343,7 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
|
||||
"untrusted",
|
||||
&BTreeMap::from([
|
||||
(Feature::ShellZshFork, true),
|
||||
(Feature::UnifiedExec, true),
|
||||
(Feature::UnifiedExec, false),
|
||||
(Feature::ShellSnapshot, false),
|
||||
]),
|
||||
&zsh_path,
|
||||
@@ -425,204 +441,6 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_start_shell_zsh_fork_interrupt_kills_approved_subcommand_v2() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let tmp = TempDir::new()?;
|
||||
let codex_home = tmp.path().join("codex_home");
|
||||
std::fs::create_dir(&codex_home)?;
|
||||
let workspace = tmp.path().join("workspace");
|
||||
std::fs::create_dir(&workspace)?;
|
||||
let launch_marker = workspace.join("approved-subcommand.started");
|
||||
let leaked_marker = workspace.join("approved-subcommand.leaked");
|
||||
let launch_marker_display = launch_marker.display().to_string();
|
||||
assert!(
|
||||
!launch_marker_display.contains('\''),
|
||||
"test workspace path should not contain single quotes: {launch_marker_display}"
|
||||
);
|
||||
let leaked_marker_display = leaked_marker.display().to_string();
|
||||
assert!(
|
||||
!leaked_marker_display.contains('\''),
|
||||
"test workspace path should not contain single quotes: {leaked_marker_display}"
|
||||
);
|
||||
|
||||
let Some(zsh_path) = find_test_zsh_path()? else {
|
||||
eprintln!("skipping zsh fork interrupt cleanup test: no zsh executable found");
|
||||
return Ok(());
|
||||
};
|
||||
if !supports_exec_wrapper_intercept(&zsh_path) {
|
||||
eprintln!(
|
||||
"skipping zsh fork interrupt cleanup test: zsh does not support EXEC_WRAPPER intercepts ({})",
|
||||
zsh_path.display()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
let zsh_path_display = zsh_path.display().to_string();
|
||||
eprintln!("using zsh path for zsh-fork test: {zsh_path_display}");
|
||||
|
||||
let shell_command = format!(
|
||||
"/bin/sh -c 'echo started > \"{launch_marker_display}\" && /bin/sleep 0.5 && echo leaked > \"{leaked_marker_display}\" && exec /bin/sleep 100'"
|
||||
);
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"cmd": shell_command,
|
||||
"yield_time_ms": 30_000,
|
||||
}))?;
|
||||
let response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(
|
||||
"call-zsh-fork-interrupt-cleanup",
|
||||
"exec_command",
|
||||
&tool_call_arguments,
|
||||
),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
let no_op_response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_completed("resp-2"),
|
||||
]);
|
||||
let server =
|
||||
create_mock_responses_server_sequence_unchecked(vec![response, no_op_response]).await;
|
||||
create_config_toml(
|
||||
&codex_home,
|
||||
&server.uri(),
|
||||
"untrusted",
|
||||
&BTreeMap::from([
|
||||
(Feature::ShellZshFork, true),
|
||||
(Feature::UnifiedExec, true),
|
||||
(Feature::ShellSnapshot, false),
|
||||
]),
|
||||
&zsh_path,
|
||||
)?;
|
||||
|
||||
let mut mcp = create_zsh_test_mcp_process(&codex_home, &workspace).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let start_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
cwd: Some(workspace.to_string_lossy().into_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let start_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
|
||||
|
||||
let turn_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: "run the long-lived command".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
cwd: Some(workspace.clone()),
|
||||
approval_policy: Some(codex_app_server_protocol::AskForApproval::UnlessTrusted),
|
||||
sandbox_policy: Some(codex_app_server_protocol::SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![workspace.clone().try_into()?],
|
||||
read_only_access: codex_app_server_protocol::ReadOnlyAccess::FullAccess,
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}),
|
||||
model: Some("mock-model".to_string()),
|
||||
effort: Some(codex_protocol::openai_models::ReasoningEffort::Medium),
|
||||
summary: Some(codex_protocol::config_types::ReasoningSummary::Auto),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let turn_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_id)),
|
||||
)
|
||||
.await??;
|
||||
let TurnStartResponse { turn } = to_response::<TurnStartResponse>(turn_resp)?;
|
||||
|
||||
let mut saw_target_approval = false;
|
||||
while !saw_target_approval {
|
||||
let server_req = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
let ServerRequest::CommandExecutionRequestApproval { request_id, params } = server_req
|
||||
else {
|
||||
panic!("expected CommandExecutionRequestApproval request");
|
||||
};
|
||||
let approval_command = params.command.clone().unwrap_or_default();
|
||||
saw_target_approval = approval_command.contains("/bin/sh")
|
||||
&& approval_command.contains(&launch_marker_display)
|
||||
&& !approval_command.contains(&zsh_path_display);
|
||||
mcp.send_response(
|
||||
request_id,
|
||||
serde_json::to_value(CommandExecutionRequestApprovalResponse {
|
||||
decision: CommandExecutionApprovalDecision::Accept,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let started_command = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notif = mcp
|
||||
.read_stream_until_notification_message("item/started")
|
||||
.await?;
|
||||
let started: ItemStartedNotification =
|
||||
serde_json::from_value(notif.params.clone().expect("item/started params"))?;
|
||||
if let ThreadItem::CommandExecution { .. } = started.item {
|
||||
return Ok::<ThreadItem, anyhow::Error>(started.item);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let ThreadItem::CommandExecution {
|
||||
id,
|
||||
process_id,
|
||||
status,
|
||||
command,
|
||||
cwd,
|
||||
..
|
||||
} = started_command
|
||||
else {
|
||||
unreachable!("loop ensures we break on command execution items");
|
||||
};
|
||||
assert_eq!(id, "call-zsh-fork-interrupt-cleanup");
|
||||
assert_eq!(status, CommandExecutionStatus::InProgress);
|
||||
assert!(command.starts_with(&zsh_path.display().to_string()));
|
||||
assert!(command.contains(" -lc "));
|
||||
assert!(command.contains(&launch_marker_display));
|
||||
assert_eq!(cwd, workspace);
|
||||
assert!(process_id.is_some(), "process id should be present");
|
||||
|
||||
timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
if launch_marker.exists() {
|
||||
return Ok::<(), anyhow::Error>(());
|
||||
}
|
||||
sleep(std::time::Duration::from_millis(20)).await;
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
|
||||
mcp.interrupt_turn_and_wait_for_aborted(
|
||||
thread.id.clone(),
|
||||
turn.id.clone(),
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
)
|
||||
.await?;
|
||||
|
||||
sleep(std::time::Duration::from_millis(750)).await;
|
||||
assert!(
|
||||
!leaked_marker.exists(),
|
||||
"expected interrupt to stop approved subcommand before it wrote {leaked_marker_display}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -654,15 +472,16 @@ async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2()
|
||||
first_file.display(),
|
||||
second_file.display()
|
||||
);
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"cmd": shell_command,
|
||||
"yield_time_ms": 5000,
|
||||
let tool_call_arguments = serde_json::to_string(&serde_json::json!({
|
||||
"command": shell_command,
|
||||
"workdir": serde_json::Value::Null,
|
||||
"timeout_ms": 5000
|
||||
}))?;
|
||||
let response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(
|
||||
"call-zsh-fork-subcommand-decline",
|
||||
"exec_command",
|
||||
"shell_command",
|
||||
&tool_call_arguments,
|
||||
),
|
||||
responses::ev_completed("resp-1"),
|
||||
@@ -683,7 +502,7 @@ async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2()
|
||||
"untrusted",
|
||||
&BTreeMap::from([
|
||||
(Feature::ShellZshFork, true),
|
||||
(Feature::UnifiedExec, true),
|
||||
(Feature::UnifiedExec, false),
|
||||
(Feature::ShellSnapshot, false),
|
||||
]),
|
||||
&zsh_path,
|
||||
@@ -925,21 +744,6 @@ async fn create_zsh_test_mcp_process(codex_home: &Path, zdotdir: &Path) -> Resul
|
||||
McpProcess::new_with_env(codex_home, &[("ZDOTDIR", Some(zdotdir.as_str()))]).await
|
||||
}
|
||||
|
||||
fn create_zsh_fork_exec_command_sse_response(
|
||||
command: &str,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"cmd": command,
|
||||
"yield_time_ms": 5000,
|
||||
}))?;
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, "exec_command", &tool_call_arguments),
|
||||
responses::ev_completed("resp-1"),
|
||||
]))
|
||||
}
|
||||
|
||||
fn create_config_toml(
|
||||
codex_home: &Path,
|
||||
server_uri: &str,
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationFunctionCallOutputItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptDelta;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionFunctionTool;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
@@ -25,7 +21,6 @@ use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@@ -46,23 +41,6 @@ use tracing::trace;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const REALTIME_AUDIO_VOICE: &str = "fathom";
|
||||
const REALTIME_V1_SESSION_TYPE: &str = "quicksilver";
|
||||
const REALTIME_V2_SESSION_TYPE: &str = "realtime";
|
||||
const REALTIME_V2_CODEX_TOOL_NAME: &str = "codex";
|
||||
const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate work to Codex and return the result.";
|
||||
|
||||
fn normalized_session_mode(
|
||||
event_parser: RealtimeEventParser,
|
||||
session_mode: RealtimeSessionMode,
|
||||
) -> RealtimeSessionMode {
|
||||
match event_parser {
|
||||
RealtimeEventParser::V1 => RealtimeSessionMode::Conversational,
|
||||
RealtimeEventParser::RealtimeV2 => session_mode,
|
||||
}
|
||||
}
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
@@ -219,7 +197,6 @@ pub struct RealtimeWebsocketConnection {
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
event_parser: RealtimeEventParser,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -281,7 +258,6 @@ impl RealtimeWebsocketConnection {
|
||||
writer: RealtimeWebsocketWriter {
|
||||
stream: Arc::clone(&stream),
|
||||
is_closed: Arc::clone(&is_closed),
|
||||
event_parser,
|
||||
},
|
||||
events: RealtimeWebsocketEvents {
|
||||
rx_message: Arc::new(Mutex::new(rx_message)),
|
||||
@@ -300,19 +276,15 @@ impl RealtimeWebsocketWriter {
|
||||
}
|
||||
|
||||
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
|
||||
let content_kind = match self.event_parser {
|
||||
RealtimeEventParser::V1 => "text",
|
||||
RealtimeEventParser::RealtimeV2 => "input_text",
|
||||
};
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItemPayload::Message(ConversationMessageItem {
|
||||
item: ConversationItem {
|
||||
kind: "message".to_string(),
|
||||
role: "user".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: content_kind.to_string(),
|
||||
kind: "text".to_string(),
|
||||
text,
|
||||
}],
|
||||
}),
|
||||
},
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -322,80 +294,29 @@ impl RealtimeWebsocketWriter {
|
||||
handoff_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
let message = match self.event_parser {
|
||||
RealtimeEventParser::V1 => RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
},
|
||||
RealtimeEventParser::RealtimeV2 => RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItemPayload::FunctionCallOutput(
|
||||
ConversationFunctionCallOutputItem {
|
||||
kind: "function_call_output".to_string(),
|
||||
call_id: handoff_id,
|
||||
output: output_text,
|
||||
},
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_json(message).await
|
||||
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_session_update(
|
||||
&self,
|
||||
instructions: String,
|
||||
session_mode: RealtimeSessionMode,
|
||||
) -> Result<(), ApiError> {
|
||||
let session_mode = normalized_session_mode(self.event_parser, session_mode);
|
||||
let (session_kind, session_instructions, output_audio) = match session_mode {
|
||||
RealtimeSessionMode::Conversational => {
|
||||
let kind = match self.event_parser {
|
||||
RealtimeEventParser::V1 => REALTIME_V1_SESSION_TYPE.to_string(),
|
||||
RealtimeEventParser::RealtimeV2 => REALTIME_V2_SESSION_TYPE.to_string(),
|
||||
};
|
||||
(
|
||||
kind,
|
||||
Some(instructions),
|
||||
Some(SessionAudioOutput {
|
||||
voice: REALTIME_AUDIO_VOICE.to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
RealtimeSessionMode::Transcription => ("transcription".to_string(), None, None),
|
||||
};
|
||||
let tools = match self.event_parser {
|
||||
RealtimeEventParser::RealtimeV2 => Some(vec![SessionFunctionTool {
|
||||
kind: "function".to_string(),
|
||||
name: REALTIME_V2_CODEX_TOOL_NAME.to_string(),
|
||||
description: REALTIME_V2_CODEX_TOOL_DESCRIPTION.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Prompt text for the delegated Codex task."
|
||||
}
|
||||
},
|
||||
"required": ["prompt"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
}]),
|
||||
RealtimeEventParser::V1 => None,
|
||||
};
|
||||
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::SessionUpdate {
|
||||
session: SessionUpdateSession {
|
||||
kind: session_kind,
|
||||
instructions: session_instructions,
|
||||
kind: "quicksilver".to_string(),
|
||||
instructions,
|
||||
audio: SessionAudio {
|
||||
input: SessionAudioInput {
|
||||
format: SessionAudioFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: REALTIME_AUDIO_SAMPLE_RATE,
|
||||
rate: 24_000,
|
||||
},
|
||||
},
|
||||
output: output_audio,
|
||||
output: SessionAudioOutput {
|
||||
voice: "fathom".to_string(),
|
||||
},
|
||||
},
|
||||
tools,
|
||||
},
|
||||
})
|
||||
.await
|
||||
@@ -544,8 +465,6 @@ impl RealtimeWebsocketClient {
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
config.model.as_deref(),
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
)?;
|
||||
|
||||
let mut request = ws_url
|
||||
@@ -587,7 +506,7 @@ impl RealtimeWebsocketClient {
|
||||
);
|
||||
connection
|
||||
.writer
|
||||
.send_session_update(config.instructions, config.session_mode)
|
||||
.send_session_update(config.instructions)
|
||||
.await?;
|
||||
Ok(connection)
|
||||
}
|
||||
@@ -632,8 +551,6 @@ fn websocket_url_from_api_url(
|
||||
api_url: &str,
|
||||
query_params: Option<&HashMap<String, String>>,
|
||||
model: Option<&str>,
|
||||
event_parser: RealtimeEventParser,
|
||||
_session_mode: RealtimeSessionMode,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = Url::parse(api_url)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
|
||||
@@ -653,20 +570,9 @@ fn websocket_url_from_api_url(
|
||||
}
|
||||
}
|
||||
|
||||
let intent = match event_parser {
|
||||
RealtimeEventParser::V1 => Some("quicksilver"),
|
||||
RealtimeEventParser::RealtimeV2 => None,
|
||||
};
|
||||
let has_extra_query_params = query_params.is_some_and(|query_params| {
|
||||
query_params
|
||||
.iter()
|
||||
.any(|(key, _)| key != "intent" && !(key == "model" && model.is_some()))
|
||||
});
|
||||
if intent.is_some() || model.is_some() || has_extra_query_params {
|
||||
{
|
||||
let mut query = url.query_pairs_mut();
|
||||
if let Some(intent) = intent {
|
||||
query.append_pair("intent", intent);
|
||||
}
|
||||
query.append_pair("intent", "quicksilver");
|
||||
if let Some(model) = model {
|
||||
query.append_pair("model", model);
|
||||
}
|
||||
@@ -947,14 +853,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_http_base_defaults_to_ws_path() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"http://127.0.0.1:8011",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
let url =
|
||||
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
|
||||
@@ -963,14 +863,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_ws_base_defaults_to_ws_path() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"wss://example.com",
|
||||
None,
|
||||
Some("realtime-test-model"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
let url =
|
||||
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
|
||||
@@ -979,14 +874,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_v1_base_appends_realtime_path() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://api.openai.com/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
@@ -995,14 +884,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/openai/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
let url =
|
||||
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
@@ -1018,8 +902,6 @@ mod tests {
|
||||
("intent".to_string(), "ignored".to_string()),
|
||||
])),
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
@@ -1028,54 +910,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_v1_ignores_transcription_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Transcription,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?intent=quicksilver"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_omits_intent_for_realtime_v2_conversational_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/v1/realtime?foo=bar",
|
||||
Some(&HashMap::from([
|
||||
("trace".to_string(), "1".to_string()),
|
||||
("intent".to_string(), "ignored".to_string()),
|
||||
])),
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_omits_intent_for_realtime_v2_transcription_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Transcription,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -1241,7 +1075,6 @@ mod tests {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -1362,352 +1195,6 @@ mod tests {
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_v2_session_update_includes_codex_tool_and_handoff_output_item() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("realtime".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["type"],
|
||||
Value::String("function".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["parameters"]["required"],
|
||||
json!(["prompt"])
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_v2", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "conversation.item.create");
|
||||
assert_eq!(
|
||||
second_json["item"]["type"],
|
||||
Value::String("message".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
second_json["item"]["content"][0]["type"],
|
||||
Value::String("input_text".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
second_json["item"]["content"][0]["text"],
|
||||
Value::String("delegate this".to_string())
|
||||
);
|
||||
|
||||
let third = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("third msg")
|
||||
.expect("third msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let third_json: Value = serde_json::from_str(&third).expect("json");
|
||||
assert_eq!(third_json["type"], "conversation.item.create");
|
||||
assert_eq!(
|
||||
third_json["item"]["type"],
|
||||
Value::String("function_call_output".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
third_json["item"]["call_id"],
|
||||
Value::String("call_1".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
third_json["item"]["output"],
|
||||
Value::String("delegated result".to_string())
|
||||
);
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_v2".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection
|
||||
.send_conversation_item_create("delegate this".to_string())
|
||||
.await
|
||||
.expect("send text item");
|
||||
connection
|
||||
.send_conversation_handoff_append("call_1".to_string(), "delegated result".to_string())
|
||||
.await
|
||||
.expect("send handoff output");
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transcription_mode_session_update_omits_output_audio_and_instructions() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("transcription".to_string())
|
||||
);
|
||||
assert!(first_json["session"].get("instructions").is_none());
|
||||
assert!(first_json["session"]["audio"].get("output").is_none());
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_transcription"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "input_audio_buffer.append");
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Transcription,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_transcription".to_string(),
|
||||
instructions: None,
|
||||
}
|
||||
);
|
||||
|
||||
connection
|
||||
.send_audio_frame(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 24_000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
})
|
||||
.await
|
||||
.expect("send audio");
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn v1_transcription_mode_is_treated_as_conversational() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("quicksilver".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
Value::String("backend prompt".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["voice"],
|
||||
Value::String("fathom".to_string())
|
||||
);
|
||||
assert!(first_json["session"].get("tools").is_none());
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_v1_mode"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Transcription,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_v1_mode".to_string(),
|
||||
instructions: None,
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_does_not_block_while_next_event_waits_for_inbound_data() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -1771,7 +1258,6 @@ mod tests {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
pub mod methods;
|
||||
pub mod protocol;
|
||||
mod protocol_common;
|
||||
mod protocol_v1;
|
||||
mod protocol_v2;
|
||||
|
||||
pub use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
@@ -12,4 +10,3 @@ pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use protocol::RealtimeEventParser;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
pub use protocol::RealtimeSessionMode;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_v1::parse_realtime_event_v1;
|
||||
use crate::endpoint::realtime_websocket::protocol_v2::parse_realtime_event_v2;
|
||||
pub use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
pub use codex_protocol::protocol::RealtimeEvent;
|
||||
@@ -7,6 +6,7 @@ pub use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
pub use codex_protocol::protocol::RealtimeTranscriptEntry;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeEventParser {
|
||||
@@ -14,19 +14,12 @@ pub enum RealtimeEventParser {
|
||||
RealtimeV2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeSessionMode {
|
||||
Conversational,
|
||||
Transcription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub instructions: String,
|
||||
pub model: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub event_parser: RealtimeEventParser,
|
||||
pub session_mode: RealtimeSessionMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -42,25 +35,21 @@ pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "session.update")]
|
||||
SessionUpdate { session: SessionUpdateSession },
|
||||
#[serde(rename = "conversation.item.create")]
|
||||
ConversationItemCreate { item: ConversationItemPayload },
|
||||
ConversationItemCreate { item: ConversationItem },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSession {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) instructions: Option<String>,
|
||||
pub(super) instructions: String,
|
||||
pub(super) audio: SessionAudio,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) tools: Option<Vec<SessionFunctionTool>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudio {
|
||||
pub(super) input: SessionAudioInput,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) output: Option<SessionAudioOutput>,
|
||||
pub(super) output: SessionAudioOutput,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -81,28 +70,13 @@ pub(super) struct SessionAudioOutput {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationMessageItem {
|
||||
pub(super) struct ConversationItem {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) role: String,
|
||||
pub(super) content: Vec<ConversationItemContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub(super) enum ConversationItemPayload {
|
||||
Message(ConversationMessageItem),
|
||||
FunctionCallOutput(ConversationFunctionCallOutputItem),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationFunctionCallOutputItem {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) call_id: String,
|
||||
pub(super) output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItemContent {
|
||||
#[serde(rename = "type")]
|
||||
@@ -110,15 +84,6 @@ pub(super) struct ConversationItemContent {
|
||||
pub(super) text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionFunctionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) name: String,
|
||||
pub(super) description: String,
|
||||
pub(super) parameters: Value,
|
||||
}
|
||||
|
||||
pub(super) fn parse_realtime_event(
|
||||
payload: &str,
|
||||
event_parser: RealtimeEventParser,
|
||||
@@ -128,3 +93,125 @@ pub(super) fn parse_realtime_event(
|
||||
RealtimeEventParser::RealtimeV2 => parse_realtime_event_v2(payload),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type,
|
||||
None => {
|
||||
debug!("received realtime event without type field: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match message_type {
|
||||
"session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
"conversation.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())?;
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u16::try_from(v).ok())?;
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok()),
|
||||
}))
|
||||
}
|
||||
"conversation.input_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.output_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parsed
|
||||
.get("item")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|item| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"conversation.handoff.requested" => {
|
||||
let handoff_id = parsed
|
||||
.get("handoff_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let input_transcript = parsed
|
||||
.get("input_transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
_ => {
|
||||
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
pub(super) fn parse_realtime_payload(payload: &str, parser_name: &str) -> Option<(Value, String)> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
debug!("failed to parse {parser_name} event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type.to_string(),
|
||||
None => {
|
||||
debug!("received {parser_name} event without type field: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some((parsed, message_type))
|
||||
}
|
||||
|
||||
pub(super) fn parse_session_updated_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
Some(RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn parse_transcript_delta_event(
|
||||
parsed: &Value,
|
||||
field: &str,
|
||||
) -> Option<RealtimeTranscriptDelta> {
|
||||
parsed
|
||||
.get(field)
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeTranscriptDelta { delta })
|
||||
}
|
||||
|
||||
pub(super) fn parse_error_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error)
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
|
||||
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v1")?;
|
||||
match message_type.as_str() {
|
||||
"session.updated" => parse_session_updated_event(&parsed),
|
||||
"conversation.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())?;
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())?;
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
}
|
||||
"conversation.input_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
|
||||
}
|
||||
"conversation.output_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
|
||||
}
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parsed
|
||||
.get("item")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|item| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"conversation.handoff.requested" => {
|
||||
let handoff_id = parsed
|
||||
.get("handoff_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let input_transcript = parsed
|
||||
.get("input_transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
"error" => parse_error_event(&parsed),
|
||||
_ => {
|
||||
debug!("received unsupported realtime v1 event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,130 +1,157 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
use serde_json::Map as JsonMap;
|
||||
use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
const CODEX_TOOL_NAME: &str = "codex";
|
||||
const DEFAULT_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const DEFAULT_AUDIO_CHANNELS: u16 = 1;
|
||||
const TOOL_ARGUMENT_KEYS: [&str; 5] = ["input_transcript", "input", "text", "prompt", "query"];
|
||||
|
||||
pub(super) fn parse_realtime_event_v2(payload: &str) -> Option<RealtimeEvent> {
|
||||
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v2")?;
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime v2 event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match message_type.as_str() {
|
||||
"session.updated" => parse_session_updated_event(&parsed),
|
||||
"response.output_audio.delta" => parse_output_audio_delta_event(&parsed),
|
||||
"conversation.item.input_audio_transcription.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type,
|
||||
None => {
|
||||
debug!("received realtime v2 event without type field: {payload}");
|
||||
return None;
|
||||
}
|
||||
"conversation.item.input_audio_transcription.completed" => {
|
||||
parse_transcript_delta_event(&parsed, "transcript")
|
||||
.map(RealtimeEvent::InputTranscriptDelta)
|
||||
};
|
||||
|
||||
match message_type {
|
||||
"session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
"response.output_text.delta" | "response.output_audio_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
|
||||
"response.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())
|
||||
.unwrap_or(24_000);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())
|
||||
.unwrap_or(1);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
}
|
||||
"conversation.item.input_audio_transcription.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.input_audio_transcription.completed" => parsed
|
||||
.get("transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"response.output_text.delta" | "response.output_audio_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parse_conversation_item_done_event(&parsed),
|
||||
"error" => parse_error_event(&parsed),
|
||||
"conversation.item.done" => {
|
||||
let item = parsed.get("item")?.as_object()?;
|
||||
let item_type = item.get("type").and_then(Value::as_str);
|
||||
let item_name = item.get("name").and_then(Value::as_str);
|
||||
|
||||
if item_type == Some("function_call") && item_name == Some("codex") {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| item.get("id").and_then(Value::as_str))?;
|
||||
let item_id = item
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or(call_id)
|
||||
.to_string();
|
||||
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
|
||||
let mut input_transcript = String::new();
|
||||
if !arguments.is_empty() {
|
||||
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
|
||||
&& let Some(arguments_object) = arguments_json.as_object()
|
||||
{
|
||||
for key in ["input_transcript", "input", "text", "prompt", "query"] {
|
||||
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
input_transcript = trimmed.to_string();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if input_transcript.is_empty() {
|
||||
input_transcript = arguments.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
return Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: call_id.to_string(),
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}));
|
||||
}
|
||||
|
||||
item.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
_ => {
|
||||
debug!("received unsupported realtime v2 event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_output_audio_delta_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())
|
||||
.unwrap_or(DEFAULT_AUDIO_SAMPLE_RATE);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())
|
||||
.unwrap_or(DEFAULT_AUDIO_CHANNELS);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_conversation_item_done_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let item = parsed.get("item")?.as_object()?;
|
||||
if let Some(handoff) = parse_handoff_requested_event(item) {
|
||||
return Some(handoff);
|
||||
}
|
||||
|
||||
item.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
|
||||
}
|
||||
|
||||
fn parse_handoff_requested_event(item: &JsonMap<String, Value>) -> Option<RealtimeEvent> {
|
||||
let item_type = item.get("type").and_then(Value::as_str);
|
||||
let item_name = item.get("name").and_then(Value::as_str);
|
||||
if item_type != Some("function_call") || item_name != Some(CODEX_TOOL_NAME) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| item.get("id").and_then(Value::as_str))?;
|
||||
let item_id = item
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or(call_id)
|
||||
.to_string();
|
||||
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
|
||||
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: call_id.to_string(),
|
||||
item_id,
|
||||
input_transcript: extract_input_transcript(arguments),
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn extract_input_transcript(arguments: &str) -> String {
|
||||
if arguments.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
|
||||
&& let Some(arguments_object) = arguments_json.as_object()
|
||||
{
|
||||
for key in TOOL_ARGUMENT_KEYS {
|
||||
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
arguments.to_string()
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ pub use crate::endpoint::memories::MemoriesClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
|
||||
@@ -6,7 +6,6 @@ use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
@@ -143,7 +142,6 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -237,7 +235,6 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -302,7 +299,6 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -364,7 +360,6 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -429,7 +424,6 @@ async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -1342,13 +1342,6 @@
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"RealtimeWsMode": {
|
||||
"enum": [
|
||||
"conversational",
|
||||
"transcription"
|
||||
],
|
||||
"type": "string"
|
||||
},
|
||||
"ReasoningEffort": {
|
||||
"description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning",
|
||||
"enum": [
|
||||
@@ -1823,14 +1816,6 @@
|
||||
"description": "Experimental / do not use. Overrides only the realtime conversation websocket transport base URL (the `Op::RealtimeConversation` `/v1/realtime` connection) without changing normal provider HTTP requests.",
|
||||
"type": "string"
|
||||
},
|
||||
"experimental_realtime_ws_mode": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/RealtimeWsMode"
|
||||
}
|
||||
],
|
||||
"description": "Experimental / do not use. Selects the realtime websocket intent mode. `conversational` is speech-to-speech while `transcription` is transcript-only."
|
||||
},
|
||||
"experimental_realtime_ws_model": {
|
||||
"description": "Experimental / do not use. Selects the realtime websocket model/snapshot used for the `Op::RealtimeConversation` connection.",
|
||||
"type": "string"
|
||||
|
||||
@@ -172,6 +172,8 @@ use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
#[cfg(test)]
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use codex_config::CONFIG_TOML_FILE;
|
||||
|
||||
mod rollout_reconstruction;
|
||||
@@ -276,6 +278,7 @@ use crate::skills::collect_explicit_skill_mentions;
|
||||
use crate::skills::injection::ToolMentionKind;
|
||||
use crate::skills::injection::app_id_from_path;
|
||||
use crate::skills::injection::tool_kind_for_path;
|
||||
use crate::skills::model::SkillManagedNetworkOverride;
|
||||
use crate::skills::resolve_skill_dependencies_for_turn;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
@@ -1182,6 +1185,61 @@ impl Session {
|
||||
Ok((network_proxy, session_network_proxy))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_or_start_network_proxy(
|
||||
self: &Arc<Self>,
|
||||
scope: NetworkProxyScope,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
managed_network_override: Option<SkillManagedNetworkOverride>,
|
||||
) -> anyhow::Result<Option<NetworkProxy>> {
|
||||
let session = Arc::clone(self);
|
||||
let started = self
|
||||
.services
|
||||
.network_proxies
|
||||
.get_or_start(
|
||||
scope.clone(),
|
||||
move |spec, managed_enabled, audit_metadata| {
|
||||
let session = Arc::clone(&session);
|
||||
let managed_network_override = managed_network_override.clone();
|
||||
let scope = scope.clone();
|
||||
let sandbox_policy = sandbox_policy.clone();
|
||||
async move {
|
||||
let network_policy_decider = session
|
||||
.services
|
||||
.network_policy_decider_session
|
||||
.as_ref()
|
||||
.map(|network_policy_decider_session| {
|
||||
build_network_policy_decider(
|
||||
Arc::clone(&session.services.network_approval),
|
||||
Arc::clone(network_policy_decider_session),
|
||||
scope,
|
||||
)
|
||||
});
|
||||
let spec = if let Some(managed_network_override) =
|
||||
managed_network_override.as_ref()
|
||||
{
|
||||
spec.with_skill_managed_network_override(managed_network_override)
|
||||
} else {
|
||||
spec
|
||||
};
|
||||
spec.start_proxy(
|
||||
&sandbox_policy,
|
||||
network_policy_decider,
|
||||
session
|
||||
.services
|
||||
.network_blocked_request_observer
|
||||
.as_ref()
|
||||
.map(Arc::clone),
|
||||
managed_enabled,
|
||||
audit_metadata,
|
||||
)
|
||||
.await
|
||||
}
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(started.map(|started| started.proxy()))
|
||||
}
|
||||
|
||||
/// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it.
|
||||
pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
|
||||
// todo(aibrahim): store this state somewhere else so we don't need to mut config
|
||||
@@ -1651,9 +1709,10 @@ impl Session {
|
||||
build_network_policy_decider(
|
||||
Arc::clone(&network_approval),
|
||||
Arc::clone(network_policy_decider_session),
|
||||
NetworkProxyScope::SessionDefault,
|
||||
)
|
||||
});
|
||||
let (network_proxy, session_network_proxy) =
|
||||
let (default_network_proxy, session_network_proxy) =
|
||||
if let Some(spec) = config.permissions.network.as_ref() {
|
||||
let (network_proxy, session_network_proxy) = Self::start_managed_network_proxy(
|
||||
spec,
|
||||
@@ -1661,13 +1720,19 @@ impl Session {
|
||||
network_policy_decider.as_ref().map(Arc::clone),
|
||||
blocked_request_observer.as_ref().map(Arc::clone),
|
||||
managed_network_requirements_enabled,
|
||||
network_proxy_audit_metadata,
|
||||
network_proxy_audit_metadata.clone(),
|
||||
)
|
||||
.await?;
|
||||
(Some(network_proxy), Some(session_network_proxy))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let network_proxies = NetworkProxyRegistry::new(
|
||||
config.permissions.network.clone(),
|
||||
managed_network_requirements_enabled,
|
||||
network_proxy_audit_metadata.clone(),
|
||||
default_network_proxy,
|
||||
);
|
||||
|
||||
let mut hook_shell_argv = default_shell.derive_exec_args("", false);
|
||||
let hook_shell_program = hook_shell_argv.remove(0);
|
||||
@@ -1725,7 +1790,9 @@ impl Session {
|
||||
mcp_manager: Arc::clone(&mcp_manager),
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxy,
|
||||
network_proxies,
|
||||
network_policy_decider_session,
|
||||
network_blocked_request_observer: blocked_request_observer,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: state_db_ctx.clone(),
|
||||
model_client: ModelClient::new(
|
||||
@@ -1764,7 +1831,9 @@ impl Session {
|
||||
js_repl,
|
||||
next_internal_sub_id: AtomicU64::new(0),
|
||||
});
|
||||
if let Some(network_policy_decider_session) = network_policy_decider_session {
|
||||
if let Some(network_policy_decider_session) =
|
||||
sess.services.network_policy_decider_session.as_ref()
|
||||
{
|
||||
let mut guard = network_policy_decider_session.write().await;
|
||||
*guard = Arc::downgrade(&sess);
|
||||
}
|
||||
@@ -2314,8 +2383,10 @@ impl Session {
|
||||
model_info,
|
||||
&self.services.models_manager,
|
||||
self.services
|
||||
.network_proxy
|
||||
.as_ref()
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
.as_deref()
|
||||
.map(StartedNetworkProxy::proxy),
|
||||
sub_id,
|
||||
Arc::clone(&self.js_repl),
|
||||
@@ -2687,6 +2758,7 @@ impl Session {
|
||||
&self,
|
||||
amendment: &NetworkPolicyAmendment,
|
||||
network_approval_context: &NetworkApprovalContext,
|
||||
scope: &NetworkProxyScope,
|
||||
) -> anyhow::Result<()> {
|
||||
let host =
|
||||
Self::validated_network_policy_amendment_host(amendment, network_approval_context)?;
|
||||
@@ -2700,7 +2772,7 @@ impl Session {
|
||||
let execpolicy_amendment =
|
||||
execpolicy_network_rule_amendment(amendment, network_approval_context, &host);
|
||||
|
||||
if let Some(started_network_proxy) = self.services.network_proxy.as_ref() {
|
||||
if let Some(started_network_proxy) = self.services.network_proxies.get(scope).await {
|
||||
let proxy = started_network_proxy.proxy();
|
||||
match amendment.action {
|
||||
NetworkPolicyRuleAction::Allow => proxy
|
||||
|
||||
@@ -11,9 +11,11 @@ use crate::exec::ExecToolCallOutput;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::models_manager::model_info;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::shell::default_user_shell;
|
||||
use crate::tools::format_exec_output_str;
|
||||
|
||||
use codex_network_proxy::NetworkProxyAuditMetadata;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
@@ -2152,7 +2154,14 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
mcp_manager,
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxy: None,
|
||||
network_proxies: NetworkProxyRegistry::new(
|
||||
None,
|
||||
false,
|
||||
NetworkProxyAuditMetadata::default(),
|
||||
None,
|
||||
),
|
||||
network_policy_decider_session: None,
|
||||
network_blocked_request_observer: None,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: None,
|
||||
model_client: ModelClient::new(
|
||||
@@ -2794,7 +2803,14 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
mcp_manager,
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxy: None,
|
||||
network_proxies: NetworkProxyRegistry::new(
|
||||
None,
|
||||
false,
|
||||
NetworkProxyAuditMetadata::default(),
|
||||
None,
|
||||
),
|
||||
network_policy_decider_session: None,
|
||||
network_blocked_request_observer: None,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: None,
|
||||
model_client: ModelClient::new(
|
||||
|
||||
@@ -4129,7 +4129,6 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4266,7 +4265,6 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4401,7 +4399,6 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4522,7 +4519,6 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -5570,34 +5566,6 @@ experimental_realtime_ws_model = "realtime-test-model"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn experimental_realtime_ws_mode_loads_from_config_toml() -> std::io::Result<()> {
|
||||
let cfg: ConfigToml = toml::from_str(
|
||||
r#"
|
||||
experimental_realtime_ws_mode = "transcription"
|
||||
"#,
|
||||
)
|
||||
.expect("TOML deserialization should succeed");
|
||||
|
||||
assert_eq!(
|
||||
cfg.experimental_realtime_ws_mode,
|
||||
Some(RealtimeWsMode::Transcription)
|
||||
);
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.experimental_realtime_ws_mode,
|
||||
RealtimeWsMode::Transcription
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn realtime_audio_loads_from_config_toml() -> std::io::Result<()> {
|
||||
let cfg: ConfigToml = toml::from_str(
|
||||
|
||||
@@ -463,9 +463,6 @@ pub struct Config {
|
||||
/// Experimental / do not use. Selects the realtime websocket model/snapshot
|
||||
/// used for the `Op::RealtimeConversation` connection.
|
||||
pub experimental_realtime_ws_model: Option<String>,
|
||||
/// Experimental / do not use. Selects the realtime websocket intent mode.
|
||||
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
|
||||
pub experimental_realtime_ws_mode: RealtimeWsMode,
|
||||
/// Experimental / do not use. Overrides only the realtime conversation
|
||||
/// websocket transport instructions (the `Op::RealtimeConversation`
|
||||
/// `/ws` session.update instructions) without changing normal prompts.
|
||||
@@ -1241,9 +1238,6 @@ pub struct ConfigToml {
|
||||
/// Experimental / do not use. Selects the realtime websocket model/snapshot
|
||||
/// used for the `Op::RealtimeConversation` connection.
|
||||
pub experimental_realtime_ws_model: Option<String>,
|
||||
/// Experimental / do not use. Selects the realtime websocket intent mode.
|
||||
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
|
||||
pub experimental_realtime_ws_mode: Option<RealtimeWsMode>,
|
||||
/// Experimental / do not use. Overrides only the realtime conversation
|
||||
/// websocket transport instructions (the `Op::RealtimeConversation`
|
||||
/// `/ws` session.update instructions) without changing normal prompts.
|
||||
@@ -1389,14 +1383,6 @@ pub struct RealtimeAudioConfig {
|
||||
pub speaker: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RealtimeWsMode {
|
||||
#[default]
|
||||
Conversational,
|
||||
Transcription,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct RealtimeAudioToml {
|
||||
@@ -2476,7 +2462,6 @@ impl Config {
|
||||
}),
|
||||
experimental_realtime_ws_base_url: cfg.experimental_realtime_ws_base_url,
|
||||
experimental_realtime_ws_model: cfg.experimental_realtime_ws_model,
|
||||
experimental_realtime_ws_mode: cfg.experimental_realtime_ws_mode.unwrap_or_default(),
|
||||
experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt,
|
||||
experimental_realtime_ws_startup_context: cfg.experimental_realtime_ws_startup_context,
|
||||
experimental_realtime_start_instructions: cfg.experimental_realtime_start_instructions,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::config_loader::NetworkConstraints;
|
||||
use crate::skills::model::SkillManagedNetworkOverride;
|
||||
use async_trait::async_trait;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
use codex_network_proxy::ConfigReloader;
|
||||
@@ -81,6 +82,28 @@ impl NetworkProxySpec {
|
||||
self.config.network.enable_socks5
|
||||
}
|
||||
|
||||
pub(crate) fn with_skill_managed_network_override(
|
||||
&self,
|
||||
managed_network_override: &SkillManagedNetworkOverride,
|
||||
) -> Self {
|
||||
let mut spec = self.clone();
|
||||
|
||||
if let Some(allowed_domains) = managed_network_override.allowed_domains.clone() {
|
||||
spec.config.network.allowed_domains = allowed_domains.clone();
|
||||
if spec.constraints.allowed_domains.is_some() {
|
||||
spec.constraints.allowed_domains = Some(allowed_domains);
|
||||
}
|
||||
}
|
||||
if let Some(denied_domains) = managed_network_override.denied_domains.clone() {
|
||||
spec.config.network.denied_domains = denied_domains.clone();
|
||||
if spec.constraints.denied_domains.is_some() {
|
||||
spec.constraints.denied_domains = Some(denied_domains);
|
||||
}
|
||||
}
|
||||
|
||||
spec
|
||||
}
|
||||
|
||||
pub(crate) fn from_config_and_constraints(
|
||||
config: NetworkProxyConfig,
|
||||
requirements: Option<NetworkConstraints>,
|
||||
|
||||
@@ -1,6 +1,94 @@
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn skill_managed_network_override_replaces_allowed_domains_and_keeps_other_settings() {
|
||||
let mut config = NetworkProxyConfig::default();
|
||||
config.network.enabled = true;
|
||||
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
|
||||
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
|
||||
config.network.enable_socks5 = true;
|
||||
config.network.enable_socks5_udp = true;
|
||||
config.network.allowed_domains = vec!["default.example.com".to_string()];
|
||||
config.network.denied_domains = vec!["blocked.example.com".to_string()];
|
||||
config.network.allow_upstream_proxy = true;
|
||||
config.network.dangerously_allow_all_unix_sockets = false;
|
||||
config.network.dangerously_allow_non_loopback_proxy = false;
|
||||
config.network.mode = codex_network_proxy::NetworkMode::Full;
|
||||
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
|
||||
config.network.allow_local_binding = true;
|
||||
config.network.mitm = false;
|
||||
let spec = NetworkProxySpec {
|
||||
config,
|
||||
constraints: NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["default.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.example.com".to_string()]),
|
||||
allowlist_expansion_enabled: Some(true),
|
||||
denylist_expansion_enabled: Some(false),
|
||||
allow_upstream_proxy: Some(true),
|
||||
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
|
||||
allow_local_binding: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
},
|
||||
hard_deny_allowlist_misses: true,
|
||||
};
|
||||
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: Some(vec!["skill.example.com".to_string()]),
|
||||
denied_domains: None,
|
||||
};
|
||||
|
||||
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
|
||||
|
||||
let mut expected = spec.clone();
|
||||
expected.config.network.allowed_domains = vec!["skill.example.com".to_string()];
|
||||
expected.constraints.allowed_domains = Some(vec!["skill.example.com".to_string()]);
|
||||
assert_eq!(overridden, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_managed_network_override_replaces_denied_domains_and_keeps_default_allowed_domains() {
|
||||
let mut config = NetworkProxyConfig::default();
|
||||
config.network.enabled = true;
|
||||
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
|
||||
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
|
||||
config.network.enable_socks5 = true;
|
||||
config.network.enable_socks5_udp = true;
|
||||
config.network.allowed_domains = vec!["default.example.com".to_string()];
|
||||
config.network.denied_domains = vec!["blocked.example.com".to_string()];
|
||||
config.network.allow_upstream_proxy = true;
|
||||
config.network.dangerously_allow_all_unix_sockets = false;
|
||||
config.network.dangerously_allow_non_loopback_proxy = false;
|
||||
config.network.mode = codex_network_proxy::NetworkMode::Full;
|
||||
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
|
||||
config.network.allow_local_binding = true;
|
||||
config.network.mitm = false;
|
||||
let spec = NetworkProxySpec {
|
||||
config,
|
||||
constraints: NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["default.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.example.com".to_string()]),
|
||||
allowlist_expansion_enabled: Some(true),
|
||||
denylist_expansion_enabled: Some(false),
|
||||
allow_upstream_proxy: Some(true),
|
||||
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
|
||||
allow_local_binding: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
},
|
||||
hard_deny_allowlist_misses: false,
|
||||
};
|
||||
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: None,
|
||||
denied_domains: Some(vec!["skill-blocked.example.com".to_string()]),
|
||||
};
|
||||
|
||||
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
|
||||
|
||||
let mut expected = spec.clone();
|
||||
expected.config.network.denied_domains = vec!["skill-blocked.example.com".to_string()];
|
||||
expected.constraints.denied_domains = Some(vec!["skill-blocked.example.com".to_string()]);
|
||||
assert_eq!(overridden, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_state_with_audit_metadata_threads_metadata_to_state() {
|
||||
let spec = NetworkProxySpec {
|
||||
|
||||
@@ -39,6 +39,7 @@ use crate::config::Constrained;
|
||||
use crate::config::NetworkProxySpec;
|
||||
use crate::event_mapping::is_contextual_user_message_content;
|
||||
use crate::features::Feature;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::truncate::approx_bytes_for_tokens;
|
||||
@@ -550,7 +551,12 @@ async fn run_guardian_subagent(
|
||||
schema: Value,
|
||||
cancel_token: CancellationToken,
|
||||
) -> anyhow::Result<GuardianAssessment> {
|
||||
let live_network_config = match session.services.network_proxy.as_ref() {
|
||||
let live_network_config = match session
|
||||
.services
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
{
|
||||
Some(network_proxy) => Some(network_proxy.proxy().current_cfg().await?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -51,6 +51,7 @@ mod mcp_tool_approval_templates;
|
||||
pub mod models_manager;
|
||||
mod network_policy_decision;
|
||||
pub mod network_proxy_loader;
|
||||
mod network_proxy_registry;
|
||||
mod original_image_detail;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
|
||||
|
||||
@@ -103,7 +103,6 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
|
||||
¶ms,
|
||||
invocation.session.user_shell(),
|
||||
invocation.turn.tools_config.allow_login_shell,
|
||||
invocation.turn.tools_config.unified_exec_backend,
|
||||
)
|
||||
.ok()?;
|
||||
Some((command, invocation.turn.resolve_path(params.workdir)))
|
||||
|
||||
77
codex-rs/core/src/network_proxy_registry.rs
Normal file
77
codex-rs/core/src/network_proxy_registry.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use crate::config::NetworkProxySpec;
|
||||
use crate::config::StartedNetworkProxy;
|
||||
use anyhow::Result;
|
||||
use codex_network_proxy::NetworkProxyAuditMetadata;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) enum NetworkProxyScope {
|
||||
SessionDefault,
|
||||
Skill { path_to_skills_md: PathBuf },
|
||||
}
|
||||
|
||||
pub(crate) struct NetworkProxyRegistry {
|
||||
spec: Option<NetworkProxySpec>,
|
||||
managed_network_requirements_enabled: bool,
|
||||
audit_metadata: NetworkProxyAuditMetadata,
|
||||
proxies: Mutex<HashMap<NetworkProxyScope, Arc<StartedNetworkProxy>>>,
|
||||
}
|
||||
|
||||
impl NetworkProxyRegistry {
|
||||
pub(crate) fn new(
|
||||
spec: Option<NetworkProxySpec>,
|
||||
managed_network_requirements_enabled: bool,
|
||||
audit_metadata: NetworkProxyAuditMetadata,
|
||||
default_proxy: Option<StartedNetworkProxy>,
|
||||
) -> Self {
|
||||
let mut proxies = HashMap::new();
|
||||
if let Some(default_proxy) = default_proxy {
|
||||
proxies.insert(NetworkProxyScope::SessionDefault, Arc::new(default_proxy));
|
||||
}
|
||||
|
||||
Self {
|
||||
spec,
|
||||
managed_network_requirements_enabled,
|
||||
audit_metadata,
|
||||
proxies: Mutex::new(proxies),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self, scope: &NetworkProxyScope) -> Option<Arc<StartedNetworkProxy>> {
|
||||
self.proxies.lock().await.get(scope).cloned()
|
||||
}
|
||||
|
||||
pub(crate) async fn get_or_start<F, Fut>(
|
||||
&self,
|
||||
scope: NetworkProxyScope,
|
||||
start: F,
|
||||
) -> Result<Option<Arc<StartedNetworkProxy>>>
|
||||
where
|
||||
F: FnOnce(NetworkProxySpec, bool, NetworkProxyAuditMetadata) -> Fut,
|
||||
Fut: Future<Output = std::io::Result<StartedNetworkProxy>>,
|
||||
{
|
||||
let mut proxies = self.proxies.lock().await;
|
||||
if let Some(existing) = proxies.get(&scope).cloned() {
|
||||
return Ok(Some(existing));
|
||||
}
|
||||
|
||||
let Some(spec) = self.spec.clone() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let started = Arc::new(
|
||||
start(
|
||||
spec,
|
||||
self.managed_network_requirements_enabled,
|
||||
self.audit_metadata.clone(),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
proxies.insert(scope, Arc::clone(&started));
|
||||
Ok(Some(started))
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
|
||||
@@ -117,7 +116,10 @@ impl RealtimeConversationManager {
|
||||
&self,
|
||||
api_provider: ApiProvider,
|
||||
extra_headers: Option<HeaderMap>,
|
||||
session_config: RealtimeSessionConfig,
|
||||
prompt: String,
|
||||
model: Option<String>,
|
||||
session_id: Option<String>,
|
||||
event_parser: RealtimeEventParser,
|
||||
) -> CodexResult<(Receiver<RealtimeEvent>, Arc<AtomicBool>)> {
|
||||
let previous_state = {
|
||||
let mut guard = self.state.lock().await;
|
||||
@@ -129,6 +131,12 @@ impl RealtimeConversationManager {
|
||||
let _ = state.task.await;
|
||||
}
|
||||
|
||||
let session_config = RealtimeSessionConfig {
|
||||
instructions: prompt,
|
||||
model,
|
||||
session_id,
|
||||
event_parser,
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
@@ -299,26 +307,23 @@ pub(crate) async fn handle_start(
|
||||
} else {
|
||||
RealtimeEventParser::V1
|
||||
};
|
||||
let session_mode = match config.experimental_realtime_ws_mode {
|
||||
crate::config::RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational,
|
||||
crate::config::RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription,
|
||||
};
|
||||
|
||||
let requested_session_id = params
|
||||
.session_id
|
||||
.or_else(|| Some(sess.conversation_id.to_string()));
|
||||
let session_config = RealtimeSessionConfig {
|
||||
instructions: prompt,
|
||||
model,
|
||||
session_id: requested_session_id.clone(),
|
||||
event_parser,
|
||||
session_mode,
|
||||
};
|
||||
let extra_headers =
|
||||
realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?;
|
||||
info!("starting realtime conversation");
|
||||
let (events_rx, realtime_active) = match sess
|
||||
.conversation
|
||||
.start(api_provider, extra_headers, session_config)
|
||||
.start(
|
||||
api_provider,
|
||||
extra_headers,
|
||||
prompt,
|
||||
model,
|
||||
requested_session_id.clone(),
|
||||
event_parser,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(events_rx) => events_rx,
|
||||
|
||||
@@ -6,12 +6,13 @@ use crate::RolloutRecorder;
|
||||
use crate::agent::AgentControl;
|
||||
use crate::analytics_client::AnalyticsEventsClient;
|
||||
use crate::client::ModelClient;
|
||||
use crate::config::StartedNetworkProxy;
|
||||
use crate::codex::Session;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::mcp::McpManager;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::plugins::PluginsManager;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::state_db::StateDbHandle;
|
||||
@@ -21,9 +22,11 @@ use crate::tools::runtimes::ExecveSessionApproval;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use codex_hooks::Hooks;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
use codex_otel::SessionTelemetry;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::watch;
|
||||
@@ -55,7 +58,9 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_manager: Arc<McpManager>,
|
||||
pub(crate) file_watcher: Arc<FileWatcher>,
|
||||
pub(crate) agent_control: AgentControl,
|
||||
pub(crate) network_proxy: Option<StartedNetworkProxy>,
|
||||
pub(crate) network_proxies: NetworkProxyRegistry,
|
||||
pub(crate) network_policy_decider_session: Option<Arc<RwLock<Weak<Session>>>>,
|
||||
pub(crate) network_blocked_request_observer: Option<Arc<dyn BlockedRequestObserver>>,
|
||||
pub(crate) network_approval: Arc<NetworkApprovalService>,
|
||||
pub(crate) state_db: Option<StateDbHandle>,
|
||||
/// Session-scoped model client shared across turns.
|
||||
|
||||
@@ -46,6 +46,7 @@ use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
use crate::features::Feature;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use ghost_snapshot::GhostSnapshotTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
@@ -292,7 +293,12 @@ impl Session {
|
||||
"false"
|
||||
},
|
||||
);
|
||||
let network_proxy_active = match self.services.network_proxy.as_ref() {
|
||||
let network_proxy_active = match self
|
||||
.services
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
{
|
||||
Some(started_network_proxy) => {
|
||||
match started_network_proxy.proxy().current_cfg().await {
|
||||
Ok(config) => config.network.enabled,
|
||||
|
||||
@@ -105,11 +105,842 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
pub mod close_agent;
|
||||
mod resume_agent;
|
||||
mod send_input;
|
||||
mod spawn;
|
||||
pub(crate) mod wait;
|
||||
mod spawn {
|
||||
use super::*;
|
||||
use crate::agent::control::SpawnAgentOptions;
|
||||
use crate::agent::role::DEFAULT_ROLE_NAME;
|
||||
use crate::agent::role::apply_role_to_config;
|
||||
|
||||
use crate::agent::exceeds_thread_spawn_depth_limit;
|
||||
use crate::agent::next_thread_spawn_depth;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = SpawnAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
|
||||
let role_name = args
|
||||
.agent_type
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|role| !role.is_empty());
|
||||
let input_items = parse_collab_input(args.message, args.items)?;
|
||||
let prompt = input_preview(&input_items);
|
||||
let session_source = turn.session_source.clone();
|
||||
let child_depth = next_thread_spawn_depth(&session_source);
|
||||
let max_depth = turn.config.agent_max_depth;
|
||||
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"Agent depth limit reached. Solve the task yourself.".to_string(),
|
||||
));
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentSpawnBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
prompt: prompt.clone(),
|
||||
model: args.model.clone().unwrap_or_default(),
|
||||
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let mut config =
|
||||
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
|
||||
apply_requested_spawn_agent_model_overrides(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&mut config,
|
||||
args.model.as_deref(),
|
||||
args.reasoning_effort,
|
||||
)
|
||||
.await?;
|
||||
apply_role_to_config(&mut config, role_name)
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?;
|
||||
apply_spawn_agent_overrides(&mut config, child_depth);
|
||||
|
||||
let result = session
|
||||
.services
|
||||
.agent_control
|
||||
.spawn_agent_with_options(
|
||||
config,
|
||||
input_items,
|
||||
Some(thread_spawn_source(
|
||||
session.conversation_id,
|
||||
child_depth,
|
||||
role_name,
|
||||
)),
|
||||
SpawnAgentOptions {
|
||||
fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(collab_spawn_error);
|
||||
let (new_thread_id, status) = match &result {
|
||||
Ok(thread_id) => (
|
||||
Some(*thread_id),
|
||||
session.services.agent_control.get_status(*thread_id).await,
|
||||
),
|
||||
Err(_) => (None, AgentStatus::NotFound),
|
||||
};
|
||||
let (new_agent_nickname, new_agent_role) = match new_thread_id {
|
||||
Some(thread_id) => session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None)),
|
||||
None => (None, None),
|
||||
};
|
||||
let nickname = new_agent_nickname.clone();
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentSpawnEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
new_thread_id,
|
||||
new_agent_nickname,
|
||||
new_agent_role,
|
||||
prompt,
|
||||
model: args.model.clone().unwrap_or_default(),
|
||||
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let new_thread_id = result?;
|
||||
let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME);
|
||||
turn.session_telemetry
|
||||
.counter("codex.multi_agent.spawn", 1, &[("role", role_tag)]);
|
||||
|
||||
Ok(SpawnAgentResult {
|
||||
agent_id: new_thread_id.to_string(),
|
||||
nickname,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SpawnAgentArgs {
|
||||
message: Option<String>,
|
||||
items: Option<Vec<UserInput>>,
|
||||
agent_type: Option<String>,
|
||||
model: Option<String>,
|
||||
reasoning_effort: Option<ReasoningEffort>,
|
||||
#[serde(default)]
|
||||
fork_context: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct SpawnAgentResult {
|
||||
agent_id: String,
|
||||
nickname: Option<String>,
|
||||
}
|
||||
|
||||
impl ToolOutput for SpawnAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "spawn_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "spawn_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "spawn_agent")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod send_input {
|
||||
use super::*;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = SendInputResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: SendInputArgs = parse_arguments(&arguments)?;
|
||||
let receiver_thread_id = agent_id(&args.id)?;
|
||||
let input_items = parse_collab_input(args.message, args.items)?;
|
||||
let prompt = input_preview(&input_items);
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
if args.interrupt {
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.interrupt_agent(receiver_thread_id)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentInteractionBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
prompt: prompt.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let result = session
|
||||
.services
|
||||
.agent_control
|
||||
.send_input(receiver_thread_id, input_items)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err));
|
||||
let status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentInteractionEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
prompt,
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let submission_id = result?;
|
||||
|
||||
Ok(SendInputResult { submission_id })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SendInputArgs {
|
||||
id: String,
|
||||
message: Option<String>,
|
||||
items: Option<Vec<UserInput>>,
|
||||
#[serde(default)]
|
||||
interrupt: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct SendInputResult {
|
||||
submission_id: String,
|
||||
}
|
||||
|
||||
impl ToolOutput for SendInputResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "send_input")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "send_input")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "send_input")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod resume_agent {
|
||||
use super::*;
|
||||
use crate::agent::next_thread_spawn_depth;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = ResumeAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: ResumeAgentArgs = parse_arguments(&arguments)?;
|
||||
let receiver_thread_id = agent_id(&args.id)?;
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
let child_depth = next_thread_spawn_depth(&turn.session_source);
|
||||
let max_depth = turn.config.agent_max_depth;
|
||||
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"Agent depth limit reached. Solve the task yourself.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabResumeBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname: receiver_agent_nickname.clone(),
|
||||
receiver_agent_role: receiver_agent_role.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
let error = if matches!(status, AgentStatus::NotFound) {
|
||||
match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth)
|
||||
.await
|
||||
{
|
||||
Ok(resumed_status) => {
|
||||
status = resumed_status;
|
||||
None
|
||||
}
|
||||
Err(err) => {
|
||||
status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
Some(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((receiver_agent_nickname, receiver_agent_role));
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabResumeEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
status: status.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(err) = error {
|
||||
return Err(err);
|
||||
}
|
||||
turn.session_telemetry
|
||||
.counter("codex.multi_agent.resume", 1, &[]);
|
||||
|
||||
Ok(ResumeAgentResult { status })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResumeAgentArgs {
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub(crate) struct ResumeAgentResult {
|
||||
pub(crate) status: AgentStatus,
|
||||
}
|
||||
|
||||
impl ToolOutput for ResumeAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "resume_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "resume_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "resume_agent")
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_resume_closed_agent(
|
||||
session: &Arc<Session>,
|
||||
turn: &Arc<TurnContext>,
|
||||
receiver_thread_id: ThreadId,
|
||||
child_depth: i32,
|
||||
) -> Result<AgentStatus, FunctionCallError> {
|
||||
let config = build_agent_resume_config(turn.as_ref(), child_depth)?;
|
||||
let resumed_thread_id = session
|
||||
.services
|
||||
.agent_control
|
||||
.resume_agent_from_rollout(
|
||||
config,
|
||||
receiver_thread_id,
|
||||
thread_spawn_source(session.conversation_id, child_depth, None),
|
||||
)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
|
||||
|
||||
Ok(session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(resumed_thread_id)
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod wait {
|
||||
use super::*;
|
||||
use crate::agent::status::is_final;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use tokio::time::timeout_at;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = WaitResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: WaitArgs = parse_arguments(&arguments)?;
|
||||
if args.ids.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"ids must be non-empty".to_owned(),
|
||||
));
|
||||
}
|
||||
let receiver_thread_ids = args
|
||||
.ids
|
||||
.iter()
|
||||
.map(|id| agent_id(id))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len());
|
||||
for receiver_thread_id in &receiver_thread_ids {
|
||||
let (agent_nickname, agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(*receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
receiver_agents.push(CollabAgentRef {
|
||||
thread_id: *receiver_thread_id,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
});
|
||||
}
|
||||
|
||||
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
|
||||
let timeout_ms = match timeout_ms {
|
||||
ms if ms <= 0 => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"timeout_ms must be greater than zero".to_owned(),
|
||||
));
|
||||
}
|
||||
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
receiver_agents: receiver_agents.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
agent_statuses: build_wait_agent_statuses(
|
||||
&statuses,
|
||||
&receiver_agents,
|
||||
),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(*id, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let agent_statuses = build_wait_agent_statuses(&statuses_map, &receiver_agents);
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
agent_statuses,
|
||||
statuses: statuses_map,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WaitArgs {
|
||||
ids: Vec<String>,
|
||||
timeout_ms: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub(crate) struct WaitResult {
|
||||
pub(crate) status: HashMap<ThreadId, AgentStatus>,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
impl ToolOutput for WaitResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "wait")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, None, "wait")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "wait")
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(
|
||||
session: Arc<Session>,
|
||||
thread_id: ThreadId,
|
||||
mut status_rx: Receiver<AgentStatus>,
|
||||
) -> Option<(ThreadId, AgentStatus)> {
|
||||
let mut status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
|
||||
loop {
|
||||
if status_rx.changed().await.is_err() {
|
||||
let latest = session.services.agent_control.get_status(thread_id).await;
|
||||
return is_final(&latest).then_some((thread_id, latest));
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod close_agent {
|
||||
use super::*;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = CloseAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: CloseAgentArgs = parse_arguments(&arguments)?;
|
||||
let agent_id = agent_id(&args.id)?;
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(agent_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let status = match session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(agent_id)
|
||||
.await
|
||||
{
|
||||
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
|
||||
Err(err) => {
|
||||
let status = session.services.agent_control.get_status(agent_id).await;
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
receiver_agent_nickname: receiver_agent_nickname.clone(),
|
||||
receiver_agent_role: receiver_agent_role.clone(),
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(agent_id, err));
|
||||
}
|
||||
};
|
||||
let result = if !matches!(status, AgentStatus::Shutdown) {
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_agent(agent_id)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(agent_id, err))
|
||||
.map(|_| ())
|
||||
} else {
|
||||
Ok(())
|
||||
};
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
status: status.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
result?;
|
||||
|
||||
Ok(CloseAgentResult { status })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct CloseAgentResult {
|
||||
pub(crate) status: AgentStatus,
|
||||
}
|
||||
|
||||
impl ToolOutput for CloseAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "close_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "close_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "close_agent")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn agent_id(id: &str) -> Result<ThreadId, FunctionCallError> {
|
||||
ThreadId::from_string(id)
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = CloseAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: CloseAgentArgs = parse_arguments(&arguments)?;
|
||||
let agent_id = agent_id(&args.id)?;
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(agent_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let status = match session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(agent_id)
|
||||
.await
|
||||
{
|
||||
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
|
||||
Err(err) => {
|
||||
let status = session.services.agent_control.get_status(agent_id).await;
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
receiver_agent_nickname: receiver_agent_nickname.clone(),
|
||||
receiver_agent_role: receiver_agent_role.clone(),
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(agent_id, err));
|
||||
}
|
||||
};
|
||||
let result = if !matches!(status, AgentStatus::Shutdown) {
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_agent(agent_id)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(agent_id, err))
|
||||
.map(|_| ())
|
||||
} else {
|
||||
Ok(())
|
||||
};
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabCloseEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id: agent_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
status: status.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
result?;
|
||||
|
||||
Ok(CloseAgentResult { status })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct CloseAgentResult {
|
||||
pub(crate) status: AgentStatus,
|
||||
}
|
||||
|
||||
impl ToolOutput for CloseAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "close_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "close_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "close_agent")
|
||||
}
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
use super::*;
|
||||
use crate::agent::next_thread_spawn_depth;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = ResumeAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: ResumeAgentArgs = parse_arguments(&arguments)?;
|
||||
let receiver_thread_id = agent_id(&args.id)?;
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
let child_depth = next_thread_spawn_depth(&turn.session_source);
|
||||
let max_depth = turn.config.agent_max_depth;
|
||||
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"Agent depth limit reached. Solve the task yourself.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabResumeBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname: receiver_agent_nickname.clone(),
|
||||
receiver_agent_role: receiver_agent_role.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
let error = if matches!(status, AgentStatus::NotFound) {
|
||||
match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth).await {
|
||||
Ok(resumed_status) => {
|
||||
status = resumed_status;
|
||||
None
|
||||
}
|
||||
Err(err) => {
|
||||
status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
Some(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((receiver_agent_nickname, receiver_agent_role));
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabResumeEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
status: status.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(err) = error {
|
||||
return Err(err);
|
||||
}
|
||||
turn.session_telemetry
|
||||
.counter("codex.multi_agent.resume", 1, &[]);
|
||||
|
||||
Ok(ResumeAgentResult { status })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResumeAgentArgs {
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub(crate) struct ResumeAgentResult {
|
||||
pub(crate) status: AgentStatus,
|
||||
}
|
||||
|
||||
impl ToolOutput for ResumeAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "resume_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "resume_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "resume_agent")
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_resume_closed_agent(
|
||||
session: &Arc<Session>,
|
||||
turn: &Arc<TurnContext>,
|
||||
receiver_thread_id: ThreadId,
|
||||
child_depth: i32,
|
||||
) -> Result<AgentStatus, FunctionCallError> {
|
||||
let config = build_agent_resume_config(turn.as_ref(), child_depth)?;
|
||||
let resumed_thread_id = session
|
||||
.services
|
||||
.agent_control
|
||||
.resume_agent_from_rollout(
|
||||
config,
|
||||
receiver_thread_id,
|
||||
thread_spawn_source(session.conversation_id, child_depth, None),
|
||||
)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
|
||||
|
||||
Ok(session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(resumed_thread_id)
|
||||
.await)
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = SendInputResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: SendInputArgs = parse_arguments(&arguments)?;
|
||||
let receiver_thread_id = agent_id(&args.id)?;
|
||||
let input_items = parse_collab_input(args.message, args.items)?;
|
||||
let prompt = input_preview(&input_items);
|
||||
let (receiver_agent_nickname, receiver_agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
if args.interrupt {
|
||||
session
|
||||
.services
|
||||
.agent_control
|
||||
.interrupt_agent(receiver_thread_id)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentInteractionBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
prompt: prompt.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let result = session
|
||||
.services
|
||||
.agent_control
|
||||
.send_input(receiver_thread_id, input_items)
|
||||
.await
|
||||
.map_err(|err| collab_agent_error(receiver_thread_id, err));
|
||||
let status = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_status(receiver_thread_id)
|
||||
.await;
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentInteractionEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_id,
|
||||
receiver_agent_nickname,
|
||||
receiver_agent_role,
|
||||
prompt,
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let submission_id = result?;
|
||||
|
||||
Ok(SendInputResult { submission_id })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SendInputArgs {
|
||||
id: String,
|
||||
message: Option<String>,
|
||||
items: Option<Vec<UserInput>>,
|
||||
#[serde(default)]
|
||||
interrupt: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct SendInputResult {
|
||||
submission_id: String,
|
||||
}
|
||||
|
||||
impl ToolOutput for SendInputResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "send_input")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "send_input")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "send_input")
|
||||
}
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
use super::*;
|
||||
use crate::agent::control::SpawnAgentOptions;
|
||||
use crate::agent::role::DEFAULT_ROLE_NAME;
|
||||
use crate::agent::role::apply_role_to_config;
|
||||
|
||||
use crate::agent::exceeds_thread_spawn_depth_limit;
|
||||
use crate::agent::next_thread_spawn_depth;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = SpawnAgentResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
|
||||
let role_name = args
|
||||
.agent_type
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|role| !role.is_empty());
|
||||
let input_items = parse_collab_input(args.message, args.items)?;
|
||||
let prompt = input_preview(&input_items);
|
||||
let session_source = turn.session_source.clone();
|
||||
let child_depth = next_thread_spawn_depth(&session_source);
|
||||
let max_depth = turn.config.agent_max_depth;
|
||||
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"Agent depth limit reached. Solve the task yourself.".to_string(),
|
||||
));
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentSpawnBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
sender_thread_id: session.conversation_id,
|
||||
prompt: prompt.clone(),
|
||||
model: args.model.clone().unwrap_or_default(),
|
||||
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let mut config =
|
||||
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
|
||||
apply_requested_spawn_agent_model_overrides(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&mut config,
|
||||
args.model.as_deref(),
|
||||
args.reasoning_effort,
|
||||
)
|
||||
.await?;
|
||||
apply_role_to_config(&mut config, role_name)
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?;
|
||||
apply_spawn_agent_overrides(&mut config, child_depth);
|
||||
|
||||
let result = session
|
||||
.services
|
||||
.agent_control
|
||||
.spawn_agent_with_options(
|
||||
config,
|
||||
input_items,
|
||||
Some(thread_spawn_source(
|
||||
session.conversation_id,
|
||||
child_depth,
|
||||
role_name,
|
||||
)),
|
||||
SpawnAgentOptions {
|
||||
fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(collab_spawn_error);
|
||||
let (new_thread_id, status) = match &result {
|
||||
Ok(thread_id) => (
|
||||
Some(*thread_id),
|
||||
session.services.agent_control.get_status(*thread_id).await,
|
||||
),
|
||||
Err(_) => (None, AgentStatus::NotFound),
|
||||
};
|
||||
let (new_agent_nickname, new_agent_role) = match new_thread_id {
|
||||
Some(thread_id) => session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None)),
|
||||
None => (None, None),
|
||||
};
|
||||
let nickname = new_agent_nickname.clone();
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabAgentSpawnEndEvent {
|
||||
call_id,
|
||||
sender_thread_id: session.conversation_id,
|
||||
new_thread_id,
|
||||
new_agent_nickname,
|
||||
new_agent_role,
|
||||
prompt,
|
||||
model: args.model.clone().unwrap_or_default(),
|
||||
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
|
||||
status,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
let new_thread_id = result?;
|
||||
let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME);
|
||||
turn.session_telemetry
|
||||
.counter("codex.multi_agent.spawn", 1, &[("role", role_tag)]);
|
||||
|
||||
Ok(SpawnAgentResult {
|
||||
agent_id: new_thread_id.to_string(),
|
||||
nickname,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SpawnAgentArgs {
|
||||
message: Option<String>,
|
||||
items: Option<Vec<UserInput>>,
|
||||
agent_type: Option<String>,
|
||||
model: Option<String>,
|
||||
reasoning_effort: Option<ReasoningEffort>,
|
||||
#[serde(default)]
|
||||
fork_context: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct SpawnAgentResult {
|
||||
agent_id: String,
|
||||
nickname: Option<String>,
|
||||
}
|
||||
|
||||
impl ToolOutput for SpawnAgentResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "spawn_agent")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, Some(true), "spawn_agent")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "spawn_agent")
|
||||
}
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
use super::*;
|
||||
use crate::agent::status::is_final;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use tokio::time::timeout_at;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for Handler {
|
||||
type Output = WaitResult;
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
let arguments = function_arguments(payload)?;
|
||||
let args: WaitArgs = parse_arguments(&arguments)?;
|
||||
if args.ids.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"ids must be non-empty".to_owned(),
|
||||
));
|
||||
}
|
||||
let receiver_thread_ids = args
|
||||
.ids
|
||||
.iter()
|
||||
.map(|id| agent_id(id))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len());
|
||||
for receiver_thread_id in &receiver_thread_ids {
|
||||
let (agent_nickname, agent_role) = session
|
||||
.services
|
||||
.agent_control
|
||||
.get_agent_nickname_and_role(*receiver_thread_id)
|
||||
.await
|
||||
.unwrap_or((None, None));
|
||||
receiver_agents.push(CollabAgentRef {
|
||||
thread_id: *receiver_thread_id,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
});
|
||||
}
|
||||
|
||||
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
|
||||
let timeout_ms = match timeout_ms {
|
||||
ms if ms <= 0 => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"timeout_ms must be greater than zero".to_owned(),
|
||||
));
|
||||
}
|
||||
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingBeginEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
receiver_thread_ids: receiver_thread_ids.clone(),
|
||||
receiver_agents: receiver_agents.clone(),
|
||||
call_id: call_id.clone(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
|
||||
let mut initial_final_statuses = Vec::new();
|
||||
for id in &receiver_thread_ids {
|
||||
match session.services.agent_control.subscribe_status(*id).await {
|
||||
Ok(rx) => {
|
||||
let status = rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
initial_final_statuses.push((*id, status));
|
||||
}
|
||||
status_rxs.push((*id, rx));
|
||||
}
|
||||
Err(CodexErr::ThreadNotFound(_)) => {
|
||||
initial_final_statuses.push((*id, AgentStatus::NotFound));
|
||||
}
|
||||
Err(err) => {
|
||||
let mut statuses = HashMap::with_capacity(1);
|
||||
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id: call_id.clone(),
|
||||
agent_statuses: build_wait_agent_statuses(
|
||||
&statuses,
|
||||
&receiver_agents,
|
||||
),
|
||||
statuses,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
return Err(collab_agent_error(*id, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let agent_statuses = build_wait_agent_statuses(&statuses_map, &receiver_agents);
|
||||
let result = WaitResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
};
|
||||
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
CollabWaitingEndEvent {
|
||||
sender_thread_id: session.conversation_id,
|
||||
call_id,
|
||||
agent_statuses,
|
||||
statuses: statuses_map,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WaitArgs {
|
||||
ids: Vec<String>,
|
||||
timeout_ms: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub(crate) struct WaitResult {
|
||||
pub(crate) status: HashMap<ThreadId, AgentStatus>,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
impl ToolOutput for WaitResult {
|
||||
fn log_preview(&self) -> String {
|
||||
tool_output_json_text(self, "wait")
|
||||
}
|
||||
|
||||
fn success_for_logging(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
tool_output_response_item(call_id, payload, self, None, "wait")
|
||||
}
|
||||
|
||||
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
|
||||
tool_output_code_mode_result(self, "wait")
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_final_status(
|
||||
session: Arc<Session>,
|
||||
thread_id: ThreadId,
|
||||
mut status_rx: Receiver<AgentStatus>,
|
||||
) -> Option<(ThreadId, AgentStatus)> {
|
||||
let mut status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
|
||||
loop {
|
||||
if status_rx.changed().await.is_err() {
|
||||
let latest = session.services.agent_control.get_status(thread_id).await;
|
||||
return is_final(&latest).then_some((thread_id, latest));
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
if is_final(&status) {
|
||||
return Some((thread_id, status));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,6 @@ use crate::tools::handlers::parse_arguments_with_base_path;
|
||||
use crate::tools::handlers::resolve_workdir_base_path;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::spec::UnifiedExecBackendConfig;
|
||||
use crate::unified_exec::ExecCommandRequest;
|
||||
use crate::unified_exec::UnifiedExecContext;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
@@ -109,7 +108,6 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
¶ms,
|
||||
invocation.session.user_shell(),
|
||||
invocation.turn.tools_config.allow_login_shell,
|
||||
invocation.turn.tools_config.unified_exec_backend,
|
||||
) {
|
||||
Ok(command) => command,
|
||||
Err(_) => return true,
|
||||
@@ -157,7 +155,6 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
&args,
|
||||
session.user_shell(),
|
||||
turn.tools_config.allow_login_shell,
|
||||
turn.tools_config.unified_exec_backend,
|
||||
)
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
|
||||
@@ -324,23 +321,12 @@ pub(crate) fn get_command(
|
||||
args: &ExecCommandArgs,
|
||||
session_shell: Arc<Shell>,
|
||||
allow_login_shell: bool,
|
||||
unified_exec_backend: UnifiedExecBackendConfig,
|
||||
) -> Result<Vec<String>, String> {
|
||||
if unified_exec_backend == UnifiedExecBackendConfig::ZshFork && args.shell.is_some() {
|
||||
return Err(
|
||||
"shell override is not supported when the zsh-fork backend is enabled.".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let model_shell = if unified_exec_backend == UnifiedExecBackendConfig::ZshFork {
|
||||
None
|
||||
} else {
|
||||
args.shell.as_ref().map(|shell_str| {
|
||||
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
|
||||
shell.shell_snapshot = crate::shell::empty_shell_snapshot_receiver();
|
||||
shell
|
||||
})
|
||||
};
|
||||
let model_shell = args.shell.as_ref().map(|shell_str| {
|
||||
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
|
||||
shell.shell_snapshot = crate::shell::empty_shell_snapshot_receiver();
|
||||
shell
|
||||
});
|
||||
|
||||
let shell = model_shell.as_ref().unwrap_or(session_shell.as_ref());
|
||||
let use_login_shell = match args.login {
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
use super::*;
|
||||
use crate::shell::ShellType;
|
||||
use crate::shell::default_user_shell;
|
||||
use crate::shell::empty_shell_snapshot_receiver;
|
||||
use crate::tools::handlers::parse_arguments_with_base_path;
|
||||
use crate::tools::handlers::resolve_workdir_base_path;
|
||||
use crate::tools::spec::UnifiedExecBackendConfig;
|
||||
use codex_protocol::models::FileSystemPermissions;
|
||||
use codex_protocol::models::PermissionProfile;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
@@ -22,13 +18,8 @@ fn test_get_command_uses_default_shell_when_unspecified() -> anyhow::Result<()>
|
||||
|
||||
assert!(args.shell.is_none());
|
||||
|
||||
let command = get_command(
|
||||
&args,
|
||||
Arc::new(default_user_shell()),
|
||||
true,
|
||||
UnifiedExecBackendConfig::Direct,
|
||||
)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let command =
|
||||
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
|
||||
|
||||
assert_eq!(command.len(), 3);
|
||||
assert_eq!(command[2], "echo hello");
|
||||
@@ -43,13 +34,8 @@ fn test_get_command_respects_explicit_bash_shell() -> anyhow::Result<()> {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("/bin/bash"));
|
||||
|
||||
let command = get_command(
|
||||
&args,
|
||||
Arc::new(default_user_shell()),
|
||||
true,
|
||||
UnifiedExecBackendConfig::Direct,
|
||||
)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let command =
|
||||
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
|
||||
|
||||
assert_eq!(command.last(), Some(&"echo hello".to_string()));
|
||||
if command
|
||||
@@ -69,13 +55,8 @@ fn test_get_command_respects_explicit_powershell_shell() -> anyhow::Result<()> {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("powershell"));
|
||||
|
||||
let command = get_command(
|
||||
&args,
|
||||
Arc::new(default_user_shell()),
|
||||
true,
|
||||
UnifiedExecBackendConfig::Direct,
|
||||
)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let command =
|
||||
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
|
||||
|
||||
assert_eq!(command[2], "echo hello");
|
||||
Ok(())
|
||||
@@ -89,13 +70,8 @@ fn test_get_command_respects_explicit_cmd_shell() -> anyhow::Result<()> {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("cmd"));
|
||||
|
||||
let command = get_command(
|
||||
&args,
|
||||
Arc::new(default_user_shell()),
|
||||
true,
|
||||
UnifiedExecBackendConfig::Direct,
|
||||
)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let command =
|
||||
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
|
||||
|
||||
assert_eq!(command[2], "echo hello");
|
||||
Ok(())
|
||||
@@ -106,13 +82,8 @@ fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<(
|
||||
let json = r#"{"cmd": "echo hello", "login": true}"#;
|
||||
|
||||
let args: ExecCommandArgs = parse_arguments(json)?;
|
||||
let err = get_command(
|
||||
&args,
|
||||
Arc::new(default_user_shell()),
|
||||
false,
|
||||
UnifiedExecBackendConfig::Direct,
|
||||
)
|
||||
.expect_err("explicit login should be rejected");
|
||||
let err = get_command(&args, Arc::new(default_user_shell()), false)
|
||||
.expect_err("explicit login should be rejected");
|
||||
|
||||
assert!(
|
||||
err.contains("login shell is disabled by config"),
|
||||
@@ -121,30 +92,6 @@ fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_command_rejects_model_shell_override_for_zsh_fork_backend() -> anyhow::Result<()> {
|
||||
let json = r#"{"cmd": "echo hello", "shell": "/bin/bash"}"#;
|
||||
let args: ExecCommandArgs = parse_arguments(json)?;
|
||||
|
||||
let session_shell = Arc::new(Shell {
|
||||
shell_type: ShellType::Zsh,
|
||||
shell_path: PathBuf::from("/tmp/configured-zsh-fork-shell"),
|
||||
shell_snapshot: empty_shell_snapshot_receiver(),
|
||||
});
|
||||
let err = get_command(
|
||||
&args,
|
||||
session_shell,
|
||||
true,
|
||||
UnifiedExecBackendConfig::ZshFork,
|
||||
)
|
||||
.expect_err("shell override should be rejected for zsh-fork backend");
|
||||
assert!(
|
||||
err.contains("shell override is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_command_args_resolve_relative_additional_permissions_against_workdir() -> anyhow::Result<()>
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::guardian::GuardianApprovalRequest;
|
||||
use crate::guardian::review_approval_request;
|
||||
use crate::guardian::routes_approval_to_guardian;
|
||||
use crate::network_policy_decision::denied_network_policy_message;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use codex_network_proxy::BlockedRequest;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
@@ -76,14 +77,20 @@ impl ActiveNetworkApproval {
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
struct HostApprovalKey {
|
||||
scope: NetworkProxyScope,
|
||||
host: String,
|
||||
protocol: &'static str,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl HostApprovalKey {
|
||||
fn from_request(request: &NetworkPolicyRequest, protocol: NetworkApprovalProtocol) -> Self {
|
||||
fn from_request(
|
||||
request: &NetworkPolicyRequest,
|
||||
protocol: NetworkApprovalProtocol,
|
||||
scope: NetworkProxyScope,
|
||||
) -> Self {
|
||||
Self {
|
||||
scope,
|
||||
host: request.host.to_ascii_lowercase(),
|
||||
protocol: protocol_key_label(protocol),
|
||||
port: request.port,
|
||||
@@ -279,6 +286,7 @@ impl NetworkApprovalService {
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
request: NetworkPolicyRequest,
|
||||
scope: NetworkProxyScope,
|
||||
) -> NetworkDecision {
|
||||
const REASON_NOT_ALLOWED: &str = "not_allowed";
|
||||
|
||||
@@ -288,7 +296,7 @@ impl NetworkApprovalService {
|
||||
NetworkProtocol::Socks5Tcp => NetworkApprovalProtocol::Socks5Tcp,
|
||||
NetworkProtocol::Socks5Udp => NetworkApprovalProtocol::Socks5Udp,
|
||||
};
|
||||
let key = HostApprovalKey::from_request(&request, protocol);
|
||||
let key = HostApprovalKey::from_request(&request, protocol, scope.clone());
|
||||
|
||||
{
|
||||
let denied_hosts = self.session_denied_hosts.lock().await;
|
||||
@@ -387,6 +395,7 @@ impl NetworkApprovalService {
|
||||
.persist_network_policy_amendment(
|
||||
&network_policy_amendment,
|
||||
&network_approval_context,
|
||||
&scope,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -417,6 +426,7 @@ impl NetworkApprovalService {
|
||||
.persist_network_policy_amendment(
|
||||
&network_policy_amendment,
|
||||
&network_approval_context,
|
||||
&scope,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -506,16 +516,18 @@ pub(crate) fn build_blocked_request_observer(
|
||||
pub(crate) fn build_network_policy_decider(
|
||||
network_approval: Arc<NetworkApprovalService>,
|
||||
network_policy_decider_session: Arc<RwLock<std::sync::Weak<Session>>>,
|
||||
scope: NetworkProxyScope,
|
||||
) -> Arc<dyn NetworkPolicyDecider> {
|
||||
Arc::new(move |request: NetworkPolicyRequest| {
|
||||
let network_approval = Arc::clone(&network_approval);
|
||||
let network_policy_decider_session = Arc::clone(&network_policy_decider_session);
|
||||
let scope = scope.clone();
|
||||
async move {
|
||||
let Some(session) = network_policy_decider_session.read().await.upgrade() else {
|
||||
return NetworkDecision::ask("not_allowed");
|
||||
};
|
||||
network_approval
|
||||
.handle_inline_policy_request(session, request)
|
||||
.handle_inline_policy_request(session, request, scope)
|
||||
.await
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::*;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use codex_network_proxy::BlockedRequestArgs;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -7,6 +8,7 @@ use pretty_assertions::assert_eq;
|
||||
async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
|
||||
let service = NetworkApprovalService::default();
|
||||
let key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 443,
|
||||
@@ -24,11 +26,13 @@ async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
|
||||
async fn pending_approvals_do_not_dedupe_across_ports() {
|
||||
let service = NetworkApprovalService::default();
|
||||
let first_key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
};
|
||||
let second_key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
@@ -49,16 +53,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
|
||||
let mut approved_hosts = source.session_approved_hosts.lock().await;
|
||||
approved_hosts.extend([
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 80,
|
||||
@@ -82,16 +89,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
|
||||
copied,
|
||||
vec![
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 80,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::features::Feature;
|
||||
use crate::guardian::GuardianApprovalRequest;
|
||||
use crate::guardian::review_approval_request;
|
||||
use crate::guardian::routes_approval_to_guardian;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::sandboxing::ExecRequest;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
use crate::shell::ShellType;
|
||||
@@ -54,6 +55,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -91,7 +93,7 @@ pub(super) async fn try_run_zsh_fork(
|
||||
req: &ShellRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx,
|
||||
shell_command: &[String],
|
||||
command: &[String],
|
||||
) -> Result<Option<ExecToolCallOutput>, ToolError> {
|
||||
let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref() else {
|
||||
tracing::warn!("ZshFork backend specified, but shell_zsh_path is not configured.");
|
||||
@@ -106,10 +108,8 @@ pub(super) async fn try_run_zsh_fork(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let ParsedShellCommand { script, login, .. } = extract_shell_script(shell_command)?;
|
||||
|
||||
let spec = build_command_spec(
|
||||
shell_command,
|
||||
command,
|
||||
&req.cwd,
|
||||
&req.env,
|
||||
req.timeout_ms.into(),
|
||||
@@ -121,7 +121,7 @@ pub(super) async fn try_run_zsh_fork(
|
||||
.env_for(spec, req.network.as_ref())
|
||||
.map_err(|err| ToolError::Codex(err.into()))?;
|
||||
let crate::sandboxing::ExecRequest {
|
||||
command: sandbox_command,
|
||||
command,
|
||||
cwd: sandbox_cwd,
|
||||
env: sandbox_env,
|
||||
network: sandbox_network,
|
||||
@@ -135,14 +135,17 @@ pub(super) async fn try_run_zsh_fork(
|
||||
justification,
|
||||
arg0,
|
||||
} = sandbox_exec_request;
|
||||
let host_zsh_path =
|
||||
resolve_host_zsh_path(sandbox_env.get("PATH").map(String::as_str), &sandbox_cwd);
|
||||
let ParsedShellCommand { script, login, .. } = extract_shell_script(&command)?;
|
||||
let effective_timeout = Duration::from_millis(
|
||||
req.timeout_ms
|
||||
.unwrap_or(crate::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
|
||||
);
|
||||
let exec_policy = Arc::new(RwLock::new(
|
||||
ctx.session.services.exec_policy.current().as_ref().clone(),
|
||||
));
|
||||
let command_executor = CoreShellCommandExecutor {
|
||||
command: sandbox_command,
|
||||
session: Some(Arc::clone(&ctx.session)),
|
||||
command,
|
||||
cwd: sandbox_cwd,
|
||||
sandbox_policy,
|
||||
file_system_sandbox_policy,
|
||||
@@ -163,8 +166,6 @@ pub(super) async fn try_run_zsh_fork(
|
||||
.clone(),
|
||||
codex_linux_sandbox_exe: ctx.turn.codex_linux_sandbox_exe.clone(),
|
||||
use_legacy_landlock: ctx.turn.features.use_legacy_landlock(),
|
||||
shell_zsh_path: ctx.session.services.shell_zsh_path.clone(),
|
||||
host_zsh_path: host_zsh_path.clone(),
|
||||
};
|
||||
let main_execve_wrapper_exe = ctx
|
||||
.session
|
||||
@@ -193,6 +194,7 @@ pub(super) async fn try_run_zsh_fork(
|
||||
req.additional_permissions_preapproved,
|
||||
);
|
||||
let escalation_policy = CoreShellActionProvider {
|
||||
policy: Arc::clone(&exec_policy),
|
||||
session: Arc::clone(&ctx.session),
|
||||
turn: Arc::clone(&ctx.turn),
|
||||
call_id: ctx.call_id.clone(),
|
||||
@@ -205,7 +207,6 @@ pub(super) async fn try_run_zsh_fork(
|
||||
approval_sandbox_permissions,
|
||||
prompt_permissions: req.additional_permissions.clone(),
|
||||
stopwatch: stopwatch.clone(),
|
||||
host_zsh_path,
|
||||
};
|
||||
|
||||
let escalate_server = EscalateServer::new(
|
||||
@@ -226,7 +227,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
req: &crate::tools::runtimes::unified_exec::UnifiedExecRequest,
|
||||
_attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx,
|
||||
shell_command: &[String],
|
||||
exec_request: ExecRequest,
|
||||
) -> Result<Option<PreparedUnifiedExecZshFork>, ToolError> {
|
||||
let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref() else {
|
||||
@@ -242,7 +242,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let parsed = match extract_shell_script(shell_command) {
|
||||
let parsed = match extract_shell_script(&exec_request.command) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(err) => {
|
||||
tracing::warn!("ZshFork unified exec fallback: {err:?}");
|
||||
@@ -258,35 +258,23 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let ExecRequest {
|
||||
command,
|
||||
cwd,
|
||||
env,
|
||||
network,
|
||||
expiration: _expiration,
|
||||
sandbox,
|
||||
windows_sandbox_level,
|
||||
sandbox_permissions,
|
||||
sandbox_policy,
|
||||
file_system_sandbox_policy,
|
||||
network_sandbox_policy,
|
||||
justification,
|
||||
arg0,
|
||||
} = &exec_request;
|
||||
let host_zsh_path = resolve_host_zsh_path(env.get("PATH").map(String::as_str), cwd);
|
||||
let exec_policy = Arc::new(RwLock::new(
|
||||
ctx.session.services.exec_policy.current().as_ref().clone(),
|
||||
));
|
||||
let command_executor = CoreShellCommandExecutor {
|
||||
command: command.clone(),
|
||||
cwd: cwd.clone(),
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
file_system_sandbox_policy: file_system_sandbox_policy.clone(),
|
||||
network_sandbox_policy: *network_sandbox_policy,
|
||||
sandbox: *sandbox,
|
||||
env: env.clone(),
|
||||
network: network.clone(),
|
||||
windows_sandbox_level: *windows_sandbox_level,
|
||||
sandbox_permissions: *sandbox_permissions,
|
||||
justification: justification.clone(),
|
||||
arg0: arg0.clone(),
|
||||
session: Some(Arc::clone(&ctx.session)),
|
||||
command: exec_request.command.clone(),
|
||||
cwd: exec_request.cwd.clone(),
|
||||
sandbox_policy: exec_request.sandbox_policy.clone(),
|
||||
file_system_sandbox_policy: exec_request.file_system_sandbox_policy.clone(),
|
||||
network_sandbox_policy: exec_request.network_sandbox_policy,
|
||||
sandbox: exec_request.sandbox,
|
||||
env: exec_request.env.clone(),
|
||||
network: exec_request.network.clone(),
|
||||
windows_sandbox_level: exec_request.windows_sandbox_level,
|
||||
sandbox_permissions: exec_request.sandbox_permissions,
|
||||
justification: exec_request.justification.clone(),
|
||||
arg0: exec_request.arg0.clone(),
|
||||
sandbox_policy_cwd: ctx.turn.cwd.clone(),
|
||||
macos_seatbelt_profile_extensions: ctx
|
||||
.turn
|
||||
@@ -296,8 +284,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
.clone(),
|
||||
codex_linux_sandbox_exe: ctx.turn.codex_linux_sandbox_exe.clone(),
|
||||
use_legacy_landlock: ctx.turn.features.use_legacy_landlock(),
|
||||
shell_zsh_path: ctx.session.services.shell_zsh_path.clone(),
|
||||
host_zsh_path: host_zsh_path.clone(),
|
||||
};
|
||||
let main_execve_wrapper_exe = ctx
|
||||
.session
|
||||
@@ -310,6 +296,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
)
|
||||
})?;
|
||||
let escalation_policy = CoreShellActionProvider {
|
||||
policy: Arc::clone(&exec_policy),
|
||||
session: Arc::clone(&ctx.session),
|
||||
turn: Arc::clone(&ctx.turn),
|
||||
call_id: ctx.call_id.clone(),
|
||||
@@ -325,7 +312,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
),
|
||||
prompt_permissions: req.additional_permissions.clone(),
|
||||
stopwatch: Stopwatch::unlimited(),
|
||||
host_zsh_path,
|
||||
};
|
||||
|
||||
let escalate_server = EscalateServer::new(
|
||||
@@ -345,6 +331,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
}
|
||||
|
||||
struct CoreShellActionProvider {
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
session: Arc<crate::codex::Session>,
|
||||
turn: Arc<crate::codex::TurnContext>,
|
||||
call_id: String,
|
||||
@@ -357,7 +344,6 @@ struct CoreShellActionProvider {
|
||||
approval_sandbox_permissions: SandboxPermissions,
|
||||
prompt_permissions: Option<PermissionProfile>,
|
||||
stopwatch: Stopwatch,
|
||||
host_zsh_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
@@ -395,66 +381,6 @@ fn execve_prompt_is_rejected_by_policy(
|
||||
}
|
||||
}
|
||||
|
||||
fn paths_match(lhs: &Path, rhs: &Path) -> bool {
|
||||
lhs == rhs
|
||||
|| match (lhs.canonicalize(), rhs.canonicalize()) {
|
||||
(Ok(lhs), Ok(rhs)) => lhs == rhs,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_host_zsh_path(_path_env: Option<&str>, _cwd: &Path) -> Option<PathBuf> {
|
||||
fn canonicalize_best_effort(path: PathBuf) -> PathBuf {
|
||||
path.canonicalize().unwrap_or(path)
|
||||
}
|
||||
|
||||
fn is_executable_file(path: &Path) -> bool {
|
||||
std::fs::metadata(path).is_ok_and(|metadata| {
|
||||
metadata.is_file() && {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
metadata.permissions().mode() & 0o111 != 0
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn find_zsh_in_dirs(dirs: impl IntoIterator<Item = PathBuf>) -> Option<PathBuf> {
|
||||
dirs.into_iter().find_map(|dir| {
|
||||
let candidate = dir.join("zsh");
|
||||
is_executable_file(&candidate).then(|| canonicalize_best_effort(candidate))
|
||||
})
|
||||
}
|
||||
|
||||
// Keep nested-zsh rewrites limited to canonical host shell installations.
|
||||
// PATH shadowing from repos, Nix environments, or tool shims should not be
|
||||
// treated as the host shell.
|
||||
find_zsh_in_dirs(
|
||||
["/bin", "/usr/bin", "/usr/local/bin", "/opt/homebrew/bin"]
|
||||
.into_iter()
|
||||
.map(PathBuf::from),
|
||||
)
|
||||
}
|
||||
|
||||
fn is_unconfigured_zsh_exec(
|
||||
program: &AbsolutePathBuf,
|
||||
shell_zsh_path: Option<&Path>,
|
||||
host_zsh_path: Option<&Path>,
|
||||
) -> bool {
|
||||
let Some(shell_zsh_path) = shell_zsh_path else {
|
||||
return false;
|
||||
};
|
||||
let Some(host_zsh_path) = host_zsh_path else {
|
||||
return false;
|
||||
};
|
||||
paths_match(program.as_path(), host_zsh_path) && !paths_match(program.as_path(), shell_zsh_path)
|
||||
}
|
||||
|
||||
impl CoreShellActionProvider {
|
||||
fn decision_driven_by_policy(matched_rules: &[RuleMatch], decision: Decision) -> bool {
|
||||
matched_rules.iter().any(|rule_match| {
|
||||
@@ -566,10 +492,6 @@ impl CoreShellActionProvider {
|
||||
command,
|
||||
workdir,
|
||||
None,
|
||||
// Intercepted exec prompts happen after the original tool call has
|
||||
// started, so we do not attach an execpolicy amendment payload here.
|
||||
// Amendments are currently surfaced only from the top-level tool
|
||||
// request path.
|
||||
None,
|
||||
None,
|
||||
additional_permissions,
|
||||
@@ -585,26 +507,7 @@ impl CoreShellActionProvider {
|
||||
/// an absolute path. The idea is that we check to see whether it matches
|
||||
/// any skills.
|
||||
async fn find_skill(&self, program: &AbsolutePathBuf) -> Option<SkillMetadata> {
|
||||
let force_reload = false;
|
||||
let skills_outcome = self
|
||||
.session
|
||||
.services
|
||||
.skills_manager
|
||||
.skills_for_cwd(&self.turn.cwd, force_reload)
|
||||
.await;
|
||||
|
||||
let program_path = program.as_path();
|
||||
for skill in skills_outcome.skills {
|
||||
// We intentionally ignore "enabled" status here for now.
|
||||
let Some(skill_root) = skill.path_to_skills_md.parent() else {
|
||||
continue;
|
||||
};
|
||||
if program_path.starts_with(skill_root.join("scripts")) {
|
||||
return Some(skill);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
find_skill_for_program(self.session.as_ref(), self.turn.cwd.as_path(), program).await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -713,6 +616,32 @@ impl CoreShellActionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_skill_for_program(
|
||||
session: &crate::codex::Session,
|
||||
cwd: &Path,
|
||||
program: &AbsolutePathBuf,
|
||||
) -> Option<SkillMetadata> {
|
||||
let force_reload = false;
|
||||
let skills_outcome = session
|
||||
.services
|
||||
.skills_manager
|
||||
.skills_for_cwd(cwd, force_reload)
|
||||
.await;
|
||||
|
||||
let program_path = program.as_path();
|
||||
for skill in skills_outcome.skills {
|
||||
// We intentionally ignore "enabled" status here for now.
|
||||
let Some(skill_root) = skill.path_to_skills_md.parent() else {
|
||||
continue;
|
||||
};
|
||||
if program_path.starts_with(skill_root.join("scripts")) {
|
||||
return Some(skill);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Shell-wrapper parsing is weaker than direct exec interception because it can
|
||||
// only see the script text, not the final resolved executable path. Keep it
|
||||
// disabled by default so path-sensitive rules rely on the later authoritative
|
||||
@@ -794,35 +723,28 @@ impl EscalationPolicy for CoreShellActionProvider {
|
||||
.await;
|
||||
}
|
||||
|
||||
let policy = self.session.services.exec_policy.current();
|
||||
let evaluation = evaluate_intercepted_exec_policy(
|
||||
policy.as_ref(),
|
||||
program,
|
||||
argv,
|
||||
InterceptedExecPolicyContext {
|
||||
approval_policy: self.approval_policy,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
sandbox_permissions: self.approval_sandbox_permissions,
|
||||
enable_shell_wrapper_parsing: ENABLE_INTERCEPTED_EXEC_POLICY_SHELL_WRAPPER_PARSING,
|
||||
},
|
||||
);
|
||||
let evaluation = {
|
||||
let policy = self.policy.read().await;
|
||||
evaluate_intercepted_exec_policy(
|
||||
&policy,
|
||||
program,
|
||||
argv,
|
||||
InterceptedExecPolicyContext {
|
||||
approval_policy: self.approval_policy,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
sandbox_permissions: self.approval_sandbox_permissions,
|
||||
enable_shell_wrapper_parsing:
|
||||
ENABLE_INTERCEPTED_EXEC_POLICY_SHELL_WRAPPER_PARSING,
|
||||
},
|
||||
)
|
||||
};
|
||||
// When true, means the Evaluation was due to *.rules, not the
|
||||
// fallback function.
|
||||
let decision_driven_by_policy =
|
||||
Self::decision_driven_by_policy(&evaluation.matched_rules, evaluation.decision);
|
||||
// Keep zsh-fork interception alive across nested shells: if an
|
||||
// intercepted exec targets the known host `zsh` path instead of the
|
||||
// configured zsh-fork binary, force it through escalation so the
|
||||
// executor can rewrite the program path back to the configured shell.
|
||||
let force_zsh_fork_reexec = is_unconfigured_zsh_exec(
|
||||
program,
|
||||
self.session.services.shell_zsh_path.as_deref(),
|
||||
self.host_zsh_path.as_deref(),
|
||||
);
|
||||
let needs_escalation = self.sandbox_permissions.requires_escalated_permissions()
|
||||
|| decision_driven_by_policy
|
||||
|| force_zsh_fork_reexec;
|
||||
let needs_escalation =
|
||||
self.sandbox_permissions.requires_escalated_permissions() || decision_driven_by_policy;
|
||||
|
||||
let decision_source = if decision_driven_by_policy {
|
||||
DecisionSource::PrefixRule
|
||||
@@ -954,6 +876,7 @@ fn commands_for_intercepted_exec_policy(
|
||||
}
|
||||
|
||||
struct CoreShellCommandExecutor {
|
||||
session: Option<Arc<crate::codex::Session>>,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
@@ -971,14 +894,13 @@ struct CoreShellCommandExecutor {
|
||||
macos_seatbelt_profile_extensions: Option<MacOsSeatbeltProfileExtensions>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
use_legacy_landlock: bool,
|
||||
shell_zsh_path: Option<PathBuf>,
|
||||
host_zsh_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
struct PrepareSandboxedExecParams<'a> {
|
||||
command: Vec<String>,
|
||||
workdir: &'a AbsolutePathBuf,
|
||||
env: HashMap<String, String>,
|
||||
network: Option<codex_network_proxy::NetworkProxy>,
|
||||
sandbox_policy: &'a SandboxPolicy,
|
||||
file_system_sandbox_policy: &'a FileSystemSandboxPolicy,
|
||||
network_sandbox_policy: NetworkSandboxPolicy,
|
||||
@@ -1045,8 +967,7 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
env: HashMap<String, String>,
|
||||
execution: EscalationExecution,
|
||||
) -> anyhow::Result<PreparedExec> {
|
||||
let program = self.rewrite_intercepted_program_for_zsh_fork(program);
|
||||
let command = join_program_and_argv(&program, argv);
|
||||
let command = join_program_and_argv(program, argv);
|
||||
let Some(first_arg) = argv.first() else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"intercepted exec request must contain argv[0]"
|
||||
@@ -1061,10 +982,12 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
arg0: Some(first_arg.clone()),
|
||||
},
|
||||
EscalationExecution::TurnDefault => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
network_sandbox_policy: self.network_sandbox_policy,
|
||||
@@ -1078,12 +1001,14 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
EscalationExecution::Permissions(EscalationPermissions::PermissionProfile(
|
||||
permission_profile,
|
||||
)) => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
// Merge additive permissions into the existing turn/request sandbox policy.
|
||||
// On macOS, additional profile extensions are unioned with the turn defaults.
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
network_sandbox_policy: self.network_sandbox_policy,
|
||||
@@ -1095,11 +1020,13 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
})?
|
||||
}
|
||||
EscalationExecution::Permissions(EscalationPermissions::Permissions(permissions)) => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
// Use a fully specified sandbox policy instead of merging into the turn policy.
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &permissions.sandbox_policy,
|
||||
file_system_sandbox_policy: &permissions.file_system_sandbox_policy,
|
||||
network_sandbox_policy: permissions.network_sandbox_policy,
|
||||
@@ -1117,33 +1044,26 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
}
|
||||
|
||||
impl CoreShellCommandExecutor {
|
||||
fn rewrite_intercepted_program_for_zsh_fork(
|
||||
async fn network_for_program(
|
||||
&self,
|
||||
program: &AbsolutePathBuf,
|
||||
) -> AbsolutePathBuf {
|
||||
let Some(shell_zsh_path) = self.shell_zsh_path.as_ref() else {
|
||||
return program.clone();
|
||||
) -> anyhow::Result<Option<codex_network_proxy::NetworkProxy>> {
|
||||
let Some(session) = self.session.as_ref() else {
|
||||
return Ok(self.network.clone());
|
||||
};
|
||||
if !is_unconfigured_zsh_exec(
|
||||
program,
|
||||
Some(shell_zsh_path.as_path()),
|
||||
self.host_zsh_path.as_deref(),
|
||||
) {
|
||||
return program.clone();
|
||||
}
|
||||
match AbsolutePathBuf::from_absolute_path(shell_zsh_path) {
|
||||
Ok(rewritten) => rewritten,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"failed to rewrite intercepted zsh path {} to configured shell {}: {err}",
|
||||
program.display(),
|
||||
shell_zsh_path.display(),
|
||||
);
|
||||
program.clone()
|
||||
}
|
||||
}
|
||||
let Some(skill) =
|
||||
find_skill_for_program(session.as_ref(), &self.sandbox_policy_cwd, program).await
|
||||
else {
|
||||
return Ok(self.network.clone());
|
||||
};
|
||||
let (scope, managed_network_override) = network_proxy_scope_for_skill(&skill);
|
||||
|
||||
session
|
||||
.get_or_start_network_proxy(scope, &self.sandbox_policy, managed_network_override)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn prepare_sandboxed_exec(
|
||||
&self,
|
||||
params: PrepareSandboxedExecParams<'_>,
|
||||
@@ -1152,6 +1072,7 @@ impl CoreShellCommandExecutor {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy,
|
||||
file_system_sandbox_policy,
|
||||
network_sandbox_policy,
|
||||
@@ -1168,7 +1089,7 @@ impl CoreShellCommandExecutor {
|
||||
network_sandbox_policy,
|
||||
SandboxablePreference::Auto,
|
||||
self.windows_sandbox_level,
|
||||
self.network.is_some(),
|
||||
network.is_some(),
|
||||
);
|
||||
let mut exec_request =
|
||||
sandbox_manager.transform(crate::sandboxing::SandboxTransformRequest {
|
||||
@@ -1190,8 +1111,8 @@ impl CoreShellCommandExecutor {
|
||||
file_system_policy: file_system_sandbox_policy,
|
||||
network_policy: network_sandbox_policy,
|
||||
sandbox,
|
||||
enforce_managed_network: self.network.is_some(),
|
||||
network: self.network.as_ref(),
|
||||
enforce_managed_network: network.is_some(),
|
||||
network: network.as_ref(),
|
||||
sandbox_policy_cwd: &self.sandbox_policy_cwd,
|
||||
#[cfg(target_os = "macos")]
|
||||
macos_seatbelt_profile_extensions,
|
||||
@@ -1212,6 +1133,23 @@ impl CoreShellCommandExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn network_proxy_scope_for_skill(
|
||||
skill: &SkillMetadata,
|
||||
) -> (
|
||||
NetworkProxyScope,
|
||||
Option<crate::skills::model::SkillManagedNetworkOverride>,
|
||||
) {
|
||||
match skill.managed_network_override.clone() {
|
||||
Some(managed_network_override) => (
|
||||
NetworkProxyScope::Skill {
|
||||
path_to_skills_md: skill.path_to_skills_md.clone(),
|
||||
},
|
||||
Some(managed_network_override),
|
||||
),
|
||||
None => (NetworkProxyScope::SessionDefault, None),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct ParsedShellCommand {
|
||||
program: String,
|
||||
@@ -1220,21 +1158,23 @@ struct ParsedShellCommand {
|
||||
}
|
||||
|
||||
fn extract_shell_script(command: &[String]) -> Result<ParsedShellCommand, ToolError> {
|
||||
if let [program, flag, script, ..] = command {
|
||||
if flag == "-c" {
|
||||
return Ok(ParsedShellCommand {
|
||||
program: program.to_owned(),
|
||||
script: script.to_owned(),
|
||||
login: false,
|
||||
});
|
||||
// Commands reaching zsh-fork can be wrapped by environment/sandbox helpers, so
|
||||
// we search for the first `-c`/`-lc` triple anywhere in the argv rather
|
||||
// than assuming it is the first positional form.
|
||||
if let Some((program, script, login)) = command.windows(3).find_map(|parts| match parts {
|
||||
[program, flag, script] if flag == "-c" => {
|
||||
Some((program.to_owned(), script.to_owned(), false))
|
||||
}
|
||||
if flag == "-lc" {
|
||||
return Ok(ParsedShellCommand {
|
||||
program: program.to_owned(),
|
||||
script: script.to_owned(),
|
||||
login: true,
|
||||
});
|
||||
[program, flag, script] if flag == "-lc" => {
|
||||
Some((program.to_owned(), script.to_owned(), true))
|
||||
}
|
||||
_ => None,
|
||||
}) {
|
||||
return Ok(ParsedShellCommand {
|
||||
program,
|
||||
script,
|
||||
login,
|
||||
});
|
||||
}
|
||||
|
||||
Err(ToolError::Rejected(
|
||||
|
||||
@@ -6,10 +6,8 @@ use super::ParsedShellCommand;
|
||||
use super::commands_for_intercepted_exec_policy;
|
||||
use super::evaluate_intercepted_exec_policy;
|
||||
use super::extract_shell_script;
|
||||
use super::is_unconfigured_zsh_exec;
|
||||
use super::join_program_and_argv;
|
||||
use super::map_exec_result;
|
||||
use super::resolve_host_zsh_path;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::config::Constrained;
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -17,6 +15,7 @@ use crate::config::Permissions;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::GranularApprovalConfig;
|
||||
use crate::protocol::ReadOnlyAccess;
|
||||
@@ -52,7 +51,6 @@ use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
#[cfg(target_os = "macos")]
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -101,6 +99,35 @@ fn test_skill_metadata(permission_profile: Option<PermissionProfile>) -> SkillMe
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn network_proxy_scope_for_skill_without_override_reuses_session_default() {
|
||||
let skill = test_skill_metadata(None);
|
||||
|
||||
assert_eq!(
|
||||
super::network_proxy_scope_for_skill(&skill),
|
||||
(NetworkProxyScope::SessionDefault, None),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn network_proxy_scope_for_skill_with_override_uses_skill_scope() {
|
||||
let mut skill = test_skill_metadata(None);
|
||||
skill.managed_network_override = Some(crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: Some(vec!["skill.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.skill.example.com".to_string()]),
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
super::network_proxy_scope_for_skill(&skill),
|
||||
(
|
||||
NetworkProxyScope::Skill {
|
||||
path_to_skills_md: PathBuf::from("/tmp/skill/SKILL.md"),
|
||||
},
|
||||
skill.managed_network_override.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execve_prompt_rejection_uses_skill_approval_for_skill_scripts() {
|
||||
let decision_source = super::DecisionSource::SkillScript {
|
||||
@@ -206,16 +233,39 @@ fn extract_shell_script_preserves_login_flag() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_shell_script_rejects_wrapped_command_prefixes() {
|
||||
let err = extract_shell_script(&[
|
||||
"/usr/bin/env".into(),
|
||||
"CODEX_EXECVE_WRAPPER=1".into(),
|
||||
"/bin/zsh".into(),
|
||||
"-lc".into(),
|
||||
"echo hello".into(),
|
||||
])
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, super::ToolError::Rejected(_)));
|
||||
fn extract_shell_script_supports_wrapped_command_prefixes() {
|
||||
assert_eq!(
|
||||
extract_shell_script(&[
|
||||
"/usr/bin/env".into(),
|
||||
"CODEX_EXECVE_WRAPPER=1".into(),
|
||||
"/bin/zsh".into(),
|
||||
"-lc".into(),
|
||||
"echo hello".into()
|
||||
])
|
||||
.unwrap(),
|
||||
ParsedShellCommand {
|
||||
program: "/bin/zsh".to_string(),
|
||||
script: "echo hello".to_string(),
|
||||
login: true,
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
extract_shell_script(&[
|
||||
"sandbox-exec".into(),
|
||||
"-p".into(),
|
||||
"sandbox_policy".into(),
|
||||
"/bin/zsh".into(),
|
||||
"-c".into(),
|
||||
"pwd".into(),
|
||||
])
|
||||
.unwrap(),
|
||||
ParsedShellCommand {
|
||||
program: "/bin/zsh".to_string(),
|
||||
script: "pwd".to_string(),
|
||||
login: false,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -254,80 +304,6 @@ fn join_program_and_argv_replaces_original_argv_zero() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_unconfigured_zsh_exec_matches_non_configured_zsh_paths() {
|
||||
let program = AbsolutePathBuf::try_from(host_absolute_path(&["bin", "zsh"])).unwrap();
|
||||
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
|
||||
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
|
||||
assert!(is_unconfigured_zsh_exec(
|
||||
&program,
|
||||
Some(configured.as_path()),
|
||||
Some(host.as_path()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_unconfigured_zsh_exec_ignores_non_zsh_or_configured_paths() {
|
||||
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
|
||||
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
|
||||
let configured_program = AbsolutePathBuf::try_from(configured.clone()).unwrap();
|
||||
assert!(!is_unconfigured_zsh_exec(
|
||||
&configured_program,
|
||||
Some(configured.as_path()),
|
||||
Some(host.as_path()),
|
||||
));
|
||||
|
||||
let non_zsh =
|
||||
AbsolutePathBuf::try_from(host_absolute_path(&["usr", "bin", "python3"])).unwrap();
|
||||
assert!(!is_unconfigured_zsh_exec(
|
||||
&non_zsh,
|
||||
Some(configured.as_path()),
|
||||
Some(host.as_path()),
|
||||
));
|
||||
assert!(!is_unconfigured_zsh_exec(
|
||||
&non_zsh,
|
||||
None,
|
||||
Some(host.as_path()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_unconfigured_zsh_exec_does_not_match_non_host_zsh_named_binaries() {
|
||||
let program = AbsolutePathBuf::try_from(host_absolute_path(&["tmp", "repo", "zsh"])).unwrap();
|
||||
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
|
||||
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
|
||||
assert!(!is_unconfigured_zsh_exec(
|
||||
&program,
|
||||
Some(configured.as_path()),
|
||||
Some(host.as_path()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_host_zsh_path_ignores_repo_local_path_shadowing() {
|
||||
let shadow_dir = tempfile::tempdir().expect("create shadow dir");
|
||||
let cwd_dir = tempfile::tempdir().expect("create cwd dir");
|
||||
let fake_zsh = shadow_dir.path().join("zsh");
|
||||
std::fs::write(&fake_zsh, "#!/bin/sh\nexit 0\n").expect("write fake zsh");
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mut permissions = std::fs::metadata(&fake_zsh)
|
||||
.expect("metadata for fake zsh")
|
||||
.permissions();
|
||||
permissions.set_mode(0o755);
|
||||
std::fs::set_permissions(&fake_zsh, permissions).expect("chmod fake zsh");
|
||||
}
|
||||
|
||||
let path_env =
|
||||
std::env::join_paths([shadow_dir.path(), Path::new("/usr/bin"), Path::new("/bin")])
|
||||
.expect("join PATH")
|
||||
.into_string()
|
||||
.expect("PATH should be UTF-8");
|
||||
let resolved = resolve_host_zsh_path(Some(&path_env), cwd_dir.path());
|
||||
assert_ne!(resolved.as_deref(), Some(fake_zsh.as_path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn commands_for_intercepted_exec_policy_parses_plain_shell_wrappers() {
|
||||
let program = AbsolutePathBuf::try_from(host_absolute_path(&["bin", "bash"])).unwrap();
|
||||
@@ -705,6 +681,7 @@ host_executable(name = "git", paths = ["{allowed_git_literal}"])
|
||||
async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions() {
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
@@ -724,8 +701,6 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
|
||||
}),
|
||||
codex_linux_sandbox_exe: None,
|
||||
use_legacy_landlock: false,
|
||||
shell_zsh_path: None,
|
||||
host_zsh_path: None,
|
||||
};
|
||||
|
||||
let prepared = executor
|
||||
@@ -759,6 +734,7 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
|
||||
async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions() {
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
@@ -775,8 +751,6 @@ async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions()
|
||||
macos_seatbelt_profile_extensions: None,
|
||||
codex_linux_sandbox_exe: None,
|
||||
use_legacy_landlock: false,
|
||||
shell_zsh_path: None,
|
||||
host_zsh_path: None,
|
||||
};
|
||||
|
||||
let permissions = Permissions {
|
||||
@@ -835,6 +809,7 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
@@ -854,8 +829,6 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
|
||||
}),
|
||||
codex_linux_sandbox_exe: None,
|
||||
use_legacy_landlock: false,
|
||||
shell_zsh_path: None,
|
||||
host_zsh_path: None,
|
||||
};
|
||||
|
||||
let prepared = executor
|
||||
|
||||
@@ -36,10 +36,9 @@ pub(crate) async fn maybe_prepare_unified_exec(
|
||||
req: &UnifiedExecRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx,
|
||||
shell_command: &[String],
|
||||
exec_request: ExecRequest,
|
||||
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
|
||||
imp::maybe_prepare_unified_exec(req, attempt, ctx, shell_command, exec_request).await
|
||||
imp::maybe_prepare_unified_exec(req, attempt, ctx, exec_request).await
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -47,34 +46,16 @@ mod imp {
|
||||
use super::*;
|
||||
use crate::tools::runtimes::shell::unix_escalation;
|
||||
use crate::unified_exec::SpawnLifecycle;
|
||||
use codex_shell_escalation::ESCALATE_SOCKET_ENV_VAR;
|
||||
use codex_shell_escalation::EscalationSession;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ZshForkSpawnLifecycle {
|
||||
escalation_session: Option<EscalationSession>,
|
||||
escalation_session: EscalationSession,
|
||||
}
|
||||
|
||||
impl SpawnLifecycle for ZshForkSpawnLifecycle {
|
||||
fn inherited_fds(&self) -> Vec<i32> {
|
||||
self.escalation_session
|
||||
.as_ref()
|
||||
.and_then(|escalation_session| {
|
||||
escalation_session.env().get(ESCALATE_SOCKET_ENV_VAR)
|
||||
})
|
||||
.and_then(|fd| fd.parse().ok())
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn after_spawn(&mut self) {
|
||||
if let Some(escalation_session) = self.escalation_session.as_ref() {
|
||||
escalation_session.close_client_socket();
|
||||
}
|
||||
}
|
||||
|
||||
fn after_exit(&mut self) {
|
||||
self.escalation_session = None;
|
||||
self.escalation_session.close_client_socket();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,17 +72,10 @@ mod imp {
|
||||
req: &UnifiedExecRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx,
|
||||
shell_command: &[String],
|
||||
exec_request: ExecRequest,
|
||||
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
|
||||
let Some(prepared) = unix_escalation::prepare_unified_exec_zsh_fork(
|
||||
req,
|
||||
attempt,
|
||||
ctx,
|
||||
shell_command,
|
||||
exec_request,
|
||||
)
|
||||
.await?
|
||||
let Some(prepared) =
|
||||
unix_escalation::prepare_unified_exec_zsh_fork(req, attempt, ctx, exec_request).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
@@ -109,7 +83,7 @@ mod imp {
|
||||
Ok(Some(PreparedUnifiedExecSpawn {
|
||||
exec_request: prepared.exec_request,
|
||||
spawn_lifecycle: Box::new(ZshForkSpawnLifecycle {
|
||||
escalation_session: Some(prepared.escalation_session),
|
||||
escalation_session: prepared.escalation_session,
|
||||
}),
|
||||
}))
|
||||
}
|
||||
@@ -133,10 +107,9 @@ mod imp {
|
||||
req: &UnifiedExecRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx,
|
||||
shell_command: &[String],
|
||||
exec_request: ExecRequest,
|
||||
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
|
||||
let _ = (req, attempt, ctx, shell_command, exec_request);
|
||||
let _ = (req, attempt, ctx, exec_request);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,11 +222,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
let exec_env = attempt
|
||||
.env_for(spec, req.network.as_ref())
|
||||
.map_err(|err| ToolError::Codex(err.into()))?;
|
||||
match zsh_fork_backend::maybe_prepare_unified_exec(
|
||||
req, attempt, ctx, &command, exec_env,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
match zsh_fork_backend::maybe_prepare_unified_exec(req, attempt, ctx, exec_env).await? {
|
||||
Some(prepared) => {
|
||||
return self
|
||||
.manager
|
||||
|
||||
@@ -311,6 +311,8 @@ impl ToolsConfig {
|
||||
);
|
||||
let shell_type = if !features.enabled(Feature::ShellTool) {
|
||||
ConfigShellToolType::Disabled
|
||||
} else if features.enabled(Feature::ShellZshFork) {
|
||||
ConfigShellToolType::ShellCommand
|
||||
} else if features.enabled(Feature::UnifiedExec) && unified_exec_allowed {
|
||||
// If ConPTY not supported (for old Windows versions), fallback on ShellCommand.
|
||||
if codex_utils_pty::conpty_supported() {
|
||||
@@ -321,8 +323,6 @@ impl ToolsConfig {
|
||||
} else if model_info.shell_type == ConfigShellToolType::UnifiedExec && !unified_exec_allowed
|
||||
{
|
||||
ConfigShellToolType::ShellCommand
|
||||
} else if features.enabled(Feature::ShellZshFork) {
|
||||
ConfigShellToolType::ShellCommand
|
||||
} else {
|
||||
model_info.shell_type
|
||||
};
|
||||
@@ -578,7 +578,6 @@ fn create_approval_parameters(
|
||||
fn create_exec_command_tool(
|
||||
allow_login_shell: bool,
|
||||
exec_permission_approvals_enabled: bool,
|
||||
unified_exec_backend: UnifiedExecBackendConfig,
|
||||
) -> ToolSpec {
|
||||
let mut properties = BTreeMap::from([
|
||||
(
|
||||
@@ -596,6 +595,12 @@ fn create_exec_command_tool(
|
||||
),
|
||||
},
|
||||
),
|
||||
(
|
||||
"shell".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Shell binary to launch. Defaults to the user's default shell.".to_string()),
|
||||
},
|
||||
),
|
||||
(
|
||||
"tty".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
@@ -623,16 +628,6 @@ fn create_exec_command_tool(
|
||||
},
|
||||
),
|
||||
]);
|
||||
if unified_exec_backend != UnifiedExecBackendConfig::ZshFork {
|
||||
properties.insert(
|
||||
"shell".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Shell binary to launch. Defaults to the user's default shell.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
}
|
||||
if allow_login_shell {
|
||||
properties.insert(
|
||||
"login".to_string(),
|
||||
@@ -2527,7 +2522,6 @@ pub(crate) fn build_specs_with_discoverable_tools(
|
||||
create_exec_command_tool(
|
||||
config.allow_login_shell,
|
||||
exec_permission_approvals_enabled,
|
||||
config.unified_exec_backend,
|
||||
),
|
||||
true,
|
||||
config.code_mode_enabled,
|
||||
|
||||
@@ -444,7 +444,7 @@ fn test_full_toolset_specs_for_gpt5_codex_unified_exec_web_search() {
|
||||
// Build expected from the same helpers used by the builder.
|
||||
let mut expected: BTreeMap<String, ToolSpec> = BTreeMap::from([]);
|
||||
for spec in [
|
||||
create_exec_command_tool(true, false, UnifiedExecBackendConfig::Direct),
|
||||
create_exec_command_tool(true, false),
|
||||
create_write_stdin_tool(),
|
||||
PLAN_TOOL.clone(),
|
||||
create_request_user_input_tool(CollaborationModesConfig::default()),
|
||||
@@ -1350,7 +1350,7 @@ fn test_build_specs_default_shell_present() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_zsh_fork_uses_unified_exec_when_enabled() {
|
||||
fn shell_zsh_fork_prefers_shell_command_over_unified_exec() {
|
||||
let config = test_config();
|
||||
let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config);
|
||||
let mut features = Features::with_defaults();
|
||||
@@ -1368,7 +1368,7 @@ fn shell_zsh_fork_uses_unified_exec_when_enabled() {
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
});
|
||||
|
||||
assert_eq!(tools_config.shell_type, ConfigShellToolType::UnifiedExec);
|
||||
assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand);
|
||||
assert_eq!(
|
||||
tools_config.shell_command_backend,
|
||||
ShellCommandBackendConfig::ZshFork
|
||||
@@ -1377,19 +1377,6 @@ fn shell_zsh_fork_uses_unified_exec_when_enabled() {
|
||||
tools_config.unified_exec_backend,
|
||||
UnifiedExecBackendConfig::ZshFork
|
||||
);
|
||||
|
||||
let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build();
|
||||
let exec_spec = find_tool(&tools, "exec_command");
|
||||
let ToolSpec::Function(exec_tool) = &exec_spec.spec else {
|
||||
panic!("exec_command should be a function tool spec");
|
||||
};
|
||||
let JsonSchema::Object { properties, .. } = &exec_tool.parameters else {
|
||||
panic!("exec_command parameters should be an object schema");
|
||||
};
|
||||
assert!(
|
||||
!properties.contains_key("shell"),
|
||||
"exec_command should omit `shell` when zsh-fork backend forces the configured shell",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#![allow(clippy::module_inception)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -27,23 +26,7 @@ use super::UnifiedExecError;
|
||||
use super::head_tail_buffer::HeadTailBuffer;
|
||||
|
||||
pub(crate) trait SpawnLifecycle: std::fmt::Debug + Send + Sync {
|
||||
/// Returns file descriptors that must stay open across the child `exec()`.
|
||||
///
|
||||
/// The returned descriptors must already be valid in the parent process and
|
||||
/// stay valid until `after_spawn()` runs, which is the first point where
|
||||
/// the parent may release its copies.
|
||||
fn inherited_fds(&self) -> Vec<i32> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn after_spawn(&mut self) {}
|
||||
|
||||
/// Releases resources that needed to stay alive until the child process was
|
||||
/// fully launched or until the session was torn down.
|
||||
///
|
||||
/// This hook must tolerate being called during normal process exit as well
|
||||
/// as early termination paths, and it is guaranteed to run at most once.
|
||||
fn after_exit(&mut self) {}
|
||||
}
|
||||
|
||||
pub(crate) type SpawnLifecycleHandle = Box<dyn SpawnLifecycle>;
|
||||
@@ -74,8 +57,7 @@ pub(crate) struct UnifiedExecProcess {
|
||||
output_drained: Arc<Notify>,
|
||||
output_task: JoinHandle<()>,
|
||||
sandbox_type: SandboxType,
|
||||
_spawn_lifecycle: Arc<StdMutex<SpawnLifecycleHandle>>,
|
||||
spawn_lifecycle_released: Arc<AtomicBool>,
|
||||
_spawn_lifecycle: SpawnLifecycleHandle,
|
||||
}
|
||||
|
||||
impl UnifiedExecProcess {
|
||||
@@ -83,7 +65,7 @@ impl UnifiedExecProcess {
|
||||
process_handle: ExecCommandSession,
|
||||
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
sandbox_type: SandboxType,
|
||||
spawn_lifecycle: Arc<StdMutex<SpawnLifecycleHandle>>,
|
||||
spawn_lifecycle: SpawnLifecycleHandle,
|
||||
) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(HeadTailBuffer::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
@@ -128,19 +110,6 @@ impl UnifiedExecProcess {
|
||||
output_task,
|
||||
sandbox_type,
|
||||
_spawn_lifecycle: spawn_lifecycle,
|
||||
spawn_lifecycle_released: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
fn release_spawn_lifecycle(&self) {
|
||||
if self
|
||||
.spawn_lifecycle_released
|
||||
.swap(true, std::sync::atomic::Ordering::AcqRel)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if let Ok(mut lifecycle) = self._spawn_lifecycle.lock() {
|
||||
lifecycle.after_exit();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +148,6 @@ impl UnifiedExecProcess {
|
||||
}
|
||||
|
||||
pub(super) fn terminate(&self) {
|
||||
self.release_spawn_lifecycle();
|
||||
self.output_closed.store(true, Ordering::Release);
|
||||
self.output_closed_notify.notify_waiters();
|
||||
self.process_handle.terminate();
|
||||
@@ -255,19 +223,12 @@ impl UnifiedExecProcess {
|
||||
mut exit_rx,
|
||||
} = spawned;
|
||||
let output_rx = codex_utils_pty::combine_output_receivers(stdout_rx, stderr_rx);
|
||||
let spawn_lifecycle = Arc::new(StdMutex::new(spawn_lifecycle));
|
||||
let managed = Self::new(
|
||||
process_handle,
|
||||
output_rx,
|
||||
sandbox_type,
|
||||
Arc::clone(&spawn_lifecycle),
|
||||
);
|
||||
let managed = Self::new(process_handle, output_rx, sandbox_type, spawn_lifecycle);
|
||||
|
||||
let exit_ready = matches!(exit_rx.try_recv(), Ok(_) | Err(TryRecvError::Closed));
|
||||
|
||||
if exit_ready {
|
||||
managed.signal_exit();
|
||||
managed.release_spawn_lifecycle();
|
||||
managed.check_for_sandbox_denial().await?;
|
||||
return Ok(managed);
|
||||
}
|
||||
@@ -277,22 +238,14 @@ impl UnifiedExecProcess {
|
||||
.is_ok()
|
||||
{
|
||||
managed.signal_exit();
|
||||
managed.release_spawn_lifecycle();
|
||||
managed.check_for_sandbox_denial().await?;
|
||||
return Ok(managed);
|
||||
}
|
||||
|
||||
tokio::spawn({
|
||||
let cancellation_token = managed.cancellation_token.clone();
|
||||
let spawn_lifecycle = Arc::clone(&spawn_lifecycle);
|
||||
let spawn_lifecycle_released = Arc::clone(&managed.spawn_lifecycle_released);
|
||||
async move {
|
||||
let _ = exit_rx.await;
|
||||
if !spawn_lifecycle_released.swap(true, Ordering::AcqRel)
|
||||
&& let Ok(mut lifecycle) = spawn_lifecycle.lock()
|
||||
{
|
||||
lifecycle.after_exit();
|
||||
}
|
||||
cancellation_token.cancel();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -537,27 +537,24 @@ impl UnifiedExecProcessManager {
|
||||
.command
|
||||
.split_first()
|
||||
.ok_or(UnifiedExecError::MissingCommandLine)?;
|
||||
let inherited_fds = spawn_lifecycle.inherited_fds();
|
||||
|
||||
let spawn_result = if tty {
|
||||
codex_utils_pty::pty::spawn_process_with_inherited_fds(
|
||||
codex_utils_pty::pty::spawn_process(
|
||||
program,
|
||||
args,
|
||||
env.cwd.as_path(),
|
||||
&env.env,
|
||||
&env.arg0,
|
||||
codex_utils_pty::TerminalSize::default(),
|
||||
&inherited_fds,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
codex_utils_pty::pipe::spawn_process_no_stdin_with_inherited_fds(
|
||||
codex_utils_pty::pipe::spawn_process_no_stdin(
|
||||
program,
|
||||
args,
|
||||
env.cwd.as_path(),
|
||||
&env.env,
|
||||
&env.arg0,
|
||||
&inherited_fds,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
|
||||
@@ -25,7 +24,6 @@ pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
|
||||
pub fn process_is_alive(pid: &str) -> anyhow::Result<bool> {
|
||||
let status = std::process::Command::new("kill")
|
||||
.args(["-0", pid])
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.context("failed to probe process liveness with kill -0")?;
|
||||
Ok(status.success())
|
||||
|
||||
@@ -18,10 +18,6 @@ pub struct ZshForkRuntime {
|
||||
}
|
||||
|
||||
impl ZshForkRuntime {
|
||||
pub fn zsh_path(&self) -> &Path {
|
||||
&self.zsh_path
|
||||
}
|
||||
|
||||
fn apply_to_config(
|
||||
&self,
|
||||
config: &mut Config,
|
||||
@@ -95,29 +91,6 @@ where
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
pub async fn build_unified_exec_zsh_fork_test<F>(
|
||||
server: &wiremock::MockServer,
|
||||
runtime: ZshForkRuntime,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
pre_build_hook: F,
|
||||
) -> Result<TestCodex>
|
||||
where
|
||||
F: FnOnce(&Path) + Send + 'static,
|
||||
{
|
||||
let mut builder = test_codex()
|
||||
.with_pre_build_hook(pre_build_hook)
|
||||
.with_config(move |config| {
|
||||
runtime.apply_to_config(config, approval_policy, sandbox_policy);
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
config
|
||||
.features
|
||||
.enable(Feature::UnifiedExec)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
fn find_test_zsh_path() -> Result<Option<PathBuf>> {
|
||||
let repo_root = codex_utils_cargo_bin::repo_root()?;
|
||||
let dotslash_zsh = repo_root.join("codex-rs/app-server/tests/suite/zsh");
|
||||
|
||||
@@ -35,7 +35,6 @@ use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
|
||||
use core_test_support::zsh_fork::build_zsh_fork_test;
|
||||
use core_test_support::zsh_fork::restrictive_workspace_write_policy;
|
||||
use core_test_support::zsh_fork::zsh_fork_runtime;
|
||||
@@ -124,7 +123,7 @@ impl ActionKind {
|
||||
let (path, _) = target.resolve_for_patch(test);
|
||||
let _ = fs::remove_file(&path);
|
||||
let command = format!("printf {content:?} > {path:?} && cat {path:?}");
|
||||
let event = shell_event(call_id, &command, 5_000, sandbox_permissions)?;
|
||||
let event = shell_event(call_id, &command, 1_000, sandbox_permissions)?;
|
||||
Ok((event, Some(command)))
|
||||
}
|
||||
ActionKind::FetchUrl {
|
||||
@@ -1986,158 +1985,6 @@ async fn approving_execpolicy_amendment_persists_policy_and_skips_future_prompts
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[cfg(unix)]
|
||||
async fn unified_exec_zsh_fork_execpolicy_amendment_skips_later_subcommands() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork execpolicy amendment test")? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let approval_policy = AskForApproval::UnlessTrusted;
|
||||
let sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
let server = start_mock_server().await;
|
||||
let test = build_unified_exec_zsh_fork_test(
|
||||
&server,
|
||||
runtime,
|
||||
approval_policy,
|
||||
sandbox_policy.clone(),
|
||||
|_| {},
|
||||
)
|
||||
.await?;
|
||||
let allow_prefix_path = test.cwd.path().join("allow-prefix-zsh-fork.txt");
|
||||
let _ = fs::remove_file(&allow_prefix_path);
|
||||
|
||||
let call_id = "allow-prefix-zsh-fork";
|
||||
let command = "touch allow-prefix-zsh-fork.txt && touch allow-prefix-zsh-fork.txt";
|
||||
let event = exec_command_event(
|
||||
call_id,
|
||||
command,
|
||||
Some(1_000),
|
||||
SandboxPermissions::UseDefault,
|
||||
None,
|
||||
)?;
|
||||
let _ = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-zsh-fork-allow-prefix-1"),
|
||||
event,
|
||||
ev_completed("resp-zsh-fork-allow-prefix-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let results = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-zsh-fork-allow-prefix-1", "done"),
|
||||
ev_completed("resp-zsh-fork-allow-prefix-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"allow-prefix-zsh-fork",
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let expected_execpolicy_amendment = ExecPolicyAmendment::new(vec![
|
||||
"touch".to_string(),
|
||||
"allow-prefix-zsh-fork.txt".to_string(),
|
||||
]);
|
||||
let mut saw_parent_approval = false;
|
||||
let mut saw_subcommand_approval = false;
|
||||
loop {
|
||||
let event = wait_for_event(&test.codex, |event| {
|
||||
matches!(
|
||||
event,
|
||||
EventMsg::ExecApprovalRequest(_) | EventMsg::TurnComplete(_)
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
match event {
|
||||
EventMsg::TurnComplete(_) => break,
|
||||
EventMsg::ExecApprovalRequest(approval) => {
|
||||
let command_parts = approval.command.clone();
|
||||
let last_arg = command_parts.last().map(String::as_str).unwrap_or_default();
|
||||
if last_arg == command {
|
||||
assert!(
|
||||
!saw_parent_approval,
|
||||
"unexpected duplicate parent approval: {command_parts:?}"
|
||||
);
|
||||
saw_parent_approval = true;
|
||||
test.codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: approval.effective_approval_id(),
|
||||
turn_id: None,
|
||||
decision: ReviewDecision::Approved,
|
||||
})
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_touch_subcommand = command_parts
|
||||
.iter()
|
||||
.any(|part| part == "allow-prefix-zsh-fork.txt")
|
||||
&& command_parts
|
||||
.first()
|
||||
.is_some_and(|part| part.ends_with("/touch") || part == "touch");
|
||||
if is_touch_subcommand {
|
||||
assert!(
|
||||
!saw_subcommand_approval,
|
||||
"execpolicy amendment should suppress later matching subcommand approvals: {command_parts:?}"
|
||||
);
|
||||
saw_subcommand_approval = true;
|
||||
assert_eq!(
|
||||
approval.proposed_execpolicy_amendment,
|
||||
Some(expected_execpolicy_amendment.clone())
|
||||
);
|
||||
test.codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: approval.effective_approval_id(),
|
||||
turn_id: None,
|
||||
decision: ReviewDecision::ApprovedExecpolicyAmendment {
|
||||
proposed_execpolicy_amendment: expected_execpolicy_amendment
|
||||
.clone(),
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
test.codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: approval.effective_approval_id(),
|
||||
turn_id: None,
|
||||
decision: ReviewDecision::Approved,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(saw_parent_approval, "expected parent unified-exec approval");
|
||||
assert!(
|
||||
saw_subcommand_approval,
|
||||
"expected at least one intercepted touch approval"
|
||||
);
|
||||
|
||||
let result = parse_result(&results.single_request().function_call_output(call_id));
|
||||
assert_eq!(result.exit_code.unwrap_or(0), 0);
|
||||
assert!(
|
||||
allow_prefix_path.exists(),
|
||||
"expected touch command to complete after approving the first intercepted subcommand; output: {}",
|
||||
result.stdout
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[cfg(unix)]
|
||||
async fn matched_prefix_rule_runs_unsandboxed_under_zsh_fork() -> Result<()> {
|
||||
|
||||
@@ -20,7 +20,6 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
|
||||
use core_test_support::zsh_fork::build_zsh_fork_test;
|
||||
use core_test_support::zsh_fork::restrictive_workspace_write_policy;
|
||||
use core_test_support::zsh_fork::zsh_fork_runtime;
|
||||
@@ -51,13 +50,6 @@ fn shell_command_arguments(command: &str) -> Result<String> {
|
||||
}))?)
|
||||
}
|
||||
|
||||
fn exec_command_arguments(command: &str) -> Result<String> {
|
||||
Ok(serde_json::to_string(&json!({
|
||||
"cmd": command,
|
||||
"yield_time_ms": 500,
|
||||
}))?)
|
||||
}
|
||||
|
||||
async fn submit_turn_with_policies(
|
||||
test: &TestCodex,
|
||||
prompt: &str,
|
||||
@@ -97,38 +89,6 @@ echo 'zsh-fork-stderr' >&2
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn write_repo_skill_with_shell_script_contents(
|
||||
repo_root: &Path,
|
||||
name: &str,
|
||||
script_name: &str,
|
||||
script_contents: &str,
|
||||
) -> Result<PathBuf> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let skill_dir = repo_root.join(".agents").join("skills").join(name);
|
||||
let scripts_dir = skill_dir.join("scripts");
|
||||
fs::create_dir_all(&scripts_dir)?;
|
||||
fs::write(repo_root.join(".git"), "gitdir: here")?;
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.md"),
|
||||
format!(
|
||||
r#"---
|
||||
name: {name}
|
||||
description: {name} skill
|
||||
---
|
||||
"#
|
||||
),
|
||||
)?;
|
||||
|
||||
let script_path = scripts_dir.join(script_name);
|
||||
fs::write(&script_path, script_contents)?;
|
||||
let mut permissions = fs::metadata(&script_path)?.permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions)?;
|
||||
Ok(script_path)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn write_skill_with_shell_script_contents(
|
||||
home: &Path,
|
||||
@@ -581,168 +541,6 @@ permissions:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_zsh_fork_prompts_for_skill_script_execution() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork skill prompt test")? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let tool_call_id = "uexec-zsh-fork-skill-call";
|
||||
let test = build_unified_exec_zsh_fork_test(
|
||||
&server,
|
||||
runtime,
|
||||
AskForApproval::OnRequest,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
|home| {
|
||||
write_skill_with_shell_script(home, "mbolin-test-skill", "hello-mbolin.sh").unwrap();
|
||||
write_skill_metadata(
|
||||
home,
|
||||
"mbolin-test-skill",
|
||||
r#"
|
||||
permissions:
|
||||
file_system:
|
||||
read:
|
||||
- "./data"
|
||||
write:
|
||||
- "./output"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (script_path_str, command) = skill_script_command(&test, "hello-mbolin.sh")?;
|
||||
let arguments = exec_command_arguments(&command)?;
|
||||
let mocks =
|
||||
mount_function_call_agent_response(&server, tool_call_id, &arguments, "exec_command").await;
|
||||
|
||||
submit_turn_with_policies(
|
||||
&test,
|
||||
"use $mbolin-test-skill",
|
||||
AskForApproval::OnRequest,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let approval = wait_for_exec_approval_request(&test)
|
||||
.await
|
||||
.expect("expected exec approval request before completion");
|
||||
assert_eq!(approval.call_id, tool_call_id);
|
||||
assert_eq!(approval.command, vec![script_path_str.clone()]);
|
||||
assert_eq!(
|
||||
approval.available_decisions,
|
||||
Some(vec![
|
||||
ReviewDecision::Approved,
|
||||
ReviewDecision::ApprovedForSession,
|
||||
ReviewDecision::Abort,
|
||||
])
|
||||
);
|
||||
assert_eq!(
|
||||
approval.additional_permissions,
|
||||
Some(PermissionProfile {
|
||||
file_system: Some(FileSystemPermissions {
|
||||
read: Some(vec![absolute_path(
|
||||
&test.codex_home_path().join("skills/mbolin-test-skill/data"),
|
||||
)]),
|
||||
write: Some(vec![absolute_path(
|
||||
&test
|
||||
.codex_home_path()
|
||||
.join("skills/mbolin-test-skill/output"),
|
||||
)]),
|
||||
}),
|
||||
..Default::default()
|
||||
})
|
||||
);
|
||||
|
||||
test.codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: approval.effective_approval_id(),
|
||||
turn_id: None,
|
||||
decision: ReviewDecision::Denied,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_turn_complete(&test).await;
|
||||
|
||||
let call_output = mocks
|
||||
.completion
|
||||
.single_request()
|
||||
.function_call_output(tool_call_id);
|
||||
let output = call_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
output.contains("Execution denied: User denied execution"),
|
||||
"expected rejection marker in function_call_output: {output:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_zsh_fork_keeps_skill_loading_pinned_to_turn_cwd() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork turn cwd skill test")? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let tool_call_id = "uexec-zsh-fork-repo-skill-call";
|
||||
let test = build_unified_exec_zsh_fork_test(
|
||||
&server,
|
||||
runtime,
|
||||
AskForApproval::OnRequest,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
|_| {},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let repo_root = test.cwd_path().join("repo");
|
||||
let script_path = write_repo_skill_with_shell_script_contents(
|
||||
&repo_root,
|
||||
"repo-skill",
|
||||
"repo-skill.sh",
|
||||
"#!/bin/sh\necho 'repo-skill-output'\n",
|
||||
)?;
|
||||
let script_path_quoted = shlex::try_join([script_path.to_string_lossy().as_ref()])?;
|
||||
let repo_root_quoted = shlex::try_join([repo_root.to_string_lossy().as_ref()])?;
|
||||
let command = format!("cd {repo_root_quoted} && {script_path_quoted}");
|
||||
let arguments = exec_command_arguments(&command)?;
|
||||
let mocks =
|
||||
mount_function_call_agent_response(&server, tool_call_id, &arguments, "exec_command").await;
|
||||
|
||||
submit_turn_with_policies(
|
||||
&test,
|
||||
"run the repo skill after changing directories",
|
||||
AskForApproval::OnRequest,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let approval = wait_for_exec_approval_request(&test).await;
|
||||
assert!(
|
||||
approval.is_none(),
|
||||
"changing directories inside unified exec should not load repo-local skills from the shell cwd",
|
||||
);
|
||||
|
||||
let call_output = mocks
|
||||
.completion
|
||||
.single_request()
|
||||
.function_call_output(tool_call_id);
|
||||
let output = call_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
output.contains("repo-skill-output"),
|
||||
"expected repo skill script to run without skill-specific approval when only the shell cwd changes: {output:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Permissionless skills should inherit the turn sandbox without prompting.
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs;
|
||||
#[cfg(unix)]
|
||||
use std::path::PathBuf;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -34,10 +32,6 @@ use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
#[cfg(unix)]
|
||||
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
|
||||
#[cfg(unix)]
|
||||
use core_test_support::zsh_fork::zsh_fork_runtime;
|
||||
use pretty_assertions::assert_eq;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
@@ -161,27 +155,6 @@ fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, ParsedUnifie
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn process_text_binary_path(pid: &str) -> Result<PathBuf> {
|
||||
let output = std::process::Command::new("lsof")
|
||||
.args(["-Fn", "-a", "-p", pid, "-d", "txt"])
|
||||
.output()
|
||||
.with_context(|| format!("failed to inspect process {pid} executable mapping with lsof"))?;
|
||||
if !output.status.success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"lsof failed for pid {pid} with status {:?}",
|
||||
output.status.code()
|
||||
));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).context("lsof output was not UTF-8")?;
|
||||
let path = stdout
|
||||
.lines()
|
||||
.find_map(|line| line.strip_prefix('n'))
|
||||
.ok_or_else(|| anyhow::anyhow!("lsof did not report a text binary path for pid {pid}"))?;
|
||||
Ok(PathBuf::from(path))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_intercepts_apply_patch_exec_command() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -1350,6 +1323,7 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> {
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let metadata = outputs
|
||||
.get(call_id)
|
||||
@@ -1471,6 +1445,7 @@ async fn unified_exec_defaults_to_pipe() -> Result<()> {
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let output = outputs
|
||||
.get(call_id)
|
||||
@@ -1564,6 +1539,7 @@ async fn unified_exec_can_enable_tty() -> Result<()> {
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let output = outputs
|
||||
.get(call_id)
|
||||
@@ -1648,6 +1624,7 @@ async fn unified_exec_respects_early_exit_notifications() -> Result<()> {
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let output = outputs
|
||||
.get(call_id)
|
||||
@@ -1846,350 +1823,6 @@ async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[cfg(unix)]
|
||||
async fn unified_exec_zsh_fork_keeps_python_repl_attached_to_zsh_session() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork tty session test")? else {
|
||||
return Ok(());
|
||||
};
|
||||
let configured_zsh_path =
|
||||
fs::canonicalize(runtime.zsh_path()).unwrap_or_else(|_| runtime.zsh_path().to_path_buf());
|
||||
|
||||
let python = match which("python3") {
|
||||
Ok(path) => path,
|
||||
Err(_) => {
|
||||
eprintln!("python3 not found in PATH, skipping zsh-fork python repl test.");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_unified_exec_zsh_fork_test(
|
||||
&server,
|
||||
runtime,
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
|_| {},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let start_call_id = "uexec-zsh-fork-python-start";
|
||||
let send_call_id = "uexec-zsh-fork-python-pid";
|
||||
let exit_call_id = "uexec-zsh-fork-python-exit";
|
||||
|
||||
let start_command = format!("{}; :", python.display());
|
||||
let start_args = serde_json::json!({
|
||||
"cmd": start_command,
|
||||
"yield_time_ms": 500,
|
||||
"tty": true,
|
||||
});
|
||||
let send_args = serde_json::json!({
|
||||
"chars": "import os; print('CODEX_PY_PID=' + str(os.getpid()))\r\n",
|
||||
"session_id": 1000,
|
||||
"yield_time_ms": 500,
|
||||
});
|
||||
let exit_args = serde_json::json!({
|
||||
"chars": "import sys; sys.exit(0)\r\n",
|
||||
"session_id": 1000,
|
||||
"yield_time_ms": 500,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
start_call_id,
|
||||
"exec_command",
|
||||
&serde_json::to_string(&start_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
send_call_id,
|
||||
"write_stdin",
|
||||
&serde_json::to_string(&send_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "python is running"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-4"),
|
||||
ev_function_call(
|
||||
exit_call_id,
|
||||
"write_stdin",
|
||||
&serde_json::to_string(&exit_args)?,
|
||||
),
|
||||
ev_completed("resp-4"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-2", "all done"),
|
||||
ev_completed("resp-5"),
|
||||
]),
|
||||
];
|
||||
let request_log = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "test unified exec zsh-fork tty behavior".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd_path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
|
||||
model: test.session_configured.model.clone(),
|
||||
effort: None,
|
||||
summary: None,
|
||||
service_tier: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let requests = request_log.requests();
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
let bodies = requests
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let start_output = outputs
|
||||
.get(start_call_id)
|
||||
.expect("missing start output for exec_command");
|
||||
let process_id = start_output
|
||||
.process_id
|
||||
.clone()
|
||||
.expect("expected process id from exec_command");
|
||||
assert!(
|
||||
start_output.exit_code.is_none(),
|
||||
"initial exec_command should leave the PTY session running"
|
||||
);
|
||||
|
||||
let send_output = outputs
|
||||
.get(send_call_id)
|
||||
.expect("missing write_stdin output");
|
||||
let normalized = send_output.output.replace("\r\n", "\n");
|
||||
let python_pid = Regex::new(r"CODEX_PY_PID=(\d+)")
|
||||
.expect("valid python pid marker regex")
|
||||
.captures(&normalized)
|
||||
.and_then(|captures| captures.get(1))
|
||||
.map(|value| value.as_str().to_string())
|
||||
.with_context(|| format!("missing python pid in output {normalized:?}"))?;
|
||||
assert!(
|
||||
process_is_alive(&python_pid)?,
|
||||
"python process should still be alive after printing its pid, got output {normalized:?}"
|
||||
);
|
||||
assert_eq!(send_output.process_id.as_deref(), Some(process_id.as_str()));
|
||||
assert!(
|
||||
send_output.exit_code.is_none(),
|
||||
"write_stdin should not report an exit code while the process is still running"
|
||||
);
|
||||
|
||||
let zsh_pid = std::process::Command::new("ps")
|
||||
.args(["-o", "ppid=", "-p", &python_pid])
|
||||
.output()
|
||||
.context("failed to look up python parent pid")?;
|
||||
let zsh_pid = String::from_utf8(zsh_pid.stdout)
|
||||
.context("python parent pid output is not UTF-8")?
|
||||
.trim()
|
||||
.to_string();
|
||||
assert!(
|
||||
!zsh_pid.is_empty(),
|
||||
"expected python parent pid to identify the zsh session"
|
||||
);
|
||||
assert!(
|
||||
process_is_alive(&zsh_pid)?,
|
||||
"expected zsh parent process {zsh_pid} to still be alive"
|
||||
);
|
||||
|
||||
let zsh_command = std::process::Command::new("ps")
|
||||
.args(["-o", "command=", "-p", &zsh_pid])
|
||||
.output()
|
||||
.context("failed to look up zsh parent command")?;
|
||||
let zsh_command =
|
||||
String::from_utf8(zsh_command.stdout).context("zsh parent command output is not UTF-8")?;
|
||||
assert!(
|
||||
zsh_command.contains("zsh"),
|
||||
"expected python parent command to be zsh, got {zsh_command:?}"
|
||||
);
|
||||
let zsh_text_binary = process_text_binary_path(&zsh_pid)?;
|
||||
let zsh_text_binary = fs::canonicalize(&zsh_text_binary).unwrap_or(zsh_text_binary);
|
||||
assert_eq!(
|
||||
zsh_text_binary, configured_zsh_path,
|
||||
"python parent shell should run with configured zsh-fork binary, got {:?} ({zsh_command:?})",
|
||||
zsh_text_binary,
|
||||
);
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "shut down the python repl".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd_path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
|
||||
model: test.session_configured.model.clone(),
|
||||
effort: None,
|
||||
summary: None,
|
||||
service_tier: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let requests = request_log.requests();
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
let bodies = requests
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let exit_output = outputs
|
||||
.get(exit_call_id)
|
||||
.expect("missing exit output after requesting python shutdown");
|
||||
assert!(
|
||||
exit_output.exit_code.is_none() || exit_output.exit_code == Some(0),
|
||||
"exit request should either leave cleanup to the background watcher or report success directly, got {exit_output:?}"
|
||||
);
|
||||
wait_for_process_exit(&python_pid).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[cfg(unix)]
|
||||
async fn unified_exec_zsh_fork_rewrites_nested_zsh_exec_to_configured_binary() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork nested zsh rewrite test")? else {
|
||||
return Ok(());
|
||||
};
|
||||
let configured_zsh_path =
|
||||
fs::canonicalize(runtime.zsh_path()).unwrap_or_else(|_| runtime.zsh_path().to_path_buf());
|
||||
let host_zsh = match which("zsh") {
|
||||
Ok(path) => path,
|
||||
Err(_) => {
|
||||
eprintln!("zsh not found in PATH, skipping nested zsh rewrite test.");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_unified_exec_zsh_fork_test(
|
||||
&server,
|
||||
runtime,
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::new_workspace_write_policy(),
|
||||
|_| {},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let start_call_id = "uexec-zsh-fork-nested-start";
|
||||
let nested_command = format!(
|
||||
"exec {} -lc 'echo CODEX_NESTED_ZSH_PID=$$; sleep 3; :'",
|
||||
host_zsh.display(),
|
||||
);
|
||||
let start_args = serde_json::json!({
|
||||
"cmd": nested_command,
|
||||
"yield_time_ms": 500,
|
||||
"tty": true,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-nested-1"),
|
||||
ev_function_call(
|
||||
start_call_id,
|
||||
"exec_command",
|
||||
&serde_json::to_string(&start_args)?,
|
||||
),
|
||||
ev_completed("resp-nested-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-nested-1", "done"),
|
||||
ev_completed("resp-nested-2"),
|
||||
]),
|
||||
];
|
||||
let request_log = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "test nested zsh rewrite behavior".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd_path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
|
||||
model: test.session_configured.model.clone(),
|
||||
effort: None,
|
||||
summary: None,
|
||||
service_tier: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let requests = request_log.requests();
|
||||
let bodies = requests
|
||||
.into_iter()
|
||||
.map(|request| request.body_json())
|
||||
.collect::<Vec<_>>();
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let start_output = outputs
|
||||
.get(start_call_id)
|
||||
.expect("missing start output for nested zsh exec_command");
|
||||
let normalized = start_output.output.replace("\r\n", "\n");
|
||||
let nested_zsh_pid = Regex::new(r"CODEX_NESTED_ZSH_PID=(\d+)")
|
||||
.expect("valid nested zsh pid regex")
|
||||
.captures(&normalized)
|
||||
.and_then(|captures| captures.get(1))
|
||||
.map(|value| value.as_str().to_string())
|
||||
.with_context(|| format!("missing nested zsh pid marker in output {normalized:?}"))?;
|
||||
assert!(
|
||||
process_is_alive(&nested_zsh_pid)?,
|
||||
"nested zsh process should be running before release, got output {normalized:?}"
|
||||
);
|
||||
|
||||
let nested_text_binary = process_text_binary_path(&nested_zsh_pid)?;
|
||||
let nested_text_binary = fs::canonicalize(&nested_text_binary).unwrap_or(nested_text_binary);
|
||||
assert_eq!(
|
||||
nested_text_binary, configured_zsh_path,
|
||||
"nested zsh exec should be rewritten to configured zsh-fork binary, got {:?}",
|
||||
nested_text_binary,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -28,8 +28,6 @@ pub use unix::ShellCommandExecutor;
|
||||
#[cfg(unix)]
|
||||
pub use unix::Stopwatch;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
#[cfg(unix)]
|
||||
pub use unix::main_execve_wrapper;
|
||||
#[cfg(unix)]
|
||||
pub use unix::run_shell_escalation_execve_wrapper;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::io;
|
||||
use std::os::fd::AsFd;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use anyhow::Context as _;
|
||||
@@ -28,12 +28,6 @@ fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
|
||||
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
|
||||
}
|
||||
|
||||
fn duplicate_fd_for_transfer(fd: impl AsFd, name: &str) -> anyhow::Result<OwnedFd> {
|
||||
fd.as_fd()
|
||||
.try_clone_to_owned()
|
||||
.with_context(|| format!("failed to duplicate {name} for escalation transfer"))
|
||||
}
|
||||
|
||||
pub async fn run_shell_escalation_execve_wrapper(
|
||||
file: String,
|
||||
argv: Vec<String>,
|
||||
@@ -68,18 +62,11 @@ pub async fn run_shell_escalation_execve_wrapper(
|
||||
.context("failed to receive EscalateResponse")?;
|
||||
match message.action {
|
||||
EscalateAction::Escalate => {
|
||||
// Duplicate stdio before transferring ownership to the server. The
|
||||
// wrapper must keep using its own stdin/stdout/stderr until the
|
||||
// escalated child takes over.
|
||||
let destination_fds = [
|
||||
io::stdin().as_raw_fd(),
|
||||
io::stdout().as_raw_fd(),
|
||||
io::stderr().as_raw_fd(),
|
||||
];
|
||||
// TODO: maybe we should send ALL open FDs (except the escalate client)?
|
||||
let fds_to_send = [
|
||||
duplicate_fd_for_transfer(io::stdin(), "stdin")?,
|
||||
duplicate_fd_for_transfer(io::stdout(), "stdout")?,
|
||||
duplicate_fd_for_transfer(io::stderr(), "stderr")?,
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
|
||||
];
|
||||
|
||||
// TODO: also forward signals over the super-exec socket
|
||||
@@ -87,7 +74,7 @@ pub async fn run_shell_escalation_execve_wrapper(
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: destination_fds.into_iter().collect(),
|
||||
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
|
||||
},
|
||||
&fds_to_send,
|
||||
)
|
||||
@@ -128,23 +115,3 @@ pub async fn run_shell_escalation_execve_wrapper(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::unix::net::UnixStream;
|
||||
|
||||
#[test]
|
||||
fn duplicate_fd_for_transfer_does_not_close_original() {
|
||||
let (left, _right) = UnixStream::pair().expect("socket pair");
|
||||
let original_fd = left.as_raw_fd();
|
||||
|
||||
let duplicate = duplicate_fd_for_transfer(&left, "test fd").expect("duplicate fd");
|
||||
assert_ne!(duplicate.as_raw_fd(), original_fd);
|
||||
|
||||
drop(duplicate);
|
||||
|
||||
assert_ne!(unsafe { libc::fcntl(original_fd, libc::F_GETFD) }, -1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,6 +319,16 @@ async fn handle_escalate_session_with_policy(
|
||||
));
|
||||
}
|
||||
|
||||
if msg
|
||||
.fds
|
||||
.iter()
|
||||
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"overlapping fds not yet supported in SuperExecMessage"
|
||||
));
|
||||
}
|
||||
|
||||
let PreparedExec {
|
||||
command,
|
||||
cwd,
|
||||
@@ -802,93 +812,6 @@ mod tests {
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
struct RestoredFd {
|
||||
target_fd: i32,
|
||||
original_fd: std::os::fd::OwnedFd,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl RestoredFd {
|
||||
fn close_temporarily(target_fd: i32) -> anyhow::Result<Self> {
|
||||
let original_fd = unsafe { libc::dup(target_fd) };
|
||||
if original_fd == -1 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
if unsafe { libc::close(target_fd) } == -1 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
unsafe {
|
||||
libc::close(original_fd);
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
Ok(Self {
|
||||
target_fd,
|
||||
original_fd: unsafe { std::os::fd::OwnedFd::from_raw_fd(original_fd) },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Drop for RestoredFd {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
libc::dup2(self.original_fd.as_raw_fd(), self.target_fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn handle_escalate_session_accepts_received_fds_that_overlap_destinations()
|
||||
-> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let stdin_restore = RestoredFd::close_temporarily(libc::STDIN_FILENO)?;
|
||||
let send_fd = std::os::fd::OwnedFd::from(std::fs::File::open("/dev/null")?);
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
Arc::new(DeterministicEscalationPolicy {
|
||||
decision: EscalationDecision::escalate(EscalationExecution::Unsandboxed),
|
||||
}),
|
||||
Arc::new(ForwardingShellCommandExecutor),
|
||||
CancellationToken::new(),
|
||||
CancellationToken::new(),
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/sh"),
|
||||
argv: vec!["sh".to_string(), "-c".to_string(), "exit 0".to_string()],
|
||||
workdir: AbsolutePathBuf::current_dir()?,
|
||||
env: HashMap::new(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
},
|
||||
response
|
||||
);
|
||||
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: vec![libc::STDIN_FILENO],
|
||||
},
|
||||
&[send_fd],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let result = client.receive::<SuperExecResult>().await?;
|
||||
assert_eq!(0, result.exit_code);
|
||||
drop(stdin_restore);
|
||||
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_passes_permissions_to_executor() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
|
||||
@@ -102,15 +102,11 @@ async fn spawn_process_with_stdin_mode(
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
stdin_mode: PipeStdinMode,
|
||||
inherited_fds: &[i32],
|
||||
) -> Result<SpawnedProcess> {
|
||||
if program.is_empty() {
|
||||
anyhow::bail!("missing program for pipe spawn");
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let _ = inherited_fds;
|
||||
|
||||
let mut command = Command::new(program);
|
||||
#[cfg(unix)]
|
||||
if let Some(arg0) = arg0 {
|
||||
@@ -119,14 +115,11 @@ async fn spawn_process_with_stdin_mode(
|
||||
#[cfg(target_os = "linux")]
|
||||
let parent_pid = unsafe { libc::getpid() };
|
||||
#[cfg(unix)]
|
||||
let inherited_fds = inherited_fds.to_vec();
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
crate::process_group::detach_from_tty()?;
|
||||
#[cfg(target_os = "linux")]
|
||||
crate::process_group::set_parent_death_signal(parent_pid)?;
|
||||
crate::pty::close_random_fds_except(&inherited_fds);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
@@ -257,7 +250,7 @@ pub async fn spawn_process(
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
) -> Result<SpawnedProcess> {
|
||||
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped, &[]).await
|
||||
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped).await
|
||||
}
|
||||
|
||||
/// Spawn a process using regular pipes, but close stdin immediately.
|
||||
@@ -268,27 +261,5 @@ pub async fn spawn_process_no_stdin(
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
) -> Result<SpawnedProcess> {
|
||||
spawn_process_no_stdin_with_inherited_fds(program, args, cwd, env, arg0, &[]).await
|
||||
}
|
||||
|
||||
/// Spawn a process using regular pipes, close stdin immediately, and preserve
|
||||
/// selected inherited file descriptors across exec on Unix.
|
||||
pub async fn spawn_process_no_stdin_with_inherited_fds(
|
||||
program: &str,
|
||||
args: &[String],
|
||||
cwd: &Path,
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
inherited_fds: &[i32],
|
||||
) -> Result<SpawnedProcess> {
|
||||
spawn_process_with_stdin_mode(
|
||||
program,
|
||||
args,
|
||||
cwd,
|
||||
env,
|
||||
arg0,
|
||||
PipeStdinMode::Null,
|
||||
inherited_fds,
|
||||
)
|
||||
.await
|
||||
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Null).await
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use core::fmt;
|
||||
use std::io;
|
||||
#[cfg(unix)]
|
||||
use std::os::fd::RawFd;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
@@ -43,24 +41,9 @@ impl From<TerminalSize> for PtySize {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) trait PtyHandleKeepAlive: Send {}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl<T: Send + ?Sized> PtyHandleKeepAlive for T {}
|
||||
|
||||
pub(crate) enum PtyMasterHandle {
|
||||
Resizable(Box<dyn MasterPty + Send>),
|
||||
#[cfg(unix)]
|
||||
Opaque {
|
||||
raw_fd: RawFd,
|
||||
_handle: Box<dyn PtyHandleKeepAlive>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct PtyHandles {
|
||||
pub _slave: Option<Box<dyn SlavePty + Send>>,
|
||||
pub(crate) _master: PtyMasterHandle,
|
||||
pub _master: Box<dyn MasterPty + Send>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for PtyHandles {
|
||||
@@ -148,11 +131,7 @@ impl ProcessHandle {
|
||||
let handles = handles
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("process is not attached to a PTY"))?;
|
||||
match &handles._master {
|
||||
PtyMasterHandle::Resizable(master) => master.resize(size.into()),
|
||||
#[cfg(unix)]
|
||||
PtyMasterHandle::Opaque { raw_fd, .. } => resize_raw_pty(*raw_fd, size),
|
||||
}
|
||||
handles._master.resize(size.into())
|
||||
}
|
||||
|
||||
/// Close the child's stdin channel.
|
||||
@@ -205,21 +184,6 @@ impl Drop for ProcessHandle {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn resize_raw_pty(raw_fd: RawFd, size: TerminalSize) -> anyhow::Result<()> {
|
||||
let mut winsize = libc::winsize {
|
||||
ws_row: size.rows,
|
||||
ws_col: size.cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
};
|
||||
let result = unsafe { libc::ioctl(raw_fd, libc::TIOCSWINSZ, &mut winsize) };
|
||||
if result == -1 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Combine split stdout/stderr receivers into a single broadcast receiver.
|
||||
pub fn combine_output_receivers(
|
||||
mut stdout_rx: mpsc::Receiver<Vec<u8>>,
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
#[cfg(unix)]
|
||||
use std::fs::File;
|
||||
use std::io::ErrorKind;
|
||||
#[cfg(unix)]
|
||||
use std::os::fd::AsRawFd;
|
||||
#[cfg(unix)]
|
||||
use std::os::fd::FromRawFd;
|
||||
#[cfg(unix)]
|
||||
use std::os::fd::RawFd;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::path::Path;
|
||||
#[cfg(unix)]
|
||||
use std::process::Command as StdCommand;
|
||||
#[cfg(unix)]
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
@@ -31,7 +17,6 @@ use tokio::task::JoinHandle;
|
||||
use crate::process::ChildTerminator;
|
||||
use crate::process::ProcessHandle;
|
||||
use crate::process::PtyHandles;
|
||||
use crate::process::PtyMasterHandle;
|
||||
use crate::process::SpawnedProcess;
|
||||
use crate::process::TerminalSize;
|
||||
|
||||
@@ -74,18 +59,6 @@ impl ChildTerminator for PtyChildTerminator {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
struct RawPidTerminator {
|
||||
process_group_id: u32,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl ChildTerminator for RawPidTerminator {
|
||||
fn kill(&mut self) -> std::io::Result<()> {
|
||||
crate::process_group::kill_process_group(self.process_group_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn platform_native_pty_system() -> Box<dyn portable_pty::PtySystem + Send> {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
@@ -106,45 +79,11 @@ pub async fn spawn_process(
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
size: TerminalSize,
|
||||
) -> Result<SpawnedProcess> {
|
||||
spawn_process_with_inherited_fds(program, args, cwd, env, arg0, size, &[]).await
|
||||
}
|
||||
|
||||
/// Spawn a process attached to a PTY, preserving any inherited file
|
||||
/// descriptors listed in `inherited_fds` across exec on Unix.
|
||||
pub async fn spawn_process_with_inherited_fds(
|
||||
program: &str,
|
||||
args: &[String],
|
||||
cwd: &Path,
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
size: TerminalSize,
|
||||
inherited_fds: &[i32],
|
||||
) -> Result<SpawnedProcess> {
|
||||
if program.is_empty() {
|
||||
anyhow::bail!("missing program for PTY spawn");
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let _ = inherited_fds;
|
||||
|
||||
#[cfg(unix)]
|
||||
if !inherited_fds.is_empty() {
|
||||
return spawn_process_preserving_fds(program, args, cwd, env, arg0, size, inherited_fds)
|
||||
.await;
|
||||
}
|
||||
|
||||
spawn_process_portable(program, args, cwd, env, arg0, size).await
|
||||
}
|
||||
|
||||
async fn spawn_process_portable(
|
||||
program: &str,
|
||||
args: &[String],
|
||||
cwd: &Path,
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
size: TerminalSize,
|
||||
) -> Result<SpawnedProcess> {
|
||||
let pty_system = platform_native_pty_system();
|
||||
let pair = pty_system.openpty(size.into())?;
|
||||
|
||||
@@ -225,7 +164,7 @@ async fn spawn_process_portable(
|
||||
} else {
|
||||
None
|
||||
},
|
||||
_master: PtyMasterHandle::Resizable(pair.master),
|
||||
_master: pair.master,
|
||||
};
|
||||
|
||||
let handle = ProcessHandle::new(
|
||||
@@ -251,225 +190,3 @@ async fn spawn_process_portable(
|
||||
exit_rx,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn spawn_process_preserving_fds(
|
||||
program: &str,
|
||||
args: &[String],
|
||||
cwd: &Path,
|
||||
env: &HashMap<String, String>,
|
||||
arg0: &Option<String>,
|
||||
size: TerminalSize,
|
||||
inherited_fds: &[RawFd],
|
||||
) -> Result<SpawnedProcess> {
|
||||
let (master, slave) = open_unix_pty(size)?;
|
||||
let mut command = StdCommand::new(program);
|
||||
if let Some(arg0) = arg0 {
|
||||
command.arg0(arg0);
|
||||
}
|
||||
command.current_dir(cwd);
|
||||
command.env_clear();
|
||||
for arg in args {
|
||||
command.arg(arg);
|
||||
}
|
||||
for (key, value) in env {
|
||||
command.env(key, value);
|
||||
}
|
||||
|
||||
let stdin = slave.try_clone()?;
|
||||
let stdout = slave.try_clone()?;
|
||||
let stderr = slave.try_clone()?;
|
||||
let inherited_fds = inherited_fds.to_vec();
|
||||
|
||||
unsafe {
|
||||
command
|
||||
.stdin(Stdio::from(stdin))
|
||||
.stdout(Stdio::from(stdout))
|
||||
.stderr(Stdio::from(stderr))
|
||||
.pre_exec(move || {
|
||||
for signo in &[
|
||||
libc::SIGCHLD,
|
||||
libc::SIGHUP,
|
||||
libc::SIGINT,
|
||||
libc::SIGQUIT,
|
||||
libc::SIGTERM,
|
||||
libc::SIGALRM,
|
||||
] {
|
||||
libc::signal(*signo, libc::SIG_DFL);
|
||||
}
|
||||
|
||||
let empty_set: libc::sigset_t = std::mem::zeroed();
|
||||
libc::sigprocmask(libc::SIG_SETMASK, &empty_set, std::ptr::null_mut());
|
||||
|
||||
if libc::setsid() == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_lossless)]
|
||||
if libc::ioctl(0, libc::TIOCSCTTY as _, 0) == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
close_random_fds_except(&inherited_fds);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
let mut child = command.spawn()?;
|
||||
drop(slave);
|
||||
let process_group_id = child.id();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (stdout_tx, stdout_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (_stderr_tx, stderr_rx) = mpsc::channel::<Vec<u8>>(1);
|
||||
let mut reader = master.try_clone()?;
|
||||
let reader_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8_192];
|
||||
loop {
|
||||
match std::io::Read::read(&mut reader, &mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = stdout_tx.blocking_send(buf[..n].to_vec());
|
||||
}
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
continue;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer = Arc::new(tokio::sync::Mutex::new(master.try_clone()?));
|
||||
let writer_handle: JoinHandle<()> = tokio::spawn({
|
||||
let writer = Arc::clone(&writer);
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let mut guard = writer.lock().await;
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = Arc::clone(&exit_status);
|
||||
let exit_code = Arc::new(StdMutex::new(None));
|
||||
let wait_exit_code = Arc::clone(&exit_code);
|
||||
let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.code().unwrap_or(-1),
|
||||
Err(_) => -1,
|
||||
};
|
||||
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
if let Ok(mut guard) = wait_exit_code.lock() {
|
||||
*guard = Some(code);
|
||||
}
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
let handles = PtyHandles {
|
||||
_slave: None,
|
||||
_master: PtyMasterHandle::Opaque {
|
||||
raw_fd: master.as_raw_fd(),
|
||||
_handle: Box::new(master),
|
||||
},
|
||||
};
|
||||
|
||||
let handle = ProcessHandle::new(
|
||||
writer_tx,
|
||||
Box::new(RawPidTerminator { process_group_id }),
|
||||
reader_handle,
|
||||
Vec::new(),
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
exit_code,
|
||||
Some(handles),
|
||||
);
|
||||
|
||||
Ok(SpawnedProcess {
|
||||
session: handle,
|
||||
stdout_rx,
|
||||
stderr_rx,
|
||||
exit_rx,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn open_unix_pty(size: TerminalSize) -> Result<(File, File)> {
|
||||
let mut master: RawFd = -1;
|
||||
let mut slave: RawFd = -1;
|
||||
let mut size = libc::winsize {
|
||||
ws_row: size.rows,
|
||||
ws_col: size.cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
};
|
||||
let winp = std::ptr::addr_of_mut!(size);
|
||||
|
||||
let result = unsafe {
|
||||
libc::openpty(
|
||||
&mut master,
|
||||
&mut slave,
|
||||
std::ptr::null_mut(),
|
||||
std::ptr::null_mut(),
|
||||
winp,
|
||||
)
|
||||
};
|
||||
if result != 0 {
|
||||
anyhow::bail!("failed to openpty: {:?}", std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
set_cloexec(master)?;
|
||||
set_cloexec(slave)?;
|
||||
|
||||
Ok(unsafe { (File::from_raw_fd(master), File::from_raw_fd(slave)) })
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn set_cloexec(fd: RawFd) -> std::io::Result<()> {
|
||||
let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
|
||||
if flags == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
let result = unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) };
|
||||
if result == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn close_random_fds_except(preserved_fds: &[RawFd]) {
|
||||
if let Ok(dir) = std::fs::read_dir("/dev/fd") {
|
||||
let mut fds = Vec::new();
|
||||
for entry in dir {
|
||||
let num = entry
|
||||
.ok()
|
||||
.map(|entry| entry.file_name())
|
||||
.and_then(|name| name.into_string().ok())
|
||||
.and_then(|name| name.parse::<RawFd>().ok());
|
||||
if let Some(num) = num {
|
||||
if num <= 2 || preserved_fds.contains(&num) {
|
||||
continue;
|
||||
}
|
||||
// Keep CLOEXEC descriptors open so std::process can still use
|
||||
// its internal exec-error pipe to report spawn failures.
|
||||
let flags = unsafe { libc::fcntl(num, libc::F_GETFD) };
|
||||
if flags == -1 || flags & libc::FD_CLOEXEC != 0 {
|
||||
continue;
|
||||
}
|
||||
fds.push(num);
|
||||
}
|
||||
}
|
||||
for fd in fds {
|
||||
unsafe {
|
||||
libc::close(fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,6 @@ use std::path::Path;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::combine_output_receivers;
|
||||
#[cfg(unix)]
|
||||
use crate::pipe::spawn_process_no_stdin_with_inherited_fds;
|
||||
#[cfg(unix)]
|
||||
use crate::pty::spawn_process_with_inherited_fds;
|
||||
use crate::spawn_pipe_process;
|
||||
use crate::spawn_pipe_process_no_stdin;
|
||||
use crate::spawn_pty_process;
|
||||
@@ -139,42 +135,6 @@ async fn collect_output_until_exit(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn wait_for_output_contains(
|
||||
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
needle: &str,
|
||||
timeout_ms: u64,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let mut collected = Vec::new();
|
||||
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
|
||||
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
let now = tokio::time::Instant::now();
|
||||
let remaining = deadline.saturating_duration_since(now);
|
||||
match tokio::time::timeout(remaining, output_rx.recv()).await {
|
||||
Ok(Ok(chunk)) => {
|
||||
collected.extend_from_slice(&chunk);
|
||||
if String::from_utf8_lossy(&collected).contains(needle) {
|
||||
return Ok(collected);
|
||||
}
|
||||
}
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue,
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
|
||||
anyhow::bail!(
|
||||
"PTY output closed while waiting for {needle:?}: {:?}",
|
||||
String::from_utf8_lossy(&collected)
|
||||
);
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"timed out waiting for {needle:?} in PTY output: {:?}",
|
||||
String::from_utf8_lossy(&collected)
|
||||
);
|
||||
}
|
||||
|
||||
async fn wait_for_python_repl_ready(
|
||||
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
timeout_ms: u64,
|
||||
@@ -210,58 +170,6 @@ async fn wait_for_python_repl_ready(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn wait_for_python_repl_ready_via_probe(
|
||||
writer: &tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
timeout_ms: u64,
|
||||
newline: &str,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let mut collected = Vec::new();
|
||||
let marker = "__codex_pty_ready__";
|
||||
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
|
||||
let probe_window = tokio::time::Duration::from_millis(if cfg!(windows) { 750 } else { 250 });
|
||||
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
writer
|
||||
.send(format!("print('{marker}'){newline}").into_bytes())
|
||||
.await?;
|
||||
|
||||
let probe_deadline = tokio::time::Instant::now() + probe_window;
|
||||
loop {
|
||||
let now = tokio::time::Instant::now();
|
||||
if now >= deadline || now >= probe_deadline {
|
||||
break;
|
||||
}
|
||||
let remaining = std::cmp::min(
|
||||
deadline.saturating_duration_since(now),
|
||||
probe_deadline.saturating_duration_since(now),
|
||||
);
|
||||
match tokio::time::timeout(remaining, output_rx.recv()).await {
|
||||
Ok(Ok(chunk)) => {
|
||||
collected.extend_from_slice(&chunk);
|
||||
if String::from_utf8_lossy(&collected).contains(marker) {
|
||||
return Ok(collected);
|
||||
}
|
||||
}
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue,
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
|
||||
anyhow::bail!(
|
||||
"PTY output closed while waiting for Python REPL readiness: {:?}",
|
||||
String::from_utf8_lossy(&collected)
|
||||
);
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"timed out waiting for Python REPL readiness in PTY: {:?}",
|
||||
String::from_utf8_lossy(&collected)
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn process_exists(pid: i32) -> anyhow::Result<bool> {
|
||||
let result = unsafe { libc::kill(pid, 0) };
|
||||
@@ -301,26 +209,16 @@ async fn wait_for_marker_pid(
|
||||
collected.extend_from_slice(&chunk);
|
||||
|
||||
let text = String::from_utf8_lossy(&collected);
|
||||
let mut offset = 0;
|
||||
while let Some(pos) = text[offset..].find(marker) {
|
||||
let marker_start = offset + pos;
|
||||
let suffix = &text[marker_start + marker.len()..];
|
||||
let digits_len = suffix
|
||||
if let Some(marker_idx) = text.find(marker) {
|
||||
let suffix = &text[marker_idx + marker.len()..];
|
||||
let digits: String = suffix
|
||||
.chars()
|
||||
.skip_while(|ch| !ch.is_ascii_digit())
|
||||
.take_while(char::is_ascii_digit)
|
||||
.map(char::len_utf8)
|
||||
.sum::<usize>();
|
||||
if digits_len == 0 {
|
||||
offset = marker_start + marker.len();
|
||||
continue;
|
||||
.collect();
|
||||
if !digits.is_empty() {
|
||||
return Ok(digits.parse()?);
|
||||
}
|
||||
|
||||
let pid_str = &suffix[..digits_len];
|
||||
let trailing = &suffix[digits_len..];
|
||||
if trailing.is_empty() {
|
||||
break;
|
||||
}
|
||||
return Ok(pid_str.parse()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -671,276 +569,3 @@ async fn pty_terminate_kills_background_children_in_same_process_group() -> anyh
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_spawn_can_preserve_inherited_fds() -> anyhow::Result<()> {
|
||||
use std::io::Read;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
let mut fds = [0; 2];
|
||||
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if result != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
|
||||
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
|
||||
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
|
||||
|
||||
let mut env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
env_map.insert(
|
||||
"PRESERVED_FD".to_string(),
|
||||
write_end.as_raw_fd().to_string(),
|
||||
);
|
||||
|
||||
let script = "printf __preserved__ >\"/dev/fd/$PRESERVED_FD\"";
|
||||
let spawned = spawn_process_with_inherited_fds(
|
||||
"/bin/sh",
|
||||
&["-c".to_string(), script.to_string()],
|
||||
Path::new("."),
|
||||
&env_map,
|
||||
&None,
|
||||
TerminalSize::default(),
|
||||
&[write_end.as_raw_fd()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(write_end);
|
||||
|
||||
let (_session, output_rx, exit_rx) = combine_spawned_output(spawned);
|
||||
let (_, code) = collect_output_until_exit(output_rx, exit_rx, 2_000).await;
|
||||
assert_eq!(code, 0, "expected preserved-fd PTY child to exit cleanly");
|
||||
|
||||
let mut pipe_output = String::new();
|
||||
read_end.read_to_string(&mut pipe_output)?;
|
||||
assert_eq!(pipe_output, "__preserved__");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_preserving_inherited_fds_keeps_python_repl_running() -> anyhow::Result<()> {
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
let Some(python) = find_python() else {
|
||||
eprintln!(
|
||||
"python not found; skipping pty_preserving_inherited_fds_keeps_python_repl_running"
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut fds = [0; 2];
|
||||
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if result != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
|
||||
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
|
||||
let preserved_fd = unsafe { std::fs::File::from_raw_fd(fds[1]) };
|
||||
|
||||
let mut env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
env_map.insert(
|
||||
"PRESERVED_FD".to_string(),
|
||||
preserved_fd.as_raw_fd().to_string(),
|
||||
);
|
||||
|
||||
let spawned = spawn_process_with_inherited_fds(
|
||||
&python,
|
||||
&[],
|
||||
Path::new("."),
|
||||
&env_map,
|
||||
&None,
|
||||
TerminalSize::default(),
|
||||
&[preserved_fd.as_raw_fd()],
|
||||
)
|
||||
.await?;
|
||||
drop(read_end);
|
||||
drop(preserved_fd);
|
||||
|
||||
let (session, mut output_rx, exit_rx) = combine_spawned_output(spawned);
|
||||
let writer = session.writer_sender();
|
||||
let newline = "\n";
|
||||
let mut output =
|
||||
wait_for_python_repl_ready_via_probe(&writer, &mut output_rx, 5_000, newline).await?;
|
||||
let marker = "__codex_preserved_py_pid:";
|
||||
writer
|
||||
.send(format!("import os; print('{marker}' + str(os.getpid())){newline}").into_bytes())
|
||||
.await?;
|
||||
|
||||
let python_pid = match wait_for_marker_pid(&mut output_rx, marker, 2_000).await {
|
||||
Ok(pid) => pid,
|
||||
Err(err) => {
|
||||
session.terminate();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
assert!(
|
||||
process_exists(python_pid)?,
|
||||
"expected python pid {python_pid} to stay alive after prompt output"
|
||||
);
|
||||
|
||||
writer.send(format!("exit(){newline}").into_bytes()).await?;
|
||||
let (remaining_output, code) = collect_output_until_exit(output_rx, exit_rx, 5_000).await;
|
||||
output.extend_from_slice(&remaining_output);
|
||||
|
||||
assert_eq!(code, 0, "expected python to exit cleanly");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_spawn_with_inherited_fds_reports_exec_failures() -> anyhow::Result<()> {
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
let mut fds = [0; 2];
|
||||
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if result != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
|
||||
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
|
||||
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
|
||||
|
||||
let env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
let spawn_result = spawn_process_with_inherited_fds(
|
||||
"/definitely/missing/command",
|
||||
&[],
|
||||
Path::new("."),
|
||||
&env_map,
|
||||
&None,
|
||||
TerminalSize::default(),
|
||||
&[write_end.as_raw_fd()],
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(read_end);
|
||||
drop(write_end);
|
||||
|
||||
let err = match spawn_result {
|
||||
Ok(spawned) => {
|
||||
spawned.session.terminate();
|
||||
anyhow::bail!("missing executable unexpectedly spawned");
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
let err_text = err.to_string();
|
||||
assert!(
|
||||
err_text.contains("No such file")
|
||||
|| err_text.contains("not found")
|
||||
|| err_text.contains("os error 2"),
|
||||
"expected spawn error for missing executable, got: {err_text}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pty_spawn_with_inherited_fds_supports_resize() -> anyhow::Result<()> {
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
let mut fds = [0; 2];
|
||||
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if result != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
|
||||
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
|
||||
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
|
||||
|
||||
let env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
let script =
|
||||
"stty -echo; printf 'start:%s\\n' \"$(stty size)\"; IFS= read _line; printf 'after:%s\\n' \"$(stty size)\"";
|
||||
let spawned = spawn_process_with_inherited_fds(
|
||||
"/bin/sh",
|
||||
&["-c".to_string(), script.to_string()],
|
||||
Path::new("."),
|
||||
&env_map,
|
||||
&None,
|
||||
TerminalSize {
|
||||
rows: 31,
|
||||
cols: 101,
|
||||
},
|
||||
&[write_end.as_raw_fd()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (session, mut output_rx, exit_rx) = combine_spawned_output(spawned);
|
||||
let writer = session.writer_sender();
|
||||
let mut output = wait_for_output_contains(&mut output_rx, "start:31 101\r\n", 5_000).await?;
|
||||
|
||||
session.resize(TerminalSize {
|
||||
rows: 45,
|
||||
cols: 132,
|
||||
})?;
|
||||
writer.send(b"go\n".to_vec()).await?;
|
||||
session.close_stdin();
|
||||
|
||||
let (remaining_output, code) = collect_output_until_exit(output_rx, exit_rx, 5_000).await;
|
||||
output.extend_from_slice(&remaining_output);
|
||||
let text = String::from_utf8_lossy(&output);
|
||||
let normalized = text.replace("\r\n", "\n");
|
||||
|
||||
assert!(
|
||||
normalized.contains("after:45 132\n"),
|
||||
"expected resized PTY dimensions in output: {text:?}"
|
||||
);
|
||||
assert_eq!(code, 0, "expected shell to exit cleanly after resize");
|
||||
|
||||
drop(read_end);
|
||||
drop(write_end);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn pipe_spawn_no_stdin_can_preserve_inherited_fds() -> anyhow::Result<()> {
|
||||
use std::io::Read;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
let mut fds = [0; 2];
|
||||
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
|
||||
if result != 0 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
}
|
||||
|
||||
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
|
||||
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
|
||||
|
||||
let mut env_map: HashMap<String, String> = std::env::vars().collect();
|
||||
env_map.insert(
|
||||
"PRESERVED_FD".to_string(),
|
||||
write_end.as_raw_fd().to_string(),
|
||||
);
|
||||
|
||||
let script = "printf __pipe_preserved__ >\"/dev/fd/$PRESERVED_FD\"";
|
||||
let spawned = spawn_process_no_stdin_with_inherited_fds(
|
||||
"/bin/sh",
|
||||
&["-c".to_string(), script.to_string()],
|
||||
Path::new("."),
|
||||
&env_map,
|
||||
&None,
|
||||
&[write_end.as_raw_fd()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(write_end);
|
||||
|
||||
let (_session, output_rx, exit_rx) = combine_spawned_output(spawned);
|
||||
let (_, code) = collect_output_until_exit(output_rx, exit_rx, 2_000).await;
|
||||
assert_eq!(code, 0, "expected preserved-fd pipe child to exit cleanly");
|
||||
|
||||
let mut pipe_output = String::new();
|
||||
read_end.read_to_string(&mut pipe_output)?;
|
||||
assert_eq!(pipe_output, "__pipe_preserved__");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user