Compare commits

..

1 Commits

Author SHA1 Message Date
celia-oai
584baeb550 changes 2026-03-12 21:46:56 -07:00
53 changed files with 1787 additions and 4103 deletions

View File

@@ -11,6 +11,7 @@ use app_test_support::McpProcess;
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::create_mock_responses_server_sequence;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::create_shell_command_sse_response;
use app_test_support::to_response;
use codex_app_server_protocol::CommandAction;
use codex_app_server_protocol::CommandExecutionApprovalDecision;
@@ -34,11 +35,9 @@ use codex_core::features::Feature;
use core_test_support::responses;
use core_test_support::skip_if_no_network;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::BTreeMap;
use std::path::Path;
use tempfile::TempDir;
use tokio::time::sleep;
use tokio::time::timeout;
#[cfg(windows)]
@@ -63,14 +62,19 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
};
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
// Keep the exec command in flight until we interrupt it. A fast command
// Keep the shell command in flight until we interrupt it. A fast command
// like `echo hi` can finish before the interrupt arrives on faster runners,
// which turns this into a test for post-command follow-up behavior instead
// of interrupting an active zsh-fork command.
let release_marker_escaped = release_marker.to_string_lossy().replace('\'', r#"'\''"#);
let wait_for_interrupt =
format!("while [ ! -f '{release_marker_escaped}' ]; do sleep 0.01; done");
let response = create_zsh_fork_exec_command_sse_response(&wait_for_interrupt, "call-zsh-fork")?;
let response = create_shell_command_sse_response(
vec!["/bin/sh".to_string(), "-c".to_string(), wait_for_interrupt],
None,
Some(5000),
"call-zsh-fork",
)?;
let no_op_response = responses::sse(vec![
responses::ev_response_created("resp-2"),
responses::ev_completed("resp-2"),
@@ -87,7 +91,7 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
"never",
&BTreeMap::from([
(Feature::ShellZshFork, true),
(Feature::UnifiedExec, true),
(Feature::UnifiedExec, false),
(Feature::ShellSnapshot, false),
]),
&zsh_path,
@@ -159,7 +163,7 @@ async fn turn_start_shell_zsh_fork_executes_command_v2() -> Result<()> {
assert_eq!(id, "call-zsh-fork");
assert_eq!(status, CommandExecutionStatus::InProgress);
assert!(command.starts_with(&zsh_path.display().to_string()));
assert!(command.contains(" -lc "));
assert!(command.contains("/bin/sh -c"));
assert!(command.contains("sleep 0.01"));
assert!(command.contains(&release_marker.display().to_string()));
assert_eq!(cwd, workspace);
@@ -187,8 +191,14 @@ async fn turn_start_shell_zsh_fork_exec_approval_decline_v2() -> Result<()> {
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
let responses = vec![
create_zsh_fork_exec_command_sse_response(
"python3 -c 'print(42)'",
create_shell_command_sse_response(
vec![
"python3".to_string(),
"-c".to_string(),
"print(42)".to_string(),
],
None,
Some(5000),
"call-zsh-fork-decline",
)?,
create_final_assistant_message_sse_response("done")?,
@@ -200,7 +210,7 @@ async fn turn_start_shell_zsh_fork_exec_approval_decline_v2() -> Result<()> {
"untrusted",
&BTreeMap::from([
(Feature::ShellZshFork, true),
(Feature::UnifiedExec, true),
(Feature::UnifiedExec, false),
(Feature::ShellSnapshot, false),
]),
&zsh_path,
@@ -316,8 +326,14 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
};
eprintln!("using zsh path for zsh-fork test: {}", zsh_path.display());
let responses = vec![create_zsh_fork_exec_command_sse_response(
"python3 -c 'print(42)'",
let responses = vec![create_shell_command_sse_response(
vec![
"python3".to_string(),
"-c".to_string(),
"print(42)".to_string(),
],
None,
Some(5000),
"call-zsh-fork-cancel",
)?];
let server = create_mock_responses_server_sequence(responses).await;
@@ -327,7 +343,7 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
"untrusted",
&BTreeMap::from([
(Feature::ShellZshFork, true),
(Feature::UnifiedExec, true),
(Feature::UnifiedExec, false),
(Feature::ShellSnapshot, false),
]),
&zsh_path,
@@ -425,204 +441,6 @@ async fn turn_start_shell_zsh_fork_exec_approval_cancel_v2() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn turn_start_shell_zsh_fork_interrupt_kills_approved_subcommand_v2() -> Result<()> {
skip_if_no_network!(Ok(()));
let tmp = TempDir::new()?;
let codex_home = tmp.path().join("codex_home");
std::fs::create_dir(&codex_home)?;
let workspace = tmp.path().join("workspace");
std::fs::create_dir(&workspace)?;
let launch_marker = workspace.join("approved-subcommand.started");
let leaked_marker = workspace.join("approved-subcommand.leaked");
let launch_marker_display = launch_marker.display().to_string();
assert!(
!launch_marker_display.contains('\''),
"test workspace path should not contain single quotes: {launch_marker_display}"
);
let leaked_marker_display = leaked_marker.display().to_string();
assert!(
!leaked_marker_display.contains('\''),
"test workspace path should not contain single quotes: {leaked_marker_display}"
);
let Some(zsh_path) = find_test_zsh_path()? else {
eprintln!("skipping zsh fork interrupt cleanup test: no zsh executable found");
return Ok(());
};
if !supports_exec_wrapper_intercept(&zsh_path) {
eprintln!(
"skipping zsh fork interrupt cleanup test: zsh does not support EXEC_WRAPPER intercepts ({})",
zsh_path.display()
);
return Ok(());
}
let zsh_path_display = zsh_path.display().to_string();
eprintln!("using zsh path for zsh-fork test: {zsh_path_display}");
let shell_command = format!(
"/bin/sh -c 'echo started > \"{launch_marker_display}\" && /bin/sleep 0.5 && echo leaked > \"{leaked_marker_display}\" && exec /bin/sleep 100'"
);
let tool_call_arguments = serde_json::to_string(&json!({
"cmd": shell_command,
"yield_time_ms": 30_000,
}))?;
let response = responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call(
"call-zsh-fork-interrupt-cleanup",
"exec_command",
&tool_call_arguments,
),
responses::ev_completed("resp-1"),
]);
let no_op_response = responses::sse(vec![
responses::ev_response_created("resp-2"),
responses::ev_completed("resp-2"),
]);
let server =
create_mock_responses_server_sequence_unchecked(vec![response, no_op_response]).await;
create_config_toml(
&codex_home,
&server.uri(),
"untrusted",
&BTreeMap::from([
(Feature::ShellZshFork, true),
(Feature::UnifiedExec, true),
(Feature::ShellSnapshot, false),
]),
&zsh_path,
)?;
let mut mcp = create_zsh_test_mcp_process(&codex_home, &workspace).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let start_id = mcp
.send_thread_start_request(ThreadStartParams {
model: Some("mock-model".to_string()),
cwd: Some(workspace.to_string_lossy().into_owned()),
..Default::default()
})
.await?;
let start_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
)
.await??;
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
let turn_id = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id.clone(),
input: vec![V2UserInput::Text {
text: "run the long-lived command".to_string(),
text_elements: Vec::new(),
}],
cwd: Some(workspace.clone()),
approval_policy: Some(codex_app_server_protocol::AskForApproval::UnlessTrusted),
sandbox_policy: Some(codex_app_server_protocol::SandboxPolicy::WorkspaceWrite {
writable_roots: vec![workspace.clone().try_into()?],
read_only_access: codex_app_server_protocol::ReadOnlyAccess::FullAccess,
network_access: false,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
}),
model: Some("mock-model".to_string()),
effort: Some(codex_protocol::openai_models::ReasoningEffort::Medium),
summary: Some(codex_protocol::config_types::ReasoningSummary::Auto),
..Default::default()
})
.await?;
let turn_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_id)),
)
.await??;
let TurnStartResponse { turn } = to_response::<TurnStartResponse>(turn_resp)?;
let mut saw_target_approval = false;
while !saw_target_approval {
let server_req = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_request_message(),
)
.await??;
let ServerRequest::CommandExecutionRequestApproval { request_id, params } = server_req
else {
panic!("expected CommandExecutionRequestApproval request");
};
let approval_command = params.command.clone().unwrap_or_default();
saw_target_approval = approval_command.contains("/bin/sh")
&& approval_command.contains(&launch_marker_display)
&& !approval_command.contains(&zsh_path_display);
mcp.send_response(
request_id,
serde_json::to_value(CommandExecutionRequestApprovalResponse {
decision: CommandExecutionApprovalDecision::Accept,
})?,
)
.await?;
}
let started_command = timeout(DEFAULT_READ_TIMEOUT, async {
loop {
let notif = mcp
.read_stream_until_notification_message("item/started")
.await?;
let started: ItemStartedNotification =
serde_json::from_value(notif.params.clone().expect("item/started params"))?;
if let ThreadItem::CommandExecution { .. } = started.item {
return Ok::<ThreadItem, anyhow::Error>(started.item);
}
}
})
.await??;
let ThreadItem::CommandExecution {
id,
process_id,
status,
command,
cwd,
..
} = started_command
else {
unreachable!("loop ensures we break on command execution items");
};
assert_eq!(id, "call-zsh-fork-interrupt-cleanup");
assert_eq!(status, CommandExecutionStatus::InProgress);
assert!(command.starts_with(&zsh_path.display().to_string()));
assert!(command.contains(" -lc "));
assert!(command.contains(&launch_marker_display));
assert_eq!(cwd, workspace);
assert!(process_id.is_some(), "process id should be present");
timeout(DEFAULT_READ_TIMEOUT, async {
loop {
if launch_marker.exists() {
return Ok::<(), anyhow::Error>(());
}
sleep(std::time::Duration::from_millis(20)).await;
}
})
.await??;
mcp.interrupt_turn_and_wait_for_aborted(
thread.id.clone(),
turn.id.clone(),
DEFAULT_READ_TIMEOUT,
)
.await?;
sleep(std::time::Duration::from_millis(750)).await;
assert!(
!leaked_marker.exists(),
"expected interrupt to stop approved subcommand before it wrote {leaked_marker_display}"
);
Ok(())
}
#[tokio::test]
async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -654,15 +472,16 @@ async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2()
first_file.display(),
second_file.display()
);
let tool_call_arguments = serde_json::to_string(&json!({
"cmd": shell_command,
"yield_time_ms": 5000,
let tool_call_arguments = serde_json::to_string(&serde_json::json!({
"command": shell_command,
"workdir": serde_json::Value::Null,
"timeout_ms": 5000
}))?;
let response = responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call(
"call-zsh-fork-subcommand-decline",
"exec_command",
"shell_command",
&tool_call_arguments,
),
responses::ev_completed("resp-1"),
@@ -683,7 +502,7 @@ async fn turn_start_shell_zsh_fork_subcommand_decline_marks_parent_declined_v2()
"untrusted",
&BTreeMap::from([
(Feature::ShellZshFork, true),
(Feature::UnifiedExec, true),
(Feature::UnifiedExec, false),
(Feature::ShellSnapshot, false),
]),
&zsh_path,
@@ -925,21 +744,6 @@ async fn create_zsh_test_mcp_process(codex_home: &Path, zdotdir: &Path) -> Resul
McpProcess::new_with_env(codex_home, &[("ZDOTDIR", Some(zdotdir.as_str()))]).await
}
fn create_zsh_fork_exec_command_sse_response(
command: &str,
call_id: &str,
) -> anyhow::Result<String> {
let tool_call_arguments = serde_json::to_string(&json!({
"cmd": command,
"yield_time_ms": 5000,
}))?;
Ok(responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call(call_id, "exec_command", &tool_call_arguments),
responses::ev_completed("resp-1"),
]))
}
fn create_config_toml(
codex_home: &Path,
server_uri: &str,

View File

@@ -1,20 +1,16 @@
use crate::endpoint::realtime_websocket::protocol::ConversationFunctionCallOutputItem;
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload;
use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem;
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptDelta;
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
use crate::endpoint::realtime_websocket::protocol::SessionFunctionTool;
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::error::ApiError;
@@ -25,7 +21,6 @@ use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
@@ -46,23 +41,6 @@ use tracing::trace;
use tungstenite::protocol::WebSocketConfig;
use url::Url;
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
const REALTIME_AUDIO_VOICE: &str = "fathom";
const REALTIME_V1_SESSION_TYPE: &str = "quicksilver";
const REALTIME_V2_SESSION_TYPE: &str = "realtime";
const REALTIME_V2_CODEX_TOOL_NAME: &str = "codex";
const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate work to Codex and return the result.";
fn normalized_session_mode(
event_parser: RealtimeEventParser,
session_mode: RealtimeSessionMode,
) -> RealtimeSessionMode {
match event_parser {
RealtimeEventParser::V1 => RealtimeSessionMode::Conversational,
RealtimeEventParser::RealtimeV2 => session_mode,
}
}
struct WsStream {
tx_command: mpsc::Sender<WsCommand>,
pump_task: tokio::task::JoinHandle<()>,
@@ -219,7 +197,6 @@ pub struct RealtimeWebsocketConnection {
pub struct RealtimeWebsocketWriter {
stream: Arc<WsStream>,
is_closed: Arc<AtomicBool>,
event_parser: RealtimeEventParser,
}
#[derive(Clone)]
@@ -281,7 +258,6 @@ impl RealtimeWebsocketConnection {
writer: RealtimeWebsocketWriter {
stream: Arc::clone(&stream),
is_closed: Arc::clone(&is_closed),
event_parser,
},
events: RealtimeWebsocketEvents {
rx_message: Arc::new(Mutex::new(rx_message)),
@@ -300,19 +276,15 @@ impl RealtimeWebsocketWriter {
}
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
let content_kind = match self.event_parser {
RealtimeEventParser::V1 => "text",
RealtimeEventParser::RealtimeV2 => "input_text",
};
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItemPayload::Message(ConversationMessageItem {
item: ConversationItem {
kind: "message".to_string(),
role: "user".to_string(),
content: vec![ConversationItemContent {
kind: content_kind.to_string(),
kind: "text".to_string(),
text,
}],
}),
},
})
.await
}
@@ -322,80 +294,29 @@ impl RealtimeWebsocketWriter {
handoff_id: String,
output_text: String,
) -> Result<(), ApiError> {
let message = match self.event_parser {
RealtimeEventParser::V1 => RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
},
RealtimeEventParser::RealtimeV2 => RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItemPayload::FunctionCallOutput(
ConversationFunctionCallOutputItem {
kind: "function_call_output".to_string(),
call_id: handoff_id,
output: output_text,
},
),
},
};
self.send_json(message).await
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
})
.await
}
pub async fn send_session_update(
&self,
instructions: String,
session_mode: RealtimeSessionMode,
) -> Result<(), ApiError> {
let session_mode = normalized_session_mode(self.event_parser, session_mode);
let (session_kind, session_instructions, output_audio) = match session_mode {
RealtimeSessionMode::Conversational => {
let kind = match self.event_parser {
RealtimeEventParser::V1 => REALTIME_V1_SESSION_TYPE.to_string(),
RealtimeEventParser::RealtimeV2 => REALTIME_V2_SESSION_TYPE.to_string(),
};
(
kind,
Some(instructions),
Some(SessionAudioOutput {
voice: REALTIME_AUDIO_VOICE.to_string(),
}),
)
}
RealtimeSessionMode::Transcription => ("transcription".to_string(), None, None),
};
let tools = match self.event_parser {
RealtimeEventParser::RealtimeV2 => Some(vec![SessionFunctionTool {
kind: "function".to_string(),
name: REALTIME_V2_CODEX_TOOL_NAME.to_string(),
description: REALTIME_V2_CODEX_TOOL_DESCRIPTION.to_string(),
parameters: json!({
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Prompt text for the delegated Codex task."
}
},
"required": ["prompt"],
"additionalProperties": false
}),
}]),
RealtimeEventParser::V1 => None,
};
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionUpdate {
session: SessionUpdateSession {
kind: session_kind,
instructions: session_instructions,
kind: "quicksilver".to_string(),
instructions,
audio: SessionAudio {
input: SessionAudioInput {
format: SessionAudioFormat {
kind: "audio/pcm".to_string(),
rate: REALTIME_AUDIO_SAMPLE_RATE,
rate: 24_000,
},
},
output: output_audio,
output: SessionAudioOutput {
voice: "fathom".to_string(),
},
},
tools,
},
})
.await
@@ -544,8 +465,6 @@ impl RealtimeWebsocketClient {
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
config.model.as_deref(),
config.event_parser,
config.session_mode,
)?;
let mut request = ws_url
@@ -587,7 +506,7 @@ impl RealtimeWebsocketClient {
);
connection
.writer
.send_session_update(config.instructions, config.session_mode)
.send_session_update(config.instructions)
.await?;
Ok(connection)
}
@@ -632,8 +551,6 @@ fn websocket_url_from_api_url(
api_url: &str,
query_params: Option<&HashMap<String, String>>,
model: Option<&str>,
event_parser: RealtimeEventParser,
_session_mode: RealtimeSessionMode,
) -> Result<Url, ApiError> {
let mut url = Url::parse(api_url)
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
@@ -653,20 +570,9 @@ fn websocket_url_from_api_url(
}
}
let intent = match event_parser {
RealtimeEventParser::V1 => Some("quicksilver"),
RealtimeEventParser::RealtimeV2 => None,
};
let has_extra_query_params = query_params.is_some_and(|query_params| {
query_params
.iter()
.any(|(key, _)| key != "intent" && !(key == "model" && model.is_some()))
});
if intent.is_some() || model.is_some() || has_extra_query_params {
{
let mut query = url.query_pairs_mut();
if let Some(intent) = intent {
query.append_pair("intent", intent);
}
query.append_pair("intent", "quicksilver");
if let Some(model) = model {
query.append_pair("model", model);
}
@@ -947,14 +853,8 @@ mod tests {
#[test]
fn websocket_url_from_http_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url(
"http://127.0.0.1:8011",
None,
None,
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
assert_eq!(
url.as_str(),
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
@@ -963,14 +863,9 @@ mod tests {
#[test]
fn websocket_url_from_ws_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url(
"wss://example.com",
None,
Some("realtime-test-model"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
@@ -979,14 +874,8 @@ mod tests {
#[test]
fn websocket_url_from_v1_base_appends_realtime_path() {
let url = websocket_url_from_api_url(
"https://api.openai.com/v1",
None,
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
@@ -995,14 +884,9 @@ mod tests {
#[test]
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
let url = websocket_url_from_api_url(
"https://example.com/openai/v1",
None,
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
@@ -1018,8 +902,6 @@ mod tests {
("intent".to_string(), "ignored".to_string()),
])),
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
assert_eq!(
@@ -1028,54 +910,6 @@ mod tests {
);
}
#[test]
fn websocket_url_v1_ignores_transcription_mode() {
let url = websocket_url_from_api_url(
"https://example.com",
None,
None,
RealtimeEventParser::V1,
RealtimeSessionMode::Transcription,
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver"
);
}
#[test]
fn websocket_url_omits_intent_for_realtime_v2_conversational_mode() {
let url = websocket_url_from_api_url(
"https://example.com/v1/realtime?foo=bar",
Some(&HashMap::from([
("trace".to_string(), "1".to_string()),
("intent".to_string(), "ignored".to_string()),
])),
Some("snapshot"),
RealtimeEventParser::RealtimeV2,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
);
}
#[test]
fn websocket_url_omits_intent_for_realtime_v2_transcription_mode() {
let url = websocket_url_from_api_url(
"https://example.com",
None,
None,
RealtimeEventParser::RealtimeV2,
RealtimeSessionMode::Transcription,
)
.expect("build ws url");
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
}
#[tokio::test]
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
@@ -1241,7 +1075,6 @@ mod tests {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -1362,352 +1195,6 @@ mod tests {
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_v2_session_update_includes_codex_tool_and_handoff_output_item() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("realtime".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["type"],
Value::String("function".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["parameters"]["required"],
json!(["prompt"])
);
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_v2", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "conversation.item.create");
assert_eq!(
second_json["item"]["type"],
Value::String("message".to_string())
);
assert_eq!(
second_json["item"]["content"][0]["type"],
Value::String("input_text".to_string())
);
assert_eq!(
second_json["item"]["content"][0]["text"],
Value::String("delegate this".to_string())
);
let third = ws
.next()
.await
.expect("third msg")
.expect("third msg ok")
.into_text()
.expect("text");
let third_json: Value = serde_json::from_str(&third).expect("json");
assert_eq!(third_json["type"], "conversation.item.create");
assert_eq!(
third_json["item"]["type"],
Value::String("function_call_output".to_string())
);
assert_eq!(
third_json["item"]["call_id"],
Value::String("call_1".to_string())
);
assert_eq!(
third_json["item"]["output"],
Value::String("delegated result".to_string())
);
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_v2".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
connection
.send_conversation_item_create("delegate this".to_string())
.await
.expect("send text item");
connection
.send_conversation_handoff_append("call_1".to_string(), "delegated result".to_string())
.await
.expect("send handoff output");
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn transcription_mode_session_update_omits_output_audio_and_instructions() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("transcription".to_string())
);
assert!(first_json["session"].get("instructions").is_none());
assert!(first_json["session"]["audio"].get("output").is_none());
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_transcription"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "input_audio_buffer.append");
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Transcription,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_transcription".to_string(),
instructions: None,
}
);
connection
.send_audio_frame(RealtimeAudioFrame {
data: "AQID".to_string(),
sample_rate: 24_000,
num_channels: 1,
samples_per_channel: Some(480),
})
.await
.expect("send audio");
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn v1_transcription_mode_is_treated_as_conversational() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("quicksilver".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["audio"]["output"]["voice"],
Value::String("fathom".to_string())
);
assert!(first_json["session"].get("tools").is_none());
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_v1_mode"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Transcription,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_v1_mode".to_string(),
instructions: None,
}
);
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn send_does_not_block_while_next_event_waits_for_inbound_data() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
@@ -1771,7 +1258,6 @@ mod tests {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),

View File

@@ -1,7 +1,5 @@
pub mod methods;
pub mod protocol;
mod protocol_common;
mod protocol_v1;
mod protocol_v2;
pub use codex_protocol::protocol::RealtimeAudioFrame;
@@ -12,4 +10,3 @@ pub use methods::RealtimeWebsocketEvents;
pub use methods::RealtimeWebsocketWriter;
pub use protocol::RealtimeEventParser;
pub use protocol::RealtimeSessionConfig;
pub use protocol::RealtimeSessionMode;

View File

@@ -1,4 +1,3 @@
use crate::endpoint::realtime_websocket::protocol_v1::parse_realtime_event_v1;
use crate::endpoint::realtime_websocket::protocol_v2::parse_realtime_event_v2;
pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
@@ -7,6 +6,7 @@ pub use codex_protocol::protocol::RealtimeTranscriptDelta;
pub use codex_protocol::protocol::RealtimeTranscriptEntry;
use serde::Serialize;
use serde_json::Value;
use tracing::debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RealtimeEventParser {
@@ -14,19 +14,12 @@ pub enum RealtimeEventParser {
RealtimeV2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RealtimeSessionMode {
Conversational,
Transcription,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RealtimeSessionConfig {
pub instructions: String,
pub model: Option<String>,
pub session_id: Option<String>,
pub event_parser: RealtimeEventParser,
pub session_mode: RealtimeSessionMode,
}
#[derive(Debug, Clone, Serialize)]
@@ -42,25 +35,21 @@ pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "session.update")]
SessionUpdate { session: SessionUpdateSession },
#[serde(rename = "conversation.item.create")]
ConversationItemCreate { item: ConversationItemPayload },
ConversationItemCreate { item: ConversationItem },
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionUpdateSession {
#[serde(rename = "type")]
pub(super) kind: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) instructions: Option<String>,
pub(super) instructions: String,
pub(super) audio: SessionAudio,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) tools: Option<Vec<SessionFunctionTool>>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudio {
pub(super) input: SessionAudioInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) output: Option<SessionAudioOutput>,
pub(super) output: SessionAudioOutput,
}
#[derive(Debug, Clone, Serialize)]
@@ -81,28 +70,13 @@ pub(super) struct SessionAudioOutput {
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationMessageItem {
pub(super) struct ConversationItem {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) role: String,
pub(super) content: Vec<ConversationItemContent>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub(super) enum ConversationItemPayload {
Message(ConversationMessageItem),
FunctionCallOutput(ConversationFunctionCallOutputItem),
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationFunctionCallOutputItem {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) call_id: String,
pub(super) output: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationItemContent {
#[serde(rename = "type")]
@@ -110,15 +84,6 @@ pub(super) struct ConversationItemContent {
pub(super) text: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionFunctionTool {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) name: String,
pub(super) description: String,
pub(super) parameters: Value,
}
pub(super) fn parse_realtime_event(
payload: &str,
event_parser: RealtimeEventParser,
@@ -128,3 +93,125 @@ pub(super) fn parse_realtime_event(
RealtimeEventParser::RealtimeV2 => parse_realtime_event_v2(payload),
}
}
fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
let parsed: Value = match serde_json::from_str(payload) {
Ok(msg) => msg,
Err(err) => {
debug!("failed to parse realtime event: {err}, data: {payload}");
return None;
}
};
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type,
None => {
debug!("received realtime event without type field: {payload}");
return None;
}
};
match message_type {
"session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
"conversation.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.or_else(|| parsed.get("data").and_then(Value::as_str))
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok())?;
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok())?;
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok()),
}))
}
"conversation.input_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.output_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parsed
.get("item")
.and_then(Value::as_object)
.and_then(|item| item.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
active_transcript: Vec::new(),
}))
}
"error" => parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
None
}
}
}

