mirror of
https://github.com/openai/codex.git
synced 2026-03-13 18:23:49 +00:00
Compare commits
3 Commits
dev/cc/mul
...
cconger/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7e847aaeb | ||
|
|
2253a9d1d7 | ||
|
|
eaf81d3f6f |
@@ -1,16 +1,20 @@
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationFunctionCallOutputItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptDelta;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionFunctionTool;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
@@ -21,6 +25,7 @@ use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@@ -41,6 +46,23 @@ use tracing::trace;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const REALTIME_AUDIO_VOICE: &str = "fathom";
|
||||
const REALTIME_V1_SESSION_TYPE: &str = "quicksilver";
|
||||
const REALTIME_V2_SESSION_TYPE: &str = "realtime";
|
||||
const REALTIME_V2_CODEX_TOOL_NAME: &str = "codex";
|
||||
const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate work to Codex and return the result.";
|
||||
|
||||
fn normalized_session_mode(
|
||||
event_parser: RealtimeEventParser,
|
||||
session_mode: RealtimeSessionMode,
|
||||
) -> RealtimeSessionMode {
|
||||
match event_parser {
|
||||
RealtimeEventParser::V1 => RealtimeSessionMode::Conversational,
|
||||
RealtimeEventParser::RealtimeV2 => session_mode,
|
||||
}
|
||||
}
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
@@ -197,6 +219,7 @@ pub struct RealtimeWebsocketConnection {
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
event_parser: RealtimeEventParser,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -258,6 +281,7 @@ impl RealtimeWebsocketConnection {
|
||||
writer: RealtimeWebsocketWriter {
|
||||
stream: Arc::clone(&stream),
|
||||
is_closed: Arc::clone(&is_closed),
|
||||
event_parser,
|
||||
},
|
||||
events: RealtimeWebsocketEvents {
|
||||
rx_message: Arc::new(Mutex::new(rx_message)),
|
||||
@@ -276,15 +300,19 @@ impl RealtimeWebsocketWriter {
|
||||
}
|
||||
|
||||
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
|
||||
let content_kind = match self.event_parser {
|
||||
RealtimeEventParser::V1 => "text",
|
||||
RealtimeEventParser::RealtimeV2 => "input_text",
|
||||
};
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem {
|
||||
item: ConversationItemPayload::Message(ConversationMessageItem {
|
||||
kind: "message".to_string(),
|
||||
role: "user".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "text".to_string(),
|
||||
kind: content_kind.to_string(),
|
||||
text,
|
||||
}],
|
||||
},
|
||||
}),
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -294,29 +322,80 @@ impl RealtimeWebsocketWriter {
|
||||
handoff_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
})
|
||||
.await
|
||||
let message = match self.event_parser {
|
||||
RealtimeEventParser::V1 => RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
},
|
||||
RealtimeEventParser::RealtimeV2 => RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItemPayload::FunctionCallOutput(
|
||||
ConversationFunctionCallOutputItem {
|
||||
kind: "function_call_output".to_string(),
|
||||
call_id: handoff_id,
|
||||
output: output_text,
|
||||
},
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_json(message).await
|
||||
}
|
||||
|
||||
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
|
||||
pub async fn send_session_update(
|
||||
&self,
|
||||
instructions: String,
|
||||
session_mode: RealtimeSessionMode,
|
||||
) -> Result<(), ApiError> {
|
||||
let session_mode = normalized_session_mode(self.event_parser, session_mode);
|
||||
let (session_kind, session_instructions, output_audio) = match session_mode {
|
||||
RealtimeSessionMode::Conversational => {
|
||||
let kind = match self.event_parser {
|
||||
RealtimeEventParser::V1 => REALTIME_V1_SESSION_TYPE.to_string(),
|
||||
RealtimeEventParser::RealtimeV2 => REALTIME_V2_SESSION_TYPE.to_string(),
|
||||
};
|
||||
(
|
||||
kind,
|
||||
Some(instructions),
|
||||
Some(SessionAudioOutput {
|
||||
voice: REALTIME_AUDIO_VOICE.to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
RealtimeSessionMode::Transcription => ("transcription".to_string(), None, None),
|
||||
};
|
||||
let tools = match self.event_parser {
|
||||
RealtimeEventParser::RealtimeV2 => Some(vec![SessionFunctionTool {
|
||||
kind: "function".to_string(),
|
||||
name: REALTIME_V2_CODEX_TOOL_NAME.to_string(),
|
||||
description: REALTIME_V2_CODEX_TOOL_DESCRIPTION.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Prompt text for the delegated Codex task."
|
||||
}
|
||||
},
|
||||
"required": ["prompt"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
}]),
|
||||
RealtimeEventParser::V1 => None,
|
||||
};
|
||||
self.send_json(RealtimeOutboundMessage::SessionUpdate {
|
||||
session: SessionUpdateSession {
|
||||
kind: "quicksilver".to_string(),
|
||||
instructions,
|
||||
kind: session_kind,
|
||||
instructions: session_instructions,
|
||||
audio: SessionAudio {
|
||||
input: SessionAudioInput {
|
||||
format: SessionAudioFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
rate: REALTIME_AUDIO_SAMPLE_RATE,
|
||||
},
|
||||
},
|
||||
output: SessionAudioOutput {
|
||||
voice: "fathom".to_string(),
|
||||
},
|
||||
output: output_audio,
|
||||
},
|
||||
tools,
|
||||
},
|
||||
})
|
||||
.await
|
||||
@@ -465,6 +544,8 @@ impl RealtimeWebsocketClient {
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
config.model.as_deref(),
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
)?;
|
||||
|
||||
let mut request = ws_url
|
||||
@@ -506,7 +587,7 @@ impl RealtimeWebsocketClient {
|
||||
);
|
||||
connection
|
||||
.writer
|
||||
.send_session_update(config.instructions)
|
||||
.send_session_update(config.instructions, config.session_mode)
|
||||
.await?;
|
||||
Ok(connection)
|
||||
}
|
||||
@@ -551,6 +632,8 @@ fn websocket_url_from_api_url(
|
||||
api_url: &str,
|
||||
query_params: Option<&HashMap<String, String>>,
|
||||
model: Option<&str>,
|
||||
event_parser: RealtimeEventParser,
|
||||
_session_mode: RealtimeSessionMode,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = Url::parse(api_url)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
|
||||
@@ -570,9 +653,20 @@ fn websocket_url_from_api_url(
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let intent = match event_parser {
|
||||
RealtimeEventParser::V1 => Some("quicksilver"),
|
||||
RealtimeEventParser::RealtimeV2 => None,
|
||||
};
|
||||
let has_extra_query_params = query_params.is_some_and(|query_params| {
|
||||
query_params
|
||||
.iter()
|
||||
.any(|(key, _)| key != "intent" && !(key == "model" && model.is_some()))
|
||||
});
|
||||
if intent.is_some() || model.is_some() || has_extra_query_params {
|
||||
let mut query = url.query_pairs_mut();
|
||||
query.append_pair("intent", "quicksilver");
|
||||
if let Some(intent) = intent {
|
||||
query.append_pair("intent", intent);
|
||||
}
|
||||
if let Some(model) = model {
|
||||
query.append_pair("model", model);
|
||||
}
|
||||
@@ -853,8 +947,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_http_base_defaults_to_ws_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"http://127.0.0.1:8011",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
|
||||
@@ -863,9 +963,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_ws_base_defaults_to_ws_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"wss://example.com",
|
||||
None,
|
||||
Some("realtime-test-model"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
|
||||
@@ -874,8 +979,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_v1_base_appends_realtime_path() {
|
||||
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://api.openai.com/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
@@ -884,9 +995,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/openai/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
@@ -902,6 +1018,8 @@ mod tests {
|
||||
("intent".to_string(), "ignored".to_string()),
|
||||
])),
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
@@ -910,6 +1028,54 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_v1_ignores_transcription_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::V1,
|
||||
RealtimeSessionMode::Transcription,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?intent=quicksilver"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_omits_intent_for_realtime_v2_conversational_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/v1/realtime?foo=bar",
|
||||
Some(&HashMap::from([
|
||||
("trace".to_string(), "1".to_string()),
|
||||
("intent".to_string(), "ignored".to_string()),
|
||||
])),
|
||||
Some("snapshot"),
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Conversational,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_omits_intent_for_realtime_v2_transcription_mode() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com",
|
||||
None,
|
||||
None,
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Transcription,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -1075,6 +1241,7 @@ mod tests {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -1195,6 +1362,352 @@ mod tests {
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_v2_session_update_includes_codex_tool_and_handoff_output_item() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("realtime".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["type"],
|
||||
Value::String("function".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["parameters"]["required"],
|
||||
json!(["prompt"])
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_v2", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "conversation.item.create");
|
||||
assert_eq!(
|
||||
second_json["item"]["type"],
|
||||
Value::String("message".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
second_json["item"]["content"][0]["type"],
|
||||
Value::String("input_text".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
second_json["item"]["content"][0]["text"],
|
||||
Value::String("delegate this".to_string())
|
||||
);
|
||||
|
||||
let third = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("third msg")
|
||||
.expect("third msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let third_json: Value = serde_json::from_str(&third).expect("json");
|
||||
assert_eq!(third_json["type"], "conversation.item.create");
|
||||
assert_eq!(
|
||||
third_json["item"]["type"],
|
||||
Value::String("function_call_output".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
third_json["item"]["call_id"],
|
||||
Value::String("call_1".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
third_json["item"]["output"],
|
||||
Value::String("delegated result".to_string())
|
||||
);
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_v2".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection
|
||||
.send_conversation_item_create("delegate this".to_string())
|
||||
.await
|
||||
.expect("send text item");
|
||||
connection
|
||||
.send_conversation_handoff_append("call_1".to_string(), "delegated result".to_string())
|
||||
.await
|
||||
.expect("send handoff output");
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transcription_mode_session_update_omits_output_audio_and_instructions() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("transcription".to_string())
|
||||
);
|
||||
assert!(first_json["session"].get("instructions").is_none());
|
||||
assert!(first_json["session"]["audio"].get("output").is_none());
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_transcription"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "input_audio_buffer.append");
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Transcription,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_transcription".to_string(),
|
||||
instructions: None,
|
||||
}
|
||||
);
|
||||
|
||||
connection
|
||||
.send_audio_frame(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 24_000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
})
|
||||
.await
|
||||
.expect("send audio");
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn v1_transcription_mode_is_treated_as_conversational() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("quicksilver".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
Value::String("backend prompt".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["voice"],
|
||||
Value::String("fathom".to_string())
|
||||
);
|
||||
assert!(first_json["session"].get("tools").is_none());
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_v1_mode"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: format!("http://{addr}"),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Transcription,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_v1_mode".to_string(),
|
||||
instructions: None,
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_does_not_block_while_next_event_waits_for_inbound_data() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -1258,6 +1771,7 @@ mod tests {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod methods;
|
||||
pub mod protocol;
|
||||
mod protocol_common;
|
||||
mod protocol_v1;
|
||||
mod protocol_v2;
|
||||
|
||||
pub use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
@@ -10,3 +12,4 @@ pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use protocol::RealtimeEventParser;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
pub use protocol::RealtimeSessionMode;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_v1::parse_realtime_event_v1;
|
||||
use crate::endpoint::realtime_websocket::protocol_v2::parse_realtime_event_v2;
|
||||
pub use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
pub use codex_protocol::protocol::RealtimeEvent;
|
||||
@@ -6,7 +7,6 @@ pub use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
pub use codex_protocol::protocol::RealtimeTranscriptEntry;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeEventParser {
|
||||
@@ -14,12 +14,19 @@ pub enum RealtimeEventParser {
|
||||
RealtimeV2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeSessionMode {
|
||||
Conversational,
|
||||
Transcription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub instructions: String,
|
||||
pub model: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub event_parser: RealtimeEventParser,
|
||||
pub session_mode: RealtimeSessionMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -35,21 +42,25 @@ pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "session.update")]
|
||||
SessionUpdate { session: SessionUpdateSession },
|
||||
#[serde(rename = "conversation.item.create")]
|
||||
ConversationItemCreate { item: ConversationItem },
|
||||
ConversationItemCreate { item: ConversationItemPayload },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSession {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) instructions: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) instructions: Option<String>,
|
||||
pub(super) audio: SessionAudio,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) tools: Option<Vec<SessionFunctionTool>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudio {
|
||||
pub(super) input: SessionAudioInput,
|
||||
pub(super) output: SessionAudioOutput,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) output: Option<SessionAudioOutput>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -70,13 +81,28 @@ pub(super) struct SessionAudioOutput {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItem {
|
||||
pub(super) struct ConversationMessageItem {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) role: String,
|
||||
pub(super) content: Vec<ConversationItemContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub(super) enum ConversationItemPayload {
|
||||
Message(ConversationMessageItem),
|
||||
FunctionCallOutput(ConversationFunctionCallOutputItem),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationFunctionCallOutputItem {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) call_id: String,
|
||||
pub(super) output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItemContent {
|
||||
#[serde(rename = "type")]
|
||||
@@ -84,6 +110,15 @@ pub(super) struct ConversationItemContent {
|
||||
pub(super) text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionFunctionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) name: String,
|
||||
pub(super) description: String,
|
||||
pub(super) parameters: Value,
|
||||
}
|
||||
|
||||
pub(super) fn parse_realtime_event(
|
||||
payload: &str,
|
||||
event_parser: RealtimeEventParser,
|
||||
@@ -93,125 +128,3 @@ pub(super) fn parse_realtime_event(
|
||||
RealtimeEventParser::RealtimeV2 => parse_realtime_event_v2(payload),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type,
|
||||
None => {
|
||||
debug!("received realtime event without type field: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match message_type {
|
||||
"session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
"conversation.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())?;
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u16::try_from(v).ok())?;
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok()),
|
||||
}))
|
||||
}
|
||||
"conversation.input_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.output_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parsed
|
||||
.get("item")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|item| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"conversation.handoff.requested" => {
|
||||
let handoff_id = parsed
|
||||
.get("handoff_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let input_transcript = parsed
|
||||
.get("input_transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
_ => {
|
||||
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
pub(super) fn parse_realtime_payload(payload: &str, parser_name: &str) -> Option<(Value, String)> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
debug!("failed to parse {parser_name} event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type.to_string(),
|
||||
None => {
|
||||
debug!("received {parser_name} event without type field: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some((parsed, message_type))
|
||||
}
|
||||
|
||||
pub(super) fn parse_session_updated_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
Some(RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn parse_transcript_delta_event(
|
||||
parsed: &Value,
|
||||
field: &str,
|
||||
) -> Option<RealtimeTranscriptDelta> {
|
||||
parsed
|
||||
.get(field)
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeTranscriptDelta { delta })
|
||||
}
|
||||
|
||||
pub(super) fn parse_error_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
|
||||
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v1")?;
|
||||
match message_type.as_str() {
|
||||
"session.updated" => parse_session_updated_event(&parsed),
|
||||
"conversation.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())?;
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())?;
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
}
|
||||
"conversation.input_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
|
||||
}
|
||||
"conversation.output_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
|
||||
}
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parsed
|
||||
.get("item")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|item| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"conversation.handoff.requested" => {
|
||||
let handoff_id = parsed
|
||||
.get("handoff_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let input_transcript = parsed
|
||||
.get("input_transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
"error" => parse_error_event(&parsed),
|
||||
_ => {
|
||||
debug!("received unsupported realtime v1 event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,157 +1,130 @@
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_error_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_realtime_payload;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_session_updated_event;
|
||||
use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta_event;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
use serde_json::Map as JsonMap;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
const CODEX_TOOL_NAME: &str = "codex";
|
||||
const DEFAULT_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const DEFAULT_AUDIO_CHANNELS: u16 = 1;
|
||||
const TOOL_ARGUMENT_KEYS: [&str; 5] = ["input_transcript", "input", "text", "prompt", "query"];
|
||||
|
||||
pub(super) fn parse_realtime_event_v2(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime v2 event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let (parsed, message_type) = parse_realtime_payload(payload, "realtime v2")?;
|
||||
|
||||
let message_type = match parsed.get("type").and_then(Value::as_str) {
|
||||
Some(message_type) => message_type,
|
||||
None => {
|
||||
debug!("received realtime v2 event without type field: {payload}");
|
||||
return None;
|
||||
match message_type.as_str() {
|
||||
"session.updated" => parse_session_updated_event(&parsed),
|
||||
"response.output_audio.delta" => parse_output_audio_delta_event(&parsed),
|
||||
"conversation.item.input_audio_transcription.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta)
|
||||
}
|
||||
};
|
||||
|
||||
match message_type {
|
||||
"session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
"conversation.item.input_audio_transcription.completed" => {
|
||||
parse_transcript_delta_event(&parsed, "transcript")
|
||||
.map(RealtimeEvent::InputTranscriptDelta)
|
||||
}
|
||||
"response.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())
|
||||
.unwrap_or(24_000);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())
|
||||
.unwrap_or(1);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
"response.output_text.delta" | "response.output_audio_transcript.delta" => {
|
||||
parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta)
|
||||
}
|
||||
"conversation.item.input_audio_transcription.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.input_audio_transcription.completed" => parsed
|
||||
.get("transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"response.output_text.delta" | "response.output_audio_transcript.delta" => parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|delta| RealtimeEvent::OutputTranscriptDelta(RealtimeTranscriptDelta { delta })),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => {
|
||||
let item = parsed.get("item")?.as_object()?;
|
||||
let item_type = item.get("type").and_then(Value::as_str);
|
||||
let item_name = item.get("name").and_then(Value::as_str);
|
||||
|
||||
if item_type == Some("function_call") && item_name == Some("codex") {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| item.get("id").and_then(Value::as_str))?;
|
||||
let item_id = item
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or(call_id)
|
||||
.to_string();
|
||||
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
|
||||
let mut input_transcript = String::new();
|
||||
if !arguments.is_empty() {
|
||||
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
|
||||
&& let Some(arguments_object) = arguments_json.as_object()
|
||||
{
|
||||
for key in ["input_transcript", "input", "text", "prompt", "query"] {
|
||||
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
input_transcript = trimmed.to_string();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if input_transcript.is_empty() {
|
||||
input_transcript = arguments.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
return Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: call_id.to_string(),
|
||||
item_id,
|
||||
input_transcript,
|
||||
active_transcript: Vec::new(),
|
||||
}));
|
||||
}
|
||||
|
||||
item.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
"conversation.item.done" => parse_conversation_item_done_event(&parsed),
|
||||
"error" => parse_error_event(&parsed),
|
||||
_ => {
|
||||
debug!("received unsupported realtime v2 event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_output_audio_delta_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())
|
||||
.unwrap_or(DEFAULT_AUDIO_SAMPLE_RATE);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())
|
||||
.unwrap_or(DEFAULT_AUDIO_CHANNELS);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_conversation_item_done_event(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let item = parsed.get("item")?.as_object()?;
|
||||
if let Some(handoff) = parse_handoff_requested_event(item) {
|
||||
return Some(handoff);
|
||||
}
|
||||
|
||||
item.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id })
|
||||
}
|
||||
|
||||
fn parse_handoff_requested_event(item: &JsonMap<String, Value>) -> Option<RealtimeEvent> {
|
||||
let item_type = item.get("type").and_then(Value::as_str);
|
||||
let item_name = item.get("name").and_then(Value::as_str);
|
||||
if item_type != Some("function_call") || item_name != Some(CODEX_TOOL_NAME) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| item.get("id").and_then(Value::as_str))?;
|
||||
let item_id = item
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or(call_id)
|
||||
.to_string();
|
||||
let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or("");
|
||||
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: call_id.to_string(),
|
||||
item_id,
|
||||
input_transcript: extract_input_transcript(arguments),
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn extract_input_transcript(arguments: &str) -> String {
|
||||
if arguments.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
if let Ok(arguments_json) = serde_json::from_str::<Value>(arguments)
|
||||
&& let Some(arguments_object) = arguments_json.as_object()
|
||||
{
|
||||
for key in TOOL_ARGUMENT_KEYS {
|
||||
if let Some(value) = arguments_object.get(key).and_then(Value::as_str) {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
arguments.to_string()
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ pub use crate::endpoint::memories::MemoriesClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
|
||||
@@ -6,6 +6,7 @@ use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
@@ -142,6 +143,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -235,6 +237,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -299,6 +302,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -360,6 +364,7 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -424,6 +429,7 @@ async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -1342,6 +1342,13 @@
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"RealtimeWsMode": {
|
||||
"enum": [
|
||||
"conversational",
|
||||
"transcription"
|
||||
],
|
||||
"type": "string"
|
||||
},
|
||||
"ReasoningEffort": {
|
||||
"description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning",
|
||||
"enum": [
|
||||
@@ -1816,6 +1823,14 @@
|
||||
"description": "Experimental / do not use. Overrides only the realtime conversation websocket transport base URL (the `Op::RealtimeConversation` `/v1/realtime` connection) without changing normal provider HTTP requests.",
|
||||
"type": "string"
|
||||
},
|
||||
"experimental_realtime_ws_mode": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/RealtimeWsMode"
|
||||
}
|
||||
],
|
||||
"description": "Experimental / do not use. Selects the realtime websocket intent mode. `conversational` is speech-to-speech while `transcription` is transcript-only."
|
||||
},
|
||||
"experimental_realtime_ws_model": {
|
||||
"description": "Experimental / do not use. Selects the realtime websocket model/snapshot used for the `Op::RealtimeConversation` connection.",
|
||||
"type": "string"
|
||||
|
||||
@@ -172,8 +172,6 @@ use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
#[cfg(test)]
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use codex_config::CONFIG_TOML_FILE;
|
||||
|
||||
mod rollout_reconstruction;
|
||||
@@ -278,7 +276,6 @@ use crate::skills::collect_explicit_skill_mentions;
|
||||
use crate::skills::injection::ToolMentionKind;
|
||||
use crate::skills::injection::app_id_from_path;
|
||||
use crate::skills::injection::tool_kind_for_path;
|
||||
use crate::skills::model::SkillManagedNetworkOverride;
|
||||
use crate::skills::resolve_skill_dependencies_for_turn;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
@@ -1185,61 +1182,6 @@ impl Session {
|
||||
Ok((network_proxy, session_network_proxy))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_or_start_network_proxy(
|
||||
self: &Arc<Self>,
|
||||
scope: NetworkProxyScope,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
managed_network_override: Option<SkillManagedNetworkOverride>,
|
||||
) -> anyhow::Result<Option<NetworkProxy>> {
|
||||
let session = Arc::clone(self);
|
||||
let started = self
|
||||
.services
|
||||
.network_proxies
|
||||
.get_or_start(
|
||||
scope.clone(),
|
||||
move |spec, managed_enabled, audit_metadata| {
|
||||
let session = Arc::clone(&session);
|
||||
let managed_network_override = managed_network_override.clone();
|
||||
let scope = scope.clone();
|
||||
let sandbox_policy = sandbox_policy.clone();
|
||||
async move {
|
||||
let network_policy_decider = session
|
||||
.services
|
||||
.network_policy_decider_session
|
||||
.as_ref()
|
||||
.map(|network_policy_decider_session| {
|
||||
build_network_policy_decider(
|
||||
Arc::clone(&session.services.network_approval),
|
||||
Arc::clone(network_policy_decider_session),
|
||||
scope,
|
||||
)
|
||||
});
|
||||
let spec = if let Some(managed_network_override) =
|
||||
managed_network_override.as_ref()
|
||||
{
|
||||
spec.with_skill_managed_network_override(managed_network_override)
|
||||
} else {
|
||||
spec
|
||||
};
|
||||
spec.start_proxy(
|
||||
&sandbox_policy,
|
||||
network_policy_decider,
|
||||
session
|
||||
.services
|
||||
.network_blocked_request_observer
|
||||
.as_ref()
|
||||
.map(Arc::clone),
|
||||
managed_enabled,
|
||||
audit_metadata,
|
||||
)
|
||||
.await
|
||||
}
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(started.map(|started| started.proxy()))
|
||||
}
|
||||
|
||||
/// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it.
|
||||
pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
|
||||
// todo(aibrahim): store this state somewhere else so we don't need to mut config
|
||||
@@ -1709,10 +1651,9 @@ impl Session {
|
||||
build_network_policy_decider(
|
||||
Arc::clone(&network_approval),
|
||||
Arc::clone(network_policy_decider_session),
|
||||
NetworkProxyScope::SessionDefault,
|
||||
)
|
||||
});
|
||||
let (default_network_proxy, session_network_proxy) =
|
||||
let (network_proxy, session_network_proxy) =
|
||||
if let Some(spec) = config.permissions.network.as_ref() {
|
||||
let (network_proxy, session_network_proxy) = Self::start_managed_network_proxy(
|
||||
spec,
|
||||
@@ -1720,19 +1661,13 @@ impl Session {
|
||||
network_policy_decider.as_ref().map(Arc::clone),
|
||||
blocked_request_observer.as_ref().map(Arc::clone),
|
||||
managed_network_requirements_enabled,
|
||||
network_proxy_audit_metadata.clone(),
|
||||
network_proxy_audit_metadata,
|
||||
)
|
||||
.await?;
|
||||
(Some(network_proxy), Some(session_network_proxy))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let network_proxies = NetworkProxyRegistry::new(
|
||||
config.permissions.network.clone(),
|
||||
managed_network_requirements_enabled,
|
||||
network_proxy_audit_metadata.clone(),
|
||||
default_network_proxy,
|
||||
);
|
||||
|
||||
let mut hook_shell_argv = default_shell.derive_exec_args("", false);
|
||||
let hook_shell_program = hook_shell_argv.remove(0);
|
||||
@@ -1790,9 +1725,7 @@ impl Session {
|
||||
mcp_manager: Arc::clone(&mcp_manager),
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxies,
|
||||
network_policy_decider_session,
|
||||
network_blocked_request_observer: blocked_request_observer,
|
||||
network_proxy,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: state_db_ctx.clone(),
|
||||
model_client: ModelClient::new(
|
||||
@@ -1831,9 +1764,7 @@ impl Session {
|
||||
js_repl,
|
||||
next_internal_sub_id: AtomicU64::new(0),
|
||||
});
|
||||
if let Some(network_policy_decider_session) =
|
||||
sess.services.network_policy_decider_session.as_ref()
|
||||
{
|
||||
if let Some(network_policy_decider_session) = network_policy_decider_session {
|
||||
let mut guard = network_policy_decider_session.write().await;
|
||||
*guard = Arc::downgrade(&sess);
|
||||
}
|
||||
@@ -2383,10 +2314,8 @@ impl Session {
|
||||
model_info,
|
||||
&self.services.models_manager,
|
||||
self.services
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
.as_deref()
|
||||
.network_proxy
|
||||
.as_ref()
|
||||
.map(StartedNetworkProxy::proxy),
|
||||
sub_id,
|
||||
Arc::clone(&self.js_repl),
|
||||
@@ -2758,7 +2687,6 @@ impl Session {
|
||||
&self,
|
||||
amendment: &NetworkPolicyAmendment,
|
||||
network_approval_context: &NetworkApprovalContext,
|
||||
scope: &NetworkProxyScope,
|
||||
) -> anyhow::Result<()> {
|
||||
let host =
|
||||
Self::validated_network_policy_amendment_host(amendment, network_approval_context)?;
|
||||
@@ -2772,7 +2700,7 @@ impl Session {
|
||||
let execpolicy_amendment =
|
||||
execpolicy_network_rule_amendment(amendment, network_approval_context, &host);
|
||||
|
||||
if let Some(started_network_proxy) = self.services.network_proxies.get(scope).await {
|
||||
if let Some(started_network_proxy) = self.services.network_proxy.as_ref() {
|
||||
let proxy = started_network_proxy.proxy();
|
||||
match amendment.action {
|
||||
NetworkPolicyRuleAction::Allow => proxy
|
||||
|
||||
@@ -11,11 +11,9 @@ use crate::exec::ExecToolCallOutput;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::models_manager::model_info;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::shell::default_user_shell;
|
||||
use crate::tools::format_exec_output_str;
|
||||
|
||||
use codex_network_proxy::NetworkProxyAuditMetadata;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
@@ -2154,14 +2152,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
mcp_manager,
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxies: NetworkProxyRegistry::new(
|
||||
None,
|
||||
false,
|
||||
NetworkProxyAuditMetadata::default(),
|
||||
None,
|
||||
),
|
||||
network_policy_decider_session: None,
|
||||
network_blocked_request_observer: None,
|
||||
network_proxy: None,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: None,
|
||||
model_client: ModelClient::new(
|
||||
@@ -2803,14 +2794,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
mcp_manager,
|
||||
file_watcher,
|
||||
agent_control,
|
||||
network_proxies: NetworkProxyRegistry::new(
|
||||
None,
|
||||
false,
|
||||
NetworkProxyAuditMetadata::default(),
|
||||
None,
|
||||
),
|
||||
network_policy_decider_session: None,
|
||||
network_blocked_request_observer: None,
|
||||
network_proxy: None,
|
||||
network_approval: Arc::clone(&network_approval),
|
||||
state_db: None,
|
||||
model_client: ModelClient::new(
|
||||
|
||||
@@ -4129,6 +4129,7 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4265,6 +4266,7 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4399,6 +4401,7 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -4519,6 +4522,7 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
|
||||
experimental_realtime_start_instructions: None,
|
||||
experimental_realtime_ws_base_url: None,
|
||||
experimental_realtime_ws_model: None,
|
||||
experimental_realtime_ws_mode: RealtimeWsMode::Conversational,
|
||||
experimental_realtime_ws_backend_prompt: None,
|
||||
experimental_realtime_ws_startup_context: None,
|
||||
base_instructions: None,
|
||||
@@ -5566,6 +5570,34 @@ experimental_realtime_ws_model = "realtime-test-model"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn experimental_realtime_ws_mode_loads_from_config_toml() -> std::io::Result<()> {
|
||||
let cfg: ConfigToml = toml::from_str(
|
||||
r#"
|
||||
experimental_realtime_ws_mode = "transcription"
|
||||
"#,
|
||||
)
|
||||
.expect("TOML deserialization should succeed");
|
||||
|
||||
assert_eq!(
|
||||
cfg.experimental_realtime_ws_mode,
|
||||
Some(RealtimeWsMode::Transcription)
|
||||
);
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.experimental_realtime_ws_mode,
|
||||
RealtimeWsMode::Transcription
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn realtime_audio_loads_from_config_toml() -> std::io::Result<()> {
|
||||
let cfg: ConfigToml = toml::from_str(
|
||||
|
||||
@@ -463,6 +463,9 @@ pub struct Config {
|
||||
/// Experimental / do not use. Selects the realtime websocket model/snapshot
|
||||
/// used for the `Op::RealtimeConversation` connection.
|
||||
pub experimental_realtime_ws_model: Option<String>,
|
||||
/// Experimental / do not use. Selects the realtime websocket intent mode.
|
||||
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
|
||||
pub experimental_realtime_ws_mode: RealtimeWsMode,
|
||||
/// Experimental / do not use. Overrides only the realtime conversation
|
||||
/// websocket transport instructions (the `Op::RealtimeConversation`
|
||||
/// `/ws` session.update instructions) without changing normal prompts.
|
||||
@@ -1238,6 +1241,9 @@ pub struct ConfigToml {
|
||||
/// Experimental / do not use. Selects the realtime websocket model/snapshot
|
||||
/// used for the `Op::RealtimeConversation` connection.
|
||||
pub experimental_realtime_ws_model: Option<String>,
|
||||
/// Experimental / do not use. Selects the realtime websocket intent mode.
|
||||
/// `conversational` is speech-to-speech while `transcription` is transcript-only.
|
||||
pub experimental_realtime_ws_mode: Option<RealtimeWsMode>,
|
||||
/// Experimental / do not use. Overrides only the realtime conversation
|
||||
/// websocket transport instructions (the `Op::RealtimeConversation`
|
||||
/// `/ws` session.update instructions) without changing normal prompts.
|
||||
@@ -1383,6 +1389,14 @@ pub struct RealtimeAudioConfig {
|
||||
pub speaker: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RealtimeWsMode {
|
||||
#[default]
|
||||
Conversational,
|
||||
Transcription,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct RealtimeAudioToml {
|
||||
@@ -2462,6 +2476,7 @@ impl Config {
|
||||
}),
|
||||
experimental_realtime_ws_base_url: cfg.experimental_realtime_ws_base_url,
|
||||
experimental_realtime_ws_model: cfg.experimental_realtime_ws_model,
|
||||
experimental_realtime_ws_mode: cfg.experimental_realtime_ws_mode.unwrap_or_default(),
|
||||
experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt,
|
||||
experimental_realtime_ws_startup_context: cfg.experimental_realtime_ws_startup_context,
|
||||
experimental_realtime_start_instructions: cfg.experimental_realtime_start_instructions,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::config_loader::NetworkConstraints;
|
||||
use crate::skills::model::SkillManagedNetworkOverride;
|
||||
use async_trait::async_trait;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
use codex_network_proxy::ConfigReloader;
|
||||
@@ -82,28 +81,6 @@ impl NetworkProxySpec {
|
||||
self.config.network.enable_socks5
|
||||
}
|
||||
|
||||
pub(crate) fn with_skill_managed_network_override(
|
||||
&self,
|
||||
managed_network_override: &SkillManagedNetworkOverride,
|
||||
) -> Self {
|
||||
let mut spec = self.clone();
|
||||
|
||||
if let Some(allowed_domains) = managed_network_override.allowed_domains.clone() {
|
||||
spec.config.network.allowed_domains = allowed_domains.clone();
|
||||
if spec.constraints.allowed_domains.is_some() {
|
||||
spec.constraints.allowed_domains = Some(allowed_domains);
|
||||
}
|
||||
}
|
||||
if let Some(denied_domains) = managed_network_override.denied_domains.clone() {
|
||||
spec.config.network.denied_domains = denied_domains.clone();
|
||||
if spec.constraints.denied_domains.is_some() {
|
||||
spec.constraints.denied_domains = Some(denied_domains);
|
||||
}
|
||||
}
|
||||
|
||||
spec
|
||||
}
|
||||
|
||||
pub(crate) fn from_config_and_constraints(
|
||||
config: NetworkProxyConfig,
|
||||
requirements: Option<NetworkConstraints>,
|
||||
|
||||
@@ -1,94 +1,6 @@
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn skill_managed_network_override_replaces_allowed_domains_and_keeps_other_settings() {
|
||||
let mut config = NetworkProxyConfig::default();
|
||||
config.network.enabled = true;
|
||||
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
|
||||
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
|
||||
config.network.enable_socks5 = true;
|
||||
config.network.enable_socks5_udp = true;
|
||||
config.network.allowed_domains = vec!["default.example.com".to_string()];
|
||||
config.network.denied_domains = vec!["blocked.example.com".to_string()];
|
||||
config.network.allow_upstream_proxy = true;
|
||||
config.network.dangerously_allow_all_unix_sockets = false;
|
||||
config.network.dangerously_allow_non_loopback_proxy = false;
|
||||
config.network.mode = codex_network_proxy::NetworkMode::Full;
|
||||
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
|
||||
config.network.allow_local_binding = true;
|
||||
config.network.mitm = false;
|
||||
let spec = NetworkProxySpec {
|
||||
config,
|
||||
constraints: NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["default.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.example.com".to_string()]),
|
||||
allowlist_expansion_enabled: Some(true),
|
||||
denylist_expansion_enabled: Some(false),
|
||||
allow_upstream_proxy: Some(true),
|
||||
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
|
||||
allow_local_binding: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
},
|
||||
hard_deny_allowlist_misses: true,
|
||||
};
|
||||
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: Some(vec!["skill.example.com".to_string()]),
|
||||
denied_domains: None,
|
||||
};
|
||||
|
||||
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
|
||||
|
||||
let mut expected = spec.clone();
|
||||
expected.config.network.allowed_domains = vec!["skill.example.com".to_string()];
|
||||
expected.constraints.allowed_domains = Some(vec!["skill.example.com".to_string()]);
|
||||
assert_eq!(overridden, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_managed_network_override_replaces_denied_domains_and_keeps_default_allowed_domains() {
|
||||
let mut config = NetworkProxyConfig::default();
|
||||
config.network.enabled = true;
|
||||
config.network.proxy_url = "http://127.0.0.1:4128".to_string();
|
||||
config.network.socks_url = "socks5://127.0.0.1:5128".to_string();
|
||||
config.network.enable_socks5 = true;
|
||||
config.network.enable_socks5_udp = true;
|
||||
config.network.allowed_domains = vec!["default.example.com".to_string()];
|
||||
config.network.denied_domains = vec!["blocked.example.com".to_string()];
|
||||
config.network.allow_upstream_proxy = true;
|
||||
config.network.dangerously_allow_all_unix_sockets = false;
|
||||
config.network.dangerously_allow_non_loopback_proxy = false;
|
||||
config.network.mode = codex_network_proxy::NetworkMode::Full;
|
||||
config.network.allow_unix_sockets = vec!["/tmp/default.sock".to_string()];
|
||||
config.network.allow_local_binding = true;
|
||||
config.network.mitm = false;
|
||||
let spec = NetworkProxySpec {
|
||||
config,
|
||||
constraints: NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["default.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.example.com".to_string()]),
|
||||
allowlist_expansion_enabled: Some(true),
|
||||
denylist_expansion_enabled: Some(false),
|
||||
allow_upstream_proxy: Some(true),
|
||||
allow_unix_sockets: Some(vec!["/tmp/default.sock".to_string()]),
|
||||
allow_local_binding: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
},
|
||||
hard_deny_allowlist_misses: false,
|
||||
};
|
||||
let managed_network_override = crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: None,
|
||||
denied_domains: Some(vec!["skill-blocked.example.com".to_string()]),
|
||||
};
|
||||
|
||||
let overridden = spec.with_skill_managed_network_override(&managed_network_override);
|
||||
|
||||
let mut expected = spec.clone();
|
||||
expected.config.network.denied_domains = vec!["skill-blocked.example.com".to_string()];
|
||||
expected.constraints.denied_domains = Some(vec!["skill-blocked.example.com".to_string()]);
|
||||
assert_eq!(overridden, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_state_with_audit_metadata_threads_metadata_to_state() {
|
||||
let spec = NetworkProxySpec {
|
||||
|
||||
@@ -39,7 +39,6 @@ use crate::config::Constrained;
|
||||
use crate::config::NetworkProxySpec;
|
||||
use crate::event_mapping::is_contextual_user_message_content;
|
||||
use crate::features::Feature;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::truncate::approx_bytes_for_tokens;
|
||||
@@ -551,12 +550,7 @@ async fn run_guardian_subagent(
|
||||
schema: Value,
|
||||
cancel_token: CancellationToken,
|
||||
) -> anyhow::Result<GuardianAssessment> {
|
||||
let live_network_config = match session
|
||||
.services
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
{
|
||||
let live_network_config = match session.services.network_proxy.as_ref() {
|
||||
Some(network_proxy) => Some(network_proxy.proxy().current_cfg().await?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -51,7 +51,6 @@ mod mcp_tool_approval_templates;
|
||||
pub mod models_manager;
|
||||
mod network_policy_decision;
|
||||
pub mod network_proxy_loader;
|
||||
mod network_proxy_registry;
|
||||
mod original_image_detail;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
use crate::config::NetworkProxySpec;
|
||||
use crate::config::StartedNetworkProxy;
|
||||
use anyhow::Result;
|
||||
use codex_network_proxy::NetworkProxyAuditMetadata;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) enum NetworkProxyScope {
|
||||
SessionDefault,
|
||||
Skill { path_to_skills_md: PathBuf },
|
||||
}
|
||||
|
||||
pub(crate) struct NetworkProxyRegistry {
|
||||
spec: Option<NetworkProxySpec>,
|
||||
managed_network_requirements_enabled: bool,
|
||||
audit_metadata: NetworkProxyAuditMetadata,
|
||||
proxies: Mutex<HashMap<NetworkProxyScope, Arc<StartedNetworkProxy>>>,
|
||||
}
|
||||
|
||||
impl NetworkProxyRegistry {
|
||||
pub(crate) fn new(
|
||||
spec: Option<NetworkProxySpec>,
|
||||
managed_network_requirements_enabled: bool,
|
||||
audit_metadata: NetworkProxyAuditMetadata,
|
||||
default_proxy: Option<StartedNetworkProxy>,
|
||||
) -> Self {
|
||||
let mut proxies = HashMap::new();
|
||||
if let Some(default_proxy) = default_proxy {
|
||||
proxies.insert(NetworkProxyScope::SessionDefault, Arc::new(default_proxy));
|
||||
}
|
||||
|
||||
Self {
|
||||
spec,
|
||||
managed_network_requirements_enabled,
|
||||
audit_metadata,
|
||||
proxies: Mutex::new(proxies),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self, scope: &NetworkProxyScope) -> Option<Arc<StartedNetworkProxy>> {
|
||||
self.proxies.lock().await.get(scope).cloned()
|
||||
}
|
||||
|
||||
pub(crate) async fn get_or_start<F, Fut>(
|
||||
&self,
|
||||
scope: NetworkProxyScope,
|
||||
start: F,
|
||||
) -> Result<Option<Arc<StartedNetworkProxy>>>
|
||||
where
|
||||
F: FnOnce(NetworkProxySpec, bool, NetworkProxyAuditMetadata) -> Fut,
|
||||
Fut: Future<Output = std::io::Result<StartedNetworkProxy>>,
|
||||
{
|
||||
let mut proxies = self.proxies.lock().await;
|
||||
if let Some(existing) = proxies.get(&scope).cloned() {
|
||||
return Ok(Some(existing));
|
||||
}
|
||||
|
||||
let Some(spec) = self.spec.clone() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let started = Arc::new(
|
||||
start(
|
||||
spec,
|
||||
self.managed_network_requirements_enabled,
|
||||
self.audit_metadata.clone(),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
proxies.insert(scope, Arc::clone(&started));
|
||||
Ok(Some(started))
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
|
||||
@@ -116,10 +117,7 @@ impl RealtimeConversationManager {
|
||||
&self,
|
||||
api_provider: ApiProvider,
|
||||
extra_headers: Option<HeaderMap>,
|
||||
prompt: String,
|
||||
model: Option<String>,
|
||||
session_id: Option<String>,
|
||||
event_parser: RealtimeEventParser,
|
||||
session_config: RealtimeSessionConfig,
|
||||
) -> CodexResult<(Receiver<RealtimeEvent>, Arc<AtomicBool>)> {
|
||||
let previous_state = {
|
||||
let mut guard = self.state.lock().await;
|
||||
@@ -131,12 +129,6 @@ impl RealtimeConversationManager {
|
||||
let _ = state.task.await;
|
||||
}
|
||||
|
||||
let session_config = RealtimeSessionConfig {
|
||||
instructions: prompt,
|
||||
model,
|
||||
session_id,
|
||||
event_parser,
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
@@ -307,23 +299,26 @@ pub(crate) async fn handle_start(
|
||||
} else {
|
||||
RealtimeEventParser::V1
|
||||
};
|
||||
|
||||
let session_mode = match config.experimental_realtime_ws_mode {
|
||||
crate::config::RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational,
|
||||
crate::config::RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription,
|
||||
};
|
||||
let requested_session_id = params
|
||||
.session_id
|
||||
.or_else(|| Some(sess.conversation_id.to_string()));
|
||||
let session_config = RealtimeSessionConfig {
|
||||
instructions: prompt,
|
||||
model,
|
||||
session_id: requested_session_id.clone(),
|
||||
event_parser,
|
||||
session_mode,
|
||||
};
|
||||
let extra_headers =
|
||||
realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?;
|
||||
info!("starting realtime conversation");
|
||||
let (events_rx, realtime_active) = match sess
|
||||
.conversation
|
||||
.start(
|
||||
api_provider,
|
||||
extra_headers,
|
||||
prompt,
|
||||
model,
|
||||
requested_session_id.clone(),
|
||||
event_parser,
|
||||
)
|
||||
.start(api_provider, extra_headers, session_config)
|
||||
.await
|
||||
{
|
||||
Ok(events_rx) => events_rx,
|
||||
|
||||
@@ -6,13 +6,12 @@ use crate::RolloutRecorder;
|
||||
use crate::agent::AgentControl;
|
||||
use crate::analytics_client::AnalyticsEventsClient;
|
||||
use crate::client::ModelClient;
|
||||
use crate::codex::Session;
|
||||
use crate::config::StartedNetworkProxy;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::mcp::McpManager;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::network_proxy_registry::NetworkProxyRegistry;
|
||||
use crate::plugins::PluginsManager;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::state_db::StateDbHandle;
|
||||
@@ -22,11 +21,9 @@ use crate::tools::runtimes::ExecveSessionApproval;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use codex_hooks::Hooks;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
use codex_otel::SessionTelemetry;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::watch;
|
||||
@@ -58,9 +55,7 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_manager: Arc<McpManager>,
|
||||
pub(crate) file_watcher: Arc<FileWatcher>,
|
||||
pub(crate) agent_control: AgentControl,
|
||||
pub(crate) network_proxies: NetworkProxyRegistry,
|
||||
pub(crate) network_policy_decider_session: Option<Arc<RwLock<Weak<Session>>>>,
|
||||
pub(crate) network_blocked_request_observer: Option<Arc<dyn BlockedRequestObserver>>,
|
||||
pub(crate) network_proxy: Option<StartedNetworkProxy>,
|
||||
pub(crate) network_approval: Arc<NetworkApprovalService>,
|
||||
pub(crate) state_db: Option<StateDbHandle>,
|
||||
/// Session-scoped model client shared across turns.
|
||||
|
||||
@@ -46,7 +46,6 @@ use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
use crate::features::Feature;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use ghost_snapshot::GhostSnapshotTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
@@ -293,12 +292,7 @@ impl Session {
|
||||
"false"
|
||||
},
|
||||
);
|
||||
let network_proxy_active = match self
|
||||
.services
|
||||
.network_proxies
|
||||
.get(&NetworkProxyScope::SessionDefault)
|
||||
.await
|
||||
{
|
||||
let network_proxy_active = match self.services.network_proxy.as_ref() {
|
||||
Some(started_network_proxy) => {
|
||||
match started_network_proxy.proxy().current_cfg().await {
|
||||
Ok(config) => config.network.enabled,
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::guardian::GuardianApprovalRequest;
|
||||
use crate::guardian::review_approval_request;
|
||||
use crate::guardian::routes_approval_to_guardian;
|
||||
use crate::network_policy_decision::denied_network_policy_message;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use codex_network_proxy::BlockedRequest;
|
||||
use codex_network_proxy::BlockedRequestObserver;
|
||||
@@ -77,20 +76,14 @@ impl ActiveNetworkApproval {
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
struct HostApprovalKey {
|
||||
scope: NetworkProxyScope,
|
||||
host: String,
|
||||
protocol: &'static str,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl HostApprovalKey {
|
||||
fn from_request(
|
||||
request: &NetworkPolicyRequest,
|
||||
protocol: NetworkApprovalProtocol,
|
||||
scope: NetworkProxyScope,
|
||||
) -> Self {
|
||||
fn from_request(request: &NetworkPolicyRequest, protocol: NetworkApprovalProtocol) -> Self {
|
||||
Self {
|
||||
scope,
|
||||
host: request.host.to_ascii_lowercase(),
|
||||
protocol: protocol_key_label(protocol),
|
||||
port: request.port,
|
||||
@@ -286,7 +279,6 @@ impl NetworkApprovalService {
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
request: NetworkPolicyRequest,
|
||||
scope: NetworkProxyScope,
|
||||
) -> NetworkDecision {
|
||||
const REASON_NOT_ALLOWED: &str = "not_allowed";
|
||||
|
||||
@@ -296,7 +288,7 @@ impl NetworkApprovalService {
|
||||
NetworkProtocol::Socks5Tcp => NetworkApprovalProtocol::Socks5Tcp,
|
||||
NetworkProtocol::Socks5Udp => NetworkApprovalProtocol::Socks5Udp,
|
||||
};
|
||||
let key = HostApprovalKey::from_request(&request, protocol, scope.clone());
|
||||
let key = HostApprovalKey::from_request(&request, protocol);
|
||||
|
||||
{
|
||||
let denied_hosts = self.session_denied_hosts.lock().await;
|
||||
@@ -395,7 +387,6 @@ impl NetworkApprovalService {
|
||||
.persist_network_policy_amendment(
|
||||
&network_policy_amendment,
|
||||
&network_approval_context,
|
||||
&scope,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -426,7 +417,6 @@ impl NetworkApprovalService {
|
||||
.persist_network_policy_amendment(
|
||||
&network_policy_amendment,
|
||||
&network_approval_context,
|
||||
&scope,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -516,18 +506,16 @@ pub(crate) fn build_blocked_request_observer(
|
||||
pub(crate) fn build_network_policy_decider(
|
||||
network_approval: Arc<NetworkApprovalService>,
|
||||
network_policy_decider_session: Arc<RwLock<std::sync::Weak<Session>>>,
|
||||
scope: NetworkProxyScope,
|
||||
) -> Arc<dyn NetworkPolicyDecider> {
|
||||
Arc::new(move |request: NetworkPolicyRequest| {
|
||||
let network_approval = Arc::clone(&network_approval);
|
||||
let network_policy_decider_session = Arc::clone(&network_policy_decider_session);
|
||||
let scope = scope.clone();
|
||||
async move {
|
||||
let Some(session) = network_policy_decider_session.read().await.upgrade() else {
|
||||
return NetworkDecision::ask("not_allowed");
|
||||
};
|
||||
network_approval
|
||||
.handle_inline_policy_request(session, request, scope)
|
||||
.handle_inline_policy_request(session, request)
|
||||
.await
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use codex_network_proxy::BlockedRequestArgs;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -8,7 +7,6 @@ use pretty_assertions::assert_eq;
|
||||
async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
|
||||
let service = NetworkApprovalService::default();
|
||||
let key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 443,
|
||||
@@ -26,13 +24,11 @@ async fn pending_approvals_are_deduped_per_host_protocol_and_port() {
|
||||
async fn pending_approvals_do_not_dedupe_across_ports() {
|
||||
let service = NetworkApprovalService::default();
|
||||
let first_key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
};
|
||||
let second_key = HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
@@ -53,19 +49,16 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
|
||||
let mut approved_hosts = source.session_approved_hosts.lock().await;
|
||||
approved_hosts.extend([
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 80,
|
||||
@@ -89,19 +82,16 @@ async fn session_approved_hosts_preserve_protocol_and_port_scope() {
|
||||
copied,
|
||||
vec![
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "http",
|
||||
port: 80,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 443,
|
||||
},
|
||||
HostApprovalKey {
|
||||
scope: NetworkProxyScope::SessionDefault,
|
||||
host: "example.com".to_string(),
|
||||
protocol: "https",
|
||||
port: 8443,
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::features::Feature;
|
||||
use crate::guardian::GuardianApprovalRequest;
|
||||
use crate::guardian::review_approval_request;
|
||||
use crate::guardian::routes_approval_to_guardian;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::sandboxing::ExecRequest;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
use crate::shell::ShellType;
|
||||
@@ -51,7 +50,6 @@ use codex_shell_escalation::ShellCommandExecutor;
|
||||
use codex_shell_escalation::Stopwatch;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -144,7 +142,6 @@ pub(super) async fn try_run_zsh_fork(
|
||||
ctx.session.services.exec_policy.current().as_ref().clone(),
|
||||
));
|
||||
let command_executor = CoreShellCommandExecutor {
|
||||
session: Some(Arc::clone(&ctx.session)),
|
||||
command,
|
||||
cwd: sandbox_cwd,
|
||||
sandbox_policy,
|
||||
@@ -262,7 +259,6 @@ pub(crate) async fn prepare_unified_exec_zsh_fork(
|
||||
ctx.session.services.exec_policy.current().as_ref().clone(),
|
||||
));
|
||||
let command_executor = CoreShellCommandExecutor {
|
||||
session: Some(Arc::clone(&ctx.session)),
|
||||
command: exec_request.command.clone(),
|
||||
cwd: exec_request.cwd.clone(),
|
||||
sandbox_policy: exec_request.sandbox_policy.clone(),
|
||||
@@ -507,7 +503,26 @@ impl CoreShellActionProvider {
|
||||
/// an absolute path. The idea is that we check to see whether it matches
|
||||
/// any skills.
|
||||
async fn find_skill(&self, program: &AbsolutePathBuf) -> Option<SkillMetadata> {
|
||||
find_skill_for_program(self.session.as_ref(), self.turn.cwd.as_path(), program).await
|
||||
let force_reload = false;
|
||||
let skills_outcome = self
|
||||
.session
|
||||
.services
|
||||
.skills_manager
|
||||
.skills_for_cwd(&self.turn.cwd, force_reload)
|
||||
.await;
|
||||
|
||||
let program_path = program.as_path();
|
||||
for skill in skills_outcome.skills {
|
||||
// We intentionally ignore "enabled" status here for now.
|
||||
let Some(skill_root) = skill.path_to_skills_md.parent() else {
|
||||
continue;
|
||||
};
|
||||
if program_path.starts_with(skill_root.join("scripts")) {
|
||||
return Some(skill);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -616,32 +631,6 @@ impl CoreShellActionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_skill_for_program(
|
||||
session: &crate::codex::Session,
|
||||
cwd: &Path,
|
||||
program: &AbsolutePathBuf,
|
||||
) -> Option<SkillMetadata> {
|
||||
let force_reload = false;
|
||||
let skills_outcome = session
|
||||
.services
|
||||
.skills_manager
|
||||
.skills_for_cwd(cwd, force_reload)
|
||||
.await;
|
||||
|
||||
let program_path = program.as_path();
|
||||
for skill in skills_outcome.skills {
|
||||
// We intentionally ignore "enabled" status here for now.
|
||||
let Some(skill_root) = skill.path_to_skills_md.parent() else {
|
||||
continue;
|
||||
};
|
||||
if program_path.starts_with(skill_root.join("scripts")) {
|
||||
return Some(skill);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Shell-wrapper parsing is weaker than direct exec interception because it can
|
||||
// only see the script text, not the final resolved executable path. Keep it
|
||||
// disabled by default so path-sensitive rules rely on the later authoritative
|
||||
@@ -876,7 +865,6 @@ fn commands_for_intercepted_exec_policy(
|
||||
}
|
||||
|
||||
struct CoreShellCommandExecutor {
|
||||
session: Option<Arc<crate::codex::Session>>,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
@@ -900,7 +888,6 @@ struct PrepareSandboxedExecParams<'a> {
|
||||
command: Vec<String>,
|
||||
workdir: &'a AbsolutePathBuf,
|
||||
env: HashMap<String, String>,
|
||||
network: Option<codex_network_proxy::NetworkProxy>,
|
||||
sandbox_policy: &'a SandboxPolicy,
|
||||
file_system_sandbox_policy: &'a FileSystemSandboxPolicy,
|
||||
network_sandbox_policy: NetworkSandboxPolicy,
|
||||
@@ -982,12 +969,10 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
arg0: Some(first_arg.clone()),
|
||||
},
|
||||
EscalationExecution::TurnDefault => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
network_sandbox_policy: self.network_sandbox_policy,
|
||||
@@ -1001,14 +986,12 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
EscalationExecution::Permissions(EscalationPermissions::PermissionProfile(
|
||||
permission_profile,
|
||||
)) => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
// Merge additive permissions into the existing turn/request sandbox policy.
|
||||
// On macOS, additional profile extensions are unioned with the turn defaults.
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &self.sandbox_policy,
|
||||
file_system_sandbox_policy: &self.file_system_sandbox_policy,
|
||||
network_sandbox_policy: self.network_sandbox_policy,
|
||||
@@ -1020,13 +1003,11 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
})?
|
||||
}
|
||||
EscalationExecution::Permissions(EscalationPermissions::Permissions(permissions)) => {
|
||||
let network = self.network_for_program(program).await?;
|
||||
// Use a fully specified sandbox policy instead of merging into the turn policy.
|
||||
self.prepare_sandboxed_exec(PrepareSandboxedExecParams {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy: &permissions.sandbox_policy,
|
||||
file_system_sandbox_policy: &permissions.file_system_sandbox_policy,
|
||||
network_sandbox_policy: permissions.network_sandbox_policy,
|
||||
@@ -1044,25 +1025,6 @@ impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
}
|
||||
|
||||
impl CoreShellCommandExecutor {
|
||||
async fn network_for_program(
|
||||
&self,
|
||||
program: &AbsolutePathBuf,
|
||||
) -> anyhow::Result<Option<codex_network_proxy::NetworkProxy>> {
|
||||
let Some(session) = self.session.as_ref() else {
|
||||
return Ok(self.network.clone());
|
||||
};
|
||||
let Some(skill) =
|
||||
find_skill_for_program(session.as_ref(), &self.sandbox_policy_cwd, program).await
|
||||
else {
|
||||
return Ok(self.network.clone());
|
||||
};
|
||||
let (scope, managed_network_override) = network_proxy_scope_for_skill(&skill);
|
||||
|
||||
session
|
||||
.get_or_start_network_proxy(scope, &self.sandbox_policy, managed_network_override)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn prepare_sandboxed_exec(
|
||||
&self,
|
||||
@@ -1072,7 +1034,6 @@ impl CoreShellCommandExecutor {
|
||||
command,
|
||||
workdir,
|
||||
env,
|
||||
network,
|
||||
sandbox_policy,
|
||||
file_system_sandbox_policy,
|
||||
network_sandbox_policy,
|
||||
@@ -1089,7 +1050,7 @@ impl CoreShellCommandExecutor {
|
||||
network_sandbox_policy,
|
||||
SandboxablePreference::Auto,
|
||||
self.windows_sandbox_level,
|
||||
network.is_some(),
|
||||
self.network.is_some(),
|
||||
);
|
||||
let mut exec_request =
|
||||
sandbox_manager.transform(crate::sandboxing::SandboxTransformRequest {
|
||||
@@ -1111,8 +1072,8 @@ impl CoreShellCommandExecutor {
|
||||
file_system_policy: file_system_sandbox_policy,
|
||||
network_policy: network_sandbox_policy,
|
||||
sandbox,
|
||||
enforce_managed_network: network.is_some(),
|
||||
network: network.as_ref(),
|
||||
enforce_managed_network: self.network.is_some(),
|
||||
network: self.network.as_ref(),
|
||||
sandbox_policy_cwd: &self.sandbox_policy_cwd,
|
||||
#[cfg(target_os = "macos")]
|
||||
macos_seatbelt_profile_extensions,
|
||||
@@ -1133,23 +1094,6 @@ impl CoreShellCommandExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn network_proxy_scope_for_skill(
|
||||
skill: &SkillMetadata,
|
||||
) -> (
|
||||
NetworkProxyScope,
|
||||
Option<crate::skills::model::SkillManagedNetworkOverride>,
|
||||
) {
|
||||
match skill.managed_network_override.clone() {
|
||||
Some(managed_network_override) => (
|
||||
NetworkProxyScope::Skill {
|
||||
path_to_skills_md: skill.path_to_skills_md.clone(),
|
||||
},
|
||||
Some(managed_network_override),
|
||||
),
|
||||
None => (NetworkProxyScope::SessionDefault, None),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct ParsedShellCommand {
|
||||
program: String,
|
||||
|
||||
@@ -15,7 +15,6 @@ use crate::config::Permissions;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::network_proxy_registry::NetworkProxyScope;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::GranularApprovalConfig;
|
||||
use crate::protocol::ReadOnlyAccess;
|
||||
@@ -99,35 +98,6 @@ fn test_skill_metadata(permission_profile: Option<PermissionProfile>) -> SkillMe
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn network_proxy_scope_for_skill_without_override_reuses_session_default() {
|
||||
let skill = test_skill_metadata(None);
|
||||
|
||||
assert_eq!(
|
||||
super::network_proxy_scope_for_skill(&skill),
|
||||
(NetworkProxyScope::SessionDefault, None),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn network_proxy_scope_for_skill_with_override_uses_skill_scope() {
|
||||
let mut skill = test_skill_metadata(None);
|
||||
skill.managed_network_override = Some(crate::skills::model::SkillManagedNetworkOverride {
|
||||
allowed_domains: Some(vec!["skill.example.com".to_string()]),
|
||||
denied_domains: Some(vec!["blocked.skill.example.com".to_string()]),
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
super::network_proxy_scope_for_skill(&skill),
|
||||
(
|
||||
NetworkProxyScope::Skill {
|
||||
path_to_skills_md: PathBuf::from("/tmp/skill/SKILL.md"),
|
||||
},
|
||||
skill.managed_network_override.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execve_prompt_rejection_uses_skill_approval_for_skill_scripts() {
|
||||
let decision_source = super::DecisionSource::SkillScript {
|
||||
@@ -681,7 +651,6 @@ host_executable(name = "git", paths = ["{allowed_git_literal}"])
|
||||
async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions() {
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
@@ -734,7 +703,6 @@ async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions
|
||||
async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions() {
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
@@ -809,7 +777,6 @@ async fn prepare_escalated_exec_permission_profile_unions_turn_and_requested_mac
|
||||
let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap();
|
||||
let sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
let executor = CoreShellCommandExecutor {
|
||||
session: None,
|
||||
command: vec!["echo".to_string(), "ok".to_string()],
|
||||
cwd: cwd.to_path_buf(),
|
||||
env: HashMap::new(),
|
||||
|
||||
@@ -123,7 +123,7 @@ impl ActionKind {
|
||||
let (path, _) = target.resolve_for_patch(test);
|
||||
let _ = fs::remove_file(&path);
|
||||
let command = format!("printf {content:?} > {path:?} && cat {path:?}");
|
||||
let event = shell_event(call_id, &command, 1_000, sandbox_permissions)?;
|
||||
let event = shell_event(call_id, &command, 5_000, sandbox_permissions)?;
|
||||
Ok((event, Some(command)))
|
||||
}
|
||||
ActionKind::FetchUrl {
|
||||
|
||||
Reference in New Issue
Block a user