Compare commits

...

24 Commits

Author SHA1 Message Date
Ahmed Ibrahim
b080c91693 Use aec3 for TUI realtime echo cancellation
Replace the native audio-processing dependency with a smaller pure-Rust realtime echo-cancellation path.

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 14:38:57 -07:00
Ahmed Ibrahim
3e206d1e6a codex: fix CI failure on PR #16806
Co-authored-by: Codex <noreply@openai.com>
2026-04-05 14:10:37 -07:00
Ahmed Ibrahim
3e203f8633 codex: fix CI failure on PR #16806
Co-authored-by: Codex <noreply@openai.com>
2026-04-05 14:08:57 -07:00
Ahmed Ibrahim
90f10e9ab9 Route TUI realtime audio through shared echo cancellation
Co-authored-by: Codex <noreply@openai.com>
2026-04-05 13:00:28 -07:00
Ahmed Ibrahim
8207ff6ead codex: fix realtime CI failures
Install the rustls crypto provider in the realtime WebRTC test server so the app-server realtime tests stop panicking, and update the core handshake assertion to match the current realtime calls URL.

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 10:57:48 -07:00
Ahmed Ibrahim
99f72e27ae codex: fix CI failure on PR #16805
Keep the first realtime audio frame from being dropped during WebRTC startup.

Also give the realtime test server a tiny flush window before it closes the data channel so the last scripted events land reliably.

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 10:44:06 -07:00
Ahmed Ibrahim
5d5305c5d4 codex: fix CI failure on PR #16805
Stop the realtime test WebRTC server from hanging after the client closes.

Teach the scripted request loop to exit when the data channel or peer connection closes so compact-remote tests can unwind instead of timing out at shutdown.

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 10:26:58 -07:00
Ahmed Ibrahim
236891f5d3 Use non-silent realtime test audio fixtures
- replace silent PCM realtime test fixtures with a deterministic tone
- avoid codec paths optimizing away audio in WebRTC test flows

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 09:48:50 -07:00
Ahmed Ibrahim
11180cefd9 Fix realtime test audio fixtures
- use valid 24 kHz mono PCM audio in realtime tests
- keep websocket/WebRTC request sequencing aligned with transport behavior

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 01:38:42 -07:00
Ahmed Ibrahim
76397dbdd0 Use opus-rs for realtime transport
- switch realtime transport and test server to opus-rs
- drop native opus bazel and cmake plumbing

Co-authored-by: Codex <noreply@openai.com>
2026-04-05 00:43:41 -07:00
Ahmed Ibrahim
f9739d0178 Add git diff headers to Opus Bazel patch
Co-authored-by: Codex <noreply@openai.com>
2026-04-04 16:17:54 -07:00
Ahmed Ibrahim
bbff67a7b9 Fix Opus Bazel patch hunk headers
Co-authored-by: Codex <noreply@openai.com>
2026-04-04 16:13:38 -07:00
Ahmed Ibrahim
d887e0be7f Fix Opus Bazel patch paths for BCR overlays
Co-authored-by: Codex <noreply@openai.com>
2026-04-04 16:05:41 -07:00
Ahmed Ibrahim
1f9567d121 Disable Opus stack probes under Windows gnullvm Bazel builds
Co-authored-by: Codex <noreply@openai.com>
2026-04-04 15:58:17 -07:00
Ahmed Ibrahim
763dc66fb6 Return realtime test SDP answers before data channel setup
Co-authored-by: Codex <noreply@openai.com>
2026-04-04 15:30:07 -07:00
Ahmed Ibrahim
357140d3c9 Fix realtime parser item_id handling and test callsites
Preserve item_id on v1 audio deltas and annotate opaque None arguments in calls URL tests so argument-comment lint passes.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 14:18:19 -07:00
Ahmed Ibrahim
1aa17fafcc Fix realtime test server helper imports
Use the workspace opus crate name directly and clone the request sender before moving it into the data-channel callback so the RTP track handler can still enqueue decoded packets.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 14:11:47 -07:00
Ahmed Ibrahim
5462954edd codex: fix CI failure on PR #16805
Explicitly include first-level response helper modules in the core_test_support Bazel target so realtime_webrtc_server.rs is available to macOS builds and lints.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 13:31:53 -07:00
Ahmed Ibrahim
da1ad103fa codex: fix CI failure on PR #16805
Include nested response helper modules in the core_test_support Bazel target so the new realtime WebRTC test server source is visible to macOS Bazel and argument-comment-lint jobs.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 13:26:19 -07:00
Ahmed Ibrahim
4abb01d268 codex: fix CI failure on PR #16805
Ignore the workspace `opus` dependency in cargo-shear for core_test_support because Rust imports that package as `audiopus`.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 13:20:32 -07:00
Ahmed Ibrahim
d6d8d6304d Add WebRTC support to realtime test server
Teach the shared test helper to accept realtime /calls POSTs, answer SDP offers, and relay scripted events over a data channel while logging incoming RTP audio packets as synthetic append requests. Update the one stale handshake-path assertion to the /realtime/calls URL.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 13:14:31 -07:00
Ahmed Ibrahim
63c1223141 codex: fix CI failure on PR #16805
Link audiopus_sys against the Bazel opus module so remote Bazel builds do not depend on host cmake.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 12:56:37 -07:00
Ahmed Ibrahim
a40422b85f codex: fix CI failure on PR #16805
Install cmake on Linux and macOS Bazel runners so audiopus_sys build scripts can build bundled Opus.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 12:51:33 -07:00
Ahmed Ibrahim
175d831ff4 Replace realtime websocket transport with WebRTC
Move the realtime model transport implementation to WebRTC while keeping the core session/event interface intact for the TUI layer.

Co-authored-by: Codex <noreply@openai.com>
2026-04-04 12:43:42 -07:00
25 changed files with 2594 additions and 2197 deletions

62
MODULE.bazel.lock generated

File diff suppressed because one or more lines are too long

863
codex-rs/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -197,6 +197,7 @@ ansi-to-tui = "7.0.0"
anyhow = "1"
arboard = { version = "3", features = ["wayland-data-control"] }
arc-swap = "1.9.0"
aec3 = "0.1.7"
assert_cmd = "2"
assert_matches = "1.5.0"
async-channel = "2.3.1"
@@ -239,6 +240,7 @@ image = { version = "^0.25.9", default-features = false }
include_dir = "0.7.4"
indexmap = "2.12.0"
insta = "1.46.3"
interceptor = "0.17.1"
inventory = "0.3.19"
itertools = "0.14.0"
jsonwebtoken = "9.3.1"
@@ -255,6 +257,7 @@ notify = "8.2.0"
nucleo = { git = "https://github.com/helix-editor/nucleo.git", rev = "4253de9faabb4e5c6d81d946a5e35a90f87347ee" }
once_cell = "1.20.2"
openssl-sys = "*"
opus-rs = "0.1.11"
opentelemetry = "0.31.0"
opentelemetry-appender-tracing = "0.31.0"
opentelemetry-otlp = "0.31.0"
@@ -349,6 +352,7 @@ v8 = "=146.4.0"
vt100 = "0.16.2"
walkdir = "2.5.0"
webbrowser = "1.0"
webrtc = "0.17.1"
which = "8"
wildmatch = "2.6.1"
zip = "2.4.2"

