Add WebRTC support to realtime test server

Teach the shared test helper to accept realtime /calls POSTs, answer SDP offers, and relay scripted events over a data channel while logging incoming RTP audio packets as synthetic append requests. Update the one stale handshake-path assertion to the /realtime/calls URL.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-04 13:14:31 -07:00
parent 63c1223141
commit d6d8d6304d
5 changed files with 453 additions and 1 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -3407,6 +3407,7 @@ dependencies = [
"notify",
"opentelemetry",
"opentelemetry_sdk",
"opus",
"pretty_assertions",
"regex-lite",
"reqwest",
@@ -3419,6 +3420,7 @@ dependencies = [
"tracing-opentelemetry",
"tracing-subscriber",
"walkdir",
"webrtc",
"wiremock",
"zstd",
]

View File

@@ -29,6 +29,7 @@ futures = { workspace = true }
notify = { workspace = true }
opentelemetry = { workspace = true }
opentelemetry_sdk = { workspace = true }
opus = { workspace = true }
regex-lite = { workspace = true }
serde_json = { workspace = true }
tempfile = { workspace = true }
@@ -38,6 +39,7 @@ tracing = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true }
walkdir = { workspace = true }
webrtc = { workspace = true }
wiremock = { workspace = true }
shlex = { workspace = true }
zstd = { workspace = true }

View File