View File

@@ -1,71 +0,0 @@
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeTranscriptDelta;
use serde_json::Value;
use tracing::debug;
pub(super) fn parse_realtime_payload(payload: &str, parser_name: &str) -> Option<(Value, String)> {
let parsed: Value = match serde_json::from_str(payload) {
Ok(message) => message,
Err(err) => {
debug!("failed to parse {parser_name} event: {err}, data: {payload}");
return None;
}
};
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type.to_string(),
None => {
debug!("received {parser_name} event without type field: {payload}");
return None;
}
};
Some((parsed, message_type))
}
pub(super) fn parse_session_updated_event(parsed: &Value) -> Option<RealtimeEvent> {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string)?;
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
Some(RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
pub(super) fn parse_transcript_delta_event(
parsed: &Value,
field: &str,
) -> Option<RealtimeTranscriptDelta> {
parsed
.get(field)
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeTranscriptDelta { delta })
}
pub(super) fn parse_error_event(parsed: &Value) -> Option<RealtimeEvent> {
parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error)
}

View File

@@ -1,83 +0,0 @@
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeHandoffRequested;
use serde_json::Value;
use tracing::debug;
pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v1")?;
match message_type.as_str() {
"session.updated" => parse_session_updated_event(&parsed),
"conversation.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.or_else(|| parsed.get("data").and_then(Value::as_str))
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())?;
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())?;
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
"conversation.input_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
}
"conversation.output_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
}
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parsed
.get("item")
.and_then(Value::as_object)
.and_then(|item| item.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
active_transcript: Vec::new(),
}))
}
"error" => parse_error_event(&parsed),
_ => {
debug!("received unsupported realtime v1 event type: {message_type}, data: {payload}");
None
}
}
}

View File

@@ -1,130 +1,157 @@
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeHandoffRequested;
use serde_json::Map as JsonMap;
use codex_protocol::protocol::RealtimeTranscriptDelta;
use serde_json::Value;
use tracing::debug;
const CODEX_TOOL_NAME: &str = "codex";
const DEFAULT_AUDIO_SAMPLE_RATE: u32 = 24_000;
const DEFAULT_AUDIO_CHANNELS: u16 = 1;
const TOOL_ARGUMENT_KEYS: [&str; 5] = ["input_transcript", "input", "text", "prompt", "query"];
pub(super) fn parse_realtime_event_v2(payload: &str) -> Option<RealtimeEvent> {
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v2")?;
let parsed: Value = match serde_json::from_str(payload) {
Ok(msg) => msg,
Err(err) => {
debug!("failed to parse realtime v2 event: {err}, data: {payload}");
return None;
}
};
match message_type.as_str() {
"session.updated" => parse_session_updated_event(&parsed),
"response.output_audio.delta" => parse_output_audio_delta_event(&parsed),
"conversation.item.input_audio_transcription.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type,
None => {
debug!("received realtime v2 event without type field: {payload}");
return None;
}
"conversation.item.input_audio_transcription.completed" => {
parse_transcript_delta_event(&parsed, "transcript")
.map(RealtimeEvent::InputTranscriptDelta)
};
match message_type {
"session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
"response.output_text.delta" | "response.output_audio_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
"response.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())
.unwrap_or(24_000);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())
.unwrap_or(1);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
"conversation.item.input_audio_transcription.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.input_audio_transcription.completed" => parsed
.get("transcript")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"response.output_text.delta" | "response.output_audio_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parse_conversation_item_done_event(&parsed),
"error" => parse_error_event(&parsed),
"conversation.item.done" => {
let item = parsed.get("item")?.as_object()?;
let item_type = item.get("type").and_then(Value::as_str);
let item_name = item.get("name").and_then(Value::as_str);
if item_type == Some("function_call") && item_name == Some("codex") {
let call_id = item
.get("call_id")
.and_then(Value::as_str)
.or_else(|| item.get("id").and_then(Value::as_str))?;
let item_id = item
.get("id")
.and_then(Value::as_str)
.unwrap_or(call_id)
.to_string();
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
let mut input_transcript = String::new();
if !arguments.is_empty() {
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
&& let Some(arguments_object) = arguments_json.as_object()
{
for key in ["input_transcript", "input", "text", "prompt", "query"] {
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
let trimmed = value.trim();
if !trimmed.is_empty() {
input_transcript = trimmed.to_string();
break;
}
}
}
}
if input_transcript.is_empty() {
input_transcript = arguments.to_string();
}
}
return Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: call_id.to_string(),
item_id,
input_transcript,
active_transcript: Vec::new(),
}));
}
item.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
}
"error" => parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime v2 event type: {message_type}, data: {payload}");
None
}
}
}
fn parse_output_audio_delta_event(parsed: &Value) -> Option<RealtimeEvent> {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())
.unwrap_or(DEFAULT_AUDIO_SAMPLE_RATE);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())
.unwrap_or(DEFAULT_AUDIO_CHANNELS);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
fn parse_conversation_item_done_event(parsed: &Value) -> Option<RealtimeEvent> {
let item = parsed.get("item")?.as_object()?;
if let Some(handoff) = parse_handoff_requested_event(item) {
return Some(handoff);
}
item.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
}
fn parse_handoff_requested_event(item: &JsonMap<String, Value>) -> Option<RealtimeEvent> {
let item_type = item.get("type").and_then(Value::as_str);
let item_name = item.get("name").and_then(Value::as_str);
if item_type != Some("function_call") || item_name != Some(CODEX_TOOL_NAME) {
return None;
}
let call_id = item
.get("call_id")
.and_then(Value::as_str)
.or_else(|| item.get("id").and_then(Value::as_str))?;
let item_id = item
.get("id")
.and_then(Value::as_str)
.unwrap_or(call_id)
.to_string();
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: call_id.to_string(),
item_id,
input_transcript: extract_input_transcript(arguments),
active_transcript: Vec::new(),
}))
}
fn extract_input_transcript(arguments: &str) -> String {
if arguments.is_empty() {
return String::new();
}
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
&& let Some(arguments_object) = arguments_json.as_object()
{
for key in TOOL_ARGUMENT_KEYS {
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
let trimmed = value.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
}
arguments.to_string()
}

View File

@@ -29,7 +29,6 @@ pub use crate::endpoint::memories::MemoriesClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
pub use crate::endpoint::responses::ResponsesClient;

View File

@@ -6,7 +6,6 @@ use codex_api::RealtimeAudioFrame;
use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
@@ -143,7 +142,6 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -237,7 +235,6 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -302,7 +299,6 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -364,7 +360,6 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -429,7 +424,6 @@ async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),

View File