View File

@@ -3,6 +3,8 @@ use anyhow::Result;
use app_test_support::McpProcess;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::to_response;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::LoginAccountResponse;
@@ -40,6 +42,16 @@ use tokio::time::timeout;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.";
fn realtime_pcm_test_tone_20ms_base64() -> String {
let pcm_bytes: Vec<u8> = (0..480)
.flat_map(|index| {
let sample = if index % 2 == 0 { 1024_i16 } else { -1024_i16 };
sample.to_le_bytes()
})
.collect();
BASE64_STANDARD.encode(pcm_bytes)
}
#[tokio::test]
async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -155,7 +167,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
.send_thread_realtime_append_audio_request(ThreadRealtimeAppendAudioParams {
thread_id: started.thread_id.clone(),
audio: ThreadRealtimeAudioChunk {
data: "BQYH".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24_000,
num_channels: 1,
samples_per_channel: Some(480),

View File

@@ -14,17 +14,21 @@ codex-protocol = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
futures = { workspace = true }
http = { workspace = true }
interceptor = { workspace = true }
opus-rs = { workspace = true }
reqwest = { workspace = true, features = ["multipart"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] }
tokio-tungstenite = { workspace = true }
tungstenite = { workspace = true }
tracing = { workspace = true }
eventsource-stream = { workspace = true }
regex-lite = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
tungstenite = { workspace = true }
url = { workspace = true }
webrtc = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -32,7 +36,6 @@ assert_matches = { workspace = true }
pretty_assertions = { workspace = true }
tokio-test = { workspace = true }
wiremock = { workspace = true }
reqwest = { workspace = true }
[lints]
workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,9 @@
use crate::endpoint::realtime_websocket::methods_v1::conversation_handoff_append_message as v1_conversation_handoff_append_message;
use crate::endpoint::realtime_websocket::methods_v1::conversation_item_create_message as v1_conversation_item_create_message;
use crate::endpoint::realtime_websocket::methods_v1::session_update_session as v1_session_update_session;
use crate::endpoint::realtime_websocket::methods_v1::websocket_intent as v1_websocket_intent;
use crate::endpoint::realtime_websocket::methods_v2::conversation_handoff_append_message as v2_conversation_handoff_append_message;
use crate::endpoint::realtime_websocket::methods_v2::conversation_item_create_message as v2_conversation_item_create_message;
use crate::endpoint::realtime_websocket::methods_v2::session_update_session as v2_session_update_session;
use crate::endpoint::realtime_websocket::methods_v2::websocket_intent as v2_websocket_intent;
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
@@ -59,10 +57,3 @@ pub(super) fn session_update_session(
RealtimeEventParser::RealtimeV2 => v2_session_update_session(instructions, session_mode),
}
}
pub(super) fn websocket_intent(event_parser: RealtimeEventParser) -> Option<&'static str> {
match event_parser {
RealtimeEventParser::V1 => v1_websocket_intent(),
RealtimeEventParser::RealtimeV2 => v2_websocket_intent(),
}
}

View File

@@ -61,7 +61,3 @@ pub(super) fn session_update_session(instructions: String) -> SessionUpdateSessi
tool_choice: None,
}
}
pub(super) fn websocket_intent() -> Option<&'static str> {
Some("quicksilver")
}

View File

@@ -126,7 +126,3 @@ pub(super) fn session_update_session(
},
}
}
pub(super) fn websocket_intent() -> Option<&'static str> {
None
}

View File

@@ -9,10 +9,10 @@ mod protocol_v2;
pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
pub use methods::RealtimeWebsocketClient;
pub use methods::RealtimeWebsocketConnection;
pub use methods::RealtimeWebsocketEvents;
pub use methods::RealtimeWebsocketWriter;
pub use methods::RealtimeWebRtcClient;
pub use methods::RealtimeWebRtcConnection;
pub use methods::RealtimeWebRtcEvents;
pub use methods::RealtimeWebRtcWriter;
pub use protocol::RealtimeEventParser;
pub use protocol::RealtimeSessionConfig;
pub use protocol::RealtimeSessionMode;

View File

