diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index c106ee2fc4..dfc1bc4790 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3407,6 +3407,7 @@ dependencies = [ "notify", "opentelemetry", "opentelemetry_sdk", + "opus", "pretty_assertions", "regex-lite", "reqwest", @@ -3419,6 +3420,7 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "walkdir", + "webrtc", "wiremock", "zstd", ] diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index e2765e8be6..026be6d388 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -29,6 +29,7 @@ futures = { workspace = true } notify = { workspace = true } opentelemetry = { workspace = true } opentelemetry_sdk = { workspace = true } +opus = { workspace = true } regex-lite = { workspace = true } serde_json = { workspace = true } tempfile = { workspace = true } @@ -38,6 +39,7 @@ tracing = { workspace = true } tracing-opentelemetry = { workspace = true } tracing-subscriber = { workspace = true } walkdir = { workspace = true } +webrtc = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } zstd = { workspace = true } diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 2e2155ebdd..8ad3ef34e6 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -37,6 +37,8 @@ use wiremock::matchers::path_regex; use crate::test_codex::ApplyPatchModelOutput; +mod realtime_webrtc_server; + #[derive(Debug, Clone)] pub struct ResponseMock { requests: Arc>>, @@ -1238,6 +1240,29 @@ pub async fn start_websocket_server_with_headers( tokio::time::sleep(delay).await; } + if realtime_webrtc_server::accept_is_http_post(&stream).await { + let connection_index = { + let mut log = requests.lock().unwrap(); + log.push(Vec::new()); + log.len() - 1 + }; + realtime_webrtc_server::serve_connection( + stream, + connection, + connection_index, + Arc::clone(&requests), + Arc::clone(&handshakes), + Arc::clone(&request_log), + start, + ) + .await; + + if connections.lock().unwrap().is_empty() { + return; + } + continue; + } + let response_headers = connection.response_headers.clone(); let handshake_log = Arc::clone(&handshakes); let callback = move |req: &Request, mut response: Response| { diff --git a/codex-rs/core/tests/common/responses/realtime_webrtc_server.rs b/codex-rs/core/tests/common/responses/realtime_webrtc_server.rs new file mode 100644 index 0000000000..d9315f90b6 --- /dev/null +++ b/codex-rs/core/tests/common/responses/realtime_webrtc_server.rs @@ -0,0 +1,423 @@ +use std::collections::VecDeque; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use audiopus::Channels; +use audiopus::coder::Decoder as OpusDecoder; +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use serde_json::Value; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; +use tokio::sync::Notify; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::time::timeout; +use webrtc::api::APIBuilder; +use webrtc::api::interceptor_registry::register_default_interceptors; +use webrtc::api::media_engine::MediaEngine; +use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::data_channel_message::DataChannelMessage; +use webrtc::interceptor::registry::Registry; +use webrtc::peer_connection::RTCPeerConnection; +use webrtc::peer_connection::configuration::RTCConfiguration; +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::rtp_transceiver::rtp_codec::RTPCodecType; +use webrtc::track::track_remote::TrackRemote; + +use super::WebSocketConnectionConfig; +use super::WebSocketHandshake; +use super::WebSocketRequest; + +const HTTP_HEADER_TERMINATOR: &[u8] = b"\r\n\r\n"; +const REALTIME_AUDIO_CHANNELS: u8 = 1; +const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000; +const REALTIME_DATA_CHANNEL_TIMEOUT: Duration = Duration::from_secs(10); +const REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760; + +pub(super) async fn accept_is_http_post(stream: &TcpStream) -> bool { + let mut method = [0u8; 4]; + matches!(stream.peek(&mut method).await, Ok(4)) && method == *b"POST" +} + +pub(super) async fn serve_connection( + mut stream: TcpStream, + connection: WebSocketConnectionConfig, + connection_index: usize, + requests: Arc>>>, + handshakes: Arc>>, + request_log_updated: Arc, + start: std::time::Instant, +) { + let Some(request) = read_http_request(&mut stream).await else { + return; + }; + + handshakes.lock().unwrap().push(WebSocketHandshake { + uri: request.uri, + headers: request.headers, + }); + + let Some(offer_sdp) = parse_multipart_field(&request.body, &request.boundary, "sdp") else { + let _ = stream + .write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n") + .await; + return; + }; + + let Some(session) = start_session( + offer_sdp, + connection, + connection_index, + requests, + request_log_updated, + start, + ) + .await + else { + let _ = stream + .write_all(b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n") + .await; + return; + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}", + session.answer_sdp.len(), + session.answer_sdp + ); + if stream.write_all(response.as_bytes()).await.is_err() { + return; + } + + let _ = session.done_rx.await; + let _ = session.peer_connection.close().await; +} + +struct HttpRealtimeRequest { + uri: String, + headers: Vec<(String, String)>, + boundary: String, + body: Vec, +} + +struct RealtimeSession { + answer_sdp: String, + peer_connection: Arc, + done_rx: oneshot::Receiver<()>, +} + +async fn read_http_request(stream: &mut TcpStream) -> Option { + let mut received = Vec::new(); + let headers_end = loop { + if let Some(headers_end) = received + .windows(HTTP_HEADER_TERMINATOR.len()) + .position(|window| window == HTTP_HEADER_TERMINATOR) + { + break headers_end + HTTP_HEADER_TERMINATOR.len(); + } + + let mut chunk = [0u8; 1024]; + let read = stream.read(&mut chunk).await.ok()?; + if read == 0 { + return None; + } + received.extend_from_slice(&chunk[..read]); + }; + + let header_text = std::str::from_utf8(&received[..headers_end]).ok()?; + let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty()); + let request_line = lines.next()?; + let mut request_line_parts = request_line.split_whitespace(); + if request_line_parts.next()? != "POST" { + return None; + } + let uri = request_line_parts.next()?.to_string(); + + let mut headers = Vec::new(); + let mut content_length = None; + let mut boundary = None; + for line in lines { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_string(); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse::().ok(); + } + if name.eq_ignore_ascii_case("content-type") { + boundary = value + .split(';') + .map(str::trim) + .find_map(|part| part.strip_prefix("boundary=")) + .map(|boundary| boundary.trim_matches('"').to_string()); + } + headers.push((name, value)); + } + + let content_length = content_length?; + while received.len() - headers_end < content_length { + let mut chunk = [0u8; 1024]; + let read = stream.read(&mut chunk).await.ok()?; + if read == 0 { + return None; + } + received.extend_from_slice(&chunk[..read]); + } + + let body_end = headers_end + content_length; + let body = received[headers_end..body_end].to_vec(); + Some(HttpRealtimeRequest { + uri, + headers, + boundary: boundary?, + body, + }) +} + +fn parse_multipart_field(body: &[u8], boundary: &str, field_name: &str) -> Option { + let body = std::str::from_utf8(body).ok()?; + let delimiter = format!("--{boundary}"); + body.split(&delimiter).find_map(|part| { + let (headers, value) = part.split_once("\r\n\r\n")?; + if !headers.contains(&format!("name=\"{field_name}\"")) { + return None; + } + Some(value.trim_end_matches("\r\n").to_string()) + }) +} + +async fn start_session( + offer_sdp: String, + connection: WebSocketConnectionConfig, + connection_index: usize, + requests: Arc>>>, + request_log_updated: Arc, + start: std::time::Instant, +) -> Option { + let peer_connection = create_peer_connection().await?; + let (tx_request, rx_request) = mpsc::unbounded_channel::(); + let (tx_data_channel, rx_data_channel) = oneshot::channel::>(); + let tx_data_channel = Mutex::new(Some(tx_data_channel)); + + peer_connection.on_data_channel(Box::new(move |data_channel| { + let tx_request = tx_request.clone(); + if let Ok(mut tx_data_channel) = tx_data_channel.lock() + && let Some(tx_data_channel) = tx_data_channel.take() + { + let _ = tx_data_channel.send(Arc::clone(&data_channel)); + } + data_channel.on_message(Box::new(move |message: DataChannelMessage| { + let tx_request = tx_request.clone(); + Box::pin(async move { + if !message.is_string { + return; + } + let Ok(text) = String::from_utf8(message.data.to_vec()) else { + return; + }; + let Ok(body) = serde_json::from_str::(&text) else { + return; + }; + let _ = tx_request.send(body); + }) + })); + Box::pin(async {}) + })); + + register_remote_audio_handler(&peer_connection, tx_request.clone()); + + let mut gather_complete = peer_connection.gathering_complete_promise().await; + let offer = RTCSessionDescription::offer(offer_sdp).ok()?; + peer_connection.set_remote_description(offer).await.ok()?; + let answer = peer_connection.create_answer(None).await.ok()?; + peer_connection.set_local_description(answer).await.ok()?; + let _ = gather_complete.recv().await; + let answer_sdp = peer_connection.local_description().await?.sdp; + + let data_channel = timeout(REALTIME_DATA_CHANNEL_TIMEOUT, rx_data_channel) + .await + .ok()? + .ok()?; + + let (done_tx, done_rx) = oneshot::channel(); + tokio::spawn(async move { + serve_scripted_requests( + connection, + connection_index, + requests, + request_log_updated, + rx_request, + data_channel, + start, + ) + .await; + let _ = done_tx.send(()); + }); + + Some(RealtimeSession { + answer_sdp, + peer_connection, + done_rx, + }) +} + +async fn create_peer_connection() -> Option> { + let mut media_engine = MediaEngine::default(); + media_engine.register_default_codecs().ok()?; + let registry = register_default_interceptors(Registry::new(), &mut media_engine).ok()?; + let api = APIBuilder::new() + .with_media_engine(media_engine) + .with_interceptor_registry(registry) + .build(); + api.new_peer_connection(RTCConfiguration::default()) + .await + .map(Arc::new) + .ok() +} + +fn register_remote_audio_handler( + peer_connection: &Arc, + tx_request: mpsc::UnboundedSender, +) { + peer_connection.on_track(Box::new(move |track, _, _| { + let tx_request = tx_request.clone(); + Box::pin(async move { + if track.kind() != RTPCodecType::Audio { + return; + } + pump_remote_audio_track(track, tx_request).await; + }) + })); +} + +async fn pump_remote_audio_track( + track: Arc, + tx_request: mpsc::UnboundedSender, +) { + let Ok(mut decoder) = OpusDecoder::new(24_000, Channels::Mono) else { + return; + }; + let mut decoded = vec![0i16; REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL]; + + while let Ok((packet, _)) = track.read_rtp().await { + if packet.payload.is_empty() { + continue; + } + let Ok(samples_per_channel) = decoder.decode(&packet.payload, &mut decoded, false) else { + return; + }; + if samples_per_channel == 0 { + continue; + } + + let mut pcm_bytes = Vec::with_capacity(samples_per_channel * 2); + for sample in &decoded[..samples_per_channel] { + pcm_bytes.extend_from_slice(&sample.to_le_bytes()); + } + let _ = tx_request.send(serde_json::json!({ + "type": "input_audio_buffer.append", + "audio": BASE64_STANDARD.encode(pcm_bytes), + "sample_rate": REALTIME_AUDIO_SAMPLE_RATE, + "channels": REALTIME_AUDIO_CHANNELS, + "samples_per_channel": samples_per_channel, + })); + } +} + +async fn serve_scripted_requests( + connection: WebSocketConnectionConfig, + connection_index: usize, + requests: Arc>>>, + request_log_updated: Arc, + mut rx_request: mpsc::UnboundedReceiver, + data_channel: Arc, + start: std::time::Instant, +) { + let mut scripted_requests = VecDeque::from(connection.requests); + while let Some(request_events) = scripted_requests.pop_front() { + let Some(body) = rx_request.recv().await else { + break; + }; + log_request( + connection_index, + body, + &requests, + &request_log_updated, + start, + ); + eprintln!( + "[ws test server +{}ms] connection={} sending batch_size={} event_types={:?} audio_data={:?}", + start.elapsed().as_millis(), + connection_index, + request_events.len(), + request_events + .iter() + .map(|event| event.get("type").and_then(Value::as_str)) + .collect::>(), + request_events + .iter() + .find_map(|event| event.get("delta").and_then(Value::as_str)), + ); + for event in &request_events { + let Ok(payload) = serde_json::to_string(event) else { + continue; + }; + if data_channel.send_text(payload).await.is_err() { + return; + } + } + } + + if connection.close_after_requests { + let _ = data_channel.close().await; + } +} + +fn log_request( + connection_index: usize, + body: Value, + requests: &Arc>>>, + request_log_updated: &Arc, + start: std::time::Instant, +) { + let mut log = requests.lock().unwrap(); + if log.len() <= connection_index { + log.resize_with(connection_index + 1, Vec::new); + } + if let Some(connection_log) = log.get_mut(connection_index) { + connection_log.push(WebSocketRequest { body }); + let request_index = connection_log.len() - 1; + let request_body = connection_log[request_index].body_json(); + eprintln!( + "[ws test server +{}ms] connection={} received request={} type={:?} role={:?} text={:?} data={:?}", + start.elapsed().as_millis(), + connection_index, + request_index, + request_body.get("type").and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("role")) + .and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|content| content.get("text")) + .and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|content| content.get("data")) + .and_then(Value::as_str), + ); + } + drop(log); + request_log_updated.notify_waiters(); +} diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index b169e229a5..660168aaf2 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -256,7 +256,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { ); assert_eq!( server.handshakes()[1].uri(), - "/v1/realtime?intent=quicksilver&model=realtime-test-model" + "/v1/realtime/calls?intent=quicksilver&model=realtime-test-model" ); let mut request_types = [ connection[1].body_json()["type"]