Compare commits

..

1 Commits

Author SHA1 Message Date
celia-oai
584baeb550 changes 2026-03-12 21:46:56 -07:00
26 changed files with 756 additions and 972 deletions

View File

@@ -72,7 +72,7 @@ members = [
resolver = "2"
[workspace.package]
version = "0.115.0-alpha.18"
version = "0.0.0"
# Track the edition for all workspace crates in one place. Individual
# crates can still override this value, but keeping it here means new
# crates created with `cargo new -w ...` automatically inherit the 2024

View File

@@ -1,20 +1,16 @@
use crate::endpoint::realtime_websocket::protocol::ConversationFunctionCallOutputItem;
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload;
use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem;
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptDelta;
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
use crate::endpoint::realtime_websocket::protocol::SessionFunctionTool;
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::error::ApiError;
@@ -25,7 +21,6 @@ use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
@@ -46,23 +41,6 @@ use tracing::trace;
use tungstenite::protocol::WebSocketConfig;
use url::Url;
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
const REALTIME_AUDIO_VOICE: &str = "fathom";
const REALTIME_V1_SESSION_TYPE: &str = "quicksilver";
const REALTIME_V2_SESSION_TYPE: &str = "realtime";
const REALTIME_V2_CODEX_TOOL_NAME: &str = "codex";
const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate work to Codex and return the result.";
fn normalized_session_mode(
event_parser: RealtimeEventParser,
session_mode: RealtimeSessionMode,
) -> RealtimeSessionMode {
match event_parser {
RealtimeEventParser::V1 => RealtimeSessionMode::Conversational,
RealtimeEventParser::RealtimeV2 => session_mode,
}
}
struct WsStream {
tx_command: mpsc::Sender<WsCommand>,
pump_task: tokio::task::JoinHandle<()>,
@@ -219,7 +197,6 @@ pub struct RealtimeWebsocketConnection {
pub struct RealtimeWebsocketWriter {
stream: Arc<WsStream>,
is_closed: Arc<AtomicBool>,
event_parser: RealtimeEventParser,
}
#[derive(Clone)]
@@ -281,7 +258,6 @@ impl RealtimeWebsocketConnection {
writer: RealtimeWebsocketWriter {
stream: Arc::clone(&stream),
is_closed: Arc::clone(&is_closed),
event_parser,
},
events: RealtimeWebsocketEvents {
rx_message: Arc::new(Mutex::new(rx_message)),
@@ -300,19 +276,15 @@ impl RealtimeWebsocketWriter {
}
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
let content_kind = match self.event_parser {
RealtimeEventParser::V1 => "text",
RealtimeEventParser::RealtimeV2 => "input_text",
};
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItemPayload::Message(ConversationMessageItem {
item: ConversationItem {
kind: "message".to_string(),
role: "user".to_string(),
content: vec![ConversationItemContent {
kind: content_kind.to_string(),
kind: "text".to_string(),
text,
}],
}),
},
})
.await
}
@@ -322,80 +294,29 @@ impl RealtimeWebsocketWriter {
handoff_id: String,
output_text: String,
) -> Result<(), ApiError> {
let message = match self.event_parser {
RealtimeEventParser::V1 => RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
},
RealtimeEventParser::RealtimeV2 => RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItemPayload::FunctionCallOutput(
ConversationFunctionCallOutputItem {
kind: "function_call_output".to_string(),
call_id: handoff_id,
output: output_text,
},
),
},
};
self.send_json(message).await
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
})
.await
}
pub async fn send_session_update(
&self,
instructions: String,
session_mode: RealtimeSessionMode,
) -> Result<(), ApiError> {
let session_mode = normalized_session_mode(self.event_parser, session_mode);
let (session_kind, session_instructions, output_audio) = match session_mode {
RealtimeSessionMode::Conversational => {
let kind = match self.event_parser {
RealtimeEventParser::V1 => REALTIME_V1_SESSION_TYPE.to_string(),
RealtimeEventParser::RealtimeV2 => REALTIME_V2_SESSION_TYPE.to_string(),
};
(
kind,
Some(instructions),
Some(SessionAudioOutput {
voice: REALTIME_AUDIO_VOICE.to_string(),
}),
)
}
RealtimeSessionMode::Transcription => ("transcription".to_string(), None, None),
};
let tools = match self.event_parser {
RealtimeEventParser::RealtimeV2 => Some(vec![SessionFunctionTool {
kind: "function".to_string(),
name: REALTIME_V2_CODEX_TOOL_NAME.to_string(),
description: REALTIME_V2_CODEX_TOOL_DESCRIPTION.to_string(),
parameters: json!({
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Prompt text for the delegated Codex task."
}
},
"required": ["prompt"],
"additionalProperties": false
}),
}]),
RealtimeEventParser::V1 => None,
};
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionUpdate {
session: SessionUpdateSession {
kind: session_kind,
instructions: session_instructions,
kind: "quicksilver".to_string(),
instructions,
audio: SessionAudio {
input: SessionAudioInput {
format: SessionAudioFormat {
kind: "audio/pcm".to_string(),
rate: REALTIME_AUDIO_SAMPLE_RATE,
rate: 24_000,
},
},
output: output_audio,
output: SessionAudioOutput {
voice: "fathom".to_string(),
},
},
tools,
},
})
.await
@@ -544,8 +465,6 @@ impl RealtimeWebsocketClient {
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
config.model.as_deref(),
config.event_parser,
config.session_mode,
)?;
let mut request = ws_url
@@ -587,7 +506,7 @@ impl RealtimeWebsocketClient {
);
connection
.writer
.send_session_update(config.instructions, config.session_mode)
.send_session_update(config.instructions)
.await?;
Ok(connection)
}
@@ -632,8 +551,6 @@ fn websocket_url_from_api_url(
api_url: &str,
query_params: Option<&HashMap<String, String>>,
model: Option<&str>,
event_parser: RealtimeEventParser,
_session_mode: RealtimeSessionMode,
) -> Result<Url, ApiError> {
let mut url = Url::parse(api_url)
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
@@ -653,20 +570,9 @@ fn websocket_url_from_api_url(
}
}
let intent = match event_parser {
RealtimeEventParser::V1 => Some("quicksilver"),
RealtimeEventParser::RealtimeV2 => None,
};
let has_extra_query_params = query_params.is_some_and(|query_params| {
query_params
.iter()
.any(|(key, _)| key != "intent" && !(key == "model" && model.is_some()))
});
if intent.is_some() || model.is_some() || has_extra_query_params {
{
let mut query = url.query_pairs_mut();
if let Some(intent) = intent {
query.append_pair("intent", intent);
}
query.append_pair("intent", "quicksilver");
if let Some(model) = model {
query.append_pair("model", model);
}
@@ -947,14 +853,8 @@ mod tests {
#[test]
fn websocket_url_from_http_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url(
"http://127.0.0.1:8011",
None,
None,
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
assert_eq!(
url.as_str(),
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
@@ -963,14 +863,9 @@ mod tests {
#[test]
fn websocket_url_from_ws_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url(
"wss://example.com",
None,
Some("realtime-test-model"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
@@ -979,14 +874,8 @@ mod tests {
#[test]
fn websocket_url_from_v1_base_appends_realtime_path() {
let url = websocket_url_from_api_url(
"https://api.openai.com/v1",
None,
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
@@ -995,14 +884,9 @@ mod tests {
#[test]
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
let url = websocket_url_from_api_url(
"https://example.com/openai/v1",
None,
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
let url =
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
@@ -1018,8 +902,6 @@ mod tests {
("intent".to_string(), "ignored".to_string()),
])),
Some("snapshot"),
RealtimeEventParser::V1,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
assert_eq!(
@@ -1028,54 +910,6 @@ mod tests {
);
}
#[test]
fn websocket_url_v1_ignores_transcription_mode() {
let url = websocket_url_from_api_url(
"https://example.com",
None,
None,
RealtimeEventParser::V1,
RealtimeSessionMode::Transcription,
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver"
);
}
#[test]
fn websocket_url_omits_intent_for_realtime_v2_conversational_mode() {
let url = websocket_url_from_api_url(
"https://example.com/v1/realtime?foo=bar",
Some(&HashMap::from([
("trace".to_string(), "1".to_string()),
("intent".to_string(), "ignored".to_string()),
])),
Some("snapshot"),
RealtimeEventParser::RealtimeV2,
RealtimeSessionMode::Conversational,
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
);
}
#[test]
fn websocket_url_omits_intent_for_realtime_v2_transcription_mode() {
let url = websocket_url_from_api_url(
"https://example.com",
None,
None,
RealtimeEventParser::RealtimeV2,
RealtimeSessionMode::Transcription,
)
.expect("build ws url");
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
}
#[tokio::test]
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
@@ -1241,7 +1075,6 @@ mod tests {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -1362,352 +1195,6 @@ mod tests {
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_v2_session_update_includes_codex_tool_and_handoff_output_item() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("realtime".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["type"],
Value::String("function".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["parameters"]["required"],
json!(["prompt"])
);
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_v2", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "conversation.item.create");
assert_eq!(
second_json["item"]["type"],
Value::String("message".to_string())
);
assert_eq!(
second_json["item"]["content"][0]["type"],
Value::String("input_text".to_string())
);
assert_eq!(
second_json["item"]["content"][0]["text"],
Value::String("delegate this".to_string())
);
let third = ws
.next()
.await
.expect("third msg")
.expect("third msg ok")
.into_text()
.expect("text");
let third_json: Value = serde_json::from_str(&third).expect("json");
assert_eq!(third_json["type"], "conversation.item.create");
assert_eq!(
third_json["item"]["type"],
Value::String("function_call_output".to_string())
);
assert_eq!(
third_json["item"]["call_id"],
Value::String("call_1".to_string())
);
assert_eq!(
third_json["item"]["output"],
Value::String("delegated result".to_string())
);
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_v2".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
connection
.send_conversation_item_create("delegate this".to_string())
.await
.expect("send text item");
connection
.send_conversation_handoff_append("call_1".to_string(), "delegated result".to_string())
.await
.expect("send handoff output");
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn transcription_mode_session_update_omits_output_audio_and_instructions() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("transcription".to_string())
);
assert!(first_json["session"].get("instructions").is_none());
assert!(first_json["session"]["audio"].get("output").is_none());
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_transcription"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "input_audio_buffer.append");
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Transcription,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_transcription".to_string(),
instructions: None,
}
);
connection
.send_audio_frame(RealtimeAudioFrame {
data: "AQID".to_string(),
sample_rate: 24_000,
num_channels: 1,
samples_per_channel: Some(480),
})
.await
.expect("send audio");
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn v1_transcription_mode_is_treated_as_conversational() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut ws = accept_async(stream).await.expect("accept ws");
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("quicksilver".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["audio"]["output"]["voice"],
Value::String("fathom".to_string())
);
assert!(first_json["session"].get("tools").is_none());
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_v1_mode"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
});
let provider = Provider {
name: "test".to_string(),
base_url: format!("http://{addr}"),
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: crate::provider::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
};
let client = RealtimeWebsocketClient::new(provider);
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Transcription,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_v1_mode".to_string(),
instructions: None,
}
);
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn send_does_not_block_while_next_event_waits_for_inbound_data() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
@@ -1771,7 +1258,6 @@ mod tests {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),

View File

@@ -1,7 +1,5 @@
pub mod methods;
pub mod protocol;
mod protocol_common;
mod protocol_v1;
mod protocol_v2;
pub use codex_protocol::protocol::RealtimeAudioFrame;
@@ -12,4 +10,3 @@ pub use methods::RealtimeWebsocketEvents;
pub use methods::RealtimeWebsocketWriter;
pub use protocol::RealtimeEventParser;
pub use protocol::RealtimeSessionConfig;
pub use protocol::RealtimeSessionMode;

View File

@@ -1,4 +1,3 @@
use crate::endpoint::realtime_websocket::protocol_v1::parse_realtime_event_v1;
use crate::endpoint::realtime_websocket::protocol_v2::parse_realtime_event_v2;
pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
@@ -7,6 +6,7 @@ pub use codex_protocol::protocol::RealtimeTranscriptDelta;
pub use codex_protocol::protocol::RealtimeTranscriptEntry;
use serde::Serialize;
use serde_json::Value;
use tracing::debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RealtimeEventParser {
@@ -14,19 +14,12 @@ pub enum RealtimeEventParser {
RealtimeV2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RealtimeSessionMode {
Conversational,
Transcription,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RealtimeSessionConfig {
pub instructions: String,
pub model: Option<String>,
pub session_id: Option<String>,
pub event_parser: RealtimeEventParser,
pub session_mode: RealtimeSessionMode,
}
#[derive(Debug, Clone, Serialize)]
@@ -42,25 +35,21 @@ pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "session.update")]
SessionUpdate { session: SessionUpdateSession },
#[serde(rename = "conversation.item.create")]
ConversationItemCreate { item: ConversationItemPayload },
ConversationItemCreate { item: ConversationItem },
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionUpdateSession {
#[serde(rename = "type")]
pub(super) kind: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) instructions: Option<String>,
pub(super) instructions: String,
pub(super) audio: SessionAudio,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) tools: Option<Vec<SessionFunctionTool>>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudio {
pub(super) input: SessionAudioInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) output: Option<SessionAudioOutput>,
pub(super) output: SessionAudioOutput,
}
#[derive(Debug, Clone, Serialize)]
@@ -81,28 +70,13 @@ pub(super) struct SessionAudioOutput {
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationMessageItem {
pub(super) struct ConversationItem {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) role: String,
pub(super) content: Vec<ConversationItemContent>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub(super) enum ConversationItemPayload {
Message(ConversationMessageItem),
FunctionCallOutput(ConversationFunctionCallOutputItem),
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationFunctionCallOutputItem {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) call_id: String,
pub(super) output: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationItemContent {
#[serde(rename = "type")]
@@ -110,15 +84,6 @@ pub(super) struct ConversationItemContent {
pub(super) text: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionFunctionTool {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) name: String,
pub(super) description: String,
pub(super) parameters: Value,
}
pub(super) fn parse_realtime_event(
payload: &str,
event_parser: RealtimeEventParser,
@@ -128,3 +93,125 @@ pub(super) fn parse_realtime_event(
RealtimeEventParser::RealtimeV2 => parse_realtime_event_v2(payload),
}
}
fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
let parsed: Value = match serde_json::from_str(payload) {
Ok(msg) => msg,
Err(err) => {
debug!("failed to parse realtime event: {err}, data: {payload}");
return None;
}
};
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type,
None => {
debug!("received realtime event without type field: {payload}");
return None;
}
};
match message_type {
"session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
"conversation.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.or_else(|| parsed.get("data").and_then(Value::as_str))
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok())?;
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok())?;
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok()),
}))
}
"conversation.input_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.output_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parsed
.get("item")
.and_then(Value::as_object)
.and_then(|item| item.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
active_transcript: Vec::new(),
}))
}
"error" => parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
None
}
}
}