@@ -32,8 +32,6 @@ pub struct RealtimeSessionConfig {
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend { audio: String },
#[serde(rename = "conversation.handoff.append")]
ConversationHandoffAppend {
handoff_id: String,

View File

@@ -35,7 +35,10 @@ pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
.get("samples_per_channel")
.and_then(Value::as_u64)
.and_then(|value| u32::try_from(value).ok()),
item_id: None,
item_id: parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string),
}))
}
"conversation.input_transcript.delta" => {

View File

@@ -34,8 +34,8 @@ pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
pub use crate::endpoint::realtime_websocket::RealtimeWebRtcClient;
pub use crate::endpoint::realtime_websocket::RealtimeWebRtcConnection;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;

View File

@@ -1,460 +0,0 @@
use std::collections::HashMap;
use std::future::Future;
use std::time::Duration;
use codex_api::RealtimeAudioFrame;
use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
use codex_protocol::protocol::RealtimeHandoffRequested;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use serde_json::Value;
use serde_json::json;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;
type RealtimeWsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
async fn spawn_realtime_ws_server<Handler, Fut>(
handler: Handler,
) -> (String, tokio::task::JoinHandle<()>)
where
Handler: FnOnce(RealtimeWsStream) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => listener,
Err(err) => panic!("failed to bind test websocket listener: {err}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr.to_string(),
Err(err) => panic!("failed to read local websocket listener address: {err}"),
};
let server = tokio::spawn(async move {
let (stream, _) = match listener.accept().await {
Ok(stream) => stream,
Err(err) => panic!("failed to accept test websocket connection: {err}"),
};
let ws = match accept_async(stream).await {
Ok(ws) => ws,
Err(err) => panic!("failed to complete websocket handshake: {err}"),
};
handler(ws).await;
});
(addr, server)
}
fn test_provider(base_url: String) -> Provider {
Provider {
name: "test".to_string(),
base_url,
query_params: Some(HashMap::new()),
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: false,
},
stream_idle_timeout: Duration::from_secs(5),
}
}
#[tokio::test]
async fn realtime_ws_e2e_session_create_and_event_flow() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("quicksilver".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["format"]["type"],
Value::String("audio/pcm".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["format"]["rate"],
Value::from(24_000)
);
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_mock", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "input_audio_buffer.append");
ws.send(Message::Text(
json!({
"type": "conversation.output_audio.delta",
"delta": "AQID",
"sample_rate": 48000,
"channels": 1
})
.to_string()
.into(),
))
.await
.expect("send audio out");
})
.await;
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let created = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionUpdated {
session_id: "sess_mock".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
connection
.send_audio_frame(RealtimeAudioFrame {
data: "AQID".to_string(),
sample_rate: 48000,
num_channels: 1,
samples_per_channel: Some(960),
item_id: None,
})
.await
.expect("send audio");
let audio_event = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
audio_event,
RealtimeEvent::AudioOut(RealtimeAudioFrame {
data: "AQID".to_string(),
sample_rate: 48000,
num_channels: 1,
samples_per_channel: None,
item_id: None,
})
);
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_ws_e2e_send_while_next_event_waits() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
let second = ws
.next()
.await
.expect("second msg")
.expect("second msg ok")
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "input_audio_buffer.append");
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_after_send", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
})
.await;
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let (send_result, next_result) = tokio::join!(
async {
tokio::time::timeout(
Duration::from_millis(200),
connection.send_audio_frame(RealtimeAudioFrame {
data: "AQID".to_string(),
sample_rate: 48000,
num_channels: 1,
samples_per_channel: Some(960),
item_id: None,
}),
)
.await
},
connection.next_event()
);
send_result
.expect("send should not block on next_event")
.expect("send audio");
let next_event = next_result.expect("next event").expect("event");
assert_eq!(
next_event,
RealtimeEvent::SessionUpdated {
session_id: "sess_after_send".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_ws_e2e_disconnected_emitted_once() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
ws.send(Message::Close(None)).await.expect("send close");
})
.await;
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let first = connection.next_event().await.expect("next event");
assert_eq!(first, None);
let second = connection.next_event().await.expect("next event");
assert_eq!(second, None);
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_ws_e2e_ignores_unknown_text_events() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
ws.send(Message::Text(
json!({
"type": "response.created",
"response": {"id": "resp_unknown"}
})
.to_string()
.into(),
))
.await
.expect("send unknown event");
ws.send(Message::Text(
json!({
"type": "session.updated",
"session": {"id": "sess_after_unknown", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.updated");
})
.await;
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::V1,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let event = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
event,
RealtimeEvent::SessionUpdated {
session_id: "sess_after_unknown".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
connection.close().await.expect("close");
server.await.expect("server task");
}
#[tokio::test]
async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
.await
.expect("first msg")
.expect("first msg ok")
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.update");
ws.send(Message::Text(
json!({
"type": "conversation.item.done",
"item": {
"id": "item_123",
"type": "function_call",
"name": "codex",
"call_id": "call_123",
"arguments": "{\"prompt\":\"delegate now\"}"
}
})
.to_string()
.into(),
))
.await
.expect("send function call");
})
.await;
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
let connection = client
.connect(
RealtimeSessionConfig {
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
event_parser: RealtimeEventParser::RealtimeV2,
session_mode: RealtimeSessionMode::Conversational,
},
HeaderMap::new(),
HeaderMap::new(),
)
.await
.expect("connect");
let event = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
event,
RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: "call_123".to_string(),
item_id: "item_123".to_string(),
input_transcript: "delegate now".to_string(),
active_transcript: Vec::new(),
})
);
connection.close().await.expect("close");
server.await.expect("server task");
}

View File

