mirror of
https://github.com/openai/codex.git
synced 2026-04-23 22:24:57 +00:00
Add WebRTC support to realtime test server
Teach the shared test helper to accept realtime /calls POSTs, answer SDP offers, and relay scripted events over a data channel while logging incoming RTP audio packets as synthetic append requests. Update the one stale handshake-path assertion to the /realtime/calls URL. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -3407,6 +3407,7 @@ dependencies = [
|
||||
"notify",
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"opus",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
@@ -3419,6 +3420,7 @@ dependencies = [
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"walkdir",
|
||||
"webrtc",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
@@ -29,6 +29,7 @@ futures = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
opus = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
@@ -38,6 +39,7 @@ tracing = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
webrtc = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
@@ -37,6 +37,8 @@ use wiremock::matchers::path_regex;
|
||||
|
||||
use crate::test_codex::ApplyPatchModelOutput;
|
||||
|
||||
mod realtime_webrtc_server;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseMock {
|
||||
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
|
||||
@@ -1238,6 +1240,29 @@ pub async fn start_websocket_server_with_headers(
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
if realtime_webrtc_server::accept_is_http_post(&stream).await {
|
||||
let connection_index = {
|
||||
let mut log = requests.lock().unwrap();
|
||||
log.push(Vec::new());
|
||||
log.len() - 1
|
||||
};
|
||||
realtime_webrtc_server::serve_connection(
|
||||
stream,
|
||||
connection,
|
||||
connection_index,
|
||||
Arc::clone(&requests),
|
||||
Arc::clone(&handshakes),
|
||||
Arc::clone(&request_log),
|
||||
start,
|
||||
)
|
||||
.await;
|
||||
|
||||
if connections.lock().unwrap().is_empty() {
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let response_headers = connection.response_headers.clone();
|
||||
let handshake_log = Arc::clone(&handshakes);
|
||||
let callback = move |req: &Request, mut response: Response| {
|
||||
|
||||
423
codex-rs/core/tests/common/responses/realtime_webrtc_server.rs
Normal file
423
codex-rs/core/tests/common/responses/realtime_webrtc_server.rs
Normal file
@@ -0,0 +1,423 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use audiopus::Channels;
|
||||
use audiopus::coder::Decoder as OpusDecoder;
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use serde_json::Value;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use webrtc::api::APIBuilder;
|
||||
use webrtc::api::interceptor_registry::register_default_interceptors;
|
||||
use webrtc::api::media_engine::MediaEngine;
|
||||
use webrtc::data_channel::RTCDataChannel;
|
||||
use webrtc::data_channel::data_channel_message::DataChannelMessage;
|
||||
use webrtc::interceptor::registry::Registry;
|
||||
use webrtc::peer_connection::RTCPeerConnection;
|
||||
use webrtc::peer_connection::configuration::RTCConfiguration;
|
||||
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
use webrtc::rtp_transceiver::rtp_codec::RTPCodecType;
|
||||
use webrtc::track::track_remote::TrackRemote;
|
||||
|
||||
use super::WebSocketConnectionConfig;
|
||||
use super::WebSocketHandshake;
|
||||
use super::WebSocketRequest;
|
||||
|
||||
const HTTP_HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
|
||||
const REALTIME_AUDIO_CHANNELS: u8 = 1;
|
||||
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const REALTIME_DATA_CHANNEL_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760;
|
||||
|
||||
pub(super) async fn accept_is_http_post(stream: &TcpStream) -> bool {
|
||||
let mut method = [0u8; 4];
|
||||
matches!(stream.peek(&mut method).await, Ok(4)) && method == *b"POST"
|
||||
}
|
||||
|
||||
pub(super) async fn serve_connection(
|
||||
mut stream: TcpStream,
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
start: std::time::Instant,
|
||||
) {
|
||||
let Some(request) = read_http_request(&mut stream).await else {
|
||||
return;
|
||||
};
|
||||
|
||||
handshakes.lock().unwrap().push(WebSocketHandshake {
|
||||
uri: request.uri,
|
||||
headers: request.headers,
|
||||
});
|
||||
|
||||
let Some(offer_sdp) = parse_multipart_field(&request.body, &request.boundary, "sdp") else {
|
||||
let _ = stream
|
||||
.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(session) = start_session(
|
||||
offer_sdp,
|
||||
connection,
|
||||
connection_index,
|
||||
requests,
|
||||
request_log_updated,
|
||||
start,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
let _ = stream
|
||||
.write_all(b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n")
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}",
|
||||
session.answer_sdp.len(),
|
||||
session.answer_sdp
|
||||
);
|
||||
if stream.write_all(response.as_bytes()).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = session.done_rx.await;
|
||||
let _ = session.peer_connection.close().await;
|
||||
}
|
||||
|
||||
struct HttpRealtimeRequest {
|
||||
uri: String,
|
||||
headers: Vec<(String, String)>,
|
||||
boundary: String,
|
||||
body: Vec<u8>,
|
||||
}
|
||||
|
||||
struct RealtimeSession {
|
||||
answer_sdp: String,
|
||||
peer_connection: Arc<RTCPeerConnection>,
|
||||
done_rx: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
async fn read_http_request(stream: &mut TcpStream) -> Option<HttpRealtimeRequest> {
|
||||
let mut received = Vec::new();
|
||||
let headers_end = loop {
|
||||
if let Some(headers_end) = received
|
||||
.windows(HTTP_HEADER_TERMINATOR.len())
|
||||
.position(|window| window == HTTP_HEADER_TERMINATOR)
|
||||
{
|
||||
break headers_end + HTTP_HEADER_TERMINATOR.len();
|
||||
}
|
||||
|
||||
let mut chunk = [0u8; 1024];
|
||||
let read = stream.read(&mut chunk).await.ok()?;
|
||||
if read == 0 {
|
||||
return None;
|
||||
}
|
||||
received.extend_from_slice(&chunk[..read]);
|
||||
};
|
||||
|
||||
let header_text = std::str::from_utf8(&received[..headers_end]).ok()?;
|
||||
let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
|
||||
let request_line = lines.next()?;
|
||||
let mut request_line_parts = request_line.split_whitespace();
|
||||
if request_line_parts.next()? != "POST" {
|
||||
return None;
|
||||
}
|
||||
let uri = request_line_parts.next()?.to_string();
|
||||
|
||||
let mut headers = Vec::new();
|
||||
let mut content_length = None;
|
||||
let mut boundary = None;
|
||||
for line in lines {
|
||||
let Some((name, value)) = line.split_once(':') else {
|
||||
continue;
|
||||
};
|
||||
let name = name.trim().to_string();
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse::<usize>().ok();
|
||||
}
|
||||
if name.eq_ignore_ascii_case("content-type") {
|
||||
boundary = value
|
||||
.split(';')
|
||||
.map(str::trim)
|
||||
.find_map(|part| part.strip_prefix("boundary="))
|
||||
.map(|boundary| boundary.trim_matches('"').to_string());
|
||||
}
|
||||
headers.push((name, value));
|
||||
}
|
||||
|
||||
let content_length = content_length?;
|
||||
while received.len() - headers_end < content_length {
|
||||
let mut chunk = [0u8; 1024];
|
||||
let read = stream.read(&mut chunk).await.ok()?;
|
||||
if read == 0 {
|
||||
return None;
|
||||
}
|
||||
received.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
let body_end = headers_end + content_length;
|
||||
let body = received[headers_end..body_end].to_vec();
|
||||
Some(HttpRealtimeRequest {
|
||||
uri,
|
||||
headers,
|
||||
boundary: boundary?,
|
||||
body,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_multipart_field(body: &[u8], boundary: &str, field_name: &str) -> Option<String> {
|
||||
let body = std::str::from_utf8(body).ok()?;
|
||||
let delimiter = format!("--{boundary}");
|
||||
body.split(&delimiter).find_map(|part| {
|
||||
let (headers, value) = part.split_once("\r\n\r\n")?;
|
||||
if !headers.contains(&format!("name=\"{field_name}\"")) {
|
||||
return None;
|
||||
}
|
||||
Some(value.trim_end_matches("\r\n").to_string())
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_session(
|
||||
offer_sdp: String,
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
start: std::time::Instant,
|
||||
) -> Option<RealtimeSession> {
|
||||
let peer_connection = create_peer_connection().await?;
|
||||
let (tx_request, rx_request) = mpsc::unbounded_channel::<Value>();
|
||||
let (tx_data_channel, rx_data_channel) = oneshot::channel::<Arc<RTCDataChannel>>();
|
||||
let tx_data_channel = Mutex::new(Some(tx_data_channel));
|
||||
|
||||
peer_connection.on_data_channel(Box::new(move |data_channel| {
|
||||
let tx_request = tx_request.clone();
|
||||
if let Ok(mut tx_data_channel) = tx_data_channel.lock()
|
||||
&& let Some(tx_data_channel) = tx_data_channel.take()
|
||||
{
|
||||
let _ = tx_data_channel.send(Arc::clone(&data_channel));
|
||||
}
|
||||
data_channel.on_message(Box::new(move |message: DataChannelMessage| {
|
||||
let tx_request = tx_request.clone();
|
||||
Box::pin(async move {
|
||||
if !message.is_string {
|
||||
return;
|
||||
}
|
||||
let Ok(text) = String::from_utf8(message.data.to_vec()) else {
|
||||
return;
|
||||
};
|
||||
let Ok(body) = serde_json::from_str::<Value>(&text) else {
|
||||
return;
|
||||
};
|
||||
let _ = tx_request.send(body);
|
||||
})
|
||||
}));
|
||||
Box::pin(async {})
|
||||
}));
|
||||
|
||||
register_remote_audio_handler(&peer_connection, tx_request.clone());
|
||||
|
||||
let mut gather_complete = peer_connection.gathering_complete_promise().await;
|
||||
let offer = RTCSessionDescription::offer(offer_sdp).ok()?;
|
||||
peer_connection.set_remote_description(offer).await.ok()?;
|
||||
let answer = peer_connection.create_answer(None).await.ok()?;
|
||||
peer_connection.set_local_description(answer).await.ok()?;
|
||||
let _ = gather_complete.recv().await;
|
||||
let answer_sdp = peer_connection.local_description().await?.sdp;
|
||||
|
||||
let data_channel = timeout(REALTIME_DATA_CHANNEL_TIMEOUT, rx_data_channel)
|
||||
.await
|
||||
.ok()?
|
||||
.ok()?;
|
||||
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
serve_scripted_requests(
|
||||
connection,
|
||||
connection_index,
|
||||
requests,
|
||||
request_log_updated,
|
||||
rx_request,
|
||||
data_channel,
|
||||
start,
|
||||
)
|
||||
.await;
|
||||
let _ = done_tx.send(());
|
||||
});
|
||||
|
||||
Some(RealtimeSession {
|
||||
answer_sdp,
|
||||
peer_connection,
|
||||
done_rx,
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_peer_connection() -> Option<Arc<RTCPeerConnection>> {
|
||||
let mut media_engine = MediaEngine::default();
|
||||
media_engine.register_default_codecs().ok()?;
|
||||
let registry = register_default_interceptors(Registry::new(), &mut media_engine).ok()?;
|
||||
let api = APIBuilder::new()
|
||||
.with_media_engine(media_engine)
|
||||
.with_interceptor_registry(registry)
|
||||
.build();
|
||||
api.new_peer_connection(RTCConfiguration::default())
|
||||
.await
|
||||
.map(Arc::new)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn register_remote_audio_handler(
|
||||
peer_connection: &Arc<RTCPeerConnection>,
|
||||
tx_request: mpsc::UnboundedSender<Value>,
|
||||
) {
|
||||
peer_connection.on_track(Box::new(move |track, _, _| {
|
||||
let tx_request = tx_request.clone();
|
||||
Box::pin(async move {
|
||||
if track.kind() != RTPCodecType::Audio {
|
||||
return;
|
||||
}
|
||||
pump_remote_audio_track(track, tx_request).await;
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
||||
async fn pump_remote_audio_track(
|
||||
track: Arc<TrackRemote>,
|
||||
tx_request: mpsc::UnboundedSender<Value>,
|
||||
) {
|
||||
let Ok(mut decoder) = OpusDecoder::new(24_000, Channels::Mono) else {
|
||||
return;
|
||||
};
|
||||
let mut decoded = vec![0i16; REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL];
|
||||
|
||||
while let Ok((packet, _)) = track.read_rtp().await {
|
||||
if packet.payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(samples_per_channel) = decoder.decode(&packet.payload, &mut decoded, false) else {
|
||||
return;
|
||||
};
|
||||
if samples_per_channel == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut pcm_bytes = Vec::with_capacity(samples_per_channel * 2);
|
||||
for sample in &decoded[..samples_per_channel] {
|
||||
pcm_bytes.extend_from_slice(&sample.to_le_bytes());
|
||||
}
|
||||
let _ = tx_request.send(serde_json::json!({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": BASE64_STANDARD.encode(pcm_bytes),
|
||||
"sample_rate": REALTIME_AUDIO_SAMPLE_RATE,
|
||||
"channels": REALTIME_AUDIO_CHANNELS,
|
||||
"samples_per_channel": samples_per_channel,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_scripted_requests(
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
mut rx_request: mpsc::UnboundedReceiver<Value>,
|
||||
data_channel: Arc<RTCDataChannel>,
|
||||
start: std::time::Instant,
|
||||
) {
|
||||
let mut scripted_requests = VecDeque::from(connection.requests);
|
||||
while let Some(request_events) = scripted_requests.pop_front() {
|
||||
let Some(body) = rx_request.recv().await else {
|
||||
break;
|
||||
};
|
||||
log_request(
|
||||
connection_index,
|
||||
body,
|
||||
&requests,
|
||||
&request_log_updated,
|
||||
start,
|
||||
);
|
||||
eprintln!(
|
||||
"[ws test server +{}ms] connection={} sending batch_size={} event_types={:?} audio_data={:?}",
|
||||
start.elapsed().as_millis(),
|
||||
connection_index,
|
||||
request_events.len(),
|
||||
request_events
|
||||
.iter()
|
||||
.map(|event| event.get("type").and_then(Value::as_str))
|
||||
.collect::<Vec<_>>(),
|
||||
request_events
|
||||
.iter()
|
||||
.find_map(|event| event.get("delta").and_then(Value::as_str)),
|
||||
);
|
||||
for event in &request_events {
|
||||
let Ok(payload) = serde_json::to_string(event) else {
|
||||
continue;
|
||||
};
|
||||
if data_channel.send_text(payload).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if connection.close_after_requests {
|
||||
let _ = data_channel.close().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn log_request(
|
||||
connection_index: usize,
|
||||
body: Value,
|
||||
requests: &Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: &Arc<Notify>,
|
||||
start: std::time::Instant,
|
||||
) {
|
||||
let mut log = requests.lock().unwrap();
|
||||
if log.len() <= connection_index {
|
||||
log.resize_with(connection_index + 1, Vec::new);
|
||||
}
|
||||
if let Some(connection_log) = log.get_mut(connection_index) {
|
||||
connection_log.push(WebSocketRequest { body });
|
||||
let request_index = connection_log.len() - 1;
|
||||
let request_body = connection_log[request_index].body_json();
|
||||
eprintln!(
|
||||
"[ws test server +{}ms] connection={} received request={} type={:?} role={:?} text={:?} data={:?}",
|
||||
start.elapsed().as_millis(),
|
||||
connection_index,
|
||||
request_index,
|
||||
request_body.get("type").and_then(Value::as_str),
|
||||
request_body
|
||||
.get("item")
|
||||
.and_then(|item| item.get("role"))
|
||||
.and_then(Value::as_str),
|
||||
request_body
|
||||
.get("item")
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|content| content.get("text"))
|
||||
.and_then(Value::as_str),
|
||||
request_body
|
||||
.get("item")
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|content| content.get("data"))
|
||||
.and_then(Value::as_str),
|
||||
);
|
||||
}
|
||||
drop(log);
|
||||
request_log_updated.notify_waiters();
|
||||
}
|
||||
@@ -256,7 +256,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
);
|
||||
assert_eq!(
|
||||
server.handshakes()[1].uri(),
|
||||
"/v1/realtime?intent=quicksilver&model=realtime-test-model"
|
||||
"/v1/realtime/calls?intent=quicksilver&model=realtime-test-model"
|
||||
);
|
||||
let mut request_types = [
|
||||
connection[1].body_json()["type"]
|
||||
|
||||
Reference in New Issue
Block a user