View File

@@ -1,71 +0,0 @@
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeTranscriptDelta;
use serde_json::Value;
use tracing::debug;
pub(super) fn parse_realtime_payload(payload: &str, parser_name: &str) -> Option<(Value, String)> {
let parsed: Value = match serde_json::from_str(payload) {
Ok(message) => message,
Err(err) => {
debug!("failed to parse {parser_name} event: {err}, data: {payload}");
return None;
}
};
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type.to_string(),
None => {
debug!("received {parser_name} event without type field: {payload}");
return None;
}
};
Some((parsed, message_type))
}
pub(super) fn parse_session_updated_event(parsed: &Value) -> Option<RealtimeEvent> {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string)?;
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
Some(RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
pub(super) fn parse_transcript_delta_event(
parsed: &Value,
field: &str,
) -> Option<RealtimeTranscriptDelta> {
parsed
.get(field)
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeTranscriptDelta { delta })
}
pub(super) fn parse_error_event(parsed: &Value) -> Option<RealtimeEvent> {
parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error)
}

View File

@@ -1,83 +0,0 @@
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeHandoffRequested;
use serde_json::Value;
use tracing::debug;
pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v1")?;
match message_type.as_str() {
"session.updated" => parse_session_updated_event(&parsed),
"conversation.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.or_else(|| parsed.get("data").and_then(Value::as_str))
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())?;
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())?;
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
"conversation.input_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
}
"conversation.output_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
}
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parsed
.get("item")
.and_then(Value::as_object)
.and_then(|item| item.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
active_transcript: Vec::new(),
}))
}
"error" => parse_error_event(&parsed),
_ => {
debug!("received unsupported realtime v1 event type: {message_type}, data: {payload}");
None
}
}
}