@@ -1342,13 +1342,6 @@
},
"type": "object"
},
"RealtimeWsMode": {
"enum": [
"conversational",
"transcription"
],
"type": "string"
},
"ReasoningEffort": {
"description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning",
"enum": [
@@ -1823,14 +1816,6 @@
"description": "Experimental / do not use. Overrides only the realtime conversation websocket transport base URL (the `Op::RealtimeConversation` `/v1/realtime` connection) without changing normal provider HTTP requests.",
"type": "string"
},
"experimental_realtime_ws_mode": {
"allOf": [
{
"$ref": "#/definitions/RealtimeWsMode"
}
],
"description": "Experimental / do not use. Selects the realtime websocket intent mode. `conversational` is speech-to-speech while `transcription` is transcript-only."
},
"experimental_realtime_ws_model": {
"description": "Experimental / do not use. Selects the realtime websocket model/snapshot used for the `Op::RealtimeConversation` connection.",
"type": "string"

View File

@@ -172,6 +172,8 @@ use crate::error::CodexErr;
use crate::error::Result as CodexResult;
#[cfg(test)]
use crate::exec::StreamOutput;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::network_proxy_registry::NetworkProxyScope;
use codex_config::CONFIG_TOML_FILE;
mod rollout_reconstruction;
@@ -276,6 +278,7 @@ use crate::skills::collect_explicit_skill_mentions;
use crate::skills::injection::ToolMentionKind;
use crate::skills::injection::app_id_from_path;
use crate::skills::injection::tool_kind_for_path;
use crate::skills::model::SkillManagedNetworkOverride;
use crate::skills::resolve_skill_dependencies_for_turn;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
@@ -1182,6 +1185,61 @@ impl Session {
Ok((network_proxy, session_network_proxy))
}
pub(crate) async fn get_or_start_network_proxy(
self: &Arc<Self>,
scope: NetworkProxyScope,
sandbox_policy: &SandboxPolicy,
managed_network_override: Option<SkillManagedNetworkOverride>,
) -> anyhow::Result<Option<NetworkProxy>> {
let session = Arc::clone(self);
let started = self
.services
.network_proxies
.get_or_start(
scope.clone(),
move |spec, managed_enabled, audit_metadata| {
let session = Arc::clone(&session);
let managed_network_override = managed_network_override.clone();
let scope = scope.clone();
let sandbox_policy = sandbox_policy.clone();
async move {
let network_policy_decider = session
.services
.network_policy_decider_session
.as_ref()
.map(|network_policy_decider_session| {
build_network_policy_decider(
Arc::clone(&session.services.network_approval),
Arc::clone(network_policy_decider_session),
scope,
)
});
let spec = if let Some(managed_network_override) =
managed_network_override.as_ref()
{
spec.with_skill_managed_network_override(managed_network_override)
} else {
spec
};
spec.start_proxy(
&sandbox_policy,
network_policy_decider,
session
.services
.network_blocked_request_observer
.as_ref()
.map(Arc::clone),
managed_enabled,
audit_metadata,
)
.await
}
},
)
.await?;
Ok(started.map(|started| started.proxy()))
}
/// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it.
pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
// todo(aibrahim): store this state somewhere else so we don't need to mut config
@@ -1651,9 +1709,10 @@ impl Session {
build_network_policy_decider(
Arc::clone(&network_approval),
Arc::clone(network_policy_decider_session),
NetworkProxyScope::SessionDefault,
)
});
let (network_proxy, session_network_proxy) =
let (default_network_proxy, session_network_proxy) =
if let Some(spec) = config.permissions.network.as_ref() {
let (network_proxy, session_network_proxy) = Self::start_managed_network_proxy(
spec,
@@ -1661,13 +1720,19 @@ impl Session {
network_policy_decider.as_ref().map(Arc::clone),
blocked_request_observer.as_ref().map(Arc::clone),
managed_network_requirements_enabled,
network_proxy_audit_metadata,
network_proxy_audit_metadata.clone(),
)
.await?;
(Some(network_proxy), Some(session_network_proxy))
} else {
(None, None)
};
let network_proxies = NetworkProxyRegistry::new(
config.permissions.network.clone(),
managed_network_requirements_enabled,
network_proxy_audit_metadata.clone(),
default_network_proxy,
);
let mut hook_shell_argv = default_shell.derive_exec_args("", false);
let hook_shell_program = hook_shell_argv.remove(0);
@@ -1725,7 +1790,9 @@ impl Session {
mcp_manager: Arc::clone(&mcp_manager),
file_watcher,
agent_control,
network_proxy,
network_proxies,
network_policy_decider_session,
network_blocked_request_observer: blocked_request_observer,
network_approval: Arc::clone(&network_approval),
state_db: state_db_ctx.clone(),
model_client: ModelClient::new(
@@ -1764,7 +1831,9 @@ impl Session {
js_repl,
next_internal_sub_id: AtomicU64::new(0),
});
if let Some(network_policy_decider_session) = network_policy_decider_session {
if let Some(network_policy_decider_session) =
sess.services.network_policy_decider_session.as_ref()
{
let mut guard = network_policy_decider_session.write().await;
*guard = Arc::downgrade(&sess);
}
@@ -2314,8 +2383,10 @@ impl Session {
model_info,
&self.services.models_manager,
self.services
.network_proxy
.as_ref()
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
.as_deref()
.map(StartedNetworkProxy::proxy),
sub_id,
Arc::clone(&self.js_repl),
@@ -2687,6 +2758,7 @@ impl Session {
&self,
amendment: &NetworkPolicyAmendment,
network_approval_context: &NetworkApprovalContext,
scope: &NetworkProxyScope,
) -> anyhow::Result<()> {
let host =
Self::validated_network_policy_amendment_host(amendment, network_approval_context)?;
@@ -2700,7 +2772,7 @@ impl Session {
let execpolicy_amendment =
execpolicy_network_rule_amendment(amendment, network_approval_context, &host);
if let Some(started_network_proxy) = self.services.network_proxy.as_ref() {
if let Some(started_network_proxy) = self.services.network_proxies.get(scope).await {
let proxy = started_network_proxy.proxy();
match amendment.action {
NetworkPolicyRuleAction::Allow => proxy

View File

@@ -11,9 +11,11 @@ use crate::exec::ExecToolCallOutput;
use crate::function_tool::FunctionCallError;
use crate::mcp_connection_manager::ToolInfo;
use crate::models_manager::model_info;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::shell::default_user_shell;
use crate::tools::format_exec_output_str;
use codex_network_proxy::NetworkProxyAuditMetadata;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
@@ -2152,7 +2154,14 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
mcp_manager,
file_watcher,
agent_control,
network_proxy: None,
network_proxies: NetworkProxyRegistry::new(
None,
false,
NetworkProxyAuditMetadata::default(),
None,
),
network_policy_decider_session: None,
network_blocked_request_observer: None,
network_approval: Arc::clone(&network_approval),
state_db: None,
model_client: ModelClient::new(
@@ -2794,7 +2803,14 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
mcp_manager,
file_watcher,
agent_control,
network_proxy: None,
network_proxies: NetworkProxyRegistry::new(
None,
false,
NetworkProxyAuditMetadata::default(),
None,
),
network_policy_decider_session: None,
network_blocked_request_observer: None,
network_approval: Arc::clone(&network_approval),
state_db: None,
model_client: ModelClient::new(

View File

@@ -4129,7 +4129,6 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4266,7 +4265,6 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4401,7 +4399,6 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4522,7 +4519,6 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -5570,34 +5566,6 @@ experimental_realtime_ws_model = "realtime-test-model"
Ok(())
}
#[test]
fn experimental_realtime_ws_mode_loads_from_config_toml() -> std::io::Result<()> {
let cfg: ConfigToml = toml::from_str(
r#"
experimental_realtime_ws_mode = "transcription"
"#,
)
.expect("TOML deserialization should succeed");
assert_eq!(
cfg.experimental_realtime_ws_mode,
Some(RealtimeWsMode::Transcription)
);
let codex_home = TempDir::new()?;
let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;
assert_eq!(
config.experimental_realtime_ws_mode,
RealtimeWsMode::Transcription
);
Ok(())
}
#[test]
fn realtime_audio_loads_from_config_toml() -> std::io::Result<()> {
let cfg: ConfigToml = toml::from_str(

View File

@@ -463,9 +463,6 @@ pub struct Config {
/// Experimental / do not use. Selects the realtime websocket model/snapshot
/// used for the `Op::RealtimeConversation` connection.
pub experimental_realtime_ws_model: Option<String>,
/// Experimental / do not use. Selects the realtime websocket intent mode.
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
pub experimental_realtime_ws_mode: RealtimeWsMode,
/// Experimental / do not use. Overrides only the realtime conversation
/// websocket transport instructions (the `Op::RealtimeConversation`
/// `/ws` session.update instructions) without changing normal prompts.
@@ -1241,9 +1238,6 @@ pub struct ConfigToml {
/// Experimental / do not use. Selects the realtime websocket model/snapshot
/// used for the `Op::RealtimeConversation` connection.
pub experimental_realtime_ws_model: Option<String>,
/// Experimental / do not use. Selects the realtime websocket intent mode.
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
pub experimental_realtime_ws_mode: Option<RealtimeWsMode>,
/// Experimental / do not use. Overrides only the realtime conversation
/// websocket transport instructions (the `Op::RealtimeConversation`
/// `/ws` session.update instructions) without changing normal prompts.
@@ -1389,14 +1383,6 @@ pub struct RealtimeAudioConfig {
pub speaker: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum RealtimeWsMode {
#[default]
Conversational,
Transcription,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct RealtimeAudioToml {
@@ -2476,7 +2462,6 @@ impl Config {
}),
experimental_realtime_ws_base_url: cfg.experimental_realtime_ws_base_url,
experimental_realtime_ws_model: cfg.experimental_realtime_ws_model,
experimental_realtime_ws_mode: cfg.experimental_realtime_ws_mode.unwrap_or_default(),
experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt,
experimental_realtime_ws_startup_context: cfg.experimental_realtime_ws_startup_context,
experimental_realtime_start_instructions: cfg.experimental_realtime_start_instructions,

View File

@@ -1,4 +1,5 @@
use crate::config_loader::NetworkConstraints;
use crate::skills::model::SkillManagedNetworkOverride;
use async_trait::async_trait;
use codex_network_proxy::BlockedRequestObserver;
use codex_network_proxy::ConfigReloader;
@@ -81,6 +82,28 @@ impl NetworkProxySpec {
self.config.network.enable_socks5
}
pub(crate) fn with_skill_managed_network_override(
&self,
managed_network_override: &SkillManagedNetworkOverride,
) -> Self {
let mut spec = self.clone();
if let Some(allowed_domains) = managed_network_override.allowed_domains.clone() {
spec.config.network.allowed_domains = allowed_domains.clone();
if spec.constraints.allowed_domains.is_some() {
spec.constraints.allowed_domains = Some(allowed_domains);
}
}
if let Some(denied_domains) = managed_network_override.denied_domains.clone() {
spec.config.network.denied_domains = denied_domains.clone();
if spec.constraints.denied_domains.is_some() {
spec.constraints.denied_domains = Some(denied_domains);
}
}
spec
}
pub(crate) fn from_config_and_constraints(
config: NetworkProxyConfig,
requirements: Option<NetworkConstraints>,

View File

@@ -1,6 +1,94 @@
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn skill_managed_network_override_replaces_allowed_domains_and_keeps_other_settings() {
let mut config = NetworkProxyConfig::default();
config.network.enabled = true;
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
config.network.enable_socks5 = true;
config.network.enable_socks5_udp = true;
config.network.allowed_domains = vec!["default.example.com".to_string()];
config.network.denied_domains = vec!["blocked.example.com".to_string()];
config.network.allow_upstream_proxy = true;
config.network.dangerously_allow_all_unix_sockets = false;
config.network.dangerously_allow_non_loopback_proxy = false;
config.network.mode = codex_network_proxy::NetworkMode::Full;
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
config.network.allow_local_binding = true;
config.network.mitm = false;
let spec = NetworkProxySpec {
config,
constraints: NetworkProxyConstraints {
allowed_domains: Some(vec!["default.example.com".to_string()]),
denied_domains: Some(vec!["blocked.example.com".to_string()]),
allowlist_expansion_enabled: Some(true),
denylist_expansion_enabled: Some(false),
allow_upstream_proxy: Some(true),
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
allow_local_binding: Some(true),
..NetworkProxyConstraints::default()
},
hard_deny_allowlist_misses: true,
};
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: Some(vec!["skill.example.com".to_string()]),
denied_domains: None,
};
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
let mut expected = spec.clone();
expected.config.network.allowed_domains = vec!["skill.example.com".to_string()];
expected.constraints.allowed_domains = Some(vec!["skill.example.com".to_string()]);
assert_eq!(overridden, expected);
}
#[test]
fn skill_managed_network_override_replaces_denied_domains_and_keeps_default_allowed_domains() {
let mut config = NetworkProxyConfig::default();
config.network.enabled = true;
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
config.network.enable_socks5 = true;
config.network.enable_socks5_udp = true;
config.network.allowed_domains = vec!["default.example.com".to_string()];
config.network.denied_domains = vec!["blocked.example.com".to_string()];
config.network.allow_upstream_proxy = true;
config.network.dangerously_allow_all_unix_sockets = false;
config.network.dangerously_allow_non_loopback_proxy = false;
config.network.mode = codex_network_proxy::NetworkMode::Full;
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
config.network.allow_local_binding = true;
config.network.mitm = false;
let spec = NetworkProxySpec {
config,
constraints: NetworkProxyConstraints {
allowed_domains: Some(vec!["default.example.com".to_string()]),
denied_domains: Some(vec!["blocked.example.com".to_string()]),
allowlist_expansion_enabled: Some(true),
denylist_expansion_enabled: Some(false),
allow_upstream_proxy: Some(true),
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
allow_local_binding: Some(true),
..NetworkProxyConstraints::default()
},
hard_deny_allowlist_misses: false,
};
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: None,
denied_domains: Some(vec!["skill-blocked.example.com".to_string()]),
};
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
let mut expected = spec.clone();
expected.config.network.denied_domains = vec!["skill-blocked.example.com".to_string()];
expected.constraints.denied_domains = Some(vec!["skill-blocked.example.com".to_string()]);
assert_eq!(overridden, expected);
}
#[test]
fn build_state_with_audit_metadata_threads_metadata_to_state() {
let spec = NetworkProxySpec {

View File

@@ -39,6 +39,7 @@ use crate::config::Constrained;
use crate::config::NetworkProxySpec;
use crate::event_mapping::is_contextual_user_message_content;
use crate::features::Feature;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::protocol::Op;
use crate::protocol::SandboxPolicy;
use crate::truncate::approx_bytes_for_tokens;
@@ -550,7 +551,12 @@ async fn run_guardian_subagent(
schema: Value,
cancel_token: CancellationToken,
) -> anyhow::Result<GuardianAssessment> {
let live_network_config = match session.services.network_proxy.as_ref() {
let live_network_config = match session
.services
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
{
Some(network_proxy) => Some(network_proxy.proxy().current_cfg().await?),
None => None,
};

View File

@@ -51,6 +51,7 @@ mod mcp_tool_approval_templates;
pub mod models_manager;
mod network_policy_decision;
pub mod network_proxy_loader;
mod network_proxy_registry;
mod original_image_detail;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;

View File

@@ -103,7 +103,6 @@ fn shell_command_for_invocation(invocation: &ToolInvocation) -> Option<(Vec<Stri
&params,
invocation.session.user_shell(),
invocation.turn.tools_config.allow_login_shell,
invocation.turn.tools_config.unified_exec_backend,
)
.ok()?;
Some((command, invocation.turn.resolve_path(params.workdir)))

View File

@@ -0,0 +1,77 @@
use crate::config::NetworkProxySpec;
use crate::config::StartedNetworkProxy;
use anyhow::Result;
use codex_network_proxy::NetworkProxyAuditMetadata;
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub(crate) enum NetworkProxyScope {
SessionDefault,
Skill { path_to_skills_md: PathBuf },
}
pub(crate) struct NetworkProxyRegistry {
spec: Option<NetworkProxySpec>,
managed_network_requirements_enabled: bool,
audit_metadata: NetworkProxyAuditMetadata,
proxies: Mutex<HashMap<NetworkProxyScope, Arc<StartedNetworkProxy>>>,
}
impl NetworkProxyRegistry {
pub(crate) fn new(
spec: Option<NetworkProxySpec>,
managed_network_requirements_enabled: bool,
audit_metadata: NetworkProxyAuditMetadata,
default_proxy: Option<StartedNetworkProxy>,
) -> Self {
let mut proxies = HashMap::new();
if let Some(default_proxy) = default_proxy {
proxies.insert(NetworkProxyScope::SessionDefault, Arc::new(default_proxy));
}
Self {
spec,
managed_network_requirements_enabled,
audit_metadata,
proxies: Mutex::new(proxies),
}
}
pub(crate) async fn get(&self, scope: &NetworkProxyScope) -> Option<Arc<StartedNetworkProxy>> {
self.proxies.lock().await.get(scope).cloned()
}
pub(crate) async fn get_or_start<F, Fut>(
&self,
scope: NetworkProxyScope,
start: F,
) -> Result<Option<Arc<StartedNetworkProxy>>>
where
F: FnOnce(NetworkProxySpec, bool, NetworkProxyAuditMetadata) -> Fut,
Fut: Future<Output = std::io::Result<StartedNetworkProxy>>,
{
let mut proxies = self.proxies.lock().await;
if let Some(existing) = proxies.get(&scope).cloned() {
return Ok(Some(existing));
}
let Some(spec) = self.spec.clone() else {
return Ok(None);
};
let started = Arc::new(
start(
spec,
self.managed_network_requirements_enabled,
self.audit_metadata.clone(),
)
.await?,
);
proxies.insert(scope, Arc::clone(&started));
Ok(Some(started))
}
}

View File

@@ -15,7 +15,6 @@ use codex_api::RealtimeAudioFrame;
use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
@@ -117,7 +116,10 @@ impl RealtimeConversationManager {
&self,
api_provider: ApiProvider,
extra_headers: Option<HeaderMap>,
session_config: RealtimeSessionConfig,
prompt: String,
model: Option<String>,
session_id: Option<String>,
event_parser: RealtimeEventParser,
) -> CodexResult<(Receiver<RealtimeEvent>, Arc<AtomicBool>)> {
let previous_state = {
let mut guard = self.state.lock().await;
@@ -129,6 +131,12 @@ impl RealtimeConversationManager {
let _ = state.task.await;
}
let session_config = RealtimeSessionConfig {
instructions: prompt,
model,
session_id,
event_parser,
};
let client = RealtimeWebsocketClient::new(api_provider);
let connection = client
.connect(
@@ -299,26 +307,23 @@ pub(crate) async fn handle_start(
} else {
RealtimeEventParser::V1
};
let session_mode = match config.experimental_realtime_ws_mode {
crate::config::RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational,
crate::config::RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription,
};
let requested_session_id = params
.session_id
.or_else(|| Some(sess.conversation_id.to_string()));
let session_config = RealtimeSessionConfig {
instructions: prompt,
model,
session_id: requested_session_id.clone(),
event_parser,
session_mode,
};
let extra_headers =
realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?;
info!("starting realtime conversation");
let (events_rx, realtime_active) = match sess
.conversation
.start(api_provider, extra_headers, session_config)
.start(
api_provider,
extra_headers,
prompt,
model,
requested_session_id.clone(),
event_parser,
)
.await
{
Ok(events_rx) => events_rx,

View File

@@ -6,12 +6,13 @@ use crate::RolloutRecorder;
use crate::agent::AgentControl;
use crate::analytics_client::AnalyticsEventsClient;
use crate::client::ModelClient;
use crate::config::StartedNetworkProxy;
use crate::codex::Session;
use crate::exec_policy::ExecPolicyManager;
use crate::file_watcher::FileWatcher;
use crate::mcp::McpManager;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::models_manager::manager::ModelsManager;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::plugins::PluginsManager;
use crate::skills::SkillsManager;
use crate::state_db::StateDbHandle;
@@ -21,9 +22,11 @@ use crate::tools::runtimes::ExecveSessionApproval;
use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecProcessManager;
use codex_hooks::Hooks;
use codex_network_proxy::BlockedRequestObserver;
use codex_otel::SessionTelemetry;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::path::PathBuf;
use std::sync::Weak;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::watch;
@@ -55,7 +58,9 @@ pub(crate) struct SessionServices {
pub(crate) mcp_manager: Arc<McpManager>,
pub(crate) file_watcher: Arc<FileWatcher>,
pub(crate) agent_control: AgentControl,
pub(crate) network_proxy: Option<StartedNetworkProxy>,
pub(crate) network_proxies: NetworkProxyRegistry,
pub(crate) network_policy_decider_session: Option<Arc<RwLock<Weak<Session>>>>,
pub(crate) network_blocked_request_observer: Option<Arc<dyn BlockedRequestObserver>>,
pub(crate) network_approval: Arc<NetworkApprovalService>,
pub(crate) state_db: Option<StateDbHandle>,
/// Session-scoped model client shared across turns.

View File

@@ -46,6 +46,7 @@ use codex_protocol::protocol::RolloutItem;
use codex_protocol::user_input::UserInput;
use crate::features::Feature;
use crate::network_proxy_registry::NetworkProxyScope;
pub(crate) use compact::CompactTask;
pub(crate) use ghost_snapshot::GhostSnapshotTask;
pub(crate) use regular::RegularTask;
@@ -292,7 +293,12 @@ impl Session {
"false"
},
);
let network_proxy_active = match self.services.network_proxy.as_ref() {
let network_proxy_active = match self
.services
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
{
Some(started_network_proxy) => {
match started_network_proxy.proxy().current_cfg().await {
Ok(config) => config.network.enabled,

View File

@@ -105,11 +105,842 @@ where
})
}
pub mod close_agent;
mod resume_agent;
mod send_input;
mod spawn;
pub(crate) mod wait;
mod spawn {
use super::*;
use crate::agent::control::SpawnAgentOptions;
use crate::agent::role::DEFAULT_ROLE_NAME;
use crate::agent::role::apply_role_to_config;
use crate::agent::exceeds_thread_spawn_depth_limit;
use crate::agent::next_thread_spawn_depth;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = SpawnAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
let role_name = args
.agent_type
.as_deref()
.map(str::trim)
.filter(|role| !role.is_empty());
let input_items = parse_collab_input(args.message, args.items)?;
let prompt = input_preview(&input_items);
let session_source = turn.session_source.clone();
let child_depth = next_thread_spawn_depth(&session_source);
let max_depth = turn.config.agent_max_depth;
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
return Err(FunctionCallError::RespondToModel(
"Agent depth limit reached. Solve the task yourself.".to_string(),
));
}
session
.send_event(
&turn,
CollabAgentSpawnBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
prompt: prompt.clone(),
model: args.model.clone().unwrap_or_default(),
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
}
.into(),
)
.await;
let mut config =
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
apply_requested_spawn_agent_model_overrides(
&session,
turn.as_ref(),
&mut config,
args.model.as_deref(),
args.reasoning_effort,
)
.await?;
apply_role_to_config(&mut config, role_name)
.await
.map_err(FunctionCallError::RespondToModel)?;
apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?;
apply_spawn_agent_overrides(&mut config, child_depth);
let result = session
.services
.agent_control
.spawn_agent_with_options(
config,
input_items,
Some(thread_spawn_source(
session.conversation_id,
child_depth,
role_name,
)),
SpawnAgentOptions {
fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()),
},
)
.await
.map_err(collab_spawn_error);
let (new_thread_id, status) = match &result {
Ok(thread_id) => (
Some(*thread_id),
session.services.agent_control.get_status(*thread_id).await,
),
Err(_) => (None, AgentStatus::NotFound),
};
let (new_agent_nickname, new_agent_role) = match new_thread_id {
Some(thread_id) => session
.services
.agent_control
.get_agent_nickname_and_role(thread_id)
.await
.unwrap_or((None, None)),
None => (None, None),
};
let nickname = new_agent_nickname.clone();
session
.send_event(
&turn,
CollabAgentSpawnEndEvent {
call_id,
sender_thread_id: session.conversation_id,
new_thread_id,
new_agent_nickname,
new_agent_role,
prompt,
model: args.model.clone().unwrap_or_default(),
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
status,
}
.into(),
)
.await;
let new_thread_id = result?;
let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME);
turn.session_telemetry
.counter("codex.multi_agent.spawn", 1, &[("role", role_tag)]);
Ok(SpawnAgentResult {
agent_id: new_thread_id.to_string(),
nickname,
})
}
}
#[derive(Debug, Deserialize)]
struct SpawnAgentArgs {
message: Option<String>,
items: Option<Vec<UserInput>>,
agent_type: Option<String>,
model: Option<String>,
reasoning_effort: Option<ReasoningEffort>,
#[serde(default)]
fork_context: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct SpawnAgentResult {
agent_id: String,
nickname: Option<String>,
}
impl ToolOutput for SpawnAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "spawn_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "spawn_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "spawn_agent")
}
}
}
mod send_input {
use super::*;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = SendInputResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: SendInputArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
let input_items = parse_collab_input(args.message, args.items)?;
let prompt = input_preview(&input_items);
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((None, None));
if args.interrupt {
session
.services
.agent_control
.interrupt_agent(receiver_thread_id)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
}
session
.send_event(
&turn,
CollabAgentInteractionBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id,
prompt: prompt.clone(),
}
.into(),
)
.await;
let result = session
.services
.agent_control
.send_input(receiver_thread_id, input_items)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err));
let status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
session
.send_event(
&turn,
CollabAgentInteractionEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname,
receiver_agent_role,
prompt,
status,
}
.into(),
)
.await;
let submission_id = result?;
Ok(SendInputResult { submission_id })
}
}
#[derive(Debug, Deserialize)]
struct SendInputArgs {
id: String,
message: Option<String>,
items: Option<Vec<UserInput>>,
#[serde(default)]
interrupt: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct SendInputResult {
submission_id: String,
}
impl ToolOutput for SendInputResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "send_input")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "send_input")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "send_input")
}
}
}
mod resume_agent {
use super::*;
use crate::agent::next_thread_spawn_depth;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = ResumeAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: ResumeAgentArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((None, None));
let child_depth = next_thread_spawn_depth(&turn.session_source);
let max_depth = turn.config.agent_max_depth;
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
return Err(FunctionCallError::RespondToModel(
"Agent depth limit reached. Solve the task yourself.".to_string(),
));
}
session
.send_event(
&turn,
CollabResumeBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname: receiver_agent_nickname.clone(),
receiver_agent_role: receiver_agent_role.clone(),
}
.into(),
)
.await;
let mut status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
let error = if matches!(status, AgentStatus::NotFound) {
match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth)
.await
{
Ok(resumed_status) => {
status = resumed_status;
None
}
Err(err) => {
status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
Some(err)
}
}
} else {
None
};
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((receiver_agent_nickname, receiver_agent_role));
session
.send_event(
&turn,
CollabResumeEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname,
receiver_agent_role,
status: status.clone(),
}
.into(),
)
.await;
if let Some(err) = error {
return Err(err);
}
turn.session_telemetry
.counter("codex.multi_agent.resume", 1, &[]);
Ok(ResumeAgentResult { status })
}
}
#[derive(Debug, Deserialize)]
struct ResumeAgentArgs {
id: String,
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
pub(crate) struct ResumeAgentResult {
pub(crate) status: AgentStatus,
}
impl ToolOutput for ResumeAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "resume_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "resume_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "resume_agent")
}
}
async fn try_resume_closed_agent(
session: &Arc<Session>,
turn: &Arc<TurnContext>,
receiver_thread_id: ThreadId,
child_depth: i32,
) -> Result<AgentStatus, FunctionCallError> {
let config = build_agent_resume_config(turn.as_ref(), child_depth)?;
let resumed_thread_id = session
.services
.agent_control
.resume_agent_from_rollout(
config,
receiver_thread_id,
thread_spawn_source(session.conversation_id, child_depth, None),
)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
Ok(session
.services
.agent_control
.get_status(resumed_thread_id)
.await)
}
}
pub(crate) mod wait {
use super::*;
use crate::agent::status::is_final;
use futures::FutureExt;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::watch::Receiver;
use tokio::time::Instant;
use tokio::time::timeout_at;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = WaitResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: WaitArgs = parse_arguments(&arguments)?;
if args.ids.is_empty() {
return Err(FunctionCallError::RespondToModel(
"ids must be non-empty".to_owned(),
));
}
let receiver_thread_ids = args
.ids
.iter()
.map(|id| agent_id(id))
.collect::<Result<Vec<_>, _>>()?;
let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len());
for receiver_thread_id in &receiver_thread_ids {
let (agent_nickname, agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(*receiver_thread_id)
.await
.unwrap_or((None, None));
receiver_agents.push(CollabAgentRef {
thread_id: *receiver_thread_id,
agent_nickname,
agent_role,
});
}
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
let timeout_ms = match timeout_ms {
ms if ms <= 0 => {
return Err(FunctionCallError::RespondToModel(
"timeout_ms must be greater than zero".to_owned(),
));
}
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
};
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
receiver_agents: receiver_agents.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
agent_statuses: build_wait_agent_statuses(
&statuses,
&receiver_agents,
),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
results
};
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let agent_statuses = build_wait_agent_statuses(&statuses_map, &receiver_agents);
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
agent_statuses,
statuses: statuses_map,
}
.into(),
)
.await;
Ok(result)
}
}
#[derive(Debug, Deserialize)]
struct WaitArgs {
ids: Vec<String>,
timeout_ms: Option<i64>,
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
pub(crate) struct WaitResult {
pub(crate) status: HashMap<ThreadId, AgentStatus>,
pub(crate) timed_out: bool,
}
impl ToolOutput for WaitResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "wait")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, None, "wait")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "wait")
}
}
async fn wait_for_final_status(
session: Arc<Session>,
thread_id: ThreadId,
mut status_rx: Receiver<AgentStatus>,
) -> Option<(ThreadId, AgentStatus)> {
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(thread_id).await;
return is_final(&latest).then_some((thread_id, latest));
}
status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
}
}
}
pub mod close_agent {
use super::*;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = CloseAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: CloseAgentArgs = parse_arguments(&arguments)?;
let agent_id = agent_id(&args.id)?;
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(agent_id)
.await
.unwrap_or((None, None));
session
.send_event(
&turn,
CollabCloseBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
}
.into(),
)
.await;
let status = match session
.services
.agent_control
.subscribe_status(agent_id)
.await
{
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
Err(err) => {
let status = session.services.agent_control.get_status(agent_id).await;
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
receiver_agent_nickname: receiver_agent_nickname.clone(),
receiver_agent_role: receiver_agent_role.clone(),
status,
}
.into(),
)
.await;
return Err(collab_agent_error(agent_id, err));
}
};
let result = if !matches!(status, AgentStatus::Shutdown) {
session
.services
.agent_control
.shutdown_agent(agent_id)
.await
.map_err(|err| collab_agent_error(agent_id, err))
.map(|_| ())
} else {
Ok(())
};
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
receiver_agent_nickname,
receiver_agent_role,
status: status.clone(),
}
.into(),
)
.await;
result?;
Ok(CloseAgentResult { status })
}
}
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct CloseAgentResult {
pub(crate) status: AgentStatus,
}
impl ToolOutput for CloseAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "close_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "close_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "close_agent")
}
}
}
fn agent_id(id: &str) -> Result<ThreadId, FunctionCallError> {
ThreadId::from_string(id)

View File

@@ -1,123 +0,0 @@
use super::*;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = CloseAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: CloseAgentArgs = parse_arguments(&arguments)?;
let agent_id = agent_id(&args.id)?;
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(agent_id)
.await
.unwrap_or((None, None));
session
.send_event(
&turn,
CollabCloseBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
}
.into(),
)
.await;
let status = match session
.services
.agent_control
.subscribe_status(agent_id)
.await
{
Ok(mut status_rx) => status_rx.borrow_and_update().clone(),
Err(err) => {
let status = session.services.agent_control.get_status(agent_id).await;
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
receiver_agent_nickname: receiver_agent_nickname.clone(),
receiver_agent_role: receiver_agent_role.clone(),
status,
}
.into(),
)
.await;
return Err(collab_agent_error(agent_id, err));
}
};
let result = if !matches!(status, AgentStatus::Shutdown) {
session
.services
.agent_control
.shutdown_agent(agent_id)
.await
.map_err(|err| collab_agent_error(agent_id, err))
.map(|_| ())
} else {
Ok(())
};
session
.send_event(
&turn,
CollabCloseEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id: agent_id,
receiver_agent_nickname,
receiver_agent_role,
status: status.clone(),
}
.into(),
)
.await;
result?;
Ok(CloseAgentResult { status })
}
}
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct CloseAgentResult {
pub(crate) status: AgentStatus,
}
impl ToolOutput for CloseAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "close_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "close_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "close_agent")
}
}

View File