@@ -13,10 +13,10 @@ use codex_api::RealtimeEvent;
use codex_api::RealtimeEventParser;
use codex_api::RealtimeSessionConfig;
use codex_api::RealtimeSessionMode;
use codex_api::RealtimeWebsocketClient;
use codex_api::RealtimeWebRtcClient;
use codex_api::api_bridge::map_api_error;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
use codex_api::endpoint::realtime_websocket::RealtimeWebRtcEvents;
use codex_api::endpoint::realtime_websocket::RealtimeWebRtcWriter;
use codex_app_server_protocol::AuthMode;
use codex_login::CodexAuth;
use codex_login::default_client::default_headers;
@@ -107,8 +107,8 @@ struct OutputAudioState {
}
struct RealtimeInputTask {
writer: RealtimeWebsocketWriter,
events: RealtimeWebsocketEvents,
writer: RealtimeWebRtcWriter,
events: RealtimeWebRtcEvents,
user_text_rx: Receiver<String>,
handoff_output_rx: Receiver<HandoffOutput>,
audio_rx: Receiver<RealtimeAudioFrame>,
@@ -132,7 +132,7 @@ impl RealtimeHandoffState {
struct ConversationState {
audio_tx: Sender<RealtimeAudioFrame>,
user_text_tx: Sender<String>,
writer: RealtimeWebsocketWriter,
writer: RealtimeWebRtcWriter,
handoff: RealtimeHandoffState,
input_task: JoinHandle<()>,
fanout_task: Option<JoinHandle<()>>,
@@ -172,7 +172,7 @@ impl RealtimeConversationManager {
RealtimeEventParser::RealtimeV2 => RealtimeSessionKind::V2,
};
let client = RealtimeWebsocketClient::new(api_provider);
let client = RealtimeWebRtcClient::new(api_provider);
let connection = client
.connect(
session_config,
@@ -392,6 +392,7 @@ async fn stop_conversation_state(
fanout_task_stop: RealtimeFanoutTaskStop,
) {
state.realtime_active.store(false, Ordering::Relaxed);
let _ = state.writer.close().await;
state.input_task.abort();
let _ = state.input_task.await;

View File

@@ -3,7 +3,11 @@ load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "common",
crate_name = "core_test_support",
crate_srcs = glob(["*.rs"]),
crate_srcs = glob([
"*.rs",
"responses/*.rs",
"responses/**/*.rs",
]),
lib_data_extra = [
"//codex-rs/core:model_availability_nux_fixtures",
],

View File

@@ -24,11 +24,13 @@ codex-models-manager = { workspace = true }
codex-protocol = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-cargo-bin = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
ctor = { workspace = true }
futures = { workspace = true }
notify = { workspace = true }
opentelemetry = { workspace = true }
opentelemetry_sdk = { workspace = true }
opus-rs = { workspace = true }
regex-lite = { workspace = true }
serde_json = { workspace = true }
tempfile = { workspace = true }
@@ -38,6 +40,7 @@ tracing = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true }
walkdir = { workspace = true }
webrtc = { workspace = true }
wiremock = { workspace = true }
shlex = { workspace = true }
zstd = { workspace = true }

View File

@@ -37,6 +37,8 @@ use wiremock::matchers::path_regex;
use crate::test_codex::ApplyPatchModelOutput;
mod realtime_webrtc_server;
#[derive(Debug, Clone)]
pub struct ResponseMock {
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
@@ -1238,6 +1240,28 @@ pub async fn start_websocket_server_with_headers(
tokio::time::sleep(delay).await;
}
if realtime_webrtc_server::accept_is_http_post(&stream).await {
let connection_index = {
let mut log = requests.lock().unwrap();
log.push(Vec::new());
log.len() - 1
};
realtime_webrtc_server::serve_connection(
stream,
connection,
connection_index,
Arc::clone(&requests),
Arc::clone(&handshakes),
Arc::clone(&request_log),
)
.await;
if connections.lock().unwrap().is_empty() {
return;
}
continue;
}
let response_headers = connection.response_headers.clone();
let handshake_log = Arc::clone(&handshakes);
let callback = move |req: &Request, mut response: Response| {

View File

@@ -0,0 +1,431 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use opus_rs::OpusDecoder;
use serde_json::Value;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::debug;
use tracing::warn;
use webrtc::api::APIBuilder;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::data_channel::RTCDataChannel;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp_transceiver::rtp_codec::RTPCodecType;
use webrtc::track::track_remote::TrackRemote;
use super::WebSocketConnectionConfig;
use super::WebSocketHandshake;
use super::WebSocketRequest;
const HTTP_HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
const REALTIME_AUDIO_CHANNELS: u8 = 1;
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
const REALTIME_DATA_CHANNEL_TIMEOUT: Duration = Duration::from_secs(10);
const REALTIME_DATA_CHANNEL_FLUSH_DELAY: Duration = Duration::from_millis(25);
const REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760;
pub(super) async fn accept_is_http_post(stream: &TcpStream) -> bool {
let mut method = [0u8; 4];
matches!(stream.peek(&mut method).await, Ok(4)) && method == *b"POST"
}
pub(super) async fn serve_connection(
mut stream: TcpStream,
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
request_log_updated: Arc<Notify>,
) {
let Some(request) = read_http_request(&mut stream).await else {
return;
};
handshakes.lock().unwrap().push(WebSocketHandshake {
uri: request.uri,
headers: request.headers,
});
let Some(offer_sdp) = parse_multipart_field(&request.body, &request.boundary, "sdp") else {
let _ = stream
.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
.await;
return;
};
let Some(session) = start_session(
offer_sdp,
connection,
connection_index,
requests,
request_log_updated,
)
.await
else {
let _ = stream
.write_all(b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n")
.await;
return;
};
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}",
session.answer_sdp.len(),
session.answer_sdp
);
if stream.write_all(response.as_bytes()).await.is_err() {
return;
}
let _ = session.done_rx.await;
let _ = session.peer_connection.close().await;
}
struct HttpRealtimeRequest {
uri: String,
headers: Vec<(String, String)>,
boundary: String,
body: Vec<u8>,
}
struct RealtimeSession {
answer_sdp: String,
peer_connection: Arc<RTCPeerConnection>,
done_rx: oneshot::Receiver<()>,
}
async fn read_http_request(stream: &mut TcpStream) -> Option<HttpRealtimeRequest> {
let mut received = Vec::new();
let headers_end = loop {
if let Some(headers_end) = received
.windows(HTTP_HEADER_TERMINATOR.len())
.position(|window| window == HTTP_HEADER_TERMINATOR)
{
break headers_end + HTTP_HEADER_TERMINATOR.len();
}
let mut chunk = [0u8; 1024];
let read = stream.read(&mut chunk).await.ok()?;
if read == 0 {
return None;
}
received.extend_from_slice(&chunk[..read]);
};
let header_text = std::str::from_utf8(&received[..headers_end]).ok()?;
let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
let request_line = lines.next()?;
let mut request_line_parts = request_line.split_whitespace();
if request_line_parts.next()? != "POST" {
return None;
}
let uri = request_line_parts.next()?.to_string();
let mut headers = Vec::new();
let mut content_length = None;
let mut boundary = None;
for line in lines {
let Some((name, value)) = line.split_once(':') else {
continue;
};
let name = name.trim().to_string();
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse::<usize>().ok();
}
if name.eq_ignore_ascii_case("content-type") {
boundary = value
.split(';')
.map(str::trim)
.find_map(|part| part.strip_prefix("boundary="))
.map(|boundary| boundary.trim_matches('"').to_string());
}
headers.push((name, value));
}
let content_length = content_length?;
while received.len() - headers_end < content_length {
let mut chunk = [0u8; 1024];
let read = stream.read(&mut chunk).await.ok()?;
if read == 0 {
return None;
}
received.extend_from_slice(&chunk[..read]);
}
let body_end = headers_end + content_length;
let body = received[headers_end..body_end].to_vec();
Some(HttpRealtimeRequest {
uri,
headers,
boundary: boundary?,
body,
})
}
fn parse_multipart_field(body: &[u8], boundary: &str, field_name: &str) -> Option<String> {
let body = std::str::from_utf8(body).ok()?;
let delimiter = format!("--{boundary}");
body.split(&delimiter).find_map(|part| {
let (headers, value) = part.split_once("\r\n\r\n")?;
if !headers.contains(&format!("name=\"{field_name}\"")) {
return None;
}
Some(value.trim_end_matches("\r\n").to_string())
})
}
async fn start_session(
offer_sdp: String,
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: Arc<Notify>,
) -> Option<RealtimeSession> {
let peer_connection = create_peer_connection().await?;
let (tx_request, rx_request) = mpsc::unbounded_channel::<Value>();
let (tx_data_channel, rx_data_channel) = oneshot::channel::<Arc<RTCDataChannel>>();
let tx_data_channel = Mutex::new(Some(tx_data_channel));
let tx_data_channel_request = tx_request.clone();
let connection_closed = Arc::new(Notify::new());
let on_data_channel_closed = Arc::clone(&connection_closed);
let on_peer_connection_closed = Arc::clone(&connection_closed);
peer_connection.on_data_channel(Box::new(move |data_channel| {
let tx_request = tx_data_channel_request.clone();
let on_data_channel_closed = Arc::clone(&on_data_channel_closed);
if let Ok(mut tx_data_channel) = tx_data_channel.lock()
&& let Some(tx_data_channel) = tx_data_channel.take()
{
let _ = tx_data_channel.send(Arc::clone(&data_channel));
}
data_channel.on_close(Box::new(move || {
let on_data_channel_closed = Arc::clone(&on_data_channel_closed);
Box::pin(async move {
on_data_channel_closed.notify_waiters();
})
}));
data_channel.on_message(Box::new(move |message: DataChannelMessage| {
let tx_request = tx_request.clone();
Box::pin(async move {
if !message.is_string {
return;
}
let Ok(text) = String::from_utf8(message.data.to_vec()) else {
return;
};
let Ok(body) = serde_json::from_str::<Value>(&text) else {
return;
};
let _ = tx_request.send(body);
})
}));
Box::pin(async {})
}));
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
let on_peer_connection_closed = Arc::clone(&on_peer_connection_closed);
Box::pin(async move {
if matches!(
state,
RTCPeerConnectionState::Closed
| RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed
) {
on_peer_connection_closed.notify_waiters();
}
})
}));
register_remote_audio_handler(&peer_connection, tx_request.clone());
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let offer = RTCSessionDescription::offer(offer_sdp).ok()?;
peer_connection.set_remote_description(offer).await.ok()?;
let answer = peer_connection.create_answer(None).await.ok()?;
peer_connection.set_local_description(answer).await.ok()?;
let _ = gather_complete.recv().await;
let answer_sdp = peer_connection.local_description().await?.sdp;
let (done_tx, done_rx) = oneshot::channel();
tokio::spawn(async move {
serve_scripted_requests(
connection,
connection_index,
requests,
request_log_updated,
rx_request,
rx_data_channel,
connection_closed,
)
.await;
let _ = done_tx.send(());
});
Some(RealtimeSession {
answer_sdp,
peer_connection,
done_rx,
})
}
async fn create_peer_connection() -> Option<Arc<RTCPeerConnection>> {
ensure_rustls_crypto_provider();
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs().ok()?;
let registry = register_default_interceptors(Registry::new(), &mut media_engine).ok()?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
api.new_peer_connection(RTCConfiguration::default())
.await
.map(Arc::new)
.ok()
}
fn register_remote_audio_handler(
peer_connection: &Arc<RTCPeerConnection>,
tx_request: mpsc::UnboundedSender<Value>,
) {
peer_connection.on_track(Box::new(move |track, _, _| {
let tx_request = tx_request.clone();
Box::pin(async move {
if track.kind() != RTPCodecType::Audio {
return;
}
pump_remote_audio_track(track, tx_request).await;
})
}));
}
async fn pump_remote_audio_track(
track: Arc<TrackRemote>,
tx_request: mpsc::UnboundedSender<Value>,
) {
let mut decoder = match OpusDecoder::new(24_000, usize::from(REALTIME_AUDIO_CHANNELS)) {
Ok(decoder) => decoder,
Err(err) => {
warn!(%err, "failed to initialize realtime Opus decoder in test server");
return;
}
};
debug!("initialized realtime Opus decoder in test server");
let mut decoded = vec![0.0f32; REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL];
while let Ok((packet, _)) = track.read_rtp().await {
if packet.payload.is_empty() {
continue;
}
let samples_per_channel = match decoder.decode(
&packet.payload,
REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL,
&mut decoded,
) {
Ok(samples_per_channel) => samples_per_channel,
Err(err) => {
warn!(
%err,
payload_len = packet.payload.len(),
"failed to decode realtime Opus packet in test server"
);
return;
}
};
if samples_per_channel == 0 {
continue;
}
let mut pcm_bytes = Vec::with_capacity(samples_per_channel * 2);
for sample in &decoded[..samples_per_channel] {
pcm_bytes.extend_from_slice(&f32_to_i16(*sample).to_le_bytes());
}
let _ = tx_request.send(serde_json::json!({
"type": "input_audio_buffer.append",
"audio": BASE64_STANDARD.encode(pcm_bytes),
"sample_rate": REALTIME_AUDIO_SAMPLE_RATE,
"channels": REALTIME_AUDIO_CHANNELS,
"samples_per_channel": samples_per_channel,
}));
}
}
fn f32_to_i16(sample: f32) -> i16 {
(sample.clamp(-1.0, 1.0) * i16::MAX as f32) as i16
}
async fn serve_scripted_requests(
connection: WebSocketConnectionConfig,
connection_index: usize,
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: Arc<Notify>,
mut rx_request: mpsc::UnboundedReceiver<Value>,
rx_data_channel: oneshot::Receiver<Arc<RTCDataChannel>>,
connection_closed: Arc<Notify>,
) {
let Ok(Ok(data_channel)) = timeout(REALTIME_DATA_CHANNEL_TIMEOUT, rx_data_channel).await else {
return;
};
let mut scripted_requests = VecDeque::from(connection.requests);
while let Some(request_events) = scripted_requests.pop_front() {
// WebRTC compact-remote tests often close the session before consuming every scripted
// request slot. Treat transport closure as end-of-script instead of waiting forever for
// another request that can no longer arrive.
let body = tokio::select! {
body = rx_request.recv() => body,
_ = connection_closed.notified() => None,
};
let Some(body) = body else {
break;
};
log_request(connection_index, body, &requests, &request_log_updated);
for event in &request_events {
let Ok(payload) = serde_json::to_string(event) else {
continue;
};
if data_channel.send_text(payload).await.is_err() {
return;
}
}
}
if connection.close_after_requests {
tokio::time::sleep(REALTIME_DATA_CHANNEL_FLUSH_DELAY).await;
let _ = data_channel.close().await;
}
}
fn log_request(
connection_index: usize,
body: Value,
requests: &Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
request_log_updated: &Arc<Notify>,
) {
let mut log = requests.lock().unwrap();
if log.len() <= connection_index {
log.resize_with(connection_index + 1, Vec::new);
}
if let Some(connection_log) = log.get_mut(connection_index) {
connection_log.push(WebSocketRequest { body });
}
drop(log);
request_log_updated.notify_waiters();
}

View File

@@ -1,5 +1,7 @@
use anyhow::Context;
use anyhow::Result;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use chrono::Utc;
use codex_login::CodexAuth;
use codex_login::OPENAI_API_KEY_ENV_VAR;
@@ -41,6 +43,17 @@ const MEMORY_PROMPT_PHRASE: &str =
"You have access to a memory folder with guidance from prior runs.";
const REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR: &str =
"CODEX_REALTIME_CONVERSATION_TEST_SUBPROCESS";
fn realtime_pcm_test_tone_20ms_base64() -> String {
let pcm_bytes: Vec<u8> = (0..480)
.flat_map(|index| {
let sample = if index % 2 == 0 { 1024_i16 } else { -1024_i16 };
sample.to_le_bytes()
})
.collect();
BASE64_STANDARD.encode(pcm_bytes)
}
fn websocket_request_text(
request: &core_test_support::responses::WebSocketRequest,
) -> Option<String> {
@@ -207,7 +220,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
test.codex
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
frame: RealtimeAudioFrame {
data: "AQID".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24000,
num_channels: 1,
samples_per_channel: Some(480),
@@ -254,10 +267,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
server.handshakes()[1].header("authorization").as_deref(),
Some("Bearer dummy")
);
assert_eq!(
server.handshakes()[1].uri(),
"/v1/realtime?intent=quicksilver&model=realtime-test-model"
);
assert_eq!(server.handshakes()[1].uri(), "/v1/realtime/calls");
let mut request_types = [
connection[1].body_json()["type"]
.as_str()
@@ -426,7 +436,7 @@ async fn conversation_audio_before_start_emits_error() -> Result<()> {
test.codex
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
frame: RealtimeAudioFrame {
data: "AQID".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24000,
num_channels: 1,
samples_per_channel: Some(480),
@@ -625,7 +635,7 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> {
test.codex
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
frame: RealtimeAudioFrame {
data: "AQID".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24000,
num_channels: 1,
samples_per_channel: Some(480),
@@ -1585,7 +1595,7 @@ async fn inbound_handoff_request_clears_active_transcript_after_each_handoff() -
test.codex
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
frame: RealtimeAudioFrame {
data: "AQID".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24000,
num_channels: 1,
samples_per_channel: Some(480),
@@ -2073,7 +2083,7 @@ async fn inbound_handoff_request_steers_active_turn() -> Result<()> {
test.codex
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
frame: RealtimeAudioFrame {
data: "AQID".to_string(),
data: realtime_pcm_test_tone_20ms_base64(),
sample_rate: 24000,
num_channels: 1,
samples_per_channel: Some(480),

View File

@@ -112,6 +112,7 @@ codex-windows-sandbox = { workspace = true }
tokio-util = { workspace = true, features = ["time"] }
[target.'cfg(not(target_os = "linux"))'.dependencies]
aec3 = { workspace = true }
cpal = "0.15"
[target.'cfg(unix)'.dependencies]

View File

@@ -1,4 +1,6 @@
use super::*;
#[cfg(not(target_os = "linux"))]
use crate::realtime_audio_processing::RealtimeAudioProcessor;
use codex_protocol::protocol::ConversationStartParams;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeConversationClosedEvent;
@@ -30,6 +32,8 @@ pub(super) struct RealtimeConversationUiState {
#[cfg(not(target_os = "linux"))]
capture_stop_flag: Option<Arc<AtomicBool>>,
#[cfg(not(target_os = "linux"))]
audio_processor: Option<RealtimeAudioProcessor>,
#[cfg(not(target_os = "linux"))]
capture: Option<crate::voice::VoiceCapture>,
#[cfg(not(target_os = "linux"))]
audio_player: Option<crate::voice::RealtimeAudioPlayer>,
@@ -331,9 +335,26 @@ impl ChatWidget {
fn enqueue_realtime_audio_out(&mut self, frame: &RealtimeAudioFrame) {
#[cfg(not(target_os = "linux"))]
{
if !self.realtime_conversation.is_active() {
return;
}
if self.realtime_conversation.audio_player.is_none() {
self.realtime_conversation.audio_player =
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
let Some(audio_processor) = self.realtime_conversation.audio_processor.clone()
else {
self.fail_realtime_conversation(
"Realtime audio processor was unavailable".to_string(),
);
return;
};
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor) {
Ok(player) => self.realtime_conversation.audio_player = Some(player),
Err(err) => {
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
return;
}
}
}
if let Some(player) = &self.realtime_conversation.audio_player
&& let Err(err) = player.enqueue_frame(frame)
@@ -367,12 +388,42 @@ impl ChatWidget {
self.realtime_conversation.meter_placeholder_id = Some(placeholder_id.clone());
self.request_redraw();
let audio_processor = match RealtimeAudioProcessor::new() {
Ok(audio_processor) => audio_processor,
Err(err) => {
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
"Failed to start realtime audio processor: {err}"
));
return;
}
};
self.realtime_conversation.audio_processor = Some(audio_processor.clone());
let audio_player =
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor.clone()) {
Ok(player) => player,
Err(err) => {
self.realtime_conversation.audio_processor = None;
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
return;
}
};
let capture = match crate::voice::VoiceCapture::start_realtime(
&self.config,
self.app_event_tx.clone(),
audio_processor,
) {
Ok(capture) => capture,
Err(err) => {
drop(audio_player);
self.realtime_conversation.audio_processor = None;
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
@@ -389,10 +440,7 @@ impl ChatWidget {
self.realtime_conversation.capture_stop_flag = Some(stop_flag.clone());
self.realtime_conversation.capture = Some(capture);
if self.realtime_conversation.audio_player.is_none() {
self.realtime_conversation.audio_player =
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
}
self.realtime_conversation.audio_player = Some(audio_player);
std::thread::spawn(move || {
let mut meter = crate::voice::RecordingMeterState::new();
@@ -423,23 +471,10 @@ impl ChatWidget {
}
match kind {
RealtimeAudioDeviceKind::Microphone => {
self.stop_realtime_microphone();
RealtimeAudioDeviceKind::Microphone | RealtimeAudioDeviceKind::Speaker => {
self.stop_realtime_local_audio();
self.start_realtime_local_audio();
}
RealtimeAudioDeviceKind::Speaker => {
self.stop_realtime_speaker();
match crate::voice::RealtimeAudioPlayer::start(&self.config) {
Ok(player) => {
self.realtime_conversation.audio_player = Some(player);
}
Err(err) => {
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
}
}
}
}
self.request_redraw();
}
@@ -453,6 +488,7 @@ impl ChatWidget {
fn stop_realtime_local_audio(&mut self) {
self.stop_realtime_microphone();
self.stop_realtime_speaker();
self.realtime_conversation.audio_processor = None;
}
#[cfg(target_os = "linux")]

View File

@@ -130,6 +130,8 @@ pub mod onboarding;
mod oss_selection;
mod pager_overlay;
pub mod public_widgets;
#[cfg(not(target_os = "linux"))]
mod realtime_audio_processing;
mod render;
mod resume_picker;
mod selection_list;

View File

@@ -0,0 +1,298 @@
use aec3::voip::VoipAec3;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::mpsc;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::Sender;
use tracing::warn;
pub(crate) const AUDIO_PROCESSING_SAMPLE_RATE: u32 = 24_000;
pub(crate) const AUDIO_PROCESSING_CHANNELS: u16 = 1;
enum AudioProcessorCommand {
Capture(Vec<i16>),
Render(Vec<i16>),
}
#[derive(Clone)]
pub(crate) struct RealtimeAudioProcessor {
capture_rx: Arc<Mutex<Option<Receiver<Vec<i16>>>>>,
command_tx: Sender<AudioProcessorCommand>,
}
impl RealtimeAudioProcessor {
pub(crate) fn new() -> Result<Self, String> {
build_pipeline()?;
let (command_tx, command_rx) = mpsc::channel();
let (capture_tx, capture_rx) = mpsc::channel();
std::thread::Builder::new()
.name("codex-realtime-aec3".to_string())
.spawn(move || run_processor_thread(command_rx, capture_tx))
.map_err(|err| format!("failed to spawn realtime audio processor: {err}"))?;
Ok(Self {
capture_rx: Arc::new(Mutex::new(Some(capture_rx))),
command_tx,
})
}
pub(crate) fn capture_stage(
&self,
input_sample_rate: u32,
input_channels: u16,
) -> Result<RealtimeCaptureAudioProcessor, String> {
let capture_rx = self
.capture_rx
.lock()
.ok()
.and_then(|mut capture_rx| capture_rx.take())
.ok_or_else(|| "realtime capture stage was already created".to_string())?;
Ok(RealtimeCaptureAudioProcessor {
capture_rx,
command_tx: self.command_tx.clone(),
input_sample_rate,
input_channels,
processed_samples: VecDeque::new(),
})
}
pub(crate) fn render_stage(
&self,
output_sample_rate: u32,
output_channels: u16,
) -> RealtimeRenderAudioProcessor {
RealtimeRenderAudioProcessor {
command_tx: self.command_tx.clone(),
output_sample_rate,
output_channels,
}
}
}
pub(crate) struct RealtimeCaptureAudioProcessor {
capture_rx: Receiver<Vec<i16>>,
command_tx: Sender<AudioProcessorCommand>,
input_sample_rate: u32,
input_channels: u16,
processed_samples: VecDeque<i16>,
}
impl RealtimeCaptureAudioProcessor {
pub(crate) fn process_samples(&mut self, samples: &[i16]) -> Vec<i16> {
let converted = convert_pcm16(
samples,
self.input_sample_rate,
self.input_channels,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
if !converted.is_empty()
&& let Err(err) = self
.command_tx
.send(AudioProcessorCommand::Capture(converted))
{
warn!("failed to queue realtime capture audio: {err}");
}
loop {
match self.capture_rx.try_recv() {
Ok(processed) => self.processed_samples.extend(processed),
Err(mpsc::TryRecvError::Empty) => break,
Err(mpsc::TryRecvError::Disconnected) => {
warn!("realtime capture audio processor disconnected");
break;
}
}
}
self.processed_samples.drain(..).collect()
}
}
pub(crate) struct RealtimeRenderAudioProcessor {
command_tx: Sender<AudioProcessorCommand>,
output_sample_rate: u32,
output_channels: u16,
}
impl RealtimeRenderAudioProcessor {
pub(crate) fn process_samples(&mut self, samples: &[i16]) {
let converted = convert_pcm16(
samples,
self.output_sample_rate,
self.output_channels,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
if !converted.is_empty()
&& let Err(err) = self
.command_tx
.send(AudioProcessorCommand::Render(converted))
{
warn!("failed to queue realtime render audio: {err}");
}
}
}
fn build_pipeline() -> Result<VoipAec3, String> {
VoipAec3::builder(
AUDIO_PROCESSING_SAMPLE_RATE as usize,
usize::from(AUDIO_PROCESSING_CHANNELS),
usize::from(AUDIO_PROCESSING_CHANNELS),
)
.enable_high_pass(false)
.enable_noise_suppression(false)
.build()
.map_err(|err| format!("failed to initialize realtime audio processor: {err}"))
}
fn run_processor_thread(command_rx: Receiver<AudioProcessorCommand>, capture_tx: Sender<Vec<i16>>) {
let mut pipeline = match build_pipeline() {
Ok(pipeline) => pipeline,
Err(err) => {
warn!("{err}");
return;
}
};
let capture_frame_len =
pipeline.capture_frame_samples() * usize::from(AUDIO_PROCESSING_CHANNELS);
let render_frame_len = pipeline.render_frame_samples() * usize::from(AUDIO_PROCESSING_CHANNELS);
let mut pending_capture = VecDeque::new();
let mut pending_render = VecDeque::new();
while let Ok(command) = command_rx.recv() {
match command {
AudioProcessorCommand::Capture(samples) => {
pending_capture.extend(samples);
while pending_capture.len() >= capture_frame_len {
let capture_frame =
drain_pending_frame(&mut pending_capture, capture_frame_len);
let capture_frame = capture_frame
.iter()
.copied()
.map(i16_to_f32)
.collect::<Vec<_>>();
let mut output = vec![0.0; capture_frame.len()];
if let Err(err) = pipeline.process_capture_frame(
&capture_frame,
/*level_change*/ false,
&mut output,
) {
warn!("failed to process realtime capture audio: {err}");
continue;
}
let processed = output.into_iter().map(f32_to_i16).collect::<Vec<_>>();
if let Err(err) = capture_tx.send(processed) {
warn!("failed to deliver realtime capture audio: {err}");
return;
}
}
}
AudioProcessorCommand::Render(samples) => {
pending_render.extend(samples);
while pending_render.len() >= render_frame_len {
let render_frame = drain_pending_frame(&mut pending_render, render_frame_len);
let render_frame = render_frame
.iter()
.copied()
.map(i16_to_f32)
.collect::<Vec<_>>();
if let Err(err) = pipeline.handle_render_frame(&render_frame) {
warn!("failed to process realtime render audio: {err}");
}
}
}
}
}
}
fn drain_pending_frame(pending_samples: &mut VecDeque<i16>, frame_len: usize) -> Vec<i16> {
pending_samples.drain(..frame_len).collect()
}
pub(crate) fn convert_pcm16(
input: &[i16],
input_sample_rate: u32,
input_channels: u16,
output_sample_rate: u32,
output_channels: u16,
) -> Vec<i16> {
if input.is_empty() || input_channels == 0 || output_channels == 0 {
return Vec::new();
}
let in_channels = input_channels as usize;
let out_channels = output_channels as usize;
let in_frames = input.len() / in_channels;
if in_frames == 0 {
return Vec::new();
}
let out_frames = if input_sample_rate == output_sample_rate {
in_frames
} else {
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
as usize
};
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
for out_frame_idx in 0..out_frames {
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
0
} else {
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
};
let src_start = src_frame_idx.saturating_mul(in_channels);
let src = &input[src_start..src_start + in_channels];
match (in_channels, out_channels) {
(1, 1) => out.push(src[0]),
(1, n) => {
for _ in 0..n {
out.push(src[0]);
}
}
(n, 1) if n >= 2 => {
let sum: i32 = src.iter().map(|s| *s as i32).sum();
out.push((sum / (n as i32)) as i16);
}
(n, m) if n == m => out.extend_from_slice(src),
(n, m) if n > m => out.extend_from_slice(&src[..m]),
(n, m) => {
out.extend_from_slice(src);
let last = *src.last().unwrap_or(&0);
for _ in n..m {
out.push(last);
}
}
}
}
out
}
fn i16_to_f32(sample: i16) -> f32 {
(sample as f32) / (i16::MAX as f32)
}
fn f32_to_i16(sample: f32) -> i16 {
(sample.clamp(-1.0, 1.0) * i16::MAX as f32) as i16
}
#[cfg(test)]
mod tests {
use super::convert_pcm16;
use pretty_assertions::assert_eq;
#[test]
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
let converted = convert_pcm16(
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
);
assert_eq!(converted, vec![200, 700]);
}
}

View File

@@ -1,4 +1,10 @@
use crate::app_event_sender::AppEventSender;
use crate::realtime_audio_processing::AUDIO_PROCESSING_CHANNELS;
use crate::realtime_audio_processing::AUDIO_PROCESSING_SAMPLE_RATE;
use crate::realtime_audio_processing::RealtimeAudioProcessor;
use crate::realtime_audio_processing::RealtimeCaptureAudioProcessor;
use crate::realtime_audio_processing::RealtimeRenderAudioProcessor;
use crate::realtime_audio_processing::convert_pcm16;
use base64::Engine;
use codex_core::config::Config;
use codex_protocol::protocol::ConversationAudioParams;
@@ -23,7 +29,11 @@ pub struct VoiceCapture {
}
impl VoiceCapture {
pub fn start_realtime(config: &Config, tx: AppEventSender) -> Result<Self, String> {
pub fn start_realtime(
config: &Config,
tx: AppEventSender,
audio_processor: RealtimeAudioProcessor,
) -> Result<Self, String> {
let (device, config) = select_realtime_input_device_and_config(config)?;
let sample_rate = config.sample_rate().0;
@@ -34,9 +44,8 @@ impl VoiceCapture {
let stream = build_realtime_input_stream(
&device,
&config,
sample_rate,
channels,
tx,
audio_processor.capture_stage(sample_rate, channels)?,
last_peak.clone(),
)?;
stream
@@ -138,50 +147,76 @@ fn select_realtime_input_device_and_config(
fn build_realtime_input_stream(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
sample_rate: u32,
channels: u16,
tx: AppEventSender,
capture_processor: RealtimeCaptureAudioProcessor,
last_peak: Arc<AtomicU16>,
) -> Result<cpal::Stream, String> {
match config.sample_format() {
cpal::SampleFormat::F32 => device
.build_input_stream(
&config.clone().into(),
move |input: &[f32], _| {
let peak = peak_f32(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::I16 => device
.build_input_stream(
&config.clone().into(),
move |input: &[i16], _| {
let peak = peak_i16(input);
last_peak.store(peak, Ordering::Relaxed);
send_realtime_audio_chunk(&tx, input.to_vec(), sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::U16 => device
.build_input_stream(
&config.clone().into(),
move |input: &[u16], _| {
let mut samples = Vec::with_capacity(input.len());
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
last_peak.store(peak, Ordering::Relaxed);
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::F32 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[f32], _| {
let peak = peak_f32(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
let samples = capture_processor.process_samples(&samples);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
cpal::SampleFormat::I16 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[i16], _| {
let peak = peak_i16(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = capture_processor.process_samples(input);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
cpal::SampleFormat::U16 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[u16], _| {
let mut samples = Vec::with_capacity(input.len());
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
last_peak.store(peak, Ordering::Relaxed);
let samples = capture_processor.process_samples(&samples);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
_ => Err("unsupported input sample format".to_string()),
}
}
@@ -288,13 +323,21 @@ pub(crate) struct RealtimeAudioPlayer {
}
impl RealtimeAudioPlayer {
pub(crate) fn start(config: &Config) -> Result<Self, String> {
pub(crate) fn start(
config: &Config,
audio_processor: RealtimeAudioProcessor,
) -> Result<Self, String> {
let (device, config) =
crate::audio_device::select_configured_output_device_and_config(config)?;
let output_sample_rate = config.sample_rate().0;
let output_channels = config.channels();
let queue = Arc::new(Mutex::new(VecDeque::new()));
let stream = build_output_stream(&device, &config, Arc::clone(&queue))?;
let stream = build_output_stream(
&device,
&config,
Arc::clone(&queue),
audio_processor.render_stage(output_sample_rate, output_channels),
)?;
stream
.play()
.map_err(|e| format!("failed to start output stream: {e}"))?;
@@ -350,140 +393,94 @@ fn build_output_stream(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
queue: Arc<Mutex<VecDeque<i16>>>,
render_processor: RealtimeRenderAudioProcessor,
) -> Result<cpal::Stream, String> {
let config_any: cpal::StreamConfig = config.clone().into();
match config.sample_format() {
cpal::SampleFormat::F32 => device
.build_output_stream(
&config_any,
move |output: &mut [f32], _| fill_output_f32(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build f32 output stream: {e}")),
cpal::SampleFormat::I16 => device
.build_output_stream(
&config_any,
move |output: &mut [i16], _| fill_output_i16(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build i16 output stream: {e}")),
cpal::SampleFormat::U16 => device
.build_output_stream(
&config_any,
move |output: &mut [u16], _| fill_output_u16(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build u16 output stream: {e}")),
cpal::SampleFormat::F32 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [f32], _| {
fill_output_f32(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build f32 output stream: {e}"))
}
cpal::SampleFormat::I16 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [i16], _| {
fill_output_i16(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build i16 output stream: {e}"))
}
cpal::SampleFormat::U16 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [u16], _| {
fill_output_u16(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build u16 output stream: {e}"))
}
other => Err(format!("unsupported output sample format: {other:?}")),
}
}
fn fill_output_i16(output: &mut [i16], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
*sample = guard.pop_front().unwrap_or(0);
}
return;
}
output.fill(0);
fn fill_output_i16(
output: &mut [i16],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
output.copy_from_slice(&samples);
render_processor.process_samples(output);
}
fn fill_output_f32(output: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
let v = guard.pop_front().unwrap_or(0);
*sample = (v as f32) / (i16::MAX as f32);
}
return;
fn fill_output_f32(
output: &mut [f32],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
*output_sample = (*sample as f32) / (i16::MAX as f32);
}
output.fill(0.0);
render_processor.process_samples(&samples);
}
fn fill_output_u16(output: &mut [u16], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
let v = guard.pop_front().unwrap_or(0);
*sample = (v as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
}
return;
fn fill_output_u16(
output: &mut [u16],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
*output_sample = (*sample as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
}
output.fill(32768);
render_processor.process_samples(&samples);
}
fn convert_pcm16(
input: &[i16],
input_sample_rate: u32,
input_channels: u16,
output_sample_rate: u32,
output_channels: u16,
) -> Vec<i16> {
if input.is_empty() || input_channels == 0 || output_channels == 0 {
return Vec::new();
}
let in_channels = input_channels as usize;
let out_channels = output_channels as usize;
let in_frames = input.len() / in_channels;
if in_frames == 0 {
return Vec::new();
}
let out_frames = if input_sample_rate == output_sample_rate {
in_frames
} else {
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
as usize
fn drain_output_samples(output_len: usize, queue: &Arc<Mutex<VecDeque<i16>>>) -> Vec<i16> {
let mut samples = vec![0; output_len];
let Ok(mut guard) = queue.lock() else {
return samples;
};
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
for out_frame_idx in 0..out_frames {
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
0
} else {
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
};
let src_start = src_frame_idx.saturating_mul(in_channels);
let src = &input[src_start..src_start + in_channels];
match (in_channels, out_channels) {
(1, 1) => out.push(src[0]),
(1, n) => {
for _ in 0..n {
out.push(src[0]);
}
}
(n, 1) if n >= 2 => {
let sum: i32 = src.iter().map(|s| *s as i32).sum();
out.push((sum / (n as i32)) as i16);
}
(n, m) if n == m => out.extend_from_slice(src),
(n, m) if n > m => out.extend_from_slice(&src[..m]),
(n, m) => {
out.extend_from_slice(src);
let last = *src.last().unwrap_or(&0);
for _ in n..m {
out.push(last);
}
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::convert_pcm16;
use pretty_assertions::assert_eq;
#[test]
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
let converted = convert_pcm16(
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
);
assert_eq!(converted, vec![200, 700]);
for sample in &mut samples {
*sample = guard.pop_front().unwrap_or(0);
}
samples
}