View File

@@ -1,130 +1,157 @@
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeEvent;
use codex_protocol::protocol::RealtimeHandoffRequested;
use serde_json::Map as JsonMap;
use codex_protocol::protocol::RealtimeTranscriptDelta;
use serde_json::Value;
use tracing::debug;
const CODEX_TOOL_NAME: &str = "codex";
const DEFAULT_AUDIO_SAMPLE_RATE: u32 = 24_000;
const DEFAULT_AUDIO_CHANNELS: u16 = 1;
const TOOL_ARGUMENT_KEYS: [&str; 5] = ["input_transcript", "input", "text", "prompt", "query"];
pub(super) fn parse_realtime_event_v2(payload: &str) -> Option<RealtimeEvent> {
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v2")?;
let parsed: Value = match serde_json::from_str(payload) {
Ok(msg) => msg,
Err(err) => {
debug!("failed to parse realtime v2 event: {err}, data: {payload}");
return None;
}
};
match message_type.as_str() {
"session.updated" => parse_session_updated_event(&parsed),
"response.output_audio.delta" => parse_output_audio_delta_event(&parsed),
"conversation.item.input_audio_transcription.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
let message_type = match parsed.get("type").and_then(Value::as_str) {
Some(message_type) => message_type,
None => {
debug!("received realtime v2 event without type field: {payload}");
return None;
}
"conversation.item.input_audio_transcription.completed" => {
parse_transcript_delta_event(&parsed, "transcript")
.map(RealtimeEvent::InputTranscriptDelta)
};
match message_type {
"session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
"response.output_text.delta" | "response.output_audio_transcript.delta" => {
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
"response.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())
.unwrap_or(24_000);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())
.unwrap_or(1);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
"conversation.item.input_audio_transcription.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.input_audio_transcription.completed" => parsed
.get("transcript")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"response.output_text.delta" | "response.output_audio_transcript.delta" => parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
"conversation.item.added" => parsed
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parse_conversation_item_done_event(&parsed),
"error" => parse_error_event(&parsed),
"conversation.item.done" => {
let item = parsed.get("item")?.as_object()?;
let item_type = item.get("type").and_then(Value::as_str);
let item_name = item.get("name").and_then(Value::as_str);
if item_type == Some("function_call") && item_name == Some("codex") {
let call_id = item
.get("call_id")
.and_then(Value::as_str)
.or_else(|| item.get("id").and_then(Value::as_str))?;
let item_id = item
.get("id")
.and_then(Value::as_str)
.unwrap_or(call_id)
.to_string();
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
let mut input_transcript = String::new();
if !arguments.is_empty() {
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
&& let Some(arguments_object) = arguments_json.as_object()
{
for key in ["input_transcript", "input", "text", "prompt", "query"] {
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
let trimmed = value.trim();
if !trimmed.is_empty() {
input_transcript = trimmed.to_string();
break;
}
}
}
}
if input_transcript.is_empty() {
input_transcript = arguments.to_string();
}
}
return Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: call_id.to_string(),
item_id,
input_transcript,
active_transcript: Vec::new(),
}));
}
item.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
}
"error" => parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime v2 event type: {message_type}, data: {payload}");
None
}
}
}
fn parse_output_audio_delta_event(parsed: &Value) -> Option<RealtimeEvent> {
let data = parsed
.get("delta")
.and_then(Value::as_str)
.map(str::to_string)?;
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok())
.unwrap_or(DEFAULT_AUDIO_SAMPLE_RATE);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())
.unwrap_or(DEFAULT_AUDIO_CHANNELS);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
num_channels,
samples_per_channel: parsed
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
}))
}
fn parse_conversation_item_done_event(parsed: &Value) -> Option<RealtimeEvent> {
let item = parsed.get("item")?.as_object()?;
if let Some(handoff) = parse_handoff_requested_event(item) {
return Some(handoff);
}
item.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
}
fn parse_handoff_requested_event(item: &JsonMap<String, Value>) -> Option<RealtimeEvent> {
let item_type = item.get("type").and_then(Value::as_str);
let item_name = item.get("name").and_then(Value::as_str);
if item_type != Some("function_call") || item_name != Some(CODEX_TOOL_NAME) {
return None;
}
let call_id = item
.get("call_id")
.and_then(Value::as_str)
.or_else(|| item.get("id").and_then(Value::as_str))?;
let item_id = item
.get("id")
.and_then(Value::as_str)
.unwrap_or(call_id)
.to_string();
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: call_id.to_string(),
item_id,
input_transcript: extract_input_transcript(arguments),
active_transcript: Vec::new(),
}))
}
fn extract_input_transcript(arguments: &str) -> String {
if arguments.is_empty() {
return String::new();
}
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
&& let Some(arguments_object) = arguments_json.as_object()
{
for key in TOOL_ARGUMENT_KEYS {
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
let trimmed = value.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
}
arguments.to_string()
}

View File

@@ -29,7 +29,6 @@ pub use crate::endpoint::memories::MemoriesClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
pub use crate::endpoint::responses::ResponsesClient;

View File

@@ -6,7 +6,6 @@ use codex_api::RealtimeAudioFrame;
use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
@@ -143,7 +142,6 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -237,7 +235,6 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -302,7 +299,6 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -364,7 +360,6 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
@@ -429,7 +424,6 @@ async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),

View File

@@ -1342,13 +1342,6 @@
},
"type": "object"
},
"RealtimeWsMode": {
"enum": [
"conversational",
"transcription"
],
"type": "string"
},
"ReasoningEffort": {
"description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning",
"enum": [
@@ -1823,14 +1816,6 @@
"description": "Experimental / do not use. Overrides only the realtime conversation websocket transport base URL (the `Op::RealtimeConversation` `/v1/realtime` connection) without changing normal provider HTTP requests.",
"type": "string"
},
"experimental_realtime_ws_mode": {
"allOf": [
{
"$ref": "#/definitions/RealtimeWsMode"
}
],
"description": "Experimental / do not use. Selects the realtime websocket intent mode. `conversational` is speech-to-speech while `transcription` is transcript-only."
},
"experimental_realtime_ws_model": {
"description": "Experimental / do not use. Selects the realtime websocket model/snapshot used for the `Op::RealtimeConversation` connection.",
"type": "string"

View File

@@ -172,6 +172,8 @@ use crate::error::CodexErr;
use crate::error::Result as CodexResult;
#[cfg(test)]
use crate::exec::StreamOutput;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::network_proxy_registry::NetworkProxyScope;
use codex_config::CONFIG_TOML_FILE;
mod rollout_reconstruction;
@@ -276,6 +278,7 @@ use crate::skills::collect_explicit_skill_mentions;
use crate::skills::injection::ToolMentionKind;
use crate::skills::injection::app_id_from_path;
use crate::skills::injection::tool_kind_for_path;
use crate::skills::model::SkillManagedNetworkOverride;
use crate::skills::resolve_skill_dependencies_for_turn;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
@@ -1182,6 +1185,61 @@ impl Session {
Ok((network_proxy, session_network_proxy))
}
pub(crate) async fn get_or_start_network_proxy(
self: &Arc<Self>,
scope: NetworkProxyScope,
sandbox_policy: &SandboxPolicy,
managed_network_override: Option<SkillManagedNetworkOverride>,
) -> anyhow::Result<Option<NetworkProxy>> {
let session = Arc::clone(self);
let started = self
.services
.network_proxies
.get_or_start(
scope.clone(),
move |spec, managed_enabled, audit_metadata| {
let session = Arc::clone(&session);
let managed_network_override = managed_network_override.clone();
let scope = scope.clone();
let sandbox_policy = sandbox_policy.clone();
async move {
let network_policy_decider = session
.services
.network_policy_decider_session
.as_ref()
.map(|network_policy_decider_session| {
build_network_policy_decider(
Arc::clone(&session.services.network_approval),
Arc::clone(network_policy_decider_session),
scope,
)
});
let spec = if let Some(managed_network_override) =
managed_network_override.as_ref()
{
spec.with_skill_managed_network_override(managed_network_override)
} else {
spec
};
spec.start_proxy(
&sandbox_policy,
network_policy_decider,
session
.services
.network_blocked_request_observer
.as_ref()
.map(Arc::clone),
managed_enabled,
audit_metadata,
)
.await
}
},
)
.await?;
Ok(started.map(|started| started.proxy()))
}
/// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it.
pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
// todo(aibrahim): store this state somewhere else so we don't need to mut config
@@ -1651,9 +1709,10 @@ impl Session {
build_network_policy_decider(
Arc::clone(&network_approval),
Arc::clone(network_policy_decider_session),
NetworkProxyScope::SessionDefault,
)
});
let (network_proxy, session_network_proxy) =
let (default_network_proxy, session_network_proxy) =
if let Some(spec) = config.permissions.network.as_ref() {
let (network_proxy, session_network_proxy) = Self::start_managed_network_proxy(
spec,
@@ -1661,13 +1720,19 @@ impl Session {
network_policy_decider.as_ref().map(Arc::clone),
blocked_request_observer.as_ref().map(Arc::clone),
managed_network_requirements_enabled,
network_proxy_audit_metadata,
network_proxy_audit_metadata.clone(),
)
.await?;
(Some(network_proxy), Some(session_network_proxy))
} else {
(None, None)
};
let network_proxies = NetworkProxyRegistry::new(
config.permissions.network.clone(),
managed_network_requirements_enabled,
network_proxy_audit_metadata.clone(),
default_network_proxy,
);
let mut hook_shell_argv = default_shell.derive_exec_args("", false);
let hook_shell_program = hook_shell_argv.remove(0);
@@ -1725,7 +1790,9 @@ impl Session {
mcp_manager: Arc::clone(&mcp_manager),
file_watcher,
agent_control,
network_proxy,
network_proxies,
network_policy_decider_session,
network_blocked_request_observer: blocked_request_observer,
network_approval: Arc::clone(&network_approval),
state_db: state_db_ctx.clone(),
model_client: ModelClient::new(
@@ -1764,7 +1831,9 @@ impl Session {
js_repl,
next_internal_sub_id: AtomicU64::new(0),
});
if let Some(network_policy_decider_session) = network_policy_decider_session {
if let Some(network_policy_decider_session) =
sess.services.network_policy_decider_session.as_ref()
{
let mut guard = network_policy_decider_session.write().await;
*guard = Arc::downgrade(&sess);
}
@@ -2314,8 +2383,10 @@ impl Session {
model_info,
&self.services.models_manager,
self.services
.network_proxy
.as_ref()
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
.as_deref()
.map(StartedNetworkProxy::proxy),
sub_id,
Arc::clone(&self.js_repl),
@@ -2687,6 +2758,7 @@ impl Session {
&self,
amendment: &NetworkPolicyAmendment,
network_approval_context: &NetworkApprovalContext,
scope: &NetworkProxyScope,
) -> anyhow::Result<()> {
let host =
Self::validated_network_policy_amendment_host(amendment, network_approval_context)?;
@@ -2700,7 +2772,7 @@ impl Session {
let execpolicy_amendment =
execpolicy_network_rule_amendment(amendment, network_approval_context, &host);
if let Some(started_network_proxy) = self.services.network_proxy.as_ref() {
if let Some(started_network_proxy) = self.services.network_proxies.get(scope).await {
let proxy = started_network_proxy.proxy();
match amendment.action {
NetworkPolicyRuleAction::Allow => proxy

View File

@@ -11,9 +11,11 @@ use crate::exec::ExecToolCallOutput;
use crate::function_tool::FunctionCallError;
use crate::mcp_connection_manager::ToolInfo;
use crate::models_manager::model_info;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::shell::default_user_shell;
use crate::tools::format_exec_output_str;
use codex_network_proxy::NetworkProxyAuditMetadata;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
@@ -2152,7 +2154,14 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
mcp_manager,
file_watcher,
agent_control,
network_proxy: None,
network_proxies: NetworkProxyRegistry::new(
None,
false,
NetworkProxyAuditMetadata::default(),
None,
),
network_policy_decider_session: None,
network_blocked_request_observer: None,
network_approval: Arc::clone(&network_approval),
state_db: None,
model_client: ModelClient::new(
@@ -2794,7 +2803,14 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
mcp_manager,
file_watcher,
agent_control,
network_proxy: None,
network_proxies: NetworkProxyRegistry::new(
None,
false,
NetworkProxyAuditMetadata::default(),
None,
),
network_policy_decider_session: None,
network_blocked_request_observer: None,
network_approval: Arc::clone(&network_approval),
state_db: None,
model_client: ModelClient::new(

View File

@@ -4129,7 +4129,6 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4266,7 +4265,6 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4401,7 +4399,6 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -4522,7 +4519,6 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
experimental_realtime_start_instructions: None,
experimental_realtime_ws_base_url: None,
experimental_realtime_ws_model: None,
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
experimental_realtime_ws_backend_prompt: None,
experimental_realtime_ws_startup_context: None,
base_instructions: None,
@@ -5570,34 +5566,6 @@ experimental_realtime_ws_model = "realtime-test-model"
Ok(())
}
#[test]
fn experimental_realtime_ws_mode_loads_from_config_toml() -> std::io::Result<()> {
let cfg: ConfigToml = toml::from_str(
r#"
experimental_realtime_ws_mode = "transcription"
"#,
)
.expect("TOML deserialization should succeed");
assert_eq!(
cfg.experimental_realtime_ws_mode,
Some(RealtimeWsMode::Transcription)
);
let codex_home = TempDir::new()?;
let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;
assert_eq!(
config.experimental_realtime_ws_mode,
RealtimeWsMode::Transcription
);
Ok(())
}
#[test]
fn realtime_audio_loads_from_config_toml() -> std::io::Result<()> {
let cfg: ConfigToml = toml::from_str(

View File

@@ -463,9 +463,6 @@ pub struct Config {
/// Experimental / do not use. Selects the realtime websocket model/snapshot
/// used for the `Op::RealtimeConversation` connection.
pub experimental_realtime_ws_model: Option<String>,
/// Experimental / do not use. Selects the realtime websocket intent mode.
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
pub experimental_realtime_ws_mode: RealtimeWsMode,
/// Experimental / do not use. Overrides only the realtime conversation
/// websocket transport instructions (the `Op::RealtimeConversation`
/// `/ws` session.update instructions) without changing normal prompts.
@@ -1241,9 +1238,6 @@ pub struct ConfigToml {
/// Experimental / do not use. Selects the realtime websocket model/snapshot
/// used for the `Op::RealtimeConversation` connection.
pub experimental_realtime_ws_model: Option<String>,
/// Experimental / do not use. Selects the realtime websocket intent mode.
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
pub experimental_realtime_ws_mode: Option<RealtimeWsMode>,
/// Experimental / do not use. Overrides only the realtime conversation
/// websocket transport instructions (the `Op::RealtimeConversation`
/// `/ws` session.update instructions) without changing normal prompts.
@@ -1389,14 +1383,6 @@ pub struct RealtimeAudioConfig {
pub speaker: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum RealtimeWsMode {
#[default]
Conversational,
Transcription,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct RealtimeAudioToml {
@@ -2476,7 +2462,6 @@ impl Config {
}),
experimental_realtime_ws_base_url: cfg.experimental_realtime_ws_base_url,
experimental_realtime_ws_model: cfg.experimental_realtime_ws_model,
experimental_realtime_ws_mode: cfg.experimental_realtime_ws_mode.unwrap_or_default(),
experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt,
experimental_realtime_ws_startup_context: cfg.experimental_realtime_ws_startup_context,
experimental_realtime_start_instructions: cfg.experimental_realtime_start_instructions,

View File

@@ -1,4 +1,5 @@
use crate::config_loader::NetworkConstraints;
use crate::skills::model::SkillManagedNetworkOverride;
use async_trait::async_trait;
use codex_network_proxy::BlockedRequestObserver;
use codex_network_proxy::ConfigReloader;
@@ -81,6 +82,28 @@ impl NetworkProxySpec {
self.config.network.enable_socks5
}
pub(crate) fn with_skill_managed_network_override(
&self,
managed_network_override: &SkillManagedNetworkOverride,
) -> Self {
let mut spec = self.clone();
if let Some(allowed_domains) = managed_network_override.allowed_domains.clone() {
spec.config.network.allowed_domains = allowed_domains.clone();
if spec.constraints.allowed_domains.is_some() {
spec.constraints.allowed_domains = Some(allowed_domains);
}
}
if let Some(denied_domains) = managed_network_override.denied_domains.clone() {
spec.config.network.denied_domains = denied_domains.clone();
if spec.constraints.denied_domains.is_some() {
spec.constraints.denied_domains = Some(denied_domains);
}
}
spec
}
pub(crate) fn from_config_and_constraints(
config: NetworkProxyConfig,
requirements: Option<NetworkConstraints>,

View File

@@ -1,6 +1,94 @@
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn skill_managed_network_override_replaces_allowed_domains_and_keeps_other_settings() {
let mut config = NetworkProxyConfig::default();
config.network.enabled = true;
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
config.network.enable_socks5 = true;
config.network.enable_socks5_udp = true;
config.network.allowed_domains = vec!["default.example.com".to_string()];
config.network.denied_domains = vec!["blocked.example.com".to_string()];
config.network.allow_upstream_proxy = true;
config.network.dangerously_allow_all_unix_sockets = false;
config.network.dangerously_allow_non_loopback_proxy = false;
config.network.mode = codex_network_proxy::NetworkMode::Full;
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
config.network.allow_local_binding = true;
config.network.mitm = false;
let spec = NetworkProxySpec {
config,
constraints: NetworkProxyConstraints {
allowed_domains: Some(vec!["default.example.com".to_string()]),
denied_domains: Some(vec!["blocked.example.com".to_string()]),
allowlist_expansion_enabled: Some(true),
denylist_expansion_enabled: Some(false),
allow_upstream_proxy: Some(true),
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
allow_local_binding: Some(true),
..NetworkProxyConstraints::default()
},
hard_deny_allowlist_misses: true,
};
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: Some(vec!["skill.example.com".to_string()]),
denied_domains: None,
};
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
let mut expected = spec.clone();
expected.config.network.allowed_domains = vec!["skill.example.com".to_string()];
expected.constraints.allowed_domains = Some(vec!["skill.example.com".to_string()]);
assert_eq!(overridden, expected);
}
#[test]
fn skill_managed_network_override_replaces_denied_domains_and_keeps_default_allowed_domains() {
let mut config = NetworkProxyConfig::default();
config.network.enabled = true;
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
config.network.enable_socks5 = true;
config.network.enable_socks5_udp = true;
config.network.allowed_domains = vec!["default.example.com".to_string()];
config.network.denied_domains = vec!["blocked.example.com".to_string()];
config.network.allow_upstream_proxy = true;
config.network.dangerously_allow_all_unix_sockets = false;
config.network.dangerously_allow_non_loopback_proxy = false;
config.network.mode = codex_network_proxy::NetworkMode::Full;
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
config.network.allow_local_binding = true;
config.network.mitm = false;
let spec = NetworkProxySpec {
config,
constraints: NetworkProxyConstraints {
allowed_domains: Some(vec!["default.example.com".to_string()]),
denied_domains: Some(vec!["blocked.example.com".to_string()]),
allowlist_expansion_enabled: Some(true),
denylist_expansion_enabled: Some(false),
allow_upstream_proxy: Some(true),
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
allow_local_binding: Some(true),
..NetworkProxyConstraints::default()
},
hard_deny_allowlist_misses: false,
};
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: None,
denied_domains: Some(vec!["skill-blocked.example.com".to_string()]),
};
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
let mut expected = spec.clone();
expected.config.network.denied_domains = vec!["skill-blocked.example.com".to_string()];
expected.constraints.denied_domains = Some(vec!["skill-blocked.example.com".to_string()]);
assert_eq!(overridden, expected);
}
#[test]
fn build_state_with_audit_metadata_threads_metadata_to_state() {
let spec = NetworkProxySpec {

View File

@@ -39,6 +39,7 @@ use crate::config::Constrained;
use crate::config::NetworkProxySpec;
use crate::event_mapping::is_contextual_user_message_content;
use crate::features::Feature;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::protocol::Op;
use crate::protocol::SandboxPolicy;
use crate::truncate::approx_bytes_for_tokens;
@@ -550,7 +551,12 @@ async fn run_guardian_subagent(
schema: Value,
cancel_token: CancellationToken,
) -> anyhow::Result<GuardianAssessment> {
let live_network_config = match session.services.network_proxy.as_ref() {
let live_network_config = match session
.services
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
{
Some(network_proxy) => Some(network_proxy.proxy().current_cfg().await?),
None => None,
};

View File

@@ -51,6 +51,7 @@ mod mcp_tool_approval_templates;
pub mod models_manager;
mod network_policy_decision;
pub mod network_proxy_loader;
mod network_proxy_registry;
mod original_image_detail;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;

View File

@@ -0,0 +1,77 @@
use crate::config::NetworkProxySpec;
use crate::config::StartedNetworkProxy;
use anyhow::Result;
use codex_network_proxy::NetworkProxyAuditMetadata;
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub(crate) enum NetworkProxyScope {
SessionDefault,
Skill { path_to_skills_md: PathBuf },
}
pub(crate) struct NetworkProxyRegistry {
spec: Option<NetworkProxySpec>,
managed_network_requirements_enabled: bool,
audit_metadata: NetworkProxyAuditMetadata,
proxies: Mutex<HashMap<NetworkProxyScope, Arc<StartedNetworkProxy>>>,
}
impl NetworkProxyRegistry {
pub(crate) fn new(
spec: Option<NetworkProxySpec>,
managed_network_requirements_enabled: bool,
audit_metadata: NetworkProxyAuditMetadata,
default_proxy: Option<StartedNetworkProxy>,
) -> Self {
let mut proxies = HashMap::new();
if let Some(default_proxy) = default_proxy {
proxies.insert(NetworkProxyScope::SessionDefault, Arc::new(default_proxy));
}
Self {
spec,
managed_network_requirements_enabled,
audit_metadata,
proxies: Mutex::new(proxies),
}
}
pub(crate) async fn get(&self, scope: &NetworkProxyScope) -> Option<Arc<StartedNetworkProxy>> {
self.proxies.lock().await.get(scope).cloned()
}
pub(crate) async fn get_or_start<F, Fut>(
&self,
scope: NetworkProxyScope,
start: F,
) -> Result<Option<Arc<StartedNetworkProxy>>>
where
F: FnOnce(NetworkProxySpec, bool, NetworkProxyAuditMetadata) -> Fut,
Fut: Future<Output = std::io::Result<StartedNetworkProxy>>,
{
let mut proxies = self.proxies.lock().await;
if let Some(existing) = proxies.get(&scope).cloned() {
return Ok(Some(existing));
}
let Some(spec) = self.spec.clone() else {
return Ok(None);
};
let started = Arc::new(
start(
spec,
self.managed_network_requirements_enabled,
self.audit_metadata.clone(),
)
.await?,
);
proxies.insert(scope, Arc::clone(&started));
Ok(Some(started))
}
}

View File

@@ -15,7 +15,6 @@ use codex_api::RealtimeAudioFrame;
use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
@@ -117,7 +116,10 @@ impl RealtimeConversationManager {
&self,
api_provider: ApiProvider,
extra_headers: Option<HeaderMap>,
session_config: RealtimeSessionConfig,
prompt: String,
model: Option<String>,
session_id: Option<String>,
event_parser: RealtimeEventParser,
) -> CodexResult<(Receiver<RealtimeEvent>, Arc<AtomicBool>)> {
let previous_state = {
let mut guard = self.state.lock().await;
@@ -129,6 +131,12 @@ impl RealtimeConversationManager {
let _ = state.task.await;
}
let session_config = RealtimeSessionConfig {
instructions: prompt,
model,
session_id,
event_parser,
};
let client = RealtimeWebsocketClient::new(api_provider);
let connection = client
.connect(
@@ -299,26 +307,23 @@ pub(crate) async fn handle_start(
} else {
RealtimeEventParser::V1
};
let session_mode = match config.experimental_realtime_ws_mode {
crate::config::RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational,
crate::config::RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription,
};
let requested_session_id = params
.session_id
.or_else(|| Some(sess.conversation_id.to_string()));
let session_config = RealtimeSessionConfig {
instructions: prompt,
model,
session_id: requested_session_id.clone(),
event_parser,
session_mode,
};
let extra_headers =
realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?;
info!("starting realtime conversation");
let (events_rx, realtime_active) = match sess
.conversation
.start(api_provider, extra_headers, session_config)
.start(
api_provider,
extra_headers,
prompt,
model,
requested_session_id.clone(),
event_parser,
)
.await
{
Ok(events_rx) => events_rx,

View File

@@ -6,12 +6,13 @@ use crate::RolloutRecorder;
use crate::agent::AgentControl;
use crate::analytics_client::AnalyticsEventsClient;
use crate::client::ModelClient;
use crate::config::StartedNetworkProxy;
use crate::codex::Session;
use crate::exec_policy::ExecPolicyManager;
use crate::file_watcher::FileWatcher;
use crate::mcp::McpManager;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::models_manager::manager::ModelsManager;
use crate::network_proxy_registry::NetworkProxyRegistry;
use crate::plugins::PluginsManager;
use crate::skills::SkillsManager;
use crate::state_db::StateDbHandle;
@@ -21,9 +22,11 @@ use crate::tools::runtimes::ExecveSessionApproval;
use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecProcessManager;
use codex_hooks::Hooks;
use codex_network_proxy::BlockedRequestObserver;
use codex_otel::SessionTelemetry;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::path::PathBuf;
use std::sync::Weak;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::watch;
@@ -55,7 +58,9 @@ pub(crate) struct SessionServices {
pub(crate) mcp_manager: Arc<McpManager>,
pub(crate) file_watcher: Arc<FileWatcher>,
pub(crate) agent_control: AgentControl,
pub(crate) network_proxy: Option<StartedNetworkProxy>,
pub(crate) network_proxies: NetworkProxyRegistry,
pub(crate) network_policy_decider_session: Option<Arc<RwLock<Weak<Session>>>>,
pub(crate) network_blocked_request_observer: Option<Arc<dyn BlockedRequestObserver>>,
pub(crate) network_approval: Arc<NetworkApprovalService>,
pub(crate) state_db: Option<StateDbHandle>,
/// Session-scoped model client shared across turns.

View File

@@ -46,6 +46,7 @@ use codex_protocol::protocol::RolloutItem;
use codex_protocol::user_input::UserInput;
use crate::features::Feature;
use crate::network_proxy_registry::NetworkProxyScope;
pub(crate) use compact::CompactTask;
pub(crate) use ghost_snapshot::GhostSnapshotTask;
pub(crate) use regular::RegularTask;
@@ -292,7 +293,12 @@ impl Session {
"false"
},
);
let network_proxy_active = match self.services.network_proxy.as_ref() {
let network_proxy_active = match self
.services
.network_proxies
.get(&NetworkProxyScope::SessionDefault)
.await
{
Some(started_network_proxy) => {
match started_network_proxy.proxy().current_cfg().await {
Ok(config) => config.network.enabled,

View File

@@ -4,6 +4,7 @@ use crate::guardian::GuardianApprovalRequest;
use crate::guardian::review_approval_request;
use crate::guardian::routes_approval_to_guardian;
use crate::network_policy_decision::denied_network_policy_message;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::tools::sandboxing::ToolError;
use codex_network_proxy::BlockedRequest;
use codex_network_proxy::BlockedRequestObserver;
@@ -76,14 +77,20 @@ impl ActiveNetworkApproval {
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct HostApprovalKey {
scope: NetworkProxyScope,
host: String,
protocol: &'static str,
port: u16,
}
impl HostApprovalKey {
fn from_request(request: &NetworkPolicyRequest, protocol: NetworkApprovalProtocol) -> Self {
fn from_request(
request: &NetworkPolicyRequest,
protocol: NetworkApprovalProtocol,
scope: NetworkProxyScope,
) -> Self {
Self {
scope,
host: request.host.to_ascii_lowercase(),
protocol: protocol_key_label(protocol),
port: request.port,
@@ -279,6 +286,7 @@ impl NetworkApprovalService {
&self,
session: Arc<Session>,
request: NetworkPolicyRequest,
scope: NetworkProxyScope,
) -> NetworkDecision {
const REASON_NOT_ALLOWED: &str = "not_allowed";
@@ -288,7 +296,7 @@ impl NetworkApprovalService {
NetworkProtocol::Socks5Tcp => NetworkApprovalProtocol::Socks5Tcp,
NetworkProtocol::Socks5Udp => NetworkApprovalProtocol::Socks5Udp,
};
let key = HostApprovalKey::from_request(&request, protocol);
let key = HostApprovalKey::from_request(&request, protocol, scope.clone());
{
let denied_hosts = self.session_denied_hosts.lock().await;
@@ -387,6 +395,7 @@ impl NetworkApprovalService {
.persist_network_policy_amendment(
&network_policy_amendment,
&network_approval_context,
&scope,
)
.await
{
@@ -417,6 +426,7 @@ impl NetworkApprovalService {
.persist_network_policy_amendment(
&network_policy_amendment,
&network_approval_context,
&scope,
)
.await
{
@@ -506,16 +516,18 @@ pub(crate) fn build_blocked_request_observer(
pub(crate) fn build_network_policy_decider(
network_approval: Arc<NetworkApprovalService>,
network_policy_decider_session: Arc<RwLock<std::sync::Weak<Session>>>,
scope: NetworkProxyScope,
) -> Arc<dyn NetworkPolicyDecider> {
Arc::new(move |request: NetworkPolicyRequest| {
let network_approval = Arc::clone(&network_approval);
let network_policy_decider_session = Arc::clone(&network_policy_decider_session);
let scope = scope.clone();
async move {
let Some(session) = network_policy_decider_session.read().await.upgrade() else {
return NetworkDecision::ask("not_allowed");
};
network_approval
.handle_inline_policy_request(session, request)
.handle_inline_policy_request(session, request, scope)
.await
}
})

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::network_proxy_registry::NetworkProxyScope;
use codex_network_proxy::BlockedRequestArgs;
use codex_protocol::protocol::AskForApproval;
use pretty_assertions::assert_eq;
@@ -7,6 +8,7 @@ use pretty_assertions::assert_eq;
async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
let service = NetworkApprovalService::default();
let key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 443,
@@ -24,11 +26,13 @@ async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
async fn pending_approvals_do_not_dedupe_across_ports() {
let service = NetworkApprovalService::default();
let first_key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
};
let second_key = HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,
@@ -49,16 +53,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
let mut approved_hosts = source.session_approved_hosts.lock().await;
approved_hosts.extend([
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 80,
@@ -82,16 +89,19 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
copied,
vec![
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "http",
port: 80,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 443,
},
HostApprovalKey {
scope: NetworkProxyScope::SessionDefault,
host: "example.com".to_string(),
protocol: "https",
port: 8443,

View File

@@ -9,6 +9,7 @@ use crate::features::Feature;
use crate::guardian::GuardianApprovalRequest;
use crate::guardian::review_approval_request;
use crate::guardian::routes_approval_to_guardian;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::sandboxing::ExecRequest;
use crate::sandboxing::SandboxPermissions;
use crate::shell::ShellType;
@@ -50,6 +51,7 @@ use codex_shell_escalation::ShellCommandExecutor;
use codex_shell_escalation::Stopwatch;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -142,6 +144,7 @@ pub(super) async fn try_run_zsh_fork(
ctx.session.services.exec_policy.current().as_ref().clone(),
));
let command_executor = CoreShellCommandExecutor {
session: Some(Arc::clone(&ctx.session)),
command,
cwd: sandbox_cwd,
sandbox_policy,
@@ -259,6 +262,7 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
ctx.session.services.exec_policy.current().as_ref().clone(),
));
let command_executor = CoreShellCommandExecutor {
session: Some(Arc::clone(&ctx.session)),
command: exec_request.command.clone(),
cwd: exec_request.cwd.clone(),
sandbox_policy: exec_request.sandbox_policy.clone(),
@@ -503,26 +507,7 @@ impl CoreShellActionProvider {
/// an absolute path. The idea is that we check to see whether it matches
/// any skills.
async fn find_skill(&self, program: &AbsolutePathBuf) -> Option<SkillMetadata> {
let force_reload = false;
let skills_outcome = self
.session
.services
.skills_manager
.skills_for_cwd(&self.turn.cwd, force_reload)
.await;
let program_path = program.as_path();
for skill in skills_outcome.skills {
// We intentionally ignore "enabled" status here for now.
let Some(skill_root) = skill.path_to_skills_md.parent() else {
continue;
};
if program_path.starts_with(skill_root.join("scripts")) {
return Some(skill);
}
}
None
find_skill_for_program(self.session.as_ref(), self.turn.cwd.as_path(), program).await
}
#[allow(clippy::too_many_arguments)]
@@ -631,6 +616,32 @@ impl CoreShellActionProvider {
}
}
async fn find_skill_for_program(
session: &crate::codex::Session,
cwd: &Path,
program: &AbsolutePathBuf,
) -> Option<SkillMetadata> {
let force_reload = false;
let skills_outcome = session
.services
.skills_manager
.skills_for_cwd(cwd, force_reload)
.await;
let program_path = program.as_path();
for skill in skills_outcome.skills {
// We intentionally ignore "enabled" status here for now.
let Some(skill_root) = skill.path_to_skills_md.parent() else {
continue;
};
if program_path.starts_with(skill_root.join("scripts")) {
return Some(skill);
}
}
None
}
// Shell-wrapper parsing is weaker than direct exec interception because it can
// only see the script text, not the final resolved executable path. Keep it
// disabled by default so path-sensitive rules rely on the later authoritative
@@ -865,6 +876,7 @@ fn commands_for_intercepted_exec_policy(
}
struct CoreShellCommandExecutor {
session: Option<Arc<crate::codex::Session>>,
command: Vec<String>,
cwd: PathBuf,
sandbox_policy: SandboxPolicy,
@@ -888,6 +900,7 @@ struct PrepareSandboxedExecParams<'a> {
command: Vec<String>,
workdir: &'a AbsolutePathBuf,
env: HashMap<String, String>,
network: Option<codex_network_proxy::NetworkProxy>,
sandbox_policy: &'a SandboxPolicy,
file_system_sandbox_policy: &'a FileSystemSandboxPolicy,
network_sandbox_policy: NetworkSandboxPolicy,
@@ -969,10 +982,12 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
arg0: Some(first_arg.clone()),
},
EscalationExecution::TurnDefault => {
let network = self.network_for_program(program).await?;
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
network_sandbox_policy: self.network_sandbox_policy,
@@ -986,12 +1001,14 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
EscalationExecution::Permissions(EscalationPermissions::PermissionProfile(
permission_profile,
)) => {
let network = self.network_for_program(program).await?;
// Merge additive permissions into the existing turn/request sandbox policy.
// On macOS, additional profile extensions are unioned with the turn defaults.
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &self.sandbox_policy,
file_system_sandbox_policy: &self.file_system_sandbox_policy,
network_sandbox_policy: self.network_sandbox_policy,
@@ -1003,11 +1020,13 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
})?
}
EscalationExecution::Permissions(EscalationPermissions::Permissions(permissions)) => {
let network = self.network_for_program(program).await?;
// Use a fully specified sandbox policy instead of merging into the turn policy.
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
command,
workdir,
env,
network,
sandbox_policy: &permissions.sandbox_policy,
file_system_sandbox_policy: &permissions.file_system_sandbox_policy,
network_sandbox_policy: permissions.network_sandbox_policy,
@@ -1025,6 +1044,25 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
}
impl CoreShellCommandExecutor {
async fn network_for_program(
&self,
program: &AbsolutePathBuf,
) -> anyhow::Result<Option<codex_network_proxy::NetworkProxy>> {
let Some(session) = self.session.as_ref() else {
return Ok(self.network.clone());
};
let Some(skill) =
find_skill_for_program(session.as_ref(), &self.sandbox_policy_cwd, program).await
else {
return Ok(self.network.clone());
};
let (scope, managed_network_override) = network_proxy_scope_for_skill(&skill);
session
.get_or_start_network_proxy(scope, &self.sandbox_policy, managed_network_override)
.await
}
#[allow(clippy::too_many_arguments)]
fn prepare_sandboxed_exec(
&self,
@@ -1034,6 +1072,7 @@ impl CoreShellCommandExecutor {
command,
workdir,
env,
network,
sandbox_policy,
file_system_sandbox_policy,
network_sandbox_policy,
@@ -1050,7 +1089,7 @@ impl CoreShellCommandExecutor {
network_sandbox_policy,
SandboxablePreference::Auto,
self.windows_sandbox_level,
self.network.is_some(),
network.is_some(),
);
let mut exec_request =
sandbox_manager.transform(crate::sandboxing::SandboxTransformRequest {
@@ -1072,8 +1111,8 @@ impl CoreShellCommandExecutor {
file_system_policy: file_system_sandbox_policy,
network_policy: network_sandbox_policy,
sandbox,
enforce_managed_network: self.network.is_some(),
network: self.network.as_ref(),
enforce_managed_network: network.is_some(),
network: network.as_ref(),
sandbox_policy_cwd: &self.sandbox_policy_cwd,
#[cfg(target_os = "macos")]
macos_seatbelt_profile_extensions,
@@ -1094,6 +1133,23 @@ impl CoreShellCommandExecutor {
}
}
fn network_proxy_scope_for_skill(
skill: &SkillMetadata,
) -> (
NetworkProxyScope,
Option<crate::skills::model::SkillManagedNetworkOverride>,
) {
match skill.managed_network_override.clone() {
Some(managed_network_override) => (
NetworkProxyScope::Skill {
path_to_skills_md: skill.path_to_skills_md.clone(),
},
Some(managed_network_override),
),
None => (NetworkProxyScope::SessionDefault, None),
}
}
#[derive(Debug, Eq, PartialEq)]
struct ParsedShellCommand {
program: String,

View File

@@ -15,6 +15,7 @@ use crate::config::Permissions;
#[cfg(target_os = "macos")]
use crate::config::types::ShellEnvironmentPolicy;
use crate::exec::SandboxType;
use crate::network_proxy_registry::NetworkProxyScope;
use crate::protocol::AskForApproval;
use crate::protocol::GranularApprovalConfig;
use crate::protocol::ReadOnlyAccess;
@@ -98,6 +99,35 @@ fn test_skill_metadata(permission_profile: Option<PermissionProfile>) -> SkillMe
}
}
#[test]
fn network_proxy_scope_for_skill_without_override_reuses_session_default() {
let skill = test_skill_metadata(None);
assert_eq!(
super::network_proxy_scope_for_skill(&skill),
(NetworkProxyScope::SessionDefault, None),
);
}
#[test]
fn network_proxy_scope_for_skill_with_override_uses_skill_scope() {
let mut skill = test_skill_metadata(None);
skill.managed_network_override = Some(crate::skills::model::SkillManagedNetworkOverride {
allowed_domains: Some(vec!["skill.example.com".to_string()]),
denied_domains: Some(vec!["blocked.skill.example.com".to_string()]),
});
assert_eq!(
super::network_proxy_scope_for_skill(&skill),
(
NetworkProxyScope::Skill {
path_to_skills_md: PathBuf::from("/tmp/skill/SKILL.md"),
},
skill.managed_network_override.clone(),
),
);
}
#[test]
fn execve_prompt_rejection_uses_skill_approval_for_skill_scripts() {
let decision_source = super::DecisionSource::SkillScript {
@@ -651,6 +681,7 @@ host_executable(name = "git", paths = ["{allowed_git_literal}"])
async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions() {
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),
@@ -703,6 +734,7 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions() {
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),
@@ -777,6 +809,7 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
let sandbox_policy = SandboxPolicy::new_read_only_policy();
let executor = CoreShellCommandExecutor {
session: None,
command: vec!["echo".to_string(), "ok".to_string()],
cwd: cwd.to_path_buf(),
env: HashMap::new(),