@@ -1,163 +0,0 @@
use super::*;
use crate::agent::next_thread_spawn_depth;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = ResumeAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: ResumeAgentArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((None, None));
let child_depth = next_thread_spawn_depth(&turn.session_source);
let max_depth = turn.config.agent_max_depth;
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
return Err(FunctionCallError::RespondToModel(
"Agent depth limit reached. Solve the task yourself.".to_string(),
));
}
session
.send_event(
&turn,
CollabResumeBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname: receiver_agent_nickname.clone(),
receiver_agent_role: receiver_agent_role.clone(),
}
.into(),
)
.await;
let mut status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
let error = if matches!(status, AgentStatus::NotFound) {
match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth).await {
Ok(resumed_status) => {
status = resumed_status;
None
}
Err(err) => {
status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
Some(err)
}
}
} else {
None
};
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((receiver_agent_nickname, receiver_agent_role));
session
.send_event(
&turn,
CollabResumeEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname,
receiver_agent_role,
status: status.clone(),
}
.into(),
)
.await;
if let Some(err) = error {
return Err(err);
}
turn.session_telemetry
.counter("codex.multi_agent.resume", 1, &[]);
Ok(ResumeAgentResult { status })
}
}
#[derive(Debug, Deserialize)]
struct ResumeAgentArgs {
id: String,
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
pub(crate) struct ResumeAgentResult {
pub(crate) status: AgentStatus,
}
impl ToolOutput for ResumeAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "resume_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "resume_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "resume_agent")
}
}
async fn try_resume_closed_agent(
session: &Arc<Session>,
turn: &Arc<TurnContext>,
receiver_thread_id: ThreadId,
child_depth: i32,
) -> Result<AgentStatus, FunctionCallError> {
let config = build_agent_resume_config(turn.as_ref(), child_depth)?;
let resumed_thread_id = session
.services
.agent_control
.resume_agent_from_rollout(
config,
receiver_thread_id,
thread_spawn_source(session.conversation_id, child_depth, None),
)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
Ok(session
.services
.agent_control
.get_status(resumed_thread_id)
.await)
}

View File

@@ -1,118 +0,0 @@
use super::*;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = SendInputResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: SendInputArgs = parse_arguments(&arguments)?;
let receiver_thread_id = agent_id(&args.id)?;
let input_items = parse_collab_input(args.message, args.items)?;
let prompt = input_preview(&input_items);
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((None, None));
if args.interrupt {
session
.services
.agent_control
.interrupt_agent(receiver_thread_id)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
}
session
.send_event(
&turn,
CollabAgentInteractionBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
receiver_thread_id,
prompt: prompt.clone(),
}
.into(),
)
.await;
let result = session
.services
.agent_control
.send_input(receiver_thread_id, input_items)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err));
let status = session
.services
.agent_control
.get_status(receiver_thread_id)
.await;
session
.send_event(
&turn,
CollabAgentInteractionEndEvent {
call_id,
sender_thread_id: session.conversation_id,
receiver_thread_id,
receiver_agent_nickname,
receiver_agent_role,
prompt,
status,
}
.into(),
)
.await;
let submission_id = result?;
Ok(SendInputResult { submission_id })
}
}
#[derive(Debug, Deserialize)]
struct SendInputArgs {
id: String,
message: Option<String>,
items: Option<Vec<UserInput>>,
#[serde(default)]
interrupt: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct SendInputResult {
submission_id: String,
}
impl ToolOutput for SendInputResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "send_input")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "send_input")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "send_input")
}
}

View File

@@ -1,173 +0,0 @@
use super::*;
use crate::agent::control::SpawnAgentOptions;
use crate::agent::role::DEFAULT_ROLE_NAME;
use crate::agent::role::apply_role_to_config;
use crate::agent::exceeds_thread_spawn_depth_limit;
use crate::agent::next_thread_spawn_depth;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = SpawnAgentResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: SpawnAgentArgs = parse_arguments(&arguments)?;
let role_name = args
.agent_type
.as_deref()
.map(str::trim)
.filter(|role| !role.is_empty());
let input_items = parse_collab_input(args.message, args.items)?;
let prompt = input_preview(&input_items);
let session_source = turn.session_source.clone();
let child_depth = next_thread_spawn_depth(&session_source);
let max_depth = turn.config.agent_max_depth;
if exceeds_thread_spawn_depth_limit(child_depth, max_depth) {
return Err(FunctionCallError::RespondToModel(
"Agent depth limit reached. Solve the task yourself.".to_string(),
));
}
session
.send_event(
&turn,
CollabAgentSpawnBeginEvent {
call_id: call_id.clone(),
sender_thread_id: session.conversation_id,
prompt: prompt.clone(),
model: args.model.clone().unwrap_or_default(),
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
}
.into(),
)
.await;
let mut config =
build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?;
apply_requested_spawn_agent_model_overrides(
&session,
turn.as_ref(),
&mut config,
args.model.as_deref(),
args.reasoning_effort,
)
.await?;
apply_role_to_config(&mut config, role_name)
.await
.map_err(FunctionCallError::RespondToModel)?;
apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?;
apply_spawn_agent_overrides(&mut config, child_depth);
let result = session
.services
.agent_control
.spawn_agent_with_options(
config,
input_items,
Some(thread_spawn_source(
session.conversation_id,
child_depth,
role_name,
)),
SpawnAgentOptions {
fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()),
},
)
.await
.map_err(collab_spawn_error);
let (new_thread_id, status) = match &result {
Ok(thread_id) => (
Some(*thread_id),
session.services.agent_control.get_status(*thread_id).await,
),
Err(_) => (None, AgentStatus::NotFound),
};
let (new_agent_nickname, new_agent_role) = match new_thread_id {
Some(thread_id) => session
.services
.agent_control
.get_agent_nickname_and_role(thread_id)
.await
.unwrap_or((None, None)),
None => (None, None),
};
let nickname = new_agent_nickname.clone();
session
.send_event(
&turn,
CollabAgentSpawnEndEvent {
call_id,
sender_thread_id: session.conversation_id,
new_thread_id,
new_agent_nickname,
new_agent_role,
prompt,
model: args.model.clone().unwrap_or_default(),
reasoning_effort: args.reasoning_effort.unwrap_or_default(),
status,
}
.into(),
)
.await;
let new_thread_id = result?;
let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME);
turn.session_telemetry
.counter("codex.multi_agent.spawn", 1, &[("role", role_tag)]);
Ok(SpawnAgentResult {
agent_id: new_thread_id.to_string(),
nickname,
})
}
}
#[derive(Debug, Deserialize)]
struct SpawnAgentArgs {
message: Option<String>,
items: Option<Vec<UserInput>>,
agent_type: Option<String>,
model: Option<String>,
reasoning_effort: Option<ReasoningEffort>,
#[serde(default)]
fork_context: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct SpawnAgentResult {
agent_id: String,
nickname: Option<String>,
}
impl ToolOutput for SpawnAgentResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "spawn_agent")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, Some(true), "spawn_agent")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "spawn_agent")
}
}

View File

@@ -1,228 +0,0 @@
use super::*;
use crate::agent::status::is_final;
use futures::FutureExt;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::watch::Receiver;
use tokio::time::Instant;
use tokio::time::timeout_at;
pub(crate) struct Handler;
#[async_trait]
impl ToolHandler for Handler {
type Output = WaitResult;
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = function_arguments(payload)?;
let args: WaitArgs = parse_arguments(&arguments)?;
if args.ids.is_empty() {
return Err(FunctionCallError::RespondToModel(
"ids must be non-empty".to_owned(),
));
}
let receiver_thread_ids = args
.ids
.iter()
.map(|id| agent_id(id))
.collect::<Result<Vec<_>, _>>()?;
let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len());
for receiver_thread_id in &receiver_thread_ids {
let (agent_nickname, agent_role) = session
.services
.agent_control
.get_agent_nickname_and_role(*receiver_thread_id)
.await
.unwrap_or((None, None));
receiver_agents.push(CollabAgentRef {
thread_id: *receiver_thread_id,
agent_nickname,
agent_role,
});
}
let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS);
let timeout_ms = match timeout_ms {
ms if ms <= 0 => {
return Err(FunctionCallError::RespondToModel(
"timeout_ms must be greater than zero".to_owned(),
));
}
ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS),
};
session
.send_event(
&turn,
CollabWaitingBeginEvent {
sender_thread_id: session.conversation_id,
receiver_thread_ids: receiver_thread_ids.clone(),
receiver_agents: receiver_agents.clone(),
call_id: call_id.clone(),
}
.into(),
)
.await;
let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len());
let mut initial_final_statuses = Vec::new();
for id in &receiver_thread_ids {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
initial_final_statuses.push((*id, status));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
}
Err(err) => {
let mut statuses = HashMap::with_capacity(1);
statuses.insert(*id, session.services.agent_control.get_status(*id).await);
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id: call_id.clone(),
agent_statuses: build_wait_agent_statuses(
&statuses,
&receiver_agents,
),
statuses,
}
.into(),
)
.await;
return Err(collab_agent_error(*id, err));
}
}
}
let statuses = if !initial_final_statuses.is_empty() {
initial_final_statuses
} else {
let mut futures = FuturesUnordered::new();
for (id, rx) in status_rxs.into_iter() {
let session = session.clone();
futures.push(wait_for_final_status(session, id, rx));
}
let mut results = Vec::new();
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
loop {
match timeout_at(deadline, futures.next()).await {
Ok(Some(Some(result))) => {
results.push(result);
break;
}
Ok(Some(None)) => continue,
Ok(None) | Err(_) => break,
}
}
if !results.is_empty() {
loop {
match futures.next().now_or_never() {
Some(Some(Some(result))) => results.push(result),
Some(Some(None)) => continue,
Some(None) | None => break,
}
}
}
results
};
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
let agent_statuses = build_wait_agent_statuses(&statuses_map, &receiver_agents);
let result = WaitResult {
status: statuses_map.clone(),
timed_out: statuses.is_empty(),
};
session
.send_event(
&turn,
CollabWaitingEndEvent {
sender_thread_id: session.conversation_id,
call_id,
agent_statuses,
statuses: statuses_map,
}
.into(),
)
.await;
Ok(result)
}
}
#[derive(Debug, Deserialize)]
struct WaitArgs {
ids: Vec<String>,
timeout_ms: Option<i64>,
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
pub(crate) struct WaitResult {
pub(crate) status: HashMap<ThreadId, AgentStatus>,
pub(crate) timed_out: bool,
}
impl ToolOutput for WaitResult {
fn log_preview(&self) -> String {
tool_output_json_text(self, "wait")
}
fn success_for_logging(&self) -> bool {
true
}
fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
tool_output_response_item(call_id, payload, self, None, "wait")
}
fn code_mode_result(&self, _payload: &ToolPayload) -> JsonValue {
tool_output_code_mode_result(self, "wait")
}
}
async fn wait_for_final_status(
session: Arc<Session>,
thread_id: ThreadId,
mut status_rx: Receiver<AgentStatus>,
) -> Option<(ThreadId, AgentStatus)> {
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(thread_id).await;
return is_final(&latest).then_some((thread_id, latest));
}
status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
}
}

View File

