mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Attach WebRTC realtime starts to sideband websocket (#17057)
Summary: - parse the realtime call Location header and join that call over the direct realtime WebSocket - keep WebRTC starts alive on the existing realtime conversation path Validation: - just fmt - git diff --check - cargo check -p codex-api - cargo check -p codex-core --tests - local cargo tests not run; relying on PR CI
This commit is contained in:
@@ -29,7 +29,9 @@ use codex_app_server_protocol::ThreadStartResponse;
|
||||
use codex_features::FEATURES;
|
||||
use codex_features::Feature;
|
||||
use codex_protocol::protocol::RealtimeConversationVersion;
|
||||
use core_test_support::responses::WebSocketConnectionConfig;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::de::DeserializeOwned;
|
||||
@@ -429,10 +431,23 @@ async fn realtime_webrtc_start_emits_sdp_notification() -> Result<()> {
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/realtime/calls"))
|
||||
.and(call_capture.clone())
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("v=answer\r\n"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("Location", "/v1/realtime/calls/rtc_app_test")
|
||||
.set_body_string("v=answer\r\n"),
|
||||
)
|
||||
.mount(&responses_server)
|
||||
.await;
|
||||
let realtime_server = start_websocket_server(vec![vec![]]).await;
|
||||
let realtime_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_webrtc", "instructions": "backend prompt" }
|
||||
})]],
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
close_after_requests: false,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(
|
||||
@@ -475,6 +490,12 @@ async fn realtime_webrtc_start_emits_sdp_notification() -> Result<()> {
|
||||
.await??;
|
||||
let _: ThreadRealtimeStartResponse = to_response(start_response)?;
|
||||
|
||||
let started =
|
||||
read_notification::<ThreadRealtimeStartedNotification>(&mut mcp, "thread/realtime/started")
|
||||
.await?;
|
||||
assert_eq!(started.thread_id, thread_id);
|
||||
assert_eq!(started.version, RealtimeConversationVersion::V2);
|
||||
|
||||
let sdp_notification =
|
||||
read_notification::<ThreadRealtimeSdpNotification>(&mut mcp, "thread/realtime/sdp").await?;
|
||||
assert_eq!(
|
||||
@@ -484,20 +505,59 @@ async fn realtime_webrtc_start_emits_sdp_notification() -> Result<()> {
|
||||
sdp: "v=answer\r\n".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
let session_update = realtime_server
|
||||
.wait_for_request(/*connection_index*/ 0, /*request_index*/ 0)
|
||||
.await;
|
||||
assert_eq!(
|
||||
session_update.body_json()["type"].as_str(),
|
||||
Some("session.update")
|
||||
);
|
||||
assert!(
|
||||
session_update.body_json()["session"]["instructions"]
|
||||
.as_str()
|
||||
.context("expected session.update instructions")?
|
||||
.contains("startup context")
|
||||
);
|
||||
assert_eq!(
|
||||
realtime_server.single_handshake().uri(),
|
||||
"/v1/realtime?call_id=rtc_app_test"
|
||||
);
|
||||
|
||||
let stop_request_id = mcp
|
||||
.send_thread_realtime_stop_request(ThreadRealtimeStopParams {
|
||||
thread_id: thread_id.clone(),
|
||||
})
|
||||
.await?;
|
||||
let stop_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(stop_request_id)),
|
||||
)
|
||||
.await??;
|
||||
let _: ThreadRealtimeStopResponse = to_response(stop_response)?;
|
||||
|
||||
let closed_notification =
|
||||
read_notification::<ThreadRealtimeClosedNotification>(&mut mcp, "thread/realtime/closed")
|
||||
.await?;
|
||||
assert_eq!(
|
||||
closed_notification,
|
||||
ThreadRealtimeClosedNotification {
|
||||
thread_id: thread_id.clone(),
|
||||
reason: Some("transport_closed".to_string())
|
||||
}
|
||||
assert_eq!(closed_notification.thread_id, thread_id);
|
||||
assert!(
|
||||
matches!(
|
||||
closed_notification.reason.as_deref(),
|
||||
Some("requested" | "transport_closed")
|
||||
),
|
||||
"unexpected close reason: {closed_notification:?}"
|
||||
);
|
||||
|
||||
let request = call_capture.single_request();
|
||||
assert_eq!(request.url.path(), "/v1/realtime/calls");
|
||||
assert_eq!(request.url.query(), None);
|
||||
assert_eq!(
|
||||
request
|
||||
.headers
|
||||
.get("content-type")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("multipart/form-data; boundary=codex-realtime-call-boundary")
|
||||
);
|
||||
let body = String::from_utf8(request.body).context("multipart body should be utf-8")?;
|
||||
let session = r#"{"tool_choice":"auto","type":"realtime","instructions":"backend prompt\n\nstartup context","output_modalities":["audio"],"audio":{"input":{"format":{"type":"audio/pcm","rate":24000},"noise_reduction":{"type":"near_field"},"turn_detection":{"type":"server_vad","interrupt_response":true,"create_response":true}},"output":{"format":{"type":"audio/pcm","rate":24000},"voice":"marin"}},"tools":[{"type":"function","name":"codex","description":"Delegate a request to Codex and return the final result to the user. Use this as the default action. If the user asks to do something next, later, after this, or once current work finishes, call this tool so the work is actually queued instead of merely promising to do it later.","parameters":{"type":"object","properties":{"prompt":{"type":"string","description":"The user request to delegate to Codex."}},"required":["prompt"],"additionalProperties":false}}]}"#;
|
||||
assert_eq!(
|
||||
|
||||
@@ -12,6 +12,7 @@ use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::Method;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::header::LOCATION;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::to_string;
|
||||
@@ -26,9 +27,14 @@ pub struct RealtimeCallClient<T: HttpTransport, A: AuthProvider> {
|
||||
session: EndpointSession<T, A>,
|
||||
}
|
||||
|
||||
/// Answer from creating a WebRTC Realtime call.
|
||||
///
|
||||
/// `sdp` configures the peer connection. `call_id` is parsed from the response `Location` header
|
||||
/// and is later used by the server-side sideband WebSocket to join this exact call.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeCallResponse {
|
||||
pub sdp: String,
|
||||
pub call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -101,8 +107,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.await?;
|
||||
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
|
||||
Ok(RealtimeCallResponse { sdp })
|
||||
Ok(RealtimeCallResponse { sdp, call_id })
|
||||
}
|
||||
|
||||
pub async fn create_with_session_and_headers(
|
||||
@@ -111,6 +118,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
session_config: RealtimeSessionConfig,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<RealtimeCallResponse, ApiError> {
|
||||
// WebRTC can begin inference as soon as the peer connection comes up, so the initial
|
||||
// session payload is sent with call creation. The sideband WebSocket still sends its normal
|
||||
// session.update after it joins.
|
||||
let mut session = realtime_session_json(session_config)?;
|
||||
if let Some(session) = session.as_object_mut() {
|
||||
session.remove("id");
|
||||
@@ -127,7 +137,8 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.execute(Method::POST, Self::path(), extra_headers, Some(body))
|
||||
.await?;
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
return Ok(RealtimeCallResponse { sdp });
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
return Ok(RealtimeCallResponse { sdp, call_id });
|
||||
}
|
||||
|
||||
let session = to_string(&session).map_err(|err| ApiError::InvalidRequest {
|
||||
@@ -164,8 +175,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.await?;
|
||||
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
|
||||
Ok(RealtimeCallResponse { sdp })
|
||||
Ok(RealtimeCallResponse { sdp, call_id })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +194,27 @@ fn decode_sdp_response(body: &[u8]) -> Result<String, ApiError> {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_call_id_from_location(headers: &HeaderMap) -> Result<String, ApiError> {
|
||||
let location = headers
|
||||
.get(LOCATION)
|
||||
.ok_or_else(|| ApiError::Stream("realtime call response missing Location".to_string()))?
|
||||
.to_str()
|
||||
.map_err(|err| ApiError::Stream(format!("invalid realtime call Location: {err}")))?;
|
||||
|
||||
location
|
||||
.split('?')
|
||||
.next()
|
||||
.unwrap_or(location)
|
||||
.rsplit('/')
|
||||
.find(|segment| segment.starts_with("rtc_") && segment.len() > "rtc_".len())
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| {
|
||||
ApiError::Stream(format!(
|
||||
"realtime call Location does not contain a call id: {location}"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -201,12 +234,27 @@ mod tests {
|
||||
#[derive(Clone)]
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
response_headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl CapturingTransport {
|
||||
fn new() -> Self {
|
||||
Self::with_location("/v1/realtime/calls/rtc_test")
|
||||
}
|
||||
|
||||
fn with_location(location: &str) -> Self {
|
||||
let mut response_headers = HeaderMap::new();
|
||||
response_headers.insert(LOCATION, HeaderValue::from_str(location).unwrap());
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
response_headers,
|
||||
}
|
||||
}
|
||||
|
||||
fn without_location() -> Self {
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
response_headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,7 +265,7 @@ mod tests {
|
||||
*self.last_request.lock().unwrap() = Some(req);
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
headers: self.response_headers.clone(),
|
||||
body: Bytes::from_static(b"v=0\r\n"),
|
||||
})
|
||||
}
|
||||
@@ -280,7 +328,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -304,6 +353,41 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extracts_call_id_from_forwarded_backend_location() {
|
||||
let transport =
|
||||
CapturingTransport::with_location("/v1/realtime/calls/calls/rtc_backend_test");
|
||||
let client = RealtimeCallClient::new(
|
||||
transport.clone(),
|
||||
provider("https://chatgpt.com/backend-api/codex"),
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let response = client
|
||||
.create("v=offer\r\n".to_string())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_backend_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
let request = transport.last_request.lock().unwrap().clone().unwrap();
|
||||
assert_eq!(request.method, Method::POST);
|
||||
assert_eq!(
|
||||
request.url,
|
||||
"https://chatgpt.com/backend-api/codex/realtime/calls"
|
||||
);
|
||||
assert_eq!(
|
||||
request.body,
|
||||
Some(RequestBody::Raw(Bytes::from_static(b"v=offer\r\n")))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sends_api_session_call_as_multipart_body() {
|
||||
let transport = CapturingTransport::new();
|
||||
@@ -324,7 +408,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -385,7 +470,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -412,4 +498,35 @@ mod tests {
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_location_is_missing() {
|
||||
let transport = CapturingTransport::without_location();
|
||||
let client =
|
||||
RealtimeCallClient::new(transport, provider("https://api.openai.com/v1"), DummyAuth);
|
||||
|
||||
let err = client
|
||||
.create("v=offer\r\n".to_string())
|
||||
.await
|
||||
.expect_err("request should require Location");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"stream error: realtime call response missing Location"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_location_without_call_id() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(LOCATION, HeaderValue::from_static("/v1/realtime/calls"));
|
||||
|
||||
let err = decode_call_id_from_location(&headers)
|
||||
.expect_err("Location without rtc_ segment should fail");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"stream error: realtime call Location does not contain a call id: /v1/realtime/calls"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use codex_client::backoff;
|
||||
use codex_client::maybe_build_rustls_client_config_with_custom_ca;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
@@ -28,6 +29,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::sleep;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
@@ -37,6 +39,7 @@ use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
@@ -455,7 +458,6 @@ impl RealtimeWebsocketClient {
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
ensure_rustls_crypto_provider();
|
||||
let ws_url = websocket_url_from_api_url(
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
@@ -463,6 +465,78 @@ impl RealtimeWebsocketClient {
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
)?;
|
||||
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn connect_webrtc_sideband(
|
||||
&self,
|
||||
config: RealtimeSessionConfig,
|
||||
call_id: &str,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
// The WebRTC call already exists; this loop only retries joining its sideband control
|
||||
// socket. Once joined, the returned connection is the same reader/writer state that the
|
||||
// ordinary websocket start path uses.
|
||||
for attempt in 0..=self.provider.retry.max_attempts {
|
||||
let result = self
|
||||
.connect_webrtc_sideband_once(
|
||||
config.clone(),
|
||||
call_id,
|
||||
extra_headers.clone(),
|
||||
default_headers.clone(),
|
||||
)
|
||||
.await;
|
||||
match result {
|
||||
Ok(connection) => return Ok(connection),
|
||||
Err(err) if attempt < self.provider.retry.max_attempts => {
|
||||
let delay = backoff(self.provider.retry.base_delay, attempt + 1);
|
||||
warn!(
|
||||
attempt = attempt + 1,
|
||||
call_id,
|
||||
delay_ms = delay.as_millis(),
|
||||
"realtime sideband websocket connect failed; retrying: {err}"
|
||||
);
|
||||
sleep(delay).await;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
Err(ApiError::Stream(
|
||||
"realtime sideband websocket retry loop exhausted".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn connect_webrtc_sideband_once(
|
||||
&self,
|
||||
config: RealtimeSessionConfig,
|
||||
call_id: &str,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
// Keep the parser/session query shaping from standalone realtime while replacing the model
|
||||
// query with a call_id join onto an existing WebRTC session.
|
||||
let ws_url = websocket_url_from_api_url_for_call(
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
call_id,
|
||||
)?;
|
||||
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_realtime_websocket_url(
|
||||
&self,
|
||||
ws_url: Url,
|
||||
config: RealtimeSessionConfig,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
ensure_rustls_crypto_provider();
|
||||
|
||||
let mut request = ws_url
|
||||
.as_str()
|
||||
@@ -596,6 +670,24 @@ fn websocket_url_from_api_url(
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn websocket_url_from_api_url_for_call(
|
||||
api_url: &str,
|
||||
query_params: Option<&HashMap<String, String>>,
|
||||
event_parser: RealtimeEventParser,
|
||||
session_mode: RealtimeSessionMode,
|
||||
call_id: &str,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = websocket_url_from_api_url(
|
||||
api_url,
|
||||
query_params,
|
||||
/*model*/ None,
|
||||
event_parser,
|
||||
session_mode,
|
||||
)?;
|
||||
url.query_pairs_mut().append_pair("call_id", call_id);
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn normalize_realtime_path(url: &mut Url) {
|
||||
let path = url.path().to_string();
|
||||
if path.is_empty() || path == "/" {
|
||||
@@ -1094,6 +1186,22 @@ mod tests {
|
||||
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_for_call_id_joins_existing_realtime_session() {
|
||||
let url = websocket_url_from_api_url_for_call(
|
||||
"https://api.openai.com/v1",
|
||||
/*query_params*/ None,
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Conversational,
|
||||
"rtc_test",
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?call_id=rtc_test"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
|
||||
@@ -195,6 +195,82 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_connect_webrtc_sideband_retries_join_until_server_is_available() {
|
||||
let reserving_listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = reserving_listener.local_addr().expect("local addr");
|
||||
drop(reserving_listener);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
let listener = TcpListener::bind(addr).await.expect("bind delayed server");
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
Value::String("backend prompt".to_string())
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_joined", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
});
|
||||
|
||||
let mut provider = test_provider(format!("http://{addr}"));
|
||||
provider.retry.max_attempts = 1;
|
||||
provider.retry.base_delay = Duration::from_millis(100);
|
||||
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect_webrtc_sideband(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
"rtc_test",
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect on retry");
|
||||
|
||||
let event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
event,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_joined".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
|
||||
@@ -40,7 +40,7 @@ use codex_api::MemorySummarizeInput as ApiMemorySummarizeInput;
|
||||
use codex_api::MemorySummarizeOutput as ApiMemorySummarizeOutput;
|
||||
use codex_api::RawMemory as ApiRawMemory;
|
||||
use codex_api::RealtimeCallClient as ApiRealtimeCallClient;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionConfig as ApiRealtimeSessionConfig;
|
||||
use codex_api::Reasoning;
|
||||
use codex_api::RequestTelemetry;
|
||||
use codex_api::ReqwestTransport;
|
||||
@@ -83,6 +83,7 @@ use futures::StreamExt;
|
||||
use http::HeaderMap as ApiHeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::StatusCode as HttpStatusCode;
|
||||
use http::header::AUTHORIZATION;
|
||||
use reqwest::StatusCode;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
@@ -258,6 +259,37 @@ enum WebsocketStreamOutcome {
|
||||
FallbackToHttp,
|
||||
}
|
||||
|
||||
/// Result of opening a WebRTC Realtime call.
|
||||
///
|
||||
/// The SDP answer goes back to the client. The call id and auth headers stay on the server so the
|
||||
/// ordinary Realtime WebSocket machinery can join the same in-progress call as a sideband
|
||||
/// controller.
|
||||
pub(crate) struct RealtimeWebrtcCallStart {
|
||||
pub(crate) sdp: String,
|
||||
pub(crate) call_id: String,
|
||||
pub(crate) sideband_headers: ApiHeaderMap,
|
||||
}
|
||||
|
||||
/// Reuses the API-auth material that created the WebRTC call for the sideband WebSocket join.
|
||||
///
|
||||
/// API-key sessions send that API bearer. ChatGPT-auth sessions send their bearer plus account id;
|
||||
/// transceiver is responsible for accepting that same call-create identity on the direct
|
||||
/// `api.openai.com` sideband path.
|
||||
fn sideband_websocket_auth_headers(api_auth: &CoreAuthProvider) -> ApiHeaderMap {
|
||||
let mut headers = ApiHeaderMap::new();
|
||||
if let Some(token) = api_auth.token.as_ref()
|
||||
&& let Ok(value) = HeaderValue::from_str(&format!("Bearer {token}"))
|
||||
{
|
||||
headers.insert(AUTHORIZATION, value);
|
||||
}
|
||||
if let Some(account_id) = api_auth.account_id.as_ref()
|
||||
&& let Ok(value) = HeaderValue::from_str(account_id)
|
||||
{
|
||||
headers.insert("ChatGPT-Account-ID", value);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Creates a new session-scoped `ModelClient`.
|
||||
@@ -445,28 +477,28 @@ impl ModelClient {
|
||||
.map_err(map_api_error)
|
||||
}
|
||||
|
||||
pub async fn create_realtime_call(
|
||||
pub(crate) async fn create_realtime_call_with_headers(
|
||||
&self,
|
||||
sdp: String,
|
||||
session_config: RealtimeSessionConfig,
|
||||
) -> Result<String> {
|
||||
self.create_realtime_call_with_headers(sdp, session_config, ApiHeaderMap::new())
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create_realtime_call_with_headers(
|
||||
&self,
|
||||
sdp: String,
|
||||
session_config: RealtimeSessionConfig,
|
||||
session_config: ApiRealtimeSessionConfig,
|
||||
extra_headers: ApiHeaderMap,
|
||||
) -> Result<String> {
|
||||
) -> Result<RealtimeWebrtcCallStart> {
|
||||
// Create the media call over HTTP first, then retain matching auth so realtime can attach
|
||||
// the server-side control WebSocket to the call id from that HTTP response.
|
||||
let client_setup = self.current_client_setup().await?;
|
||||
let mut sideband_headers = extra_headers.clone();
|
||||
sideband_headers.extend(sideband_websocket_auth_headers(&client_setup.api_auth));
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
|
||||
.create_with_session_and_headers(sdp, session_config, extra_headers)
|
||||
.await
|
||||
.map(|response| response.sdp)
|
||||
.map_err(map_api_error)
|
||||
let response =
|
||||
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
|
||||
.create_with_session_and_headers(sdp, session_config, extra_headers)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
Ok(RealtimeWebrtcCallStart {
|
||||
sdp: response.sdp,
|
||||
call_id: response.call_id,
|
||||
sideband_headers,
|
||||
})
|
||||
}
|
||||
|
||||
/// Builds memory summaries for each provided normalized raw memory.
|
||||
|
||||
@@ -152,12 +152,8 @@ struct RealtimeStart {
|
||||
|
||||
struct RealtimeStartOutput {
|
||||
realtime_active: Arc<AtomicBool>,
|
||||
connection: RealtimeStartConnection,
|
||||
}
|
||||
|
||||
enum RealtimeStartConnection {
|
||||
Websocket { events_rx: Receiver<RealtimeEvent> },
|
||||
Webrtc { sdp: String },
|
||||
events_rx: Receiver<RealtimeEvent>,
|
||||
sdp: Option<String>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
@@ -199,31 +195,37 @@ impl RealtimeConversationManager {
|
||||
RealtimeEventParser::V1 => RealtimeSessionKind::V1,
|
||||
RealtimeEventParser::RealtimeV2 => RealtimeSessionKind::V2,
|
||||
};
|
||||
let realtime_active = Arc::new(AtomicBool::new(true));
|
||||
|
||||
if let Some(sdp) = sdp {
|
||||
let sdp = model_client
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let (connection, sdp) = if let Some(sdp) = sdp {
|
||||
let call = model_client
|
||||
.create_realtime_call_with_headers(
|
||||
sdp,
|
||||
session_config,
|
||||
session_config.clone(),
|
||||
extra_headers.unwrap_or_default(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(RealtimeStartOutput {
|
||||
realtime_active,
|
||||
connection: RealtimeStartConnection::Webrtc { sdp },
|
||||
});
|
||||
}
|
||||
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
session_config,
|
||||
extra_headers.unwrap_or_default(),
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
let connection = client
|
||||
.connect_webrtc_sideband(
|
||||
session_config,
|
||||
&call.call_id,
|
||||
call.sideband_headers,
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
(connection, Some(call.sdp))
|
||||
} else {
|
||||
let connection = client
|
||||
.connect(
|
||||
session_config,
|
||||
extra_headers.unwrap_or_default(),
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
(connection, None)
|
||||
};
|
||||
|
||||
let writer = connection.writer();
|
||||
let events = connection.events();
|
||||
@@ -261,7 +263,8 @@ impl RealtimeConversationManager {
|
||||
});
|
||||
Ok(RealtimeStartOutput {
|
||||
realtime_active,
|
||||
connection: RealtimeStartConnection::Websocket { events_rx },
|
||||
events_rx,
|
||||
sdp,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -511,11 +514,7 @@ async fn prepare_realtime_start(
|
||||
let transport = params
|
||||
.transport
|
||||
.unwrap_or(ConversationStartTransport::Websocket);
|
||||
let mut api_provider = if matches!(transport, ConversationStartTransport::Websocket) {
|
||||
provider.to_api_provider(Some(AuthMode::ApiKey))?
|
||||
} else {
|
||||
provider.to_api_provider(auth.as_ref().map(CodexAuth::auth_mode))?
|
||||
};
|
||||
let mut api_provider = provider.to_api_provider(Some(AuthMode::ApiKey))?;
|
||||
if let Some(realtime_ws_base_url) = &config.experimental_realtime_ws_base_url {
|
||||
api_provider.base_url = realtime_ws_base_url.clone();
|
||||
}
|
||||
@@ -626,26 +625,16 @@ async fn handle_start_inner(
|
||||
|
||||
let RealtimeStartOutput {
|
||||
realtime_active,
|
||||
connection,
|
||||
events_rx,
|
||||
sdp,
|
||||
} = start_output;
|
||||
let events_rx = match connection {
|
||||
RealtimeStartConnection::Websocket { events_rx } => events_rx,
|
||||
RealtimeStartConnection::Webrtc { sdp } => {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::RealtimeConversationSdp(RealtimeConversationSdpEvent { sdp }),
|
||||
})
|
||||
.await;
|
||||
sess.conversation.finish_if_active(&realtime_active).await;
|
||||
send_realtime_conversation_closed(
|
||||
sess,
|
||||
sub_id.to_string(),
|
||||
RealtimeConversationEnd::TransportClosed,
|
||||
)
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
if let Some(sdp) = sdp {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::RealtimeConversationSdp(RealtimeConversationSdpEvent { sdp }),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
let sess_clone = Arc::clone(sess);
|
||||
let sub_id = sub_id.to_string();
|
||||
|
||||
@@ -19,8 +19,10 @@ use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::WebSocketConnectionConfig;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::streaming_sse::StreamingSseChunk;
|
||||
use core_test_support::streaming_sse::start_streaming_sse_server;
|
||||
@@ -337,19 +339,37 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let capture = RealtimeCallRequestCapture::new();
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/realtime/calls$"))
|
||||
.and(capture.clone())
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("v=answer\r\n"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("Location", "/v1/realtime/calls/calls/rtc_core_test")
|
||||
.set_body_string("v=answer\r\n"),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let realtime_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_webrtc", "instructions": "backend prompt" }
|
||||
})]],
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
close_after_requests: false,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
let realtime_ws_base_url = realtime_server.uri().to_string();
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.experimental_realtime_ws_backend_prompt = Some("backend prompt".to_string());
|
||||
config.experimental_realtime_ws_model = Some("realtime-test-model".to_string());
|
||||
config.experimental_realtime_ws_startup_context = Some("startup context".to_string());
|
||||
config.experimental_realtime_ws_base_url = Some(realtime_ws_base_url);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
@@ -363,6 +383,8 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
}))
|
||||
.await?;
|
||||
|
||||
// Phase 1: the client gets the SDP answer that configures its peer connection, and then the
|
||||
// normal realtime event stream from the joined sideband WebSocket.
|
||||
let created = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationSdp(created) => Some(Ok(created.clone())),
|
||||
EventMsg::Error(err) => Some(Err(err.clone())),
|
||||
@@ -371,13 +393,18 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
.await
|
||||
.unwrap_or_else(|err: ErrorEvent| panic!("conversation call create failed: {err:?}"));
|
||||
assert_eq!(created.sdp, "v=answer\r\n");
|
||||
let closed = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
|
||||
|
||||
let session_updated = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionUpdated { session_id, .. },
|
||||
}) => Some(session_id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(closed.reason.as_deref(), Some("transport_closed"));
|
||||
assert_eq!(session_updated, "sess_webrtc");
|
||||
|
||||
// Phase 2: call creation posts the offer and generated session together, so the media leg can
|
||||
// begin inference before the sideband WebSocket is ready.
|
||||
let request = capture.single_request();
|
||||
assert_eq!(request.url.path(), "/v1/realtime/calls");
|
||||
assert_eq!(request.url.query(), None);
|
||||
@@ -415,6 +442,42 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
)
|
||||
);
|
||||
|
||||
// Phase 3: the server joins that same call over the direct sideband WebSocket, sends the
|
||||
// ordinary session.update, and keeps the conversation alive until the client closes it.
|
||||
let session_update = realtime_server
|
||||
.wait_for_request(/*connection_index*/ 0, /*request_index*/ 0)
|
||||
.await;
|
||||
assert_eq!(
|
||||
session_update.body_json()["type"].as_str(),
|
||||
Some("session.update")
|
||||
);
|
||||
assert!(
|
||||
websocket_request_instructions(&session_update)
|
||||
.context("session.update should include instructions")?
|
||||
.contains("startup context")
|
||||
);
|
||||
let handshake = realtime_server.single_handshake();
|
||||
assert_eq!(
|
||||
handshake.uri(),
|
||||
"/v1/realtime?intent=quicksilver&call_id=rtc_core_test"
|
||||
);
|
||||
assert_eq!(
|
||||
handshake.header("authorization").as_deref(),
|
||||
Some("Bearer dummy")
|
||||
);
|
||||
|
||||
test.codex.submit(Op::RealtimeConversationClose).await?;
|
||||
let closed = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert!(matches!(
|
||||
closed.reason.as_deref(),
|
||||
Some("requested" | "transport_closed")
|
||||
));
|
||||
|
||||
realtime_server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user