@@ -37,6 +37,8 @@ use wiremock::matchers::path_regex;
use crate::test_codex::ApplyPatchModelOutput;
mod realtime_webrtc_server;
#[derive(Debug, Clone)]
pub struct ResponseMock {
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
@@ -1238,6 +1240,29 @@ pub async fn start_websocket_server_with_headers(
tokio::time::sleep(delay).await;
}
if realtime_webrtc_server::accept_is_http_post(&stream).await {
let connection_index = {
let mut log = requests.lock().unwrap();
log.push(Vec::new());
log.len() - 1
};
realtime_webrtc_server::serve_connection(
stream,
connection,
connection_index,
Arc::clone(&requests),
Arc::clone(&handshakes),
Arc::clone(&request_log),
start,
)
.await;
if connections.lock().unwrap().is_empty() {
return;
}
continue;
}
let response_headers = connection.response_headers.clone();
let handshake_log = Arc::clone(&handshakes);
let callback = move |req: &Request, mut response: Response| {

View File

@@ -0,0 +1,423 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use audiopus::Channels;
use audiopus::coder::Decoder as OpusDecoder;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde_json::Value;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::timeout;
use webrtc::api::APIBuilder;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::data_channel::RTCDataChannel;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp_transceiver::rtp_codec::RTPCodecType;
use webrtc::track::track_remote::TrackRemote;
use super::WebSocketConnectionConfig;
use super::WebSocketHandshake;
use super::WebSocketRequest;
const HTTP_HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
const REALTIME_AUDIO_CHANNELS: u8 = 1;
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
const REALTIME_DATA_CHANNEL_TIMEOUT: Duration = Duration::from_secs(10);
const REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760;
pub(super) async fn accept_is_http_post(stream: &TcpStream) -> bool {
let mut method = [0u8; 4];
matches!(stream.peek(&mut method).await, Ok(4)) && method == *b"POST"
}
pub(super) async fn serve_connection(
mut stream: TcpStream,
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
request_log_updated: Arc<Notify>,
start: std::time::Instant,
) {
let Some(request) = read_http_request(&mut stream).await else {
return;
};
handshakes.lock().unwrap().push(WebSocketHandshake {
uri: request.uri,
headers: request.headers,
});
let Some(offer_sdp) = parse_multipart_field(&request.body, &request.boundary, "sdp") else {
let _ = stream
.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
.await;
return;
};
let Some(session) = start_session(
offer_sdp,
connection,
connection_index,
requests,
request_log_updated,
start,
)
.await
else {
let _ = stream
.write_all(b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n")
.await;
return;
};
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}",
session.answer_sdp.len(),
session.answer_sdp
);
if stream.write_all(response.as_bytes()).await.is_err() {
return;
}
let _ = session.done_rx.await;
let _ = session.peer_connection.close().await;
}
struct HttpRealtimeRequest {
uri: String,
headers: Vec<(String, String)>,
boundary: String,
body: Vec<u8>,
}
struct RealtimeSession {
answer_sdp: String,
peer_connection: Arc<RTCPeerConnection>,
done_rx: oneshot::Receiver<()>,
}
async fn read_http_request(stream: &mut TcpStream) -> Option<HttpRealtimeRequest> {
let mut received = Vec::new();
let headers_end = loop {
if let Some(headers_end) = received
.windows(HTTP_HEADER_TERMINATOR.len())
.position(|window| window == HTTP_HEADER_TERMINATOR)
{
break headers_end + HTTP_HEADER_TERMINATOR.len();
}
let mut chunk = [0u8; 1024];
let read = stream.read(&mut chunk).await.ok()?;
if read == 0 {
return None;
}
received.extend_from_slice(&chunk[..read]);
};
let header_text = std::str::from_utf8(&received[..headers_end]).ok()?;
let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
let request_line = lines.next()?;
let mut request_line_parts = request_line.split_whitespace();
if request_line_parts.next()? != "POST" {
return None;
}
let uri = request_line_parts.next()?.to_string();
let mut headers = Vec::new();
let mut content_length = None;
let mut boundary = None;
for line in lines {
let Some((name, value)) = line.split_once(':') else {
continue;
};
let name = name.trim().to_string();
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse::<usize>().ok();
}
if name.eq_ignore_ascii_case("content-type") {
boundary = value
.split(';')
.map(str::trim)
.find_map(|part| part.strip_prefix("boundary="))
.map(|boundary| boundary.trim_matches('"').to_string());
}
headers.push((name, value));
}
let content_length = content_length?;
while received.len() - headers_end < content_length {
let mut chunk = [0u8; 1024];
let read = stream.read(&mut chunk).await.ok()?;
if read == 0 {
return None;
}
received.extend_from_slice(&chunk[..read]);
}
let body_end = headers_end + content_length;
let body = received[headers_end..body_end].to_vec();
Some(HttpRealtimeRequest {
uri,
headers,
boundary: boundary?,
body,
})
}
fn parse_multipart_field(body: &[u8], boundary: &str, field_name: &str) -> Option<String> {
let body = std::str::from_utf8(body).ok()?;
let delimiter = format!("--{boundary}");
body.split(&delimiter).find_map(|part| {
let (headers, value) = part.split_once("\r\n\r\n")?;
if !headers.contains(&format!("name=\"{field_name}\"")) {
return None;
}
Some(value.trim_end_matches("\r\n").to_string())
})
}
async fn start_session(
offer_sdp: String,
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: Arc<Notify>,
start: std::time::Instant,
) -> Option<RealtimeSession> {
let peer_connection = create_peer_connection().await?;
let (tx_request, rx_request) = mpsc::unbounded_channel::<Value>();
let (tx_data_channel, rx_data_channel) = oneshot::channel::<Arc<RTCDataChannel>>();
let tx_data_channel = Mutex::new(Some(tx_data_channel));
peer_connection.on_data_channel(Box::new(move |data_channel| {
let tx_request = tx_request.clone();
if let Ok(mut tx_data_channel) = tx_data_channel.lock()
&& let Some(tx_data_channel) = tx_data_channel.take()
{
let _ = tx_data_channel.send(Arc::clone(&data_channel));
}
data_channel.on_message(Box::new(move |message: DataChannelMessage| {
let tx_request = tx_request.clone();
Box::pin(async move {
if !message.is_string {
return;
}
let Ok(text) = String::from_utf8(message.data.to_vec()) else {
return;
};
let Ok(body) = serde_json::from_str::<Value>(&text) else {
return;
};
let _ = tx_request.send(body);
})
}));
Box::pin(async {})
}));
register_remote_audio_handler(&peer_connection, tx_request.clone());
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let offer = RTCSessionDescription::offer(offer_sdp).ok()?;
peer_connection.set_remote_description(offer).await.ok()?;
let answer = peer_connection.create_answer(None).await.ok()?;
peer_connection.set_local_description(answer).await.ok()?;
let _ = gather_complete.recv().await;
let answer_sdp = peer_connection.local_description().await?.sdp;
let data_channel = timeout(REALTIME_DATA_CHANNEL_TIMEOUT, rx_data_channel)
.await
.ok()?
.ok()?;
let (done_tx, done_rx) = oneshot::channel();
tokio::spawn(async move {
serve_scripted_requests(
connection,
connection_index,
requests,
request_log_updated,
rx_request,
data_channel,
start,
)
.await;
let _ = done_tx.send(());
});
Some(RealtimeSession {
answer_sdp,
peer_connection,
done_rx,
})
}
async fn create_peer_connection() -> Option<Arc<RTCPeerConnection>> {
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs().ok()?;
let registry = register_default_interceptors(Registry::new(), &mut media_engine).ok()?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
api.new_peer_connection(RTCConfiguration::default())
.await
.map(Arc::new)
.ok()
}
fn register_remote_audio_handler(
peer_connection: &Arc<RTCPeerConnection>,
tx_request: mpsc::UnboundedSender<Value>,
) {
peer_connection.on_track(Box::new(move |track, _, _| {
let tx_request = tx_request.clone();
Box::pin(async move {
if track.kind() != RTPCodecType::Audio {
return;
}
pump_remote_audio_track(track, tx_request).await;
})
}));
}
async fn pump_remote_audio_track(
track: Arc<TrackRemote>,
tx_request: mpsc::UnboundedSender<Value>,
) {
let Ok(mut decoder) = OpusDecoder::new(24_000, Channels::Mono) else {
return;
};
let mut decoded = vec![0i16; REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL];
while let Ok((packet, _)) = track.read_rtp().await {
if packet.payload.is_empty() {
continue;
}
let Ok(samples_per_channel) = decoder.decode(&packet.payload, &mut decoded, false) else {
return;
};
if samples_per_channel == 0 {
continue;
}
let mut pcm_bytes = Vec::with_capacity(samples_per_channel * 2);
for sample in &decoded[..samples_per_channel] {
pcm_bytes.extend_from_slice(&sample.to_le_bytes());
}
let _ = tx_request.send(serde_json::json!({
"type": "input_audio_buffer.append",
"audio": BASE64_STANDARD.encode(pcm_bytes),
"sample_rate": REALTIME_AUDIO_SAMPLE_RATE,
"channels": REALTIME_AUDIO_CHANNELS,
"samples_per_channel": samples_per_channel,
}));
}
}
async fn serve_scripted_requests(
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: Arc<Notify>,
mut rx_request: mpsc::UnboundedReceiver<Value>,
data_channel: Arc<RTCDataChannel>,
start: std::time::Instant,
) {
let mut scripted_requests = VecDeque::from(connection.requests);
while let Some(request_events) = scripted_requests.pop_front() {
let Some(body) = rx_request.recv().await else {
break;
};
log_request(
connection_index,
body,
&requests,
&request_log_updated,
start,
);
eprintln!(
"[ws test server +{}ms] connection={} sending batch_size={} event_types={:?} audio_data={:?}",
start.elapsed().as_millis(),
connection_index,
request_events.len(),
request_events
.iter()
.map(|event| event.get("type").and_then(Value::as_str))
.collect::<Vec<_>>(),
request_events
.iter()
.find_map(|event| event.get("delta").and_then(Value::as_str)),
);
for event in &request_events {
let Ok(payload) = serde_json::to_string(event) else {
continue;
};
if data_channel.send_text(payload).await.is_err() {
return;
}
}
}
if connection.close_after_requests {
let _ = data_channel.close().await;
}
}
fn log_request(
connection_index: usize,
body: Value,
requests: &Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: &Arc<Notify>,
start: std::time::Instant,
) {
let mut log = requests.lock().unwrap();
if log.len() <= connection_index {
log.resize_with(connection_index + 1, Vec::new);
}
if let Some(connection_log) = log.get_mut(connection_index) {
connection_log.push(WebSocketRequest { body });
let request_index = connection_log.len() - 1;
let request_body = connection_log[request_index].body_json();
eprintln!(
"[ws test server +{}ms] connection={} received request={} type={:?} role={:?} text={:?} data={:?}",
start.elapsed().as_millis(),
connection_index,
request_index,
request_body.get("type").and_then(Value::as_str),
request_body
.get("item")
.and_then(|item| item.get("role"))
.and_then(Value::as_str),
request_body
.get("item")
.and_then(|item| item.get("content"))
.and_then(Value::as_array)
.and_then(|content| content.first())
.and_then(|content| content.get("text"))
.and_then(Value::as_str),
request_body
.get("item")
.and_then(|item| item.get("content"))
.and_then(Value::as_array)
.and_then(|content| content.first())
.and_then(|content| content.get("data"))
.and_then(Value::as_str),
);
}
drop(log);
request_log_updated.notify_waiters();
}

View File

@@ -256,7 +256,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
);
assert_eq!(
server.handshakes()[1].uri(),
"/v1/realtime?intent=quicksilver&model=realtime-test-model"
"/v1/realtime/calls?intent=quicksilver&model=realtime-test-model"
);
let mut request_types = [
connection[1].body_json()["type"]