@@ -19,7 +19,6 @@ use crate::tools::handlers::parse_arguments_with_base_path;
use crate::tools::handlers::resolve_workdir_base_path;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::tools::spec::UnifiedExecBackendConfig;
use crate::unified_exec::ExecCommandRequest;
use crate::unified_exec::UnifiedExecContext;
use crate::unified_exec::UnifiedExecProcessManager;
@@ -109,7 +108,6 @@ impl ToolHandler for UnifiedExecHandler {
&params,
invocation.session.user_shell(),
invocation.turn.tools_config.allow_login_shell,
invocation.turn.tools_config.unified_exec_backend,
) {
Ok(command) => command,
Err(_) => return true,
@@ -157,7 +155,6 @@ impl ToolHandler for UnifiedExecHandler {
&args,
session.user_shell(),
turn.tools_config.allow_login_shell,
turn.tools_config.unified_exec_backend,
)
.map_err(FunctionCallError::RespondToModel)?;
@@ -324,23 +321,12 @@ pub(crate) fn get_command(
args: &ExecCommandArgs,
session_shell: Arc<Shell>,
allow_login_shell: bool,
unified_exec_backend: UnifiedExecBackendConfig,
) -> Result<Vec<String>, String> {
if unified_exec_backend == UnifiedExecBackendConfig::ZshFork && args.shell.is_some() {
return Err(
"shell override is not supported when the zsh-fork backend is enabled.".to_string(),
);
}
let model_shell = if unified_exec_backend == UnifiedExecBackendConfig::ZshFork {
None
} else {
args.shell.as_ref().map(|shell_str| {
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
shell.shell_snapshot = crate::shell::empty_shell_snapshot_receiver();
shell
})
};
let model_shell = args.shell.as_ref().map(|shell_str| {
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
shell.shell_snapshot = crate::shell::empty_shell_snapshot_receiver();
shell
});
let shell = model_shell.as_ref().unwrap_or(session_shell.as_ref());
let use_login_shell = match args.login {

View File

@@ -1,16 +1,12 @@
use super::*;
use crate::shell::ShellType;
use crate::shell::default_user_shell;
use crate::shell::empty_shell_snapshot_receiver;
use crate::tools::handlers::parse_arguments_with_base_path;
use crate::tools::handlers::resolve_workdir_base_path;
use crate::tools::spec::UnifiedExecBackendConfig;
use codex_protocol::models::FileSystemPermissions;
use codex_protocol::models::PermissionProfile;
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::tempdir;
@@ -22,13 +18,8 @@ fn test_get_command_uses_default_shell_when_unspecified() -> anyhow::Result<()>
assert!(args.shell.is_none());
let command = get_command(
&args,
Arc::new(default_user_shell()),
true,
UnifiedExecBackendConfig::Direct,
)
.map_err(anyhow::Error::msg)?;
let command =
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
assert_eq!(command.len(), 3);
assert_eq!(command[2], "echo hello");
@@ -43,13 +34,8 @@ fn test_get_command_respects_explicit_bash_shell() -> anyhow::Result<()> {
assert_eq!(args.shell.as_deref(), Some("/bin/bash"));
let command = get_command(
&args,
Arc::new(default_user_shell()),
true,
UnifiedExecBackendConfig::Direct,
)
.map_err(anyhow::Error::msg)?;
let command =
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
assert_eq!(command.last(), Some(&"echo hello".to_string()));
if command
@@ -69,13 +55,8 @@ fn test_get_command_respects_explicit_powershell_shell() -> anyhow::Result<()> {
assert_eq!(args.shell.as_deref(), Some("powershell"));
let command = get_command(
&args,
Arc::new(default_user_shell()),
true,
UnifiedExecBackendConfig::Direct,
)
.map_err(anyhow::Error::msg)?;
let command =
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
assert_eq!(command[2], "echo hello");
Ok(())
@@ -89,13 +70,8 @@ fn test_get_command_respects_explicit_cmd_shell() -> anyhow::Result<()> {
assert_eq!(args.shell.as_deref(), Some("cmd"));
let command = get_command(
&args,
Arc::new(default_user_shell()),
true,
UnifiedExecBackendConfig::Direct,
)
.map_err(anyhow::Error::msg)?;
let command =
get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?;
assert_eq!(command[2], "echo hello");
Ok(())
@@ -106,13 +82,8 @@ fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<(
let json = r#"{"cmd": "echo hello", "login": true}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
let err = get_command(
&args,
Arc::new(default_user_shell()),
false,
UnifiedExecBackendConfig::Direct,
)
.expect_err("explicit login should be rejected");
let err = get_command(&args, Arc::new(default_user_shell()), false)
.expect_err("explicit login should be rejected");
assert!(
err.contains("login shell is disabled by config"),
@@ -121,30 +92,6 @@ fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<(
Ok(())
}
#[test]
fn get_command_rejects_model_shell_override_for_zsh_fork_backend() -> anyhow::Result<()> {
let json = r#"{"cmd": "echo hello", "shell": "/bin/bash"}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
let session_shell = Arc::new(Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/tmp/configured-zsh-fork-shell"),
shell_snapshot: empty_shell_snapshot_receiver(),
});
let err = get_command(
&args,
session_shell,
true,
UnifiedExecBackendConfig::ZshFork,
)
.expect_err("shell override should be rejected for zsh-fork backend");
assert!(
err.contains("shell override is not supported"),
"unexpected error: {err}"
);
Ok(())
}
#[test]
fn exec_command_args_resolve_relative_additional_permissions_against_workdir() -> anyhow::Result<()>
{

View File

@@ -4,6 +4,7 @@ use crate::guardian::GuardianApprovalRequest;
use crate::guardian::review_approval_request;
use crate::guardian::routes_approval_to_guardian;
use crate::network_policy_decision::denied_network_policy_message;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::tools::sandboxing::ToolError;
use codex_network_proxy::BlockedRequest;
use codex_network_proxy::BlockedRequestObserver;
@@ -76,14 +77,20 @@ impl ActiveNetworkApproval {
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct HostApprovalKey {
scope: NetworkProxyScope,
host: String,
protocol: &'static str,
port: u16,
}
impl HostApprovalKey {
fn from_request(request: &NetworkPolicyRequest, protocol: NetworkApprovalProtocol) -> Self {
fn from_request(
request: &NetworkPolicyRequest,
protocol: NetworkApprovalProtocol,
scope: NetworkProxyScope,
) -> Self {
Self {
scope,
host: request.host.to_ascii_lowercase(),
protocol: protocol_key_label(protocol),
port: request.port,
@@ -279,6 +286,7 @@ impl NetworkApprovalService {
&self,
session: Arc<Session>,
request: NetworkPolicyRequest,
scope: NetworkProxyScope,
) -> NetworkDecision {
const REASON_NOT_ALLOWED: &str = "not_allowed";
@@ -288,7 +296,7 @@ impl NetworkApprovalService {
NetworkProtocol::Socks5Tcp => NetworkApprovalProtocol::Socks5Tcp,
NetworkProtocol::Socks5Udp => NetworkApprovalProtocol::Socks5Udp,
};
let key = HostApprovalKey::from_request(&request, protocol);
let key = HostApprovalKey::from_request(&request, protocol, scope.clone());
{
let denied_hosts = self.session_denied_hosts.lock().await;
@@ -387,6 +395,7 @@ impl NetworkApprovalService {
.persist_network_policy_amendment(
&network_policy_amendment,
&network_approval_context,
&scope,
)
.await
{
@@ -417,6 +426,7 @@ impl NetworkApprovalService {
.persist_network_policy_amendment(
&network_policy_amendment,
&network_approval_context,
&scope,
)
.await
{
@@ -506,16 +516,18 @@ pub(crate) fn build_blocked_request_observer(
pub(crate) fn build_network_policy_decider(
network_approval: Arc<NetworkApprovalService>,
network_policy_decider_session: Arc<RwLock<std::sync::Weak<Session>>>,
scope: NetworkProxyScope,
) -> Arc<dyn NetworkPolicyDecider> {
Arc::new(move |request: NetworkPolicyRequest| {
let network_approval = Arc::clone(&network_approval);
let network_policy_decider_session = Arc::clone(&network_policy_decider_session);
let scope = scope.clone();
async move {
let Some(session) = network_policy_decider_session.read().await.upgrade() else {
return NetworkDecision::ask("not_allowed");
};
network_approval
.handle_inline_policy_request(session, request)
.handle_inline_policy_request(session, request, scope)
.await
}
})

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::network_proxy_registry::NetworkProxyScope;
use codex_network_proxy::BlockedRequestArgs;
use codex_protocol::protocol::AskForApproval;
use pretty_assertions::assert_eq;
@@ -7,6 +8,7 @@ use pretty_assertions::assert_eq;
async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
let service = NetworkApprovalService::default();
let key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 443,
@@ -24,11 +26,13 @@ async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
async fn pending_approvals_do_not_dedupe_across_ports() {
let service = NetworkApprovalService::default();
let first_key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
};
let second_key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,
@@ -49,16 +53,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
let mut approved_hosts = source.session_approved_hosts.lock().await;
approved_hosts.extend([
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 80,
@@ -82,16 +89,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
copied,
vec![
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 80,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,

View File

@@ -9,6 +9,7 @@ use crate::features::Feature;
use crate::guardian::GuardianApprovalRequest;
use crate::guardian::review_approval_request;
use crate::guardian::routes_approval_to_guardian;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::sandboxing::ExecRequest;
use crate::sandboxing::SandboxPermissions;
use crate::shell::ShellType;
@@ -54,6 +55,7 @@ use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
@@ -91,7 +93,7 @@ pub(super) async fn try_run_zsh_fork(
req: &ShellRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx,
shell_command: &[String],
command: &[String],
) -> Result<Option<ExecToolCallOutput>, ToolError> {
let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref() else {
tracing::warn!("ZshFork backend specified, but shell_zsh_path is not configured.");
@@ -106,10 +108,8 @@ pub(super) async fn try_run_zsh_fork(
return Ok(None);
}
let ParsedShellCommand { script, login, .. } = extract_shell_script(shell_command)?;
let spec = build_command_spec(
shell_command,
command,
&req.cwd,
&req.env,
req.timeout_ms.into(),
@@ -121,7 +121,7 @@ pub(super) async fn try_run_zsh_fork(
.env_for(spec, req.network.as_ref())
.map_err(|err| ToolError::Codex(err.into()))?;
let crate::sandboxing::ExecRequest {
command: sandbox_command,
command,
cwd: sandbox_cwd,
env: sandbox_env,
network: sandbox_network,
@@ -135,14 +135,17 @@ pub(super) async fn try_run_zsh_fork(
justification,
arg0,
} = sandbox_exec_request;
let host_zsh_path =
resolve_host_zsh_path(sandbox_env.get("PATH").map(String::as_str), &sandbox_cwd);
let ParsedShellCommand { script, login, .. } = extract_shell_script(&command)?;
let effective_timeout = Duration::from_millis(
req.timeout_ms
.unwrap_or(crate::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
);
let exec_policy = Arc::new(RwLock::new(
ctx.session.services.exec_policy.current().as_ref().clone(),
));
let command_executor = CoreShellCommandExecutor {
command: sandbox_command,
session: Some(Arc::clone(&ctx.session)),
command,
cwd: sandbox_cwd,
sandbox_policy,
file_system_sandbox_policy,
@@ -163,8 +166,6 @@ pub(super) async fn try_run_zsh_fork(
.clone(),
codex_linux_sandbox_exe: ctx.turn.codex_linux_sandbox_exe.clone(),
use_legacy_landlock: ctx.turn.features.use_legacy_landlock(),
shell_zsh_path: ctx.session.services.shell_zsh_path.clone(),
host_zsh_path: host_zsh_path.clone(),
};
let main_execve_wrapper_exe = ctx
.session
@@ -193,6 +194,7 @@ pub(super) async fn try_run_zsh_fork(
req.additional_permissions_preapproved,
);
let escalation_policy = CoreShellActionProvider {
policy: Arc::clone(&exec_policy),
session: Arc::clone(&ctx.session),
turn: Arc::clone(&ctx.turn),
call_id: ctx.call_id.clone(),
@@ -205,7 +207,6 @@ pub(super) async fn try_run_zsh_fork(
approval_sandbox_permissions,
prompt_permissions: req.additional_permissions.clone(),
stopwatch: stopwatch.clone(),
host_zsh_path,
};
let escalate_server = EscalateServer::new(
@@ -226,7 +227,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
req: &crate::tools::runtimes::unified_exec::UnifiedExecRequest,
_attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx,
shell_command: &[String],
exec_request: ExecRequest,
) -> Result<Option<PreparedUnifiedExecZshFork>, ToolError> {
let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref() else {
@@ -242,7 +242,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
return Ok(None);
}
let parsed = match extract_shell_script(shell_command) {
let parsed = match extract_shell_script(&exec_request.command) {
Ok(parsed) => parsed,
Err(err) => {
tracing::warn!("ZshFork unified exec fallback: {err:?}");
@@ -258,35 +258,23 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
return Ok(None);
}
let ExecRequest {
command,
cwd,
env,
network,
expiration: _expiration,
sandbox,
windows_sandbox_level,
sandbox_permissions,
sandbox_policy,
file_system_sandbox_policy,
network_sandbox_policy,
justification,
arg0,
} = &exec_request;
let host_zsh_path = resolve_host_zsh_path(env.get("PATH").map(String::as_str), cwd);
let exec_policy = Arc::new(RwLock::new(
ctx.session.services.exec_policy.current().as_ref().clone(),
));
let command_executor = CoreShellCommandExecutor {
command: command.clone(),
cwd: cwd.clone(),
sandbox_policy: sandbox_policy.clone(),
file_system_sandbox_policy: file_system_sandbox_policy.clone(),
network_sandbox_policy: *network_sandbox_policy,
sandbox: *sandbox,
env: env.clone(),
network: network.clone(),
windows_sandbox_level: *windows_sandbox_level,
sandbox_permissions: *sandbox_permissions,
justification: justification.clone(),
arg0: arg0.clone(),
session: Some(Arc::clone(&ctx.session)),
command: exec_request.command.clone(),
cwd: exec_request.cwd.clone(),
sandbox_policy: exec_request.sandbox_policy.clone(),
file_system_sandbox_policy: exec_request.file_system_sandbox_policy.clone(),
network_sandbox_policy: exec_request.network_sandbox_policy,
sandbox: exec_request.sandbox,
env: exec_request.env.clone(),
network: exec_request.network.clone(),
windows_sandbox_level: exec_request.windows_sandbox_level,
sandbox_permissions: exec_request.sandbox_permissions,
justification: exec_request.justification.clone(),
arg0: exec_request.arg0.clone(),
sandbox_policy_cwd: ctx.turn.cwd.clone(),
macos_seatbelt_profile_extensions: ctx
.turn
@@ -296,8 +284,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
.clone(),
codex_linux_sandbox_exe: ctx.turn.codex_linux_sandbox_exe.clone(),
use_legacy_landlock: ctx.turn.features.use_legacy_landlock(),
shell_zsh_path: ctx.session.services.shell_zsh_path.clone(),
host_zsh_path: host_zsh_path.clone(),
};
let main_execve_wrapper_exe = ctx
.session
@@ -310,6 +296,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
)
})?;
let escalation_policy = CoreShellActionProvider {
policy: Arc::clone(&exec_policy),
session: Arc::clone(&ctx.session),
turn: Arc::clone(&ctx.turn),
call_id: ctx.call_id.clone(),
@@ -325,7 +312,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
),
prompt_permissions: req.additional_permissions.clone(),
stopwatch: Stopwatch::unlimited(),
host_zsh_path,
};
let escalate_server = EscalateServer::new(
@@ -345,6 +331,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
}
struct CoreShellActionProvider {
policy: Arc<RwLock<Policy>>,
session: Arc<crate::codex::Session>,
turn: Arc<crate::codex::TurnContext>,
call_id: String,
@@ -357,7 +344,6 @@ struct CoreShellActionProvider {
approval_sandbox_permissions: SandboxPermissions,
prompt_permissions: Option<PermissionProfile>,
stopwatch: Stopwatch,
host_zsh_path: Option<PathBuf>,
}
#[allow(clippy::large_enum_variant)]
@@ -395,66 +381,6 @@ fn execve_prompt_is_rejected_by_policy(
}
}
fn paths_match(lhs: &Path, rhs: &Path) -> bool {
lhs == rhs
|| match (lhs.canonicalize(), rhs.canonicalize()) {
(Ok(lhs), Ok(rhs)) => lhs == rhs,
_ => false,
}
}
fn resolve_host_zsh_path(_path_env: Option<&str>, _cwd: &Path) -> Option<PathBuf> {
fn canonicalize_best_effort(path: PathBuf) -> PathBuf {
path.canonicalize().unwrap_or(path)
}
fn is_executable_file(path: &Path) -> bool {
std::fs::metadata(path).is_ok_and(|metadata| {
metadata.is_file() && {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
metadata.permissions().mode() & 0o111 != 0
}
#[cfg(not(unix))]
{
true
}
}
})
}
fn find_zsh_in_dirs(dirs: impl IntoIterator<Item = PathBuf>) -> Option<PathBuf> {
dirs.into_iter().find_map(|dir| {
let candidate = dir.join("zsh");
is_executable_file(&candidate).then(|| canonicalize_best_effort(candidate))
})
}
// Keep nested-zsh rewrites limited to canonical host shell installations.
// PATH shadowing from repos, Nix environments, or tool shims should not be
// treated as the host shell.
find_zsh_in_dirs(
["/bin", "/usr/bin", "/usr/local/bin", "/opt/homebrew/bin"]
.into_iter()
.map(PathBuf::from),
)
}
fn is_unconfigured_zsh_exec(
program: &AbsolutePathBuf,
shell_zsh_path: Option<&Path>,
host_zsh_path: Option<&Path>,
) -> bool {
let Some(shell_zsh_path) = shell_zsh_path else {
return false;
};
let Some(host_zsh_path) = host_zsh_path else {
return false;
};
paths_match(program.as_path(), host_zsh_path) && !paths_match(program.as_path(), shell_zsh_path)
}
impl CoreShellActionProvider {
fn decision_driven_by_policy(matched_rules: &[RuleMatch], decision: Decision) -> bool {
matched_rules.iter().any(|rule_match| {
@@ -566,10 +492,6 @@ impl CoreShellActionProvider {
command,
workdir,
None,
// Intercepted exec prompts happen after the original tool call has
// started, so we do not attach an execpolicy amendment payload here.
// Amendments are currently surfaced only from the top-level tool
// request path.
None,
None,
additional_permissions,
@@ -585,26 +507,7 @@ impl CoreShellActionProvider {
/// an absolute path. The idea is that we check to see whether it matches
/// any skills.
async fn find_skill(&self, program: &AbsolutePathBuf) -> Option<SkillMetadata> {
let force_reload = false;
let skills_outcome = self
.session
.services
.skills_manager
.skills_for_cwd(&self.turn.cwd, force_reload)
.await;
let program_path = program.as_path();
for skill in skills_outcome.skills {
// We intentionally ignore "enabled" status here for now.
let Some(skill_root) = skill.path_to_skills_md.parent() else {
continue;
};
if program_path.starts_with(skill_root.join("scripts")) {
return Some(skill);
}
}
None
find_skill_for_program(self.session.as_ref(), self.turn.cwd.as_path(), program).await
}
#[allow(clippy::too_many_arguments)]
@@ -713,6 +616,32 @@ impl CoreShellActionProvider {
}
}
async fn find_skill_for_program(
session: &crate::codex::Session,
cwd: &Path,
program: &AbsolutePathBuf,
) -> Option<SkillMetadata> {
let force_reload = false;
let skills_outcome = session
.services
.skills_manager
.skills_for_cwd(cwd, force_reload)
.await;
let program_path = program.as_path();
for skill in skills_outcome.skills {
// We intentionally ignore "enabled" status here for now.
let Some(skill_root) = skill.path_to_skills_md.parent() else {
continue;
};
if program_path.starts_with(skill_root.join("scripts")) {
return Some(skill);
}
}
None
}
// Shell-wrapper parsing is weaker than direct exec interception because it can
// only see the script text, not the final resolved executable path. Keep it
// disabled by default so path-sensitive rules rely on the later authoritative
@@ -794,35 +723,28 @@ impl EscalationPolicy for CoreShellActionProvider {
.await;
}
let policy = self.session.services.exec_policy.current();
let evaluation = evaluate_intercepted_exec_policy(
policy.as_ref(),
program,
argv,
InterceptedExecPolicyContext {
approval_policy: self.approval_policy,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
sandbox_permissions: self.approval_sandbox_permissions,
enable_shell_wrapper_parsing: ENABLE_INTERCEPTED_EXEC_POLICY_SHELL_WRAPPER_PARSING,
},
);
let evaluation = {
let policy = self.policy.read().await;
evaluate_intercepted_exec_policy(
&policy,
program,
argv,
InterceptedExecPolicyContext {
approval_policy: self.approval_policy,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
sandbox_permissions: self.approval_sandbox_permissions,
enable_shell_wrapper_parsing:
ENABLE_INTERCEPTED_EXEC_POLICY_SHELL_WRAPPER_PARSING,
},
)
};
// When true, means the Evaluation was due to *.rules, not the
// fallback function.
let decision_driven_by_policy =
Self::decision_driven_by_policy(&evaluation.matched_rules, evaluation.decision);
// Keep zsh-fork interception alive across nested shells: if an
// intercepted exec targets the known host `zsh` path instead of the
// configured zsh-fork binary, force it through escalation so the
// executor can rewrite the program path back to the configured shell.
let force_zsh_fork_reexec = is_unconfigured_zsh_exec(
program,
self.session.services.shell_zsh_path.as_deref(),
self.host_zsh_path.as_deref(),
);
let needs_escalation = self.sandbox_permissions.requires_escalated_permissions()
|| decision_driven_by_policy
|| force_zsh_fork_reexec;
let needs_escalation =
self.sandbox_permissions.requires_escalated_permissions() || decision_driven_by_policy;
let decision_source = if decision_driven_by_policy {
DecisionSource::PrefixRule
@@ -954,6 +876,7 @@ fn commands_for_intercepted_exec_policy(
}
struct CoreShellCommandExecutor {
session: Option<Arc<crate::codex::Session>>,
command: Vec<String>,
cwd: PathBuf,
sandbox_policy: SandboxPolicy,
@@ -971,14 +894,13 @@ struct CoreShellCommandExecutor {
macos_seatbelt_profile_extensions: Option<MacOsSeatbeltProfileExtensions>,
codex_linux_sandbox_exe: Option<PathBuf>,
use_legacy_landlock: bool,
shell_zsh_path: Option<PathBuf>,
host_zsh_path: Option<PathBuf>,
}
struct PrepareSandboxedExecParams<'a> {
command: Vec<String>,
workdir: &'a AbsolutePathBuf,
env: HashMap<String, String>,
network: Option<codex_network_proxy::NetworkProxy>,
sandbox_policy: &'a SandboxPolicy,
file_system_sandbox_policy: &'a FileSystemSandboxPolicy,
network_sandbox_policy: NetworkSandboxPolicy,
@@ -1045,8 +967,7 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
env: HashMap<String, String>,
execution: EscalationExecution,
) -> anyhow::Result<PreparedExec> {
let program = self.rewrite_intercepted_program_for_zsh_fork(program);
let command = join_program_and_argv(&program, argv);
let command = join_program_and_argv(program, argv);
let Some(first_arg) = argv.first() else {
return Err(anyhow::anyhow!(
"intercepted exec request must contain argv[0]"
@@ -1061,10 +982,12 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
arg0: Some(first_arg.clone()),
},
EscalationExecution::TurnDefault => {
let network = self.network_for_program(program).await?;
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
network_sandbox_policy: self.network_sandbox_policy,
@@ -1078,12 +1001,14 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
EscalationExecution::Permissions(EscalationPermissions::PermissionProfile(
permission_profile,
)) => {
let network = self.network_for_program(program).await?;
// Merge additive permissions into the existing turn/request sandbox policy.
// On macOS, additional profile extensions are unioned with the turn defaults.
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
network_sandbox_policy: self.network_sandbox_policy,
@@ -1095,11 +1020,13 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
})?
}
EscalationExecution::Permissions(EscalationPermissions::Permissions(permissions)) => {
let network = self.network_for_program(program).await?;
// Use a fully specified sandbox policy instead of merging into the turn policy.
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &permissions.sandbox_policy,
file_system_sandbox_policy: &permissions.file_system_sandbox_policy,
network_sandbox_policy: permissions.network_sandbox_policy,
@@ -1117,33 +1044,26 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
}
impl CoreShellCommandExecutor {
fn rewrite_intercepted_program_for_zsh_fork(
async fn network_for_program(
&self,
program: &AbsolutePathBuf,
) -> AbsolutePathBuf {
let Some(shell_zsh_path) = self.shell_zsh_path.as_ref() else {
return program.clone();
) -> anyhow::Result<Option<codex_network_proxy::NetworkProxy>> {
let Some(session) = self.session.as_ref() else {
return Ok(self.network.clone());
};
if !is_unconfigured_zsh_exec(
program,
Some(shell_zsh_path.as_path()),
self.host_zsh_path.as_deref(),
) {
return program.clone();
}
match AbsolutePathBuf::from_absolute_path(shell_zsh_path) {
Ok(rewritten) => rewritten,
Err(err) => {
tracing::warn!(
"failed to rewrite intercepted zsh path {} to configured shell {}: {err}",
program.display(),
shell_zsh_path.display(),
);
program.clone()
}
}
let Some(skill) =
find_skill_for_program(session.as_ref(), &self.sandbox_policy_cwd, program).await
else {
return Ok(self.network.clone());
};
let (scope, managed_network_override) = network_proxy_scope_for_skill(&skill);
session
.get_or_start_network_proxy(scope, &self.sandbox_policy, managed_network_override)
.await
}
#[allow(clippy::too_many_arguments)]
fn prepare_sandboxed_exec(
&self,
params: PrepareSandboxedExecParams<'_>,
@@ -1152,6 +1072,7 @@ impl CoreShellCommandExecutor {
command,
workdir,
env,
network,
sandbox_policy,
file_system_sandbox_policy,
network_sandbox_policy,
@@ -1168,7 +1089,7 @@ impl CoreShellCommandExecutor {
network_sandbox_policy,
SandboxablePreference::Auto,
self.windows_sandbox_level,
self.network.is_some(),
network.is_some(),
);
let mut exec_request =
sandbox_manager.transform(crate::sandboxing::SandboxTransformRequest {
@@ -1190,8 +1111,8 @@ impl CoreShellCommandExecutor {
file_system_policy: file_system_sandbox_policy,
network_policy: network_sandbox_policy,
sandbox,
enforce_managed_network: self.network.is_some(),
network: self.network.as_ref(),
enforce_managed_network: network.is_some(),
network: network.as_ref(),
sandbox_policy_cwd: &self.sandbox_policy_cwd,
#[cfg(target_os = "macos")]
macos_seatbelt_profile_extensions,
@@ -1212,6 +1133,23 @@ impl CoreShellCommandExecutor {
}
}
fn network_proxy_scope_for_skill(
skill: &SkillMetadata,
) -> (
NetworkProxyScope,
Option<crate::skills::model::SkillManagedNetworkOverride>,
) {
match skill.managed_network_override.clone() {
Some(managed_network_override) => (
NetworkProxyScope::Skill {
path_to_skills_md: skill.path_to_skills_md.clone(),
},
Some(managed_network_override),
),
None => (NetworkProxyScope::SessionDefault, None),
}
}
#[derive(Debug, Eq, PartialEq)]
struct ParsedShellCommand {
program: String,
@@ -1220,21 +1158,23 @@ struct ParsedShellCommand {
}
fn extract_shell_script(command: &[String]) -> Result<ParsedShellCommand, ToolError> {
if let [program, flag, script, ..] = command {
if flag == "-c" {
return Ok(ParsedShellCommand {
program: program.to_owned(),
script: script.to_owned(),
login: false,
});
// Commands reaching zsh-fork can be wrapped by environment/sandbox helpers, so
// we search for the first `-c`/`-lc` triple anywhere in the argv rather
// than assuming it is the first positional form.
if let Some((program, script, login)) = command.windows(3).find_map(|parts| match parts {
[program, flag, script] if flag == "-c" => {
Some((program.to_owned(), script.to_owned(), false))
}
if flag == "-lc" {
return Ok(ParsedShellCommand {
program: program.to_owned(),
script: script.to_owned(),
login: true,
});
[program, flag, script] if flag == "-lc" => {
Some((program.to_owned(), script.to_owned(), true))
}
_ => None,
}) {
return Ok(ParsedShellCommand {
program,
script,
login,
});
}
Err(ToolError::Rejected(

View File

@@ -6,10 +6,8 @@ use super::ParsedShellCommand;
use super::commands_for_intercepted_exec_policy;
use super::evaluate_intercepted_exec_policy;
use super::extract_shell_script;
use super::is_unconfigured_zsh_exec;
use super::join_program_and_argv;
use super::map_exec_result;
use super::resolve_host_zsh_path;
#[cfg(target_os = "macos")]
use crate::config::Constrained;
#[cfg(target_os = "macos")]
@@ -17,6 +15,7 @@ use crate::config::Permissions;
#[cfg(target_os = "macos")]
use crate::config::types::ShellEnvironmentPolicy;
use crate::exec::SandboxType;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::protocol::AskForApproval;
use crate::protocol::GranularApprovalConfig;
use crate::protocol::ReadOnlyAccess;
@@ -52,7 +51,6 @@ use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
#[cfg(target_os = "macos")]
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
@@ -101,6 +99,35 @@ fn test_skill_metadata(permission_profile: Option<PermissionProfile>) -> SkillMe
}
}
#[test]
fn network_proxy_scope_for_skill_without_override_reuses_session_default() {
let skill = test_skill_metadata(None);
assert_eq!(
super::network_proxy_scope_for_skill(&skill),
(NetworkProxyScope::SessionDefault, None),
);
}
#[test]
fn network_proxy_scope_for_skill_with_override_uses_skill_scope() {
let mut skill = test_skill_metadata(None);
skill.managed_network_override = Some(crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: Some(vec!["skill.example.com".to_string()]),
denied_domains: Some(vec!["blocked.skill.example.com".to_string()]),
});
assert_eq!(
super::network_proxy_scope_for_skill(&skill),
(
NetworkProxyScope::Skill {
path_to_skills_md: PathBuf::from("/tmp/skill/SKILL.md"),
},
skill.managed_network_override.clone(),
),
);
}
#[test]
fn execve_prompt_rejection_uses_skill_approval_for_skill_scripts() {
let decision_source = super::DecisionSource::SkillScript {
@@ -206,16 +233,39 @@ fn extract_shell_script_preserves_login_flag() {
}
#[test]
fn extract_shell_script_rejects_wrapped_command_prefixes() {
let err = extract_shell_script(&[
"/usr/bin/env".into(),
"CODEX_EXECVE_WRAPPER=1".into(),
"/bin/zsh".into(),
"-lc".into(),
"echo hello".into(),
])
.unwrap_err();
assert!(matches!(err, super::ToolError::Rejected(_)));
fn extract_shell_script_supports_wrapped_command_prefixes() {
assert_eq!(
extract_shell_script(&[
"/usr/bin/env".into(),
"CODEX_EXECVE_WRAPPER=1".into(),
"/bin/zsh".into(),
"-lc".into(),
"echo hello".into()
])
.unwrap(),
ParsedShellCommand {
program: "/bin/zsh".to_string(),
script: "echo hello".to_string(),
login: true,
}
);
assert_eq!(
extract_shell_script(&[
"sandbox-exec".into(),
"-p".into(),
"sandbox_policy".into(),
"/bin/zsh".into(),
"-c".into(),
"pwd".into(),
])
.unwrap(),
ParsedShellCommand {
program: "/bin/zsh".to_string(),
script: "pwd".to_string(),
login: false,
}
);
}
#[test]
@@ -254,80 +304,6 @@ fn join_program_and_argv_replaces_original_argv_zero() {
);
}
#[test]
fn is_unconfigured_zsh_exec_matches_non_configured_zsh_paths() {
let program = AbsolutePathBuf::try_from(host_absolute_path(&["bin", "zsh"])).unwrap();
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
assert!(is_unconfigured_zsh_exec(
&program,
Some(configured.as_path()),
Some(host.as_path()),
));
}
#[test]
fn is_unconfigured_zsh_exec_ignores_non_zsh_or_configured_paths() {
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
let configured_program = AbsolutePathBuf::try_from(configured.clone()).unwrap();
assert!(!is_unconfigured_zsh_exec(
&configured_program,
Some(configured.as_path()),
Some(host.as_path()),
));
let non_zsh =
AbsolutePathBuf::try_from(host_absolute_path(&["usr", "bin", "python3"])).unwrap();
assert!(!is_unconfigured_zsh_exec(
&non_zsh,
Some(configured.as_path()),
Some(host.as_path()),
));
assert!(!is_unconfigured_zsh_exec(
&non_zsh,
None,
Some(host.as_path()),
));
}
#[test]
fn is_unconfigured_zsh_exec_does_not_match_non_host_zsh_named_binaries() {
let program = AbsolutePathBuf::try_from(host_absolute_path(&["tmp", "repo", "zsh"])).unwrap();
let configured = PathBuf::from(host_absolute_path(&["tmp", "codex-zsh"]));
let host = PathBuf::from(host_absolute_path(&["bin", "zsh"]));
assert!(!is_unconfigured_zsh_exec(
&program,
Some(configured.as_path()),
Some(host.as_path()),
));
}
#[test]
fn resolve_host_zsh_path_ignores_repo_local_path_shadowing() {
let shadow_dir = tempfile::tempdir().expect("create shadow dir");
let cwd_dir = tempfile::tempdir().expect("create cwd dir");
let fake_zsh = shadow_dir.path().join("zsh");
std::fs::write(&fake_zsh, "#!/bin/sh\nexit 0\n").expect("write fake zsh");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut permissions = std::fs::metadata(&fake_zsh)
.expect("metadata for fake zsh")
.permissions();
permissions.set_mode(0o755);
std::fs::set_permissions(&fake_zsh, permissions).expect("chmod fake zsh");
}
let path_env =
std::env::join_paths([shadow_dir.path(), Path::new("/usr/bin"), Path::new("/bin")])
.expect("join PATH")
.into_string()
.expect("PATH should be UTF-8");
let resolved = resolve_host_zsh_path(Some(&path_env), cwd_dir.path());
assert_ne!(resolved.as_deref(), Some(fake_zsh.as_path()));
}
#[test]
fn commands_for_intercepted_exec_policy_parses_plain_shell_wrappers() {
let program = AbsolutePathBuf::try_from(host_absolute_path(&["bin", "bash"])).unwrap();
@@ -705,6 +681,7 @@ host_executable(name = "git", paths = ["{allowed_git_literal}"])
async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions() {
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),
@@ -724,8 +701,6 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
}),
codex_linux_sandbox_exe: None,
use_legacy_landlock: false,
shell_zsh_path: None,
host_zsh_path: None,
};
let prepared = executor
@@ -759,6 +734,7 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions() {
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),
@@ -775,8 +751,6 @@ async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions()
macos_seatbelt_profile_extensions: None,
codex_linux_sandbox_exe: None,
use_legacy_landlock: false,
shell_zsh_path: None,
host_zsh_path: None,
};
let permissions = Permissions {
@@ -835,6 +809,7 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let sandbox_policy = SandboxPolicy::new_read_only_policy();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),
@@ -854,8 +829,6 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
}),
codex_linux_sandbox_exe: None,
use_legacy_landlock: false,
shell_zsh_path: None,
host_zsh_path: None,
};
let prepared = executor

View File

@@ -36,10 +36,9 @@ pub(crate) async fn maybe_prepare_unified_exec(
req: &UnifiedExecRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx,
shell_command: &[String],
exec_request: ExecRequest,
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
imp::maybe_prepare_unified_exec(req, attempt, ctx, shell_command, exec_request).await
imp::maybe_prepare_unified_exec(req, attempt, ctx, exec_request).await
}
#[cfg(unix)]
@@ -47,34 +46,16 @@ mod imp {
use super::*;
use crate::tools::runtimes::shell::unix_escalation;
use crate::unified_exec::SpawnLifecycle;
use codex_shell_escalation::ESCALATE_SOCKET_ENV_VAR;
use codex_shell_escalation::EscalationSession;
#[derive(Debug)]
struct ZshForkSpawnLifecycle {
escalation_session: Option<EscalationSession>,
escalation_session: EscalationSession,
}
impl SpawnLifecycle for ZshForkSpawnLifecycle {
fn inherited_fds(&self) -> Vec<i32> {
self.escalation_session
.as_ref()
.and_then(|escalation_session| {
escalation_session.env().get(ESCALATE_SOCKET_ENV_VAR)
})
.and_then(|fd| fd.parse().ok())
.into_iter()
.collect()
}
fn after_spawn(&mut self) {
if let Some(escalation_session) = self.escalation_session.as_ref() {
escalation_session.close_client_socket();
}
}
fn after_exit(&mut self) {
self.escalation_session = None;
self.escalation_session.close_client_socket();
}
}
@@ -91,17 +72,10 @@ mod imp {
req: &UnifiedExecRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx,
shell_command: &[String],
exec_request: ExecRequest,
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
let Some(prepared) = unix_escalation::prepare_unified_exec_zsh_fork(
req,
attempt,
ctx,
shell_command,
exec_request,
)
.await?
let Some(prepared) =
unix_escalation::prepare_unified_exec_zsh_fork(req, attempt, ctx, exec_request).await?
else {
return Ok(None);
};
@@ -109,7 +83,7 @@ mod imp {
Ok(Some(PreparedUnifiedExecSpawn {
exec_request: prepared.exec_request,
spawn_lifecycle: Box::new(ZshForkSpawnLifecycle {
escalation_session: Some(prepared.escalation_session),
escalation_session: prepared.escalation_session,
}),
}))
}
@@ -133,10 +107,9 @@ mod imp {
req: &UnifiedExecRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx,
shell_command: &[String],
exec_request: ExecRequest,
) -> Result<Option<PreparedUnifiedExecSpawn>, ToolError> {
let _ = (req, attempt, ctx, shell_command, exec_request);
let _ = (req, attempt, ctx, exec_request);
Ok(None)
}
}

View File

@@ -222,11 +222,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
let exec_env = attempt
.env_for(spec, req.network.as_ref())
.map_err(|err| ToolError::Codex(err.into()))?;
match zsh_fork_backend::maybe_prepare_unified_exec(
req, attempt, ctx, &command, exec_env,
)
.await?
{
match zsh_fork_backend::maybe_prepare_unified_exec(req, attempt, ctx, exec_env).await? {
Some(prepared) => {
return self
.manager

View File

@@ -311,6 +311,8 @@ impl ToolsConfig {
);
let shell_type = if !features.enabled(Feature::ShellTool) {
ConfigShellToolType::Disabled
} else if features.enabled(Feature::ShellZshFork) {
ConfigShellToolType::ShellCommand
} else if features.enabled(Feature::UnifiedExec) && unified_exec_allowed {
// If ConPTY not supported (for old Windows versions), fallback on ShellCommand.
if codex_utils_pty::conpty_supported() {
@@ -321,8 +323,6 @@ impl ToolsConfig {
} else if model_info.shell_type == ConfigShellToolType::UnifiedExec && !unified_exec_allowed
{
ConfigShellToolType::ShellCommand
} else if features.enabled(Feature::ShellZshFork) {
ConfigShellToolType::ShellCommand
} else {
model_info.shell_type
};
@@ -578,7 +578,6 @@ fn create_approval_parameters(
fn create_exec_command_tool(
allow_login_shell: bool,
exec_permission_approvals_enabled: bool,
unified_exec_backend: UnifiedExecBackendConfig,
) -> ToolSpec {
let mut properties = BTreeMap::from([
(
@@ -596,6 +595,12 @@ fn create_exec_command_tool(
),
},
),
(
"shell".to_string(),
JsonSchema::String {
description: Some("Shell binary to launch. Defaults to the user's default shell.".to_string()),
},
),
(
"tty".to_string(),
JsonSchema::Boolean {
@@ -623,16 +628,6 @@ fn create_exec_command_tool(
},
),
]);
if unified_exec_backend != UnifiedExecBackendConfig::ZshFork {
properties.insert(
"shell".to_string(),
JsonSchema::String {
description: Some(
"Shell binary to launch. Defaults to the user's default shell.".to_string(),
),
},
);
}
if allow_login_shell {
properties.insert(
"login".to_string(),
@@ -2527,7 +2522,6 @@ pub(crate) fn build_specs_with_discoverable_tools(
create_exec_command_tool(
config.allow_login_shell,
exec_permission_approvals_enabled,
config.unified_exec_backend,
),
true,
config.code_mode_enabled,

View File

@@ -444,7 +444,7 @@ fn test_full_toolset_specs_for_gpt5_codex_unified_exec_web_search() {
// Build expected from the same helpers used by the builder.
let mut expected: BTreeMap<String, ToolSpec> = BTreeMap::from([]);
for spec in [
create_exec_command_tool(true, false, UnifiedExecBackendConfig::Direct),
create_exec_command_tool(true, false),
create_write_stdin_tool(),
PLAN_TOOL.clone(),
create_request_user_input_tool(CollaborationModesConfig::default()),
@@ -1350,7 +1350,7 @@ fn test_build_specs_default_shell_present() {
}
#[test]
fn shell_zsh_fork_uses_unified_exec_when_enabled() {
fn shell_zsh_fork_prefers_shell_command_over_unified_exec() {
let config = test_config();
let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config);
let mut features = Features::with_defaults();
@@ -1368,7 +1368,7 @@ fn shell_zsh_fork_uses_unified_exec_when_enabled() {
windows_sandbox_level: WindowsSandboxLevel::Disabled,
});
assert_eq!(tools_config.shell_type, ConfigShellToolType::UnifiedExec);
assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand);
assert_eq!(
tools_config.shell_command_backend,
ShellCommandBackendConfig::ZshFork
@@ -1377,19 +1377,6 @@ fn shell_zsh_fork_uses_unified_exec_when_enabled() {
tools_config.unified_exec_backend,
UnifiedExecBackendConfig::ZshFork
);
let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build();
let exec_spec = find_tool(&tools, "exec_command");
let ToolSpec::Function(exec_tool) = &exec_spec.spec else {
panic!("exec_command should be a function tool spec");
};
let JsonSchema::Object { properties, .. } = &exec_tool.parameters else {
panic!("exec_command parameters should be an object schema");
};
assert!(
!properties.contains_key("shell"),
"exec_command should omit `shell` when zsh-fork backend forces the configured shell",
);
}
#[test]

View File

@@ -1,7 +1,6 @@
#![allow(clippy::module_inception)]
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use tokio::sync::Mutex;
@@ -27,23 +26,7 @@ use super::UnifiedExecError;
use super::head_tail_buffer::HeadTailBuffer;
pub(crate) trait SpawnLifecycle: std::fmt::Debug + Send + Sync {
/// Returns file descriptors that must stay open across the child `exec()`.
///
/// The returned descriptors must already be valid in the parent process and
/// stay valid until `after_spawn()` runs, which is the first point where
/// the parent may release its copies.
fn inherited_fds(&self) -> Vec<i32> {
Vec::new()
}
fn after_spawn(&mut self) {}
/// Releases resources that needed to stay alive until the child process was
/// fully launched or until the session was torn down.
///
/// This hook must tolerate being called during normal process exit as well
/// as early termination paths, and it is guaranteed to run at most once.
fn after_exit(&mut self) {}
}
pub(crate) type SpawnLifecycleHandle = Box<dyn SpawnLifecycle>;
@@ -74,8 +57,7 @@ pub(crate) struct UnifiedExecProcess {
output_drained: Arc<Notify>,
output_task: JoinHandle<()>,
sandbox_type: SandboxType,
_spawn_lifecycle: Arc<StdMutex<SpawnLifecycleHandle>>,
spawn_lifecycle_released: Arc<AtomicBool>,
_spawn_lifecycle: SpawnLifecycleHandle,
}
impl UnifiedExecProcess {
@@ -83,7 +65,7 @@ impl UnifiedExecProcess {
process_handle: ExecCommandSession,
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
sandbox_type: SandboxType,
spawn_lifecycle: Arc<StdMutex<SpawnLifecycleHandle>>,
spawn_lifecycle: SpawnLifecycleHandle,
) -> Self {
let output_buffer = Arc::new(Mutex::new(HeadTailBuffer::default()));
let output_notify = Arc::new(Notify::new());
@@ -128,19 +110,6 @@ impl UnifiedExecProcess {
output_task,
sandbox_type,
_spawn_lifecycle: spawn_lifecycle,
spawn_lifecycle_released: Arc::new(AtomicBool::new(false)),
}
}
fn release_spawn_lifecycle(&self) {
if self
.spawn_lifecycle_released
.swap(true, std::sync::atomic::Ordering::AcqRel)
{
return;
}
if let Ok(mut lifecycle) = self._spawn_lifecycle.lock() {
lifecycle.after_exit();
}
}
@@ -179,7 +148,6 @@ impl UnifiedExecProcess {
}
pub(super) fn terminate(&self) {
self.release_spawn_lifecycle();
self.output_closed.store(true, Ordering::Release);
self.output_closed_notify.notify_waiters();
self.process_handle.terminate();
@@ -255,19 +223,12 @@ impl UnifiedExecProcess {
mut exit_rx,
} = spawned;
let output_rx = codex_utils_pty::combine_output_receivers(stdout_rx, stderr_rx);
let spawn_lifecycle = Arc::new(StdMutex::new(spawn_lifecycle));
let managed = Self::new(
process_handle,
output_rx,
sandbox_type,
Arc::clone(&spawn_lifecycle),
);
let managed = Self::new(process_handle, output_rx, sandbox_type, spawn_lifecycle);
let exit_ready = matches!(exit_rx.try_recv(), Ok(_) | Err(TryRecvError::Closed));
if exit_ready {
managed.signal_exit();
managed.release_spawn_lifecycle();
managed.check_for_sandbox_denial().await?;
return Ok(managed);
}
@@ -277,22 +238,14 @@ impl UnifiedExecProcess {
.is_ok()
{
managed.signal_exit();
managed.release_spawn_lifecycle();
managed.check_for_sandbox_denial().await?;
return Ok(managed);
}
tokio::spawn({
let cancellation_token = managed.cancellation_token.clone();
let spawn_lifecycle = Arc::clone(&spawn_lifecycle);
let spawn_lifecycle_released = Arc::clone(&managed.spawn_lifecycle_released);
async move {
let _ = exit_rx.await;
if !spawn_lifecycle_released.swap(true, Ordering::AcqRel)
&& let Ok(mut lifecycle) = spawn_lifecycle.lock()
{
lifecycle.after_exit();
}
cancellation_token.cancel();
}
});

View File

@@ -537,27 +537,24 @@ impl UnifiedExecProcessManager {
.command
.split_first()
.ok_or(UnifiedExecError::MissingCommandLine)?;
let inherited_fds = spawn_lifecycle.inherited_fds();
let spawn_result = if tty {
codex_utils_pty::pty::spawn_process_with_inherited_fds(
codex_utils_pty::pty::spawn_process(
program,
args,
env.cwd.as_path(),
&env.env,
&env.arg0,
codex_utils_pty::TerminalSize::default(),
&inherited_fds,
)
.await
} else {
codex_utils_pty::pipe::spawn_process_no_stdin_with_inherited_fds(
codex_utils_pty::pipe::spawn_process_no_stdin(
program,
args,
env.cwd.as_path(),
&env.env,
&env.arg0,
&inherited_fds,
)
.await
};

View File

@@ -1,7 +1,6 @@
use anyhow::Context;
use std::fs;
use std::path::Path;
use std::process::Stdio;
use std::time::Duration;
pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
@@ -25,7 +24,6 @@ pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
pub fn process_is_alive(pid: &str) -> anyhow::Result<bool> {
let status = std::process::Command::new("kill")
.args(["-0", pid])
.stderr(Stdio::null())
.status()
.context("failed to probe process liveness with kill -0")?;
Ok(status.success())

View File

@@ -18,10 +18,6 @@ pub struct ZshForkRuntime {
}
impl ZshForkRuntime {
pub fn zsh_path(&self) -> &Path {
&self.zsh_path
}
fn apply_to_config(
&self,
config: &mut Config,
@@ -95,29 +91,6 @@ where
builder.build(server).await
}
pub async fn build_unified_exec_zsh_fork_test<F>(
server: &wiremock::MockServer,
runtime: ZshForkRuntime,
approval_policy: AskForApproval,
sandbox_policy: SandboxPolicy,
pre_build_hook: F,
) -> Result<TestCodex>
where
F: FnOnce(&Path) + Send + 'static,
{
let mut builder = test_codex()
.with_pre_build_hook(pre_build_hook)
.with_config(move |config| {
runtime.apply_to_config(config, approval_policy, sandbox_policy);
config.use_experimental_unified_exec_tool = true;
config
.features
.enable(Feature::UnifiedExec)
.expect("test config should allow feature update");
});
builder.build(server).await
}
fn find_test_zsh_path() -> Result<Option<PathBuf>> {
let repo_root = codex_utils_cargo_bin::repo_root()?;
let dotslash_zsh = repo_root.join("codex-rs/app-server/tests/suite/zsh");

View File

@@ -35,7 +35,6 @@ use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_with_timeout;
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
use core_test_support::zsh_fork::build_zsh_fork_test;
use core_test_support::zsh_fork::restrictive_workspace_write_policy;
use core_test_support::zsh_fork::zsh_fork_runtime;
@@ -124,7 +123,7 @@ impl ActionKind {
let (path, _) = target.resolve_for_patch(test);
let _ = fs::remove_file(&path);
let command = format!("printf {content:?} > {path:?} && cat {path:?}");
let event = shell_event(call_id, &command, 5_000, sandbox_permissions)?;
let event = shell_event(call_id, &command, 1_000, sandbox_permissions)?;
Ok((event, Some(command)))
}
ActionKind::FetchUrl {
@@ -1986,158 +1985,6 @@ async fn approving_execpolicy_amendment_persists_policy_and_skips_future_prompts
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[cfg(unix)]
async fn unified_exec_zsh_fork_execpolicy_amendment_skips_later_subcommands() -> Result<()> {
skip_if_no_network!(Ok(()));
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork execpolicy amendment test")? else {
return Ok(());
};
let approval_policy = AskForApproval::UnlessTrusted;
let sandbox_policy = SandboxPolicy::new_read_only_policy();
let server = start_mock_server().await;
let test = build_unified_exec_zsh_fork_test(
&server,
runtime,
approval_policy,
sandbox_policy.clone(),
|_| {},
)
.await?;
let allow_prefix_path = test.cwd.path().join("allow-prefix-zsh-fork.txt");
let _ = fs::remove_file(&allow_prefix_path);
let call_id = "allow-prefix-zsh-fork";
let command = "touch allow-prefix-zsh-fork.txt && touch allow-prefix-zsh-fork.txt";
let event = exec_command_event(
call_id,
command,
Some(1_000),
SandboxPermissions::UseDefault,
None,
)?;
let _ = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-zsh-fork-allow-prefix-1"),
event,
ev_completed("resp-zsh-fork-allow-prefix-1"),
]),
)
.await;
let results = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-zsh-fork-allow-prefix-1", "done"),
ev_completed("resp-zsh-fork-allow-prefix-2"),
]),
)
.await;
submit_turn(
&test,
"allow-prefix-zsh-fork",
approval_policy,
sandbox_policy,
)
.await?;
let expected_execpolicy_amendment = ExecPolicyAmendment::new(vec![
"touch".to_string(),
"allow-prefix-zsh-fork.txt".to_string(),
]);
let mut saw_parent_approval = false;
let mut saw_subcommand_approval = false;
loop {
let event = wait_for_event(&test.codex, |event| {
matches!(
event,
EventMsg::ExecApprovalRequest(_) | EventMsg::TurnComplete(_)
)
})
.await;
match event {
EventMsg::TurnComplete(_) => break,
EventMsg::ExecApprovalRequest(approval) => {
let command_parts = approval.command.clone();
let last_arg = command_parts.last().map(String::as_str).unwrap_or_default();
if last_arg == command {
assert!(
!saw_parent_approval,
"unexpected duplicate parent approval: {command_parts:?}"
);
saw_parent_approval = true;
test.codex
.submit(Op::ExecApproval {
id: approval.effective_approval_id(),
turn_id: None,
decision: ReviewDecision::Approved,
})
.await?;
continue;
}
let is_touch_subcommand = command_parts
.iter()
.any(|part| part == "allow-prefix-zsh-fork.txt")
&& command_parts
.first()
.is_some_and(|part| part.ends_with("/touch") || part == "touch");
if is_touch_subcommand {
assert!(
!saw_subcommand_approval,
"execpolicy amendment should suppress later matching subcommand approvals: {command_parts:?}"
);
saw_subcommand_approval = true;
assert_eq!(
approval.proposed_execpolicy_amendment,
Some(expected_execpolicy_amendment.clone())
);
test.codex
.submit(Op::ExecApproval {
id: approval.effective_approval_id(),
turn_id: None,
decision: ReviewDecision::ApprovedExecpolicyAmendment {
proposed_execpolicy_amendment: expected_execpolicy_amendment
.clone(),
},
})
.await?;
continue;
}
test.codex
.submit(Op::ExecApproval {
id: approval.effective_approval_id(),
turn_id: None,
decision: ReviewDecision::Approved,
})
.await?;
}
other => panic!("unexpected event: {other:?}"),
}
}
assert!(saw_parent_approval, "expected parent unified-exec approval");
assert!(
saw_subcommand_approval,
"expected at least one intercepted touch approval"
);
let result = parse_result(&results.single_request().function_call_output(call_id));
assert_eq!(result.exit_code.unwrap_or(0), 0);
assert!(
allow_prefix_path.exists(),
"expected touch command to complete after approving the first intercepted subcommand; output: {}",
result.stdout
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[cfg(unix)]
async fn matched_prefix_rule_runs_unsandboxed_under_zsh_fork() -> Result<()> {

View File

@@ -20,7 +20,6 @@ use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
use core_test_support::zsh_fork::build_zsh_fork_test;
use core_test_support::zsh_fork::restrictive_workspace_write_policy;
use core_test_support::zsh_fork::zsh_fork_runtime;
@@ -51,13 +50,6 @@ fn shell_command_arguments(command: &str) -> Result<String> {
}))?)
}
fn exec_command_arguments(command: &str) -> Result<String> {
Ok(serde_json::to_string(&json!({
"cmd": command,
"yield_time_ms": 500,
}))?)
}
async fn submit_turn_with_policies(
test: &TestCodex,
prompt: &str,
@@ -97,38 +89,6 @@ echo 'zsh-fork-stderr' >&2
)
}
#[cfg(unix)]
fn write_repo_skill_with_shell_script_contents(
repo_root: &Path,
name: &str,
script_name: &str,
script_contents: &str,
) -> Result<PathBuf> {
use std::os::unix::fs::PermissionsExt;
let skill_dir = repo_root.join(".agents").join("skills").join(name);
let scripts_dir = skill_dir.join("scripts");
fs::create_dir_all(&scripts_dir)?;
fs::write(repo_root.join(".git"), "gitdir: here")?;
fs::write(
skill_dir.join("SKILL.md"),
format!(
r#"---
name: {name}
description: {name} skill
---
"#
),
)?;
let script_path = scripts_dir.join(script_name);
fs::write(&script_path, script_contents)?;
let mut permissions = fs::metadata(&script_path)?.permissions();
permissions.set_mode(0o755);
fs::set_permissions(&script_path, permissions)?;
Ok(script_path)
}
#[cfg(unix)]
fn write_skill_with_shell_script_contents(
home: &Path,
@@ -581,168 +541,6 @@ permissions:
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unified_exec_zsh_fork_prompts_for_skill_script_execution() -> Result<()> {
skip_if_no_network!(Ok(()));
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork skill prompt test")? else {
return Ok(());
};
let server = start_mock_server().await;
let tool_call_id = "uexec-zsh-fork-skill-call";
let test = build_unified_exec_zsh_fork_test(
&server,
runtime,
AskForApproval::OnRequest,
SandboxPolicy::new_workspace_write_policy(),
|home| {
write_skill_with_shell_script(home, "mbolin-test-skill", "hello-mbolin.sh").unwrap();
write_skill_metadata(
home,
"mbolin-test-skill",
r#"
permissions:
file_system:
read:
- "./data"
write:
- "./output"
"#,
)
.unwrap();
},
)
.await?;
let (script_path_str, command) = skill_script_command(&test, "hello-mbolin.sh")?;
let arguments = exec_command_arguments(&command)?;
let mocks =
mount_function_call_agent_response(&server, tool_call_id, &arguments, "exec_command").await;
submit_turn_with_policies(
&test,
"use $mbolin-test-skill",
AskForApproval::OnRequest,
SandboxPolicy::new_workspace_write_policy(),
)
.await?;
let approval = wait_for_exec_approval_request(&test)
.await
.expect("expected exec approval request before completion");
assert_eq!(approval.call_id, tool_call_id);
assert_eq!(approval.command, vec![script_path_str.clone()]);
assert_eq!(
approval.available_decisions,
Some(vec![
ReviewDecision::Approved,
ReviewDecision::ApprovedForSession,
ReviewDecision::Abort,
])
);
assert_eq!(
approval.additional_permissions,
Some(PermissionProfile {
file_system: Some(FileSystemPermissions {
read: Some(vec![absolute_path(
&test.codex_home_path().join("skills/mbolin-test-skill/data"),
)]),
write: Some(vec![absolute_path(
&test
.codex_home_path()
.join("skills/mbolin-test-skill/output"),
)]),
}),
..Default::default()
})
);
test.codex
.submit(Op::ExecApproval {
id: approval.effective_approval_id(),
turn_id: None,
decision: ReviewDecision::Denied,
})
.await?;
wait_for_turn_complete(&test).await;
let call_output = mocks
.completion
.single_request()
.function_call_output(tool_call_id);
let output = call_output["output"].as_str().unwrap_or_default();
assert!(
output.contains("Execution denied: User denied execution"),
"expected rejection marker in function_call_output: {output:?}"
);
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unified_exec_zsh_fork_keeps_skill_loading_pinned_to_turn_cwd() -> Result<()> {
skip_if_no_network!(Ok(()));
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork turn cwd skill test")? else {
return Ok(());
};
let server = start_mock_server().await;
let tool_call_id = "uexec-zsh-fork-repo-skill-call";
let test = build_unified_exec_zsh_fork_test(
&server,
runtime,
AskForApproval::OnRequest,
SandboxPolicy::new_workspace_write_policy(),
|_| {},
)
.await?;
let repo_root = test.cwd_path().join("repo");
let script_path = write_repo_skill_with_shell_script_contents(
&repo_root,
"repo-skill",
"repo-skill.sh",
"#!/bin/sh\necho 'repo-skill-output'\n",
)?;
let script_path_quoted = shlex::try_join([script_path.to_string_lossy().as_ref()])?;
let repo_root_quoted = shlex::try_join([repo_root.to_string_lossy().as_ref()])?;
let command = format!("cd {repo_root_quoted} && {script_path_quoted}");
let arguments = exec_command_arguments(&command)?;
let mocks =
mount_function_call_agent_response(&server, tool_call_id, &arguments, "exec_command").await;
submit_turn_with_policies(
&test,
"run the repo skill after changing directories",
AskForApproval::OnRequest,
SandboxPolicy::new_workspace_write_policy(),
)
.await?;
let approval = wait_for_exec_approval_request(&test).await;
assert!(
approval.is_none(),
"changing directories inside unified exec should not load repo-local skills from the shell cwd",
);
let call_output = mocks
.completion
.single_request()
.function_call_output(tool_call_id);
let output = call_output["output"].as_str().unwrap_or_default();
assert!(
output.contains("repo-skill-output"),
"expected repo skill script to run without skill-specific approval when only the shell cwd changes: {output:?}"
);
Ok(())
}
/// Permissionless skills should inherit the turn sandbox without prompting.
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use std::ffi::OsStr;
use std::fs;
#[cfg(unix)]
use std::path::PathBuf;
use std::sync::OnceLock;
use anyhow::Context;
@@ -34,10 +32,6 @@ use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use core_test_support::wait_for_event_with_timeout;
#[cfg(unix)]
use core_test_support::zsh_fork::build_unified_exec_zsh_fork_test;
#[cfg(unix)]
use core_test_support::zsh_fork::zsh_fork_runtime;
use pretty_assertions::assert_eq;
use regex_lite::Regex;
use serde_json::Value;
@@ -161,27 +155,6 @@ fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, ParsedUnifie
Ok(outputs)
}
#[cfg(unix)]
fn process_text_binary_path(pid: &str) -> Result<PathBuf> {
let output = std::process::Command::new("lsof")
.args(["-Fn", "-a", "-p", pid, "-d", "txt"])
.output()
.with_context(|| format!("failed to inspect process {pid} executable mapping with lsof"))?;
if !output.status.success() {
return Err(anyhow::anyhow!(
"lsof failed for pid {pid} with status {:?}",
output.status.code()
));
}
let stdout = String::from_utf8(output.stdout).context("lsof output was not UTF-8")?;
let path = stdout
.lines()
.find_map(|line| line.strip_prefix('n'))
.ok_or_else(|| anyhow::anyhow!("lsof did not report a text binary path for pid {pid}"))?;
Ok(PathBuf::from(path))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unified_exec_intercepts_apply_patch_exec_command() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -1350,6 +1323,7 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> {
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let metadata = outputs
.get(call_id)
@@ -1471,6 +1445,7 @@ async fn unified_exec_defaults_to_pipe() -> Result<()> {
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let output = outputs
.get(call_id)
@@ -1564,6 +1539,7 @@ async fn unified_exec_can_enable_tty() -> Result<()> {
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let output = outputs
.get(call_id)
@@ -1648,6 +1624,7 @@ async fn unified_exec_respects_early_exit_notifications() -> Result<()> {
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let output = outputs
.get(call_id)
@@ -1846,350 +1823,6 @@ async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[cfg(unix)]
async fn unified_exec_zsh_fork_keeps_python_repl_attached_to_zsh_session() -> Result<()> {
skip_if_no_network!(Ok(()));
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork tty session test")? else {
return Ok(());
};
let configured_zsh_path =
fs::canonicalize(runtime.zsh_path()).unwrap_or_else(|_| runtime.zsh_path().to_path_buf());
let python = match which("python3") {
Ok(path) => path,
Err(_) => {
eprintln!("python3 not found in PATH, skipping zsh-fork python repl test.");
return Ok(());
}
};
let server = start_mock_server().await;
let test = build_unified_exec_zsh_fork_test(
&server,
runtime,
AskForApproval::Never,
SandboxPolicy::new_workspace_write_policy(),
|_| {},
)
.await?;
let start_call_id = "uexec-zsh-fork-python-start";
let send_call_id = "uexec-zsh-fork-python-pid";
let exit_call_id = "uexec-zsh-fork-python-exit";
let start_command = format!("{}; :", python.display());
let start_args = serde_json::json!({
"cmd": start_command,
"yield_time_ms": 500,
"tty": true,
});
let send_args = serde_json::json!({
"chars": "import os; print('CODEX_PY_PID=' + str(os.getpid()))\r\n",
"session_id": 1000,
"yield_time_ms": 500,
});
let exit_args = serde_json::json!({
"chars": "import sys; sys.exit(0)\r\n",
"session_id": 1000,
"yield_time_ms": 500,
});
let responses = vec![
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
start_call_id,
"exec_command",
&serde_json::to_string(&start_args)?,
),
ev_completed("resp-1"),
]),
sse(vec![
ev_response_created("resp-2"),
ev_function_call(
send_call_id,
"write_stdin",
&serde_json::to_string(&send_args)?,
),
ev_completed("resp-2"),
]),
sse(vec![
ev_assistant_message("msg-1", "python is running"),
ev_completed("resp-3"),
]),
sse(vec![
ev_response_created("resp-4"),
ev_function_call(
exit_call_id,
"write_stdin",
&serde_json::to_string(&exit_args)?,
),
ev_completed("resp-4"),
]),
sse(vec![
ev_assistant_message("msg-2", "all done"),
ev_completed("resp-5"),
]),
];
let request_log = mount_sse_sequence(&server, responses).await;
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "test unified exec zsh-fork tty behavior".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
model: test.session_configured.model.clone(),
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let requests = request_log.requests();
assert!(!requests.is_empty(), "expected at least one POST request");
let bodies = requests
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let start_output = outputs
.get(start_call_id)
.expect("missing start output for exec_command");
let process_id = start_output
.process_id
.clone()
.expect("expected process id from exec_command");
assert!(
start_output.exit_code.is_none(),
"initial exec_command should leave the PTY session running"
);
let send_output = outputs
.get(send_call_id)
.expect("missing write_stdin output");
let normalized = send_output.output.replace("\r\n", "\n");
let python_pid = Regex::new(r"CODEX_PY_PID=(\d+)")
.expect("valid python pid marker regex")
.captures(&normalized)
.and_then(|captures| captures.get(1))
.map(|value| value.as_str().to_string())
.with_context(|| format!("missing python pid in output {normalized:?}"))?;
assert!(
process_is_alive(&python_pid)?,
"python process should still be alive after printing its pid, got output {normalized:?}"
);
assert_eq!(send_output.process_id.as_deref(), Some(process_id.as_str()));
assert!(
send_output.exit_code.is_none(),
"write_stdin should not report an exit code while the process is still running"
);
let zsh_pid = std::process::Command::new("ps")
.args(["-o", "ppid=", "-p", &python_pid])
.output()
.context("failed to look up python parent pid")?;
let zsh_pid = String::from_utf8(zsh_pid.stdout)
.context("python parent pid output is not UTF-8")?
.trim()
.to_string();
assert!(
!zsh_pid.is_empty(),
"expected python parent pid to identify the zsh session"
);
assert!(
process_is_alive(&zsh_pid)?,
"expected zsh parent process {zsh_pid} to still be alive"
);
let zsh_command = std::process::Command::new("ps")
.args(["-o", "command=", "-p", &zsh_pid])
.output()
.context("failed to look up zsh parent command")?;
let zsh_command =
String::from_utf8(zsh_command.stdout).context("zsh parent command output is not UTF-8")?;
assert!(
zsh_command.contains("zsh"),
"expected python parent command to be zsh, got {zsh_command:?}"
);
let zsh_text_binary = process_text_binary_path(&zsh_pid)?;
let zsh_text_binary = fs::canonicalize(&zsh_text_binary).unwrap_or(zsh_text_binary);
assert_eq!(
zsh_text_binary, configured_zsh_path,
"python parent shell should run with configured zsh-fork binary, got {:?} ({zsh_command:?})",
zsh_text_binary,
);
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "shut down the python repl".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
model: test.session_configured.model.clone(),
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let requests = request_log.requests();
assert!(!requests.is_empty(), "expected at least one POST request");
let bodies = requests
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let exit_output = outputs
.get(exit_call_id)
.expect("missing exit output after requesting python shutdown");
assert!(
exit_output.exit_code.is_none() || exit_output.exit_code == Some(0),
"exit request should either leave cleanup to the background watcher or report success directly, got {exit_output:?}"
);
wait_for_process_exit(&python_pid).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[cfg(unix)]
async fn unified_exec_zsh_fork_rewrites_nested_zsh_exec_to_configured_binary() -> Result<()> {
skip_if_no_network!(Ok(()));
let Some(runtime) = zsh_fork_runtime("unified exec zsh-fork nested zsh rewrite test")? else {
return Ok(());
};
let configured_zsh_path =
fs::canonicalize(runtime.zsh_path()).unwrap_or_else(|_| runtime.zsh_path().to_path_buf());
let host_zsh = match which("zsh") {
Ok(path) => path,
Err(_) => {
eprintln!("zsh not found in PATH, skipping nested zsh rewrite test.");
return Ok(());
}
};
let server = start_mock_server().await;
let test = build_unified_exec_zsh_fork_test(
&server,
runtime,
AskForApproval::Never,
SandboxPolicy::new_workspace_write_policy(),
|_| {},
)
.await?;
let start_call_id = "uexec-zsh-fork-nested-start";
let nested_command = format!(
"exec {} -lc 'echo CODEX_NESTED_ZSH_PID=$$; sleep 3; :'",
host_zsh.display(),
);
let start_args = serde_json::json!({
"cmd": nested_command,
"yield_time_ms": 500,
"tty": true,
});
let responses = vec![
sse(vec![
ev_response_created("resp-nested-1"),
ev_function_call(
start_call_id,
"exec_command",
&serde_json::to_string(&start_args)?,
),
ev_completed("resp-nested-1"),
]),
sse(vec![
ev_assistant_message("msg-nested-1", "done"),
ev_completed("resp-nested-2"),
]),
];
let request_log = mount_sse_sequence(&server, responses).await;
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "test nested zsh rewrite behavior".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_workspace_write_policy(),
model: test.session_configured.model.clone(),
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let requests = request_log.requests();
let bodies = requests
.into_iter()
.map(|request| request.body_json())
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let start_output = outputs
.get(start_call_id)
.expect("missing start output for nested zsh exec_command");
let normalized = start_output.output.replace("\r\n", "\n");
let nested_zsh_pid = Regex::new(r"CODEX_NESTED_ZSH_PID=(\d+)")
.expect("valid nested zsh pid regex")
.captures(&normalized)
.and_then(|captures| captures.get(1))
.map(|value| value.as_str().to_string())
.with_context(|| format!("missing nested zsh pid marker in output {normalized:?}"))?;
assert!(
process_is_alive(&nested_zsh_pid)?,
"nested zsh process should be running before release, got output {normalized:?}"
);
let nested_text_binary = process_text_binary_path(&nested_zsh_pid)?;
let nested_text_binary = fs::canonicalize(&nested_text_binary).unwrap_or(nested_text_binary);
assert_eq!(
nested_text_binary, configured_zsh_path,
"nested zsh exec should be rewritten to configured zsh-fork binary, got {:?}",
nested_text_binary,
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<()> {
skip_if_no_network!(Ok(()));

View File

@@ -28,8 +28,6 @@ pub use unix::ShellCommandExecutor;
#[cfg(unix)]
pub use unix::Stopwatch;
#[cfg(unix)]
pub use unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
#[cfg(unix)]
pub use unix::main_execve_wrapper;
#[cfg(unix)]
pub use unix::run_shell_escalation_execve_wrapper;

View File

@@ -1,6 +1,6 @@
use std::io;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
use std::os::fd::OwnedFd;
use anyhow::Context as _;
@@ -28,12 +28,6 @@ fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
}
fn duplicate_fd_for_transfer(fd: impl AsFd, name: &str) -> anyhow::Result<OwnedFd> {
fd.as_fd()
.try_clone_to_owned()
.with_context(|| format!("failed to duplicate {name} for escalation transfer"))
}
pub async fn run_shell_escalation_execve_wrapper(
file: String,
argv: Vec<String>,
@@ -68,18 +62,11 @@ pub async fn run_shell_escalation_execve_wrapper(
.context("failed to receive EscalateResponse")?;
match message.action {
EscalateAction::Escalate => {
// Duplicate stdio before transferring ownership to the server. The
// wrapper must keep using its own stdin/stdout/stderr until the
// escalated child takes over.
let destination_fds = [
io::stdin().as_raw_fd(),
io::stdout().as_raw_fd(),
io::stderr().as_raw_fd(),
];
// TODO: maybe we should send ALL open FDs (except the escalate client)?
let fds_to_send = [
duplicate_fd_for_transfer(io::stdin(), "stdin")?,
duplicate_fd_for_transfer(io::stdout(), "stdout")?,
duplicate_fd_for_transfer(io::stderr(), "stderr")?,
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
];
// TODO: also forward signals over the super-exec socket
@@ -87,7 +74,7 @@ pub async fn run_shell_escalation_execve_wrapper(
client
.send_with_fds(
SuperExecMessage {
fds: destination_fds.into_iter().collect(),
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
},
&fds_to_send,
)
@@ -128,23 +115,3 @@ pub async fn run_shell_escalation_execve_wrapper(
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::os::fd::AsRawFd;
use std::os::unix::net::UnixStream;
#[test]
fn duplicate_fd_for_transfer_does_not_close_original() {
let (left, _right) = UnixStream::pair().expect("socket pair");
let original_fd = left.as_raw_fd();
let duplicate = duplicate_fd_for_transfer(&left, "test fd").expect("duplicate fd");
assert_ne!(duplicate.as_raw_fd(), original_fd);
drop(duplicate);
assert_ne!(unsafe { libc::fcntl(original_fd, libc::F_GETFD) }, -1);
}
}

View File

@@ -319,6 +319,16 @@ async fn handle_escalate_session_with_policy(
));
}
if msg
.fds
.iter()
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
{
return Err(anyhow::anyhow!(
"overlapping fds not yet supported in SuperExecMessage"
));
}
let PreparedExec {
command,
cwd,
@@ -802,93 +812,6 @@ mod tests {
server_task.await?
}
#[cfg(unix)]
struct RestoredFd {
target_fd: i32,
original_fd: std::os::fd::OwnedFd,
}
#[cfg(unix)]
impl RestoredFd {
fn close_temporarily(target_fd: i32) -> anyhow::Result<Self> {
let original_fd = unsafe { libc::dup(target_fd) };
if original_fd == -1 {
return Err(std::io::Error::last_os_error().into());
}
if unsafe { libc::close(target_fd) } == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(original_fd);
}
return Err(err.into());
}
Ok(Self {
target_fd,
original_fd: unsafe { std::os::fd::OwnedFd::from_raw_fd(original_fd) },
})
}
}
#[cfg(unix)]
impl Drop for RestoredFd {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original_fd.as_raw_fd(), self.target_fd);
}
}
}
#[tokio::test]
#[cfg(unix)]
async fn handle_escalate_session_accepts_received_fds_that_overlap_destinations()
-> anyhow::Result<()> {
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
let stdin_restore = RestoredFd::close_temporarily(libc::STDIN_FILENO)?;
let send_fd = std::os::fd::OwnedFd::from(std::fs::File::open("/dev/null")?);
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
Arc::new(DeterministicEscalationPolicy {
decision: EscalationDecision::escalate(EscalationExecution::Unsandboxed),
}),
Arc::new(ForwardingShellCommandExecutor),
CancellationToken::new(),
CancellationToken::new(),
));
client
.send(EscalateRequest {
file: PathBuf::from("/bin/sh"),
argv: vec!["sh".to_string(), "-c".to_string(), "exit 0".to_string()],
workdir: AbsolutePathBuf::current_dir()?,
env: HashMap::new(),
})
.await?;
let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Escalate,
},
response
);
client
.send_with_fds(
SuperExecMessage {
fds: vec![libc::STDIN_FILENO],
},
&[send_fd],
)
.await?;
let result = client.receive::<SuperExecResult>().await?;
assert_eq!(0, result.exit_code);
drop(stdin_restore);
server_task.await?
}
#[tokio::test]
async fn handle_escalate_session_passes_permissions_to_executor() -> anyhow::Result<()> {
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;

View File

@@ -102,15 +102,11 @@ async fn spawn_process_with_stdin_mode(
env: &HashMap<String, String>,
arg0: &Option<String>,
stdin_mode: PipeStdinMode,
inherited_fds: &[i32],
) -> Result<SpawnedProcess> {
if program.is_empty() {
anyhow::bail!("missing program for pipe spawn");
}
#[cfg(not(unix))]
let _ = inherited_fds;
let mut command = Command::new(program);
#[cfg(unix)]
if let Some(arg0) = arg0 {
@@ -119,14 +115,11 @@ async fn spawn_process_with_stdin_mode(
#[cfg(target_os = "linux")]
let parent_pid = unsafe { libc::getpid() };
#[cfg(unix)]
let inherited_fds = inherited_fds.to_vec();
#[cfg(unix)]
unsafe {
command.pre_exec(move || {
crate::process_group::detach_from_tty()?;
#[cfg(target_os = "linux")]
crate::process_group::set_parent_death_signal(parent_pid)?;
crate::pty::close_random_fds_except(&inherited_fds);
Ok(())
});
}
@@ -257,7 +250,7 @@ pub async fn spawn_process(
env: &HashMap<String, String>,
arg0: &Option<String>,
) -> Result<SpawnedProcess> {
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped, &[]).await
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped).await
}
/// Spawn a process using regular pipes, but close stdin immediately.
@@ -268,27 +261,5 @@ pub async fn spawn_process_no_stdin(
env: &HashMap<String, String>,
arg0: &Option<String>,
) -> Result<SpawnedProcess> {
spawn_process_no_stdin_with_inherited_fds(program, args, cwd, env, arg0, &[]).await
}
/// Spawn a process using regular pipes, close stdin immediately, and preserve
/// selected inherited file descriptors across exec on Unix.
pub async fn spawn_process_no_stdin_with_inherited_fds(
program: &str,
args: &[String],
cwd: &Path,
env: &HashMap<String, String>,
arg0: &Option<String>,
inherited_fds: &[i32],
) -> Result<SpawnedProcess> {
spawn_process_with_stdin_mode(
program,
args,
cwd,
env,
arg0,
PipeStdinMode::Null,
inherited_fds,
)
.await
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Null).await
}

View File

@@ -1,7 +1,5 @@
use core::fmt;
use std::io;
#[cfg(unix)]
use std::os::fd::RawFd;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
@@ -43,24 +41,9 @@ impl From<TerminalSize> for PtySize {
}
}
#[cfg(unix)]
pub(crate) trait PtyHandleKeepAlive: Send {}
#[cfg(unix)]
impl<T: Send + ?Sized> PtyHandleKeepAlive for T {}
pub(crate) enum PtyMasterHandle {
Resizable(Box<dyn MasterPty + Send>),
#[cfg(unix)]
Opaque {
raw_fd: RawFd,
_handle: Box<dyn PtyHandleKeepAlive>,
},
}
pub struct PtyHandles {
pub _slave: Option<Box<dyn SlavePty + Send>>,
pub(crate) _master: PtyMasterHandle,
pub _master: Box<dyn MasterPty + Send>,
}
impl fmt::Debug for PtyHandles {
@@ -148,11 +131,7 @@ impl ProcessHandle {
let handles = handles
.as_ref()
.ok_or_else(|| anyhow!("process is not attached to a PTY"))?;
match &handles._master {
PtyMasterHandle::Resizable(master) => master.resize(size.into()),
#[cfg(unix)]
PtyMasterHandle::Opaque { raw_fd, .. } => resize_raw_pty(*raw_fd, size),
}
handles._master.resize(size.into())
}
/// Close the child's stdin channel.
@@ -205,21 +184,6 @@ impl Drop for ProcessHandle {
}
}
#[cfg(unix)]
fn resize_raw_pty(raw_fd: RawFd, size: TerminalSize) -> anyhow::Result<()> {
let mut winsize = libc::winsize {
ws_row: size.rows,
ws_col: size.cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
let result = unsafe { libc::ioctl(raw_fd, libc::TIOCSWINSZ, &mut winsize) };
if result == -1 {
return Err(std::io::Error::last_os_error().into());
}
Ok(())
}
/// Combine split stdout/stderr receivers into a single broadcast receiver.
pub fn combine_output_receivers(
mut stdout_rx: mpsc::Receiver<Vec<u8>>,

View File

@@ -1,20 +1,6 @@
use std::collections::HashMap;
#[cfg(unix)]
use std::fs::File;
use std::io::ErrorKind;
#[cfg(unix)]
use std::os::fd::AsRawFd;
#[cfg(unix)]
use std::os::fd::FromRawFd;
#[cfg(unix)]
use std::os::fd::RawFd;
#[cfg(unix)]
use std::os::unix::process::CommandExt;
use std::path::Path;
#[cfg(unix)]
use std::process::Command as StdCommand;
#[cfg(unix)]
use std::process::Stdio;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
@@ -31,7 +17,6 @@ use tokio::task::JoinHandle;
use crate::process::ChildTerminator;
use crate::process::ProcessHandle;
use crate::process::PtyHandles;
use crate::process::PtyMasterHandle;
use crate::process::SpawnedProcess;
use crate::process::TerminalSize;
@@ -74,18 +59,6 @@ impl ChildTerminator for PtyChildTerminator {
}
}
#[cfg(unix)]
struct RawPidTerminator {
process_group_id: u32,
}
#[cfg(unix)]
impl ChildTerminator for RawPidTerminator {
fn kill(&mut self) -> std::io::Result<()> {
crate::process_group::kill_process_group(self.process_group_id)
}
}
fn platform_native_pty_system() -> Box<dyn portable_pty::PtySystem + Send> {
#[cfg(windows)]
{
@@ -106,45 +79,11 @@ pub async fn spawn_process(
env: &HashMap<String, String>,
arg0: &Option<String>,
size: TerminalSize,
) -> Result<SpawnedProcess> {
spawn_process_with_inherited_fds(program, args, cwd, env, arg0, size, &[]).await
}
/// Spawn a process attached to a PTY, preserving any inherited file
/// descriptors listed in `inherited_fds` across exec on Unix.
pub async fn spawn_process_with_inherited_fds(
program: &str,
args: &[String],
cwd: &Path,
env: &HashMap<String, String>,
arg0: &Option<String>,
size: TerminalSize,
inherited_fds: &[i32],
) -> Result<SpawnedProcess> {
if program.is_empty() {
anyhow::bail!("missing program for PTY spawn");
}
#[cfg(not(unix))]
let _ = inherited_fds;
#[cfg(unix)]
if !inherited_fds.is_empty() {
return spawn_process_preserving_fds(program, args, cwd, env, arg0, size, inherited_fds)
.await;
}
spawn_process_portable(program, args, cwd, env, arg0, size).await
}
async fn spawn_process_portable(
program: &str,
args: &[String],
cwd: &Path,
env: &HashMap<String, String>,
arg0: &Option<String>,
size: TerminalSize,
) -> Result<SpawnedProcess> {
let pty_system = platform_native_pty_system();
let pair = pty_system.openpty(size.into())?;
@@ -225,7 +164,7 @@ async fn spawn_process_portable(
} else {
None
},
_master: PtyMasterHandle::Resizable(pair.master),
_master: pair.master,
};
let handle = ProcessHandle::new(
@@ -251,225 +190,3 @@ async fn spawn_process_portable(
exit_rx,
})
}
#[cfg(unix)]
async fn spawn_process_preserving_fds(
program: &str,
args: &[String],
cwd: &Path,
env: &HashMap<String, String>,
arg0: &Option<String>,
size: TerminalSize,
inherited_fds: &[RawFd],
) -> Result<SpawnedProcess> {
let (master, slave) = open_unix_pty(size)?;
let mut command = StdCommand::new(program);
if let Some(arg0) = arg0 {
command.arg0(arg0);
}
command.current_dir(cwd);
command.env_clear();
for arg in args {
command.arg(arg);
}
for (key, value) in env {
command.env(key, value);
}
let stdin = slave.try_clone()?;
let stdout = slave.try_clone()?;
let stderr = slave.try_clone()?;
let inherited_fds = inherited_fds.to_vec();
unsafe {
command
.stdin(Stdio::from(stdin))
.stdout(Stdio::from(stdout))
.stderr(Stdio::from(stderr))
.pre_exec(move || {
for signo in &[
libc::SIGCHLD,
libc::SIGHUP,
libc::SIGINT,
libc::SIGQUIT,
libc::SIGTERM,
libc::SIGALRM,
] {
libc::signal(*signo, libc::SIG_DFL);
}
let empty_set: libc::sigset_t = std::mem::zeroed();
libc::sigprocmask(libc::SIG_SETMASK, &empty_set, std::ptr::null_mut());
if libc::setsid() == -1 {
return Err(std::io::Error::last_os_error());
}
#[allow(clippy::cast_lossless)]
if libc::ioctl(0, libc::TIOCSCTTY as _, 0) == -1 {
return Err(std::io::Error::last_os_error());
}
close_random_fds_except(&inherited_fds);
Ok(())
});
}
let mut child = command.spawn()?;
drop(slave);
let process_group_id = child.id();
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
let (stdout_tx, stdout_rx) = mpsc::channel::<Vec<u8>>(128);
let (_stderr_tx, stderr_rx) = mpsc::channel::<Vec<u8>>(1);
let mut reader = master.try_clone()?;
let reader_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || {
let mut buf = [0u8; 8_192];
loop {
match std::io::Read::read(&mut reader, &mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = stdout_tx.blocking_send(buf[..n].to_vec());
}
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_millis(5));
continue;
}
Err(_) => break,
}
}
});
let writer = Arc::new(tokio::sync::Mutex::new(master.try_clone()?));
let writer_handle: JoinHandle<()> = tokio::spawn({
let writer = Arc::clone(&writer);
async move {
while let Some(bytes) = writer_rx.recv().await {
let mut guard = writer.lock().await;
use std::io::Write;
let _ = guard.write_all(&bytes);
let _ = guard.flush();
}
}
});
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
let exit_status = Arc::new(AtomicBool::new(false));
let wait_exit_status = Arc::clone(&exit_status);
let exit_code = Arc::new(StdMutex::new(None));
let wait_exit_code = Arc::clone(&exit_code);
let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || {
let code = match child.wait() {
Ok(status) => status.code().unwrap_or(-1),
Err(_) => -1,
};
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
if let Ok(mut guard) = wait_exit_code.lock() {
*guard = Some(code);
}
let _ = exit_tx.send(code);
});
let handles = PtyHandles {
_slave: None,
_master: PtyMasterHandle::Opaque {
raw_fd: master.as_raw_fd(),
_handle: Box::new(master),
},
};
let handle = ProcessHandle::new(
writer_tx,
Box::new(RawPidTerminator { process_group_id }),
reader_handle,
Vec::new(),
writer_handle,
wait_handle,
exit_status,
exit_code,
Some(handles),
);
Ok(SpawnedProcess {
session: handle,
stdout_rx,
stderr_rx,
exit_rx,
})
}
#[cfg(unix)]
fn open_unix_pty(size: TerminalSize) -> Result<(File, File)> {
let mut master: RawFd = -1;
let mut slave: RawFd = -1;
let mut size = libc::winsize {
ws_row: size.rows,
ws_col: size.cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
let winp = std::ptr::addr_of_mut!(size);
let result = unsafe {
libc::openpty(
&mut master,
&mut slave,
std::ptr::null_mut(),
std::ptr::null_mut(),
winp,
)
};
if result != 0 {
anyhow::bail!("failed to openpty: {:?}", std::io::Error::last_os_error());
}
set_cloexec(master)?;
set_cloexec(slave)?;
Ok(unsafe { (File::from_raw_fd(master), File::from_raw_fd(slave)) })
}
#[cfg(unix)]
fn set_cloexec(fd: RawFd) -> std::io::Result<()> {
let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
if flags == -1 {
return Err(std::io::Error::last_os_error());
}
let result = unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) };
if result == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
#[cfg(unix)]
pub(crate) fn close_random_fds_except(preserved_fds: &[RawFd]) {
if let Ok(dir) = std::fs::read_dir("/dev/fd") {
let mut fds = Vec::new();
for entry in dir {
let num = entry
.ok()
.map(|entry| entry.file_name())
.and_then(|name| name.into_string().ok())
.and_then(|name| name.parse::<RawFd>().ok());
if let Some(num) = num {
if num <= 2 || preserved_fds.contains(&num) {
continue;
}
// Keep CLOEXEC descriptors open so std::process can still use
// its internal exec-error pipe to report spawn failures.
let flags = unsafe { libc::fcntl(num, libc::F_GETFD) };
if flags == -1 || flags & libc::FD_CLOEXEC != 0 {
continue;
}
fds.push(num);
}
}
for fd in fds {
unsafe {
libc::close(fd);
}
}
}
}

View File

@@ -4,10 +4,6 @@ use std::path::Path;
use pretty_assertions::assert_eq;
use crate::combine_output_receivers;
#[cfg(unix)]
use crate::pipe::spawn_process_no_stdin_with_inherited_fds;
#[cfg(unix)]
use crate::pty::spawn_process_with_inherited_fds;
use crate::spawn_pipe_process;
use crate::spawn_pipe_process_no_stdin;
use crate::spawn_pty_process;
@@ -139,42 +135,6 @@ async fn collect_output_until_exit(
}
}
#[cfg(unix)]
async fn wait_for_output_contains(
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
needle: &str,
timeout_ms: u64,
) -> anyhow::Result<Vec<u8>> {
let mut collected = Vec::new();
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
while tokio::time::Instant::now() < deadline {
let now = tokio::time::Instant::now();
let remaining = deadline.saturating_duration_since(now);
match tokio::time::timeout(remaining, output_rx.recv()).await {
Ok(Ok(chunk)) => {
collected.extend_from_slice(&chunk);
if String::from_utf8_lossy(&collected).contains(needle) {
return Ok(collected);
}
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
anyhow::bail!(
"PTY output closed while waiting for {needle:?}: {:?}",
String::from_utf8_lossy(&collected)
);
}
Err(_) => break,
}
}
anyhow::bail!(
"timed out waiting for {needle:?} in PTY output: {:?}",
String::from_utf8_lossy(&collected)
);
}
async fn wait_for_python_repl_ready(
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
timeout_ms: u64,
@@ -210,58 +170,6 @@ async fn wait_for_python_repl_ready(
);
}
#[cfg(unix)]
async fn wait_for_python_repl_ready_via_probe(
writer: &tokio::sync::mpsc::Sender<Vec<u8>>,
output_rx: &mut tokio::sync::broadcast::Receiver<Vec<u8>>,
timeout_ms: u64,
newline: &str,
) -> anyhow::Result<Vec<u8>> {
let mut collected = Vec::new();
let marker = "__codex_pty_ready__";
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
let probe_window = tokio::time::Duration::from_millis(if cfg!(windows) { 750 } else { 250 });
while tokio::time::Instant::now() < deadline {
writer
.send(format!("print('{marker}'){newline}").into_bytes())
.await?;
let probe_deadline = tokio::time::Instant::now() + probe_window;
loop {
let now = tokio::time::Instant::now();
if now >= deadline || now >= probe_deadline {
break;
}
let remaining = std::cmp::min(
deadline.saturating_duration_since(now),
probe_deadline.saturating_duration_since(now),
);
match tokio::time::timeout(remaining, output_rx.recv()).await {
Ok(Ok(chunk)) => {
collected.extend_from_slice(&chunk);
if String::from_utf8_lossy(&collected).contains(marker) {
return Ok(collected);
}
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
anyhow::bail!(
"PTY output closed while waiting for Python REPL readiness: {:?}",
String::from_utf8_lossy(&collected)
);
}
Err(_) => break,
}
}
}
anyhow::bail!(
"timed out waiting for Python REPL readiness in PTY: {:?}",
String::from_utf8_lossy(&collected)
);
}
#[cfg(unix)]
fn process_exists(pid: i32) -> anyhow::Result<bool> {
let result = unsafe { libc::kill(pid, 0) };
@@ -301,26 +209,16 @@ async fn wait_for_marker_pid(
collected.extend_from_slice(&chunk);
let text = String::from_utf8_lossy(&collected);
let mut offset = 0;
while let Some(pos) = text[offset..].find(marker) {
let marker_start = offset + pos;
let suffix = &text[marker_start + marker.len()..];
let digits_len = suffix
if let Some(marker_idx) = text.find(marker) {
let suffix = &text[marker_idx + marker.len()..];
let digits: String = suffix
.chars()
.skip_while(|ch| !ch.is_ascii_digit())
.take_while(char::is_ascii_digit)
.map(char::len_utf8)
.sum::<usize>();
if digits_len == 0 {
offset = marker_start + marker.len();
continue;
.collect();
if !digits.is_empty() {
return Ok(digits.parse()?);
}
let pid_str = &suffix[..digits_len];
let trailing = &suffix[digits_len..];
if trailing.is_empty() {
break;
}
return Ok(pid_str.parse()?);
}
}
}
@@ -671,276 +569,3 @@ async fn pty_terminate_kills_background_children_in_same_process_group() -> anyh
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_spawn_can_preserve_inherited_fds() -> anyhow::Result<()> {
use std::io::Read;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
let mut fds = [0; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result != 0 {
return Err(std::io::Error::last_os_error().into());
}
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let mut env_map: HashMap<String, String> = std::env::vars().collect();
env_map.insert(
"PRESERVED_FD".to_string(),
write_end.as_raw_fd().to_string(),
);
let script = "printf __preserved__ >\"/dev/fd/$PRESERVED_FD\"";
let spawned = spawn_process_with_inherited_fds(
"/bin/sh",
&["-c".to_string(), script.to_string()],
Path::new("."),
&env_map,
&None,
TerminalSize::default(),
&[write_end.as_raw_fd()],
)
.await?;
drop(write_end);
let (_session, output_rx, exit_rx) = combine_spawned_output(spawned);
let (_, code) = collect_output_until_exit(output_rx, exit_rx, 2_000).await;
assert_eq!(code, 0, "expected preserved-fd PTY child to exit cleanly");
let mut pipe_output = String::new();
read_end.read_to_string(&mut pipe_output)?;
assert_eq!(pipe_output, "__preserved__");
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_preserving_inherited_fds_keeps_python_repl_running() -> anyhow::Result<()> {
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
let Some(python) = find_python() else {
eprintln!(
"python not found; skipping pty_preserving_inherited_fds_keeps_python_repl_running"
);
return Ok(());
};
let mut fds = [0; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result != 0 {
return Err(std::io::Error::last_os_error().into());
}
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let preserved_fd = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let mut env_map: HashMap<String, String> = std::env::vars().collect();
env_map.insert(
"PRESERVED_FD".to_string(),
preserved_fd.as_raw_fd().to_string(),
);
let spawned = spawn_process_with_inherited_fds(
&python,
&[],
Path::new("."),
&env_map,
&None,
TerminalSize::default(),
&[preserved_fd.as_raw_fd()],
)
.await?;
drop(read_end);
drop(preserved_fd);
let (session, mut output_rx, exit_rx) = combine_spawned_output(spawned);
let writer = session.writer_sender();
let newline = "\n";
let mut output =
wait_for_python_repl_ready_via_probe(&writer, &mut output_rx, 5_000, newline).await?;
let marker = "__codex_preserved_py_pid:";
writer
.send(format!("import os; print('{marker}' + str(os.getpid())){newline}").into_bytes())
.await?;
let python_pid = match wait_for_marker_pid(&mut output_rx, marker, 2_000).await {
Ok(pid) => pid,
Err(err) => {
session.terminate();
return Err(err);
}
};
assert!(
process_exists(python_pid)?,
"expected python pid {python_pid} to stay alive after prompt output"
);
writer.send(format!("exit(){newline}").into_bytes()).await?;
let (remaining_output, code) = collect_output_until_exit(output_rx, exit_rx, 5_000).await;
output.extend_from_slice(&remaining_output);
assert_eq!(code, 0, "expected python to exit cleanly");
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_spawn_with_inherited_fds_reports_exec_failures() -> anyhow::Result<()> {
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
let mut fds = [0; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result != 0 {
return Err(std::io::Error::last_os_error().into());
}
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let env_map: HashMap<String, String> = std::env::vars().collect();
let spawn_result = spawn_process_with_inherited_fds(
"/definitely/missing/command",
&[],
Path::new("."),
&env_map,
&None,
TerminalSize::default(),
&[write_end.as_raw_fd()],
)
.await;
drop(read_end);
drop(write_end);
let err = match spawn_result {
Ok(spawned) => {
spawned.session.terminate();
anyhow::bail!("missing executable unexpectedly spawned");
}
Err(err) => err,
};
let err_text = err.to_string();
assert!(
err_text.contains("No such file")
|| err_text.contains("not found")
|| err_text.contains("os error 2"),
"expected spawn error for missing executable, got: {err_text}",
);
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pty_spawn_with_inherited_fds_supports_resize() -> anyhow::Result<()> {
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
let mut fds = [0; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result != 0 {
return Err(std::io::Error::last_os_error().into());
}
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let env_map: HashMap<String, String> = std::env::vars().collect();
let script =
"stty -echo; printf 'start:%s\\n' \"$(stty size)\"; IFS= read _line; printf 'after:%s\\n' \"$(stty size)\"";
let spawned = spawn_process_with_inherited_fds(
"/bin/sh",
&["-c".to_string(), script.to_string()],
Path::new("."),
&env_map,
&None,
TerminalSize {
rows: 31,
cols: 101,
},
&[write_end.as_raw_fd()],
)
.await?;
let (session, mut output_rx, exit_rx) = combine_spawned_output(spawned);
let writer = session.writer_sender();
let mut output = wait_for_output_contains(&mut output_rx, "start:31 101\r\n", 5_000).await?;
session.resize(TerminalSize {
rows: 45,
cols: 132,
})?;
writer.send(b"go\n".to_vec()).await?;
session.close_stdin();
let (remaining_output, code) = collect_output_until_exit(output_rx, exit_rx, 5_000).await;
output.extend_from_slice(&remaining_output);
let text = String::from_utf8_lossy(&output);
let normalized = text.replace("\r\n", "\n");
assert!(
normalized.contains("after:45 132\n"),
"expected resized PTY dimensions in output: {text:?}"
);
assert_eq!(code, 0, "expected shell to exit cleanly after resize");
drop(read_end);
drop(write_end);
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pipe_spawn_no_stdin_can_preserve_inherited_fds() -> anyhow::Result<()> {
use std::io::Read;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
let mut fds = [0; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result != 0 {
return Err(std::io::Error::last_os_error().into());
}
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let mut env_map: HashMap<String, String> = std::env::vars().collect();
env_map.insert(
"PRESERVED_FD".to_string(),
write_end.as_raw_fd().to_string(),
);
let script = "printf __pipe_preserved__ >\"/dev/fd/$PRESERVED_FD\"";
let spawned = spawn_process_no_stdin_with_inherited_fds(
"/bin/sh",
&["-c".to_string(), script.to_string()],
Path::new("."),
&env_map,
&None,
&[write_end.as_raw_fd()],
)
.await?;
drop(write_end);
let (_session, output_rx, exit_rx) = combine_spawned_output(spawned);
let (_, code) = collect_output_until_exit(output_rx, exit_rx, 2_000).await;
assert_eq!(code, 0, "expected preserved-fd pipe child to exit cleanly");
let mut pipe_output = String::new();
read_end.read_to_string(&mut pipe_output)?;
assert_eq!(pipe_output, "__pipe_preserved__");
Ok(())
}