mirror of
https://github.com/openai/codex.git
synced 2026-04-17 11:14:48 +00:00
Compare commits
24 Commits
dev/shaqay
...
realtime-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b080c91693 | ||
|
|
3e206d1e6a | ||
|
|
3e203f8633 | ||
|
|
90f10e9ab9 | ||
|
|
8207ff6ead | ||
|
|
99f72e27ae | ||
|
|
5d5305c5d4 | ||
|
|
236891f5d3 | ||
|
|
11180cefd9 | ||
|
|
76397dbdd0 | ||
|
|
f9739d0178 | ||
|
|
bbff67a7b9 | ||
|
|
d887e0be7f | ||
|
|
1f9567d121 | ||
|
|
763dc66fb6 | ||
|
|
357140d3c9 | ||
|
|
1aa17fafcc | ||
|
|
5462954edd | ||
|
|
da1ad103fa | ||
|
|
4abb01d268 | ||
|
|
d6d8d6304d | ||
|
|
63c1223141 | ||
|
|
a40422b85f | ||
|
|
175d831ff4 |
62
MODULE.bazel.lock
generated
62
MODULE.bazel.lock
generated
File diff suppressed because one or more lines are too long
863
codex-rs/Cargo.lock
generated
863
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -197,6 +197,7 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = { version = "3", features = ["wayland-data-control"] }
|
||||
arc-swap = "1.9.0"
|
||||
aec3 = "0.1.7"
|
||||
assert_cmd = "2"
|
||||
assert_matches = "1.5.0"
|
||||
async-channel = "2.3.1"
|
||||
@@ -239,6 +240,7 @@ image = { version = "^0.25.9", default-features = false }
|
||||
include_dir = "0.7.4"
|
||||
indexmap = "2.12.0"
|
||||
insta = "1.46.3"
|
||||
interceptor = "0.17.1"
|
||||
inventory = "0.3.19"
|
||||
itertools = "0.14.0"
|
||||
jsonwebtoken = "9.3.1"
|
||||
@@ -255,6 +257,7 @@ notify = "8.2.0"
|
||||
nucleo = { git = "https://github.com/helix-editor/nucleo.git", rev = "4253de9faabb4e5c6d81d946a5e35a90f87347ee" }
|
||||
once_cell = "1.20.2"
|
||||
openssl-sys = "*"
|
||||
opus-rs = "0.1.11"
|
||||
opentelemetry = "0.31.0"
|
||||
opentelemetry-appender-tracing = "0.31.0"
|
||||
opentelemetry-otlp = "0.31.0"
|
||||
@@ -349,6 +352,7 @@ v8 = "=146.4.0"
|
||||
vt100 = "0.16.2"
|
||||
walkdir = "2.5.0"
|
||||
webbrowser = "1.0"
|
||||
webrtc = "0.17.1"
|
||||
which = "8"
|
||||
wildmatch = "2.6.1"
|
||||
zip = "2.4.2"
|
||||
|
||||
@@ -3,6 +3,8 @@ use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::to_response;
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::LoginAccountResponse;
|
||||
@@ -40,6 +42,16 @@ use tokio::time::timeout;
|
||||
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.";
|
||||
|
||||
fn realtime_pcm_test_tone_20ms_base64() -> String {
|
||||
let pcm_bytes: Vec<u8> = (0..480)
|
||||
.flat_map(|index| {
|
||||
let sample = if index % 2 == 0 { 1024_i16 } else { -1024_i16 };
|
||||
sample.to_le_bytes()
|
||||
})
|
||||
.collect();
|
||||
BASE64_STANDARD.encode(pcm_bytes)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -155,7 +167,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
|
||||
.send_thread_realtime_append_audio_request(ThreadRealtimeAppendAudioParams {
|
||||
thread_id: started.thread_id.clone(),
|
||||
audio: ThreadRealtimeAudioChunk {
|
||||
data: "BQYH".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24_000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
|
||||
@@ -14,17 +14,21 @@ codex-protocol = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
http = { workspace = true }
|
||||
interceptor = { workspace = true }
|
||||
opus-rs = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["multipart"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tungstenite = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
tungstenite = { workspace = true }
|
||||
url = { workspace = true }
|
||||
webrtc = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@@ -32,7 +36,6 @@ assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,9 @@
|
||||
use crate::endpoint::realtime_websocket::methods_v1::conversation_handoff_append_message as v1_conversation_handoff_append_message;
|
||||
use crate::endpoint::realtime_websocket::methods_v1::conversation_item_create_message as v1_conversation_item_create_message;
|
||||
use crate::endpoint::realtime_websocket::methods_v1::session_update_session as v1_session_update_session;
|
||||
use crate::endpoint::realtime_websocket::methods_v1::websocket_intent as v1_websocket_intent;
|
||||
use crate::endpoint::realtime_websocket::methods_v2::conversation_handoff_append_message as v2_conversation_handoff_append_message;
|
||||
use crate::endpoint::realtime_websocket::methods_v2::conversation_item_create_message as v2_conversation_item_create_message;
|
||||
use crate::endpoint::realtime_websocket::methods_v2::session_update_session as v2_session_update_session;
|
||||
use crate::endpoint::realtime_websocket::methods_v2::websocket_intent as v2_websocket_intent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
|
||||
@@ -59,10 +57,3 @@ pub(super) fn session_update_session(
|
||||
RealtimeEventParser::RealtimeV2 => v2_session_update_session(instructions, session_mode),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn websocket_intent(event_parser: RealtimeEventParser) -> Option<&'static str> {
|
||||
match event_parser {
|
||||
RealtimeEventParser::V1 => v1_websocket_intent(),
|
||||
RealtimeEventParser::RealtimeV2 => v2_websocket_intent(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,3 @@ pub(super) fn session_update_session(instructions: String) -> SessionUpdateSessi
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn websocket_intent() -> Option<&'static str> {
|
||||
Some("quicksilver")
|
||||
}
|
||||
|
||||
@@ -126,7 +126,3 @@ pub(super) fn session_update_session(
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn websocket_intent() -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ mod protocol_v2;
|
||||
|
||||
pub use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
pub use codex_protocol::protocol::RealtimeEvent;
|
||||
pub use methods::RealtimeWebsocketClient;
|
||||
pub use methods::RealtimeWebsocketConnection;
|
||||
pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use methods::RealtimeWebRtcClient;
|
||||
pub use methods::RealtimeWebRtcConnection;
|
||||
pub use methods::RealtimeWebRtcEvents;
|
||||
pub use methods::RealtimeWebRtcWriter;
|
||||
pub use protocol::RealtimeEventParser;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
pub use protocol::RealtimeSessionMode;
|
||||
|
||||
@@ -32,8 +32,6 @@ pub struct RealtimeSessionConfig {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "input_audio_buffer.append")]
|
||||
InputAudioBufferAppend { audio: String },
|
||||
#[serde(rename = "conversation.handoff.append")]
|
||||
ConversationHandoffAppend {
|
||||
handoff_id: String,
|
||||
|
||||
@@ -35,7 +35,10 @@ pub(super) fn parse_realtime_event_v1(payload: &str) -> Option<RealtimeEvent> {
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok()),
|
||||
item_id: None,
|
||||
item_id: parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string),
|
||||
}))
|
||||
}
|
||||
"conversation.input_transcript.delta" => {
|
||||
|
||||
@@ -34,8 +34,8 @@ pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeEventParser;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionMode;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebRtcClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebRtcConnection;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
type RealtimeWsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
|
||||
|
||||
async fn spawn_realtime_ws_server<Handler, Fut>(
|
||||
handler: Handler,
|
||||
) -> (String, tokio::task::JoinHandle<()>)
|
||||
where
|
||||
Handler: FnOnce(RealtimeWsStream) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||
Ok(listener) => listener,
|
||||
Err(err) => panic!("failed to bind test websocket listener: {err}"),
|
||||
};
|
||||
let addr = match listener.local_addr() {
|
||||
Ok(addr) => addr.to_string(),
|
||||
Err(err) => panic!("failed to read local websocket listener address: {err}"),
|
||||
};
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = match listener.accept().await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => panic!("failed to accept test websocket connection: {err}"),
|
||||
};
|
||||
let ws = match accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(err) => panic!("failed to complete websocket handshake: {err}"),
|
||||
};
|
||||
handler(ws).await;
|
||||
});
|
||||
|
||||
(addr, server)
|
||||
}
|
||||
|
||||
fn test_provider(base_url: String) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url,
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("quicksilver".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
Value::String("backend prompt".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["format"]["type"],
|
||||
Value::String("audio/pcm".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["format"]["rate"],
|
||||
Value::from(24_000)
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_mock", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "input_audio_buffer.append");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "conversation.output_audio.delta",
|
||||
"delta": "AQID",
|
||||
"sample_rate": 48000,
|
||||
"channels": 1
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send audio out");
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let created = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
created,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_mock".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection
|
||||
.send_audio_frame(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 48000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(960),
|
||||
item_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("send audio");
|
||||
|
||||
let audio_event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
audio_event,
|
||||
RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 48000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: None,
|
||||
item_id: None,
|
||||
})
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "input_audio_buffer.append");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_after_send", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let (send_result, next_result) = tokio::join!(
|
||||
async {
|
||||
tokio::time::timeout(
|
||||
Duration::from_millis(200),
|
||||
connection.send_audio_frame(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 48000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(960),
|
||||
item_id: None,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
},
|
||||
connection.next_event()
|
||||
);
|
||||
|
||||
send_result
|
||||
.expect("send should not block on next_event")
|
||||
.expect("send audio");
|
||||
let next_event = next_result.expect("next event").expect("event");
|
||||
assert_eq!(
|
||||
next_event,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_after_send".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
|
||||
ws.send(Message::Close(None)).await.expect("send close");
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let first = connection.next_event().await.expect("next event");
|
||||
assert_eq!(first, None);
|
||||
|
||||
let second = connection.next_event().await.expect("next event");
|
||||
assert_eq!(second, None);
|
||||
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp_unknown"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send unknown event");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.updated",
|
||||
"session": {"id": "sess_after_unknown", "instructions": "backend prompt"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.updated");
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::V1,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
event,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_after_unknown".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_realtime_v2_parser_emits_handoff_requested() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "conversation.item.done",
|
||||
"item": {
|
||||
"id": "item_123",
|
||||
"type": "function_call",
|
||||
"name": "codex",
|
||||
"call_id": "call_123",
|
||||
"arguments": "{\"prompt\":\"delegate now\"}"
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send function call");
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = RealtimeWebsocketClient::new(test_provider(format!("http://{addr}")));
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
event_parser: RealtimeEventParser::RealtimeV2,
|
||||
session_mode: RealtimeSessionMode::Conversational,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
event,
|
||||
RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: "call_123".to_string(),
|
||||
item_id: "item_123".to_string(),
|
||||
input_transcript: "delegate now".to_string(),
|
||||
active_transcript: Vec::new(),
|
||||
})
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
@@ -13,10 +13,10 @@ use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeEventParser;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeSessionMode;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::RealtimeWebRtcClient;
|
||||
use codex_api::api_bridge::map_api_error;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketEvents;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebsocketWriter;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebRtcEvents;
|
||||
use codex_api::endpoint::realtime_websocket::RealtimeWebRtcWriter;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::default_client::default_headers;
|
||||
@@ -107,8 +107,8 @@ struct OutputAudioState {
|
||||
}
|
||||
|
||||
struct RealtimeInputTask {
|
||||
writer: RealtimeWebsocketWriter,
|
||||
events: RealtimeWebsocketEvents,
|
||||
writer: RealtimeWebRtcWriter,
|
||||
events: RealtimeWebRtcEvents,
|
||||
user_text_rx: Receiver<String>,
|
||||
handoff_output_rx: Receiver<HandoffOutput>,
|
||||
audio_rx: Receiver<RealtimeAudioFrame>,
|
||||
@@ -132,7 +132,7 @@ impl RealtimeHandoffState {
|
||||
struct ConversationState {
|
||||
audio_tx: Sender<RealtimeAudioFrame>,
|
||||
user_text_tx: Sender<String>,
|
||||
writer: RealtimeWebsocketWriter,
|
||||
writer: RealtimeWebRtcWriter,
|
||||
handoff: RealtimeHandoffState,
|
||||
input_task: JoinHandle<()>,
|
||||
fanout_task: Option<JoinHandle<()>>,
|
||||
@@ -172,7 +172,7 @@ impl RealtimeConversationManager {
|
||||
RealtimeEventParser::RealtimeV2 => RealtimeSessionKind::V2,
|
||||
};
|
||||
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let client = RealtimeWebRtcClient::new(api_provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
session_config,
|
||||
@@ -392,6 +392,7 @@ async fn stop_conversation_state(
|
||||
fanout_task_stop: RealtimeFanoutTaskStop,
|
||||
) {
|
||||
state.realtime_active.store(false, Ordering::Relaxed);
|
||||
let _ = state.writer.close().await;
|
||||
state.input_task.abort();
|
||||
let _ = state.input_task.await;
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@ load("//:defs.bzl", "codex_rust_crate")
|
||||
codex_rust_crate(
|
||||
name = "common",
|
||||
crate_name = "core_test_support",
|
||||
crate_srcs = glob(["*.rs"]),
|
||||
crate_srcs = glob([
|
||||
"*.rs",
|
||||
"responses/*.rs",
|
||||
"responses/**/*.rs",
|
||||
]),
|
||||
lib_data_extra = [
|
||||
"//codex-rs/core:model_availability_nux_fixtures",
|
||||
],
|
||||
|
||||
@@ -24,11 +24,13 @@ codex-models-manager = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
ctor = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
opus-rs = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
@@ -38,6 +40,7 @@ tracing = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
webrtc = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
@@ -37,6 +37,8 @@ use wiremock::matchers::path_regex;
|
||||
|
||||
use crate::test_codex::ApplyPatchModelOutput;
|
||||
|
||||
mod realtime_webrtc_server;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseMock {
|
||||
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
|
||||
@@ -1238,6 +1240,28 @@ pub async fn start_websocket_server_with_headers(
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
if realtime_webrtc_server::accept_is_http_post(&stream).await {
|
||||
let connection_index = {
|
||||
let mut log = requests.lock().unwrap();
|
||||
log.push(Vec::new());
|
||||
log.len() - 1
|
||||
};
|
||||
realtime_webrtc_server::serve_connection(
|
||||
stream,
|
||||
connection,
|
||||
connection_index,
|
||||
Arc::clone(&requests),
|
||||
Arc::clone(&handshakes),
|
||||
Arc::clone(&request_log),
|
||||
)
|
||||
.await;
|
||||
|
||||
if connections.lock().unwrap().is_empty() {
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let response_headers = connection.response_headers.clone();
|
||||
let handshake_log = Arc::clone(&handshakes);
|
||||
let callback = move |req: &Request, mut response: Response| {
|
||||
|
||||
431
codex-rs/core/tests/common/responses/realtime_webrtc_server.rs
Normal file
431
codex-rs/core/tests/common/responses/realtime_webrtc_server.rs
Normal file
@@ -0,0 +1,431 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use opus_rs::OpusDecoder;
|
||||
use serde_json::Value;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
use webrtc::api::APIBuilder;
|
||||
use webrtc::api::interceptor_registry::register_default_interceptors;
|
||||
use webrtc::api::media_engine::MediaEngine;
|
||||
use webrtc::data_channel::RTCDataChannel;
|
||||
use webrtc::data_channel::data_channel_message::DataChannelMessage;
|
||||
use webrtc::interceptor::registry::Registry;
|
||||
use webrtc::peer_connection::RTCPeerConnection;
|
||||
use webrtc::peer_connection::configuration::RTCConfiguration;
|
||||
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
|
||||
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
use webrtc::rtp_transceiver::rtp_codec::RTPCodecType;
|
||||
use webrtc::track::track_remote::TrackRemote;
|
||||
|
||||
use super::WebSocketConnectionConfig;
|
||||
use super::WebSocketHandshake;
|
||||
use super::WebSocketRequest;
|
||||
|
||||
const HTTP_HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
|
||||
const REALTIME_AUDIO_CHANNELS: u8 = 1;
|
||||
const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000;
|
||||
const REALTIME_DATA_CHANNEL_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const REALTIME_DATA_CHANNEL_FLUSH_DELAY: Duration = Duration::from_millis(25);
|
||||
const REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760;
|
||||
|
||||
pub(super) async fn accept_is_http_post(stream: &TcpStream) -> bool {
|
||||
let mut method = [0u8; 4];
|
||||
matches!(stream.peek(&mut method).await, Ok(4)) && method == *b"POST"
|
||||
}
|
||||
|
||||
pub(super) async fn serve_connection(
|
||||
mut stream: TcpStream,
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
) {
|
||||
let Some(request) = read_http_request(&mut stream).await else {
|
||||
return;
|
||||
};
|
||||
|
||||
handshakes.lock().unwrap().push(WebSocketHandshake {
|
||||
uri: request.uri,
|
||||
headers: request.headers,
|
||||
});
|
||||
|
||||
let Some(offer_sdp) = parse_multipart_field(&request.body, &request.boundary, "sdp") else {
|
||||
let _ = stream
|
||||
.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(session) = start_session(
|
||||
offer_sdp,
|
||||
connection,
|
||||
connection_index,
|
||||
requests,
|
||||
request_log_updated,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
let _ = stream
|
||||
.write_all(b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n")
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}",
|
||||
session.answer_sdp.len(),
|
||||
session.answer_sdp
|
||||
);
|
||||
if stream.write_all(response.as_bytes()).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = session.done_rx.await;
|
||||
let _ = session.peer_connection.close().await;
|
||||
}
|
||||
|
||||
struct HttpRealtimeRequest {
|
||||
uri: String,
|
||||
headers: Vec<(String, String)>,
|
||||
boundary: String,
|
||||
body: Vec<u8>,
|
||||
}
|
||||
|
||||
struct RealtimeSession {
|
||||
answer_sdp: String,
|
||||
peer_connection: Arc<RTCPeerConnection>,
|
||||
done_rx: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
async fn read_http_request(stream: &mut TcpStream) -> Option<HttpRealtimeRequest> {
|
||||
let mut received = Vec::new();
|
||||
let headers_end = loop {
|
||||
if let Some(headers_end) = received
|
||||
.windows(HTTP_HEADER_TERMINATOR.len())
|
||||
.position(|window| window == HTTP_HEADER_TERMINATOR)
|
||||
{
|
||||
break headers_end + HTTP_HEADER_TERMINATOR.len();
|
||||
}
|
||||
|
||||
let mut chunk = [0u8; 1024];
|
||||
let read = stream.read(&mut chunk).await.ok()?;
|
||||
if read == 0 {
|
||||
return None;
|
||||
}
|
||||
received.extend_from_slice(&chunk[..read]);
|
||||
};
|
||||
|
||||
let header_text = std::str::from_utf8(&received[..headers_end]).ok()?;
|
||||
let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
|
||||
let request_line = lines.next()?;
|
||||
let mut request_line_parts = request_line.split_whitespace();
|
||||
if request_line_parts.next()? != "POST" {
|
||||
return None;
|
||||
}
|
||||
let uri = request_line_parts.next()?.to_string();
|
||||
|
||||
let mut headers = Vec::new();
|
||||
let mut content_length = None;
|
||||
let mut boundary = None;
|
||||
for line in lines {
|
||||
let Some((name, value)) = line.split_once(':') else {
|
||||
continue;
|
||||
};
|
||||
let name = name.trim().to_string();
|
||||
let value = value.trim().to_string();
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
content_length = value.parse::<usize>().ok();
|
||||
}
|
||||
if name.eq_ignore_ascii_case("content-type") {
|
||||
boundary = value
|
||||
.split(';')
|
||||
.map(str::trim)
|
||||
.find_map(|part| part.strip_prefix("boundary="))
|
||||
.map(|boundary| boundary.trim_matches('"').to_string());
|
||||
}
|
||||
headers.push((name, value));
|
||||
}
|
||||
|
||||
let content_length = content_length?;
|
||||
while received.len() - headers_end < content_length {
|
||||
let mut chunk = [0u8; 1024];
|
||||
let read = stream.read(&mut chunk).await.ok()?;
|
||||
if read == 0 {
|
||||
return None;
|
||||
}
|
||||
received.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
let body_end = headers_end + content_length;
|
||||
let body = received[headers_end..body_end].to_vec();
|
||||
Some(HttpRealtimeRequest {
|
||||
uri,
|
||||
headers,
|
||||
boundary: boundary?,
|
||||
body,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_multipart_field(body: &[u8], boundary: &str, field_name: &str) -> Option<String> {
|
||||
let body = std::str::from_utf8(body).ok()?;
|
||||
let delimiter = format!("--{boundary}");
|
||||
body.split(&delimiter).find_map(|part| {
|
||||
let (headers, value) = part.split_once("\r\n\r\n")?;
|
||||
if !headers.contains(&format!("name=\"{field_name}\"")) {
|
||||
return None;
|
||||
}
|
||||
Some(value.trim_end_matches("\r\n").to_string())
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_session(
|
||||
offer_sdp: String,
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
) -> Option<RealtimeSession> {
|
||||
let peer_connection = create_peer_connection().await?;
|
||||
let (tx_request, rx_request) = mpsc::unbounded_channel::<Value>();
|
||||
let (tx_data_channel, rx_data_channel) = oneshot::channel::<Arc<RTCDataChannel>>();
|
||||
let tx_data_channel = Mutex::new(Some(tx_data_channel));
|
||||
let tx_data_channel_request = tx_request.clone();
|
||||
let connection_closed = Arc::new(Notify::new());
|
||||
let on_data_channel_closed = Arc::clone(&connection_closed);
|
||||
let on_peer_connection_closed = Arc::clone(&connection_closed);
|
||||
|
||||
peer_connection.on_data_channel(Box::new(move |data_channel| {
|
||||
let tx_request = tx_data_channel_request.clone();
|
||||
let on_data_channel_closed = Arc::clone(&on_data_channel_closed);
|
||||
if let Ok(mut tx_data_channel) = tx_data_channel.lock()
|
||||
&& let Some(tx_data_channel) = tx_data_channel.take()
|
||||
{
|
||||
let _ = tx_data_channel.send(Arc::clone(&data_channel));
|
||||
}
|
||||
data_channel.on_close(Box::new(move || {
|
||||
let on_data_channel_closed = Arc::clone(&on_data_channel_closed);
|
||||
Box::pin(async move {
|
||||
on_data_channel_closed.notify_waiters();
|
||||
})
|
||||
}));
|
||||
data_channel.on_message(Box::new(move |message: DataChannelMessage| {
|
||||
let tx_request = tx_request.clone();
|
||||
Box::pin(async move {
|
||||
if !message.is_string {
|
||||
return;
|
||||
}
|
||||
let Ok(text) = String::from_utf8(message.data.to_vec()) else {
|
||||
return;
|
||||
};
|
||||
let Ok(body) = serde_json::from_str::<Value>(&text) else {
|
||||
return;
|
||||
};
|
||||
let _ = tx_request.send(body);
|
||||
})
|
||||
}));
|
||||
Box::pin(async {})
|
||||
}));
|
||||
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
|
||||
let on_peer_connection_closed = Arc::clone(&on_peer_connection_closed);
|
||||
Box::pin(async move {
|
||||
if matches!(
|
||||
state,
|
||||
RTCPeerConnectionState::Closed
|
||||
| RTCPeerConnectionState::Disconnected
|
||||
| RTCPeerConnectionState::Failed
|
||||
) {
|
||||
on_peer_connection_closed.notify_waiters();
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
register_remote_audio_handler(&peer_connection, tx_request.clone());
|
||||
|
||||
let mut gather_complete = peer_connection.gathering_complete_promise().await;
|
||||
let offer = RTCSessionDescription::offer(offer_sdp).ok()?;
|
||||
peer_connection.set_remote_description(offer).await.ok()?;
|
||||
let answer = peer_connection.create_answer(None).await.ok()?;
|
||||
peer_connection.set_local_description(answer).await.ok()?;
|
||||
let _ = gather_complete.recv().await;
|
||||
let answer_sdp = peer_connection.local_description().await?.sdp;
|
||||
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
serve_scripted_requests(
|
||||
connection,
|
||||
connection_index,
|
||||
requests,
|
||||
request_log_updated,
|
||||
rx_request,
|
||||
rx_data_channel,
|
||||
connection_closed,
|
||||
)
|
||||
.await;
|
||||
let _ = done_tx.send(());
|
||||
});
|
||||
|
||||
Some(RealtimeSession {
|
||||
answer_sdp,
|
||||
peer_connection,
|
||||
done_rx,
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_peer_connection() -> Option<Arc<RTCPeerConnection>> {
|
||||
ensure_rustls_crypto_provider();
|
||||
|
||||
let mut media_engine = MediaEngine::default();
|
||||
media_engine.register_default_codecs().ok()?;
|
||||
let registry = register_default_interceptors(Registry::new(), &mut media_engine).ok()?;
|
||||
let api = APIBuilder::new()
|
||||
.with_media_engine(media_engine)
|
||||
.with_interceptor_registry(registry)
|
||||
.build();
|
||||
api.new_peer_connection(RTCConfiguration::default())
|
||||
.await
|
||||
.map(Arc::new)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn register_remote_audio_handler(
|
||||
peer_connection: &Arc<RTCPeerConnection>,
|
||||
tx_request: mpsc::UnboundedSender<Value>,
|
||||
) {
|
||||
peer_connection.on_track(Box::new(move |track, _, _| {
|
||||
let tx_request = tx_request.clone();
|
||||
Box::pin(async move {
|
||||
if track.kind() != RTPCodecType::Audio {
|
||||
return;
|
||||
}
|
||||
pump_remote_audio_track(track, tx_request).await;
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
||||
async fn pump_remote_audio_track(
|
||||
track: Arc<TrackRemote>,
|
||||
tx_request: mpsc::UnboundedSender<Value>,
|
||||
) {
|
||||
let mut decoder = match OpusDecoder::new(24_000, usize::from(REALTIME_AUDIO_CHANNELS)) {
|
||||
Ok(decoder) => decoder,
|
||||
Err(err) => {
|
||||
warn!(%err, "failed to initialize realtime Opus decoder in test server");
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("initialized realtime Opus decoder in test server");
|
||||
let mut decoded = vec![0.0f32; REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL];
|
||||
|
||||
while let Ok((packet, _)) = track.read_rtp().await {
|
||||
if packet.payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let samples_per_channel = match decoder.decode(
|
||||
&packet.payload,
|
||||
REALTIME_MAX_DECODED_SAMPLES_PER_CHANNEL,
|
||||
&mut decoded,
|
||||
) {
|
||||
Ok(samples_per_channel) => samples_per_channel,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
%err,
|
||||
payload_len = packet.payload.len(),
|
||||
"failed to decode realtime Opus packet in test server"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if samples_per_channel == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut pcm_bytes = Vec::with_capacity(samples_per_channel * 2);
|
||||
for sample in &decoded[..samples_per_channel] {
|
||||
pcm_bytes.extend_from_slice(&f32_to_i16(*sample).to_le_bytes());
|
||||
}
|
||||
let _ = tx_request.send(serde_json::json!({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": BASE64_STANDARD.encode(pcm_bytes),
|
||||
"sample_rate": REALTIME_AUDIO_SAMPLE_RATE,
|
||||
"channels": REALTIME_AUDIO_CHANNELS,
|
||||
"samples_per_channel": samples_per_channel,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
fn f32_to_i16(sample: f32) -> i16 {
|
||||
(sample.clamp(-1.0, 1.0) * i16::MAX as f32) as i16
|
||||
}
|
||||
|
||||
async fn serve_scripted_requests(
|
||||
connection: WebSocketConnectionConfig,
|
||||
connection_index: usize,
|
||||
requests: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: Arc<Notify>,
|
||||
mut rx_request: mpsc::UnboundedReceiver<Value>,
|
||||
rx_data_channel: oneshot::Receiver<Arc<RTCDataChannel>>,
|
||||
connection_closed: Arc<Notify>,
|
||||
) {
|
||||
let Ok(Ok(data_channel)) = timeout(REALTIME_DATA_CHANNEL_TIMEOUT, rx_data_channel).await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut scripted_requests = VecDeque::from(connection.requests);
|
||||
while let Some(request_events) = scripted_requests.pop_front() {
|
||||
// WebRTC compact-remote tests often close the session before consuming every scripted
|
||||
// request slot. Treat transport closure as end-of-script instead of waiting forever for
|
||||
// another request that can no longer arrive.
|
||||
let body = tokio::select! {
|
||||
body = rx_request.recv() => body,
|
||||
_ = connection_closed.notified() => None,
|
||||
};
|
||||
let Some(body) = body else {
|
||||
break;
|
||||
};
|
||||
log_request(connection_index, body, &requests, &request_log_updated);
|
||||
for event in &request_events {
|
||||
let Ok(payload) = serde_json::to_string(event) else {
|
||||
continue;
|
||||
};
|
||||
if data_channel.send_text(payload).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if connection.close_after_requests {
|
||||
tokio::time::sleep(REALTIME_DATA_CHANNEL_FLUSH_DELAY).await;
|
||||
let _ = data_channel.close().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn log_request(
|
||||
connection_index: usize,
|
||||
body: Value,
|
||||
requests: &Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
request_log_updated: &Arc<Notify>,
|
||||
) {
|
||||
let mut log = requests.lock().unwrap();
|
||||
if log.len() <= connection_index {
|
||||
log.resize_with(connection_index + 1, Vec::new);
|
||||
}
|
||||
if let Some(connection_log) = log.get_mut(connection_index) {
|
||||
connection_log.push(WebSocketRequest { body });
|
||||
}
|
||||
drop(log);
|
||||
request_log_updated.notify_waiters();
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use chrono::Utc;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::OPENAI_API_KEY_ENV_VAR;
|
||||
@@ -41,6 +43,17 @@ const MEMORY_PROMPT_PHRASE: &str =
|
||||
"You have access to a memory folder with guidance from prior runs.";
|
||||
const REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR: &str =
|
||||
"CODEX_REALTIME_CONVERSATION_TEST_SUBPROCESS";
|
||||
|
||||
fn realtime_pcm_test_tone_20ms_base64() -> String {
|
||||
let pcm_bytes: Vec<u8> = (0..480)
|
||||
.flat_map(|index| {
|
||||
let sample = if index % 2 == 0 { 1024_i16 } else { -1024_i16 };
|
||||
sample.to_le_bytes()
|
||||
})
|
||||
.collect();
|
||||
BASE64_STANDARD.encode(pcm_bytes)
|
||||
}
|
||||
|
||||
fn websocket_request_text(
|
||||
request: &core_test_support::responses::WebSocketRequest,
|
||||
) -> Option<String> {
|
||||
@@ -207,7 +220,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
@@ -254,10 +267,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
server.handshakes()[1].header("authorization").as_deref(),
|
||||
Some("Bearer dummy")
|
||||
);
|
||||
assert_eq!(
|
||||
server.handshakes()[1].uri(),
|
||||
"/v1/realtime?intent=quicksilver&model=realtime-test-model"
|
||||
);
|
||||
assert_eq!(server.handshakes()[1].uri(), "/v1/realtime/calls");
|
||||
let mut request_types = [
|
||||
connection[1].body_json()["type"]
|
||||
.as_str()
|
||||
@@ -426,7 +436,7 @@ async fn conversation_audio_before_start_emits_error() -> Result<()> {
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
@@ -625,7 +635,7 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> {
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
@@ -1585,7 +1595,7 @@ async fn inbound_handoff_request_clears_active_transcript_after_each_handoff() -
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
@@ -2073,7 +2083,7 @@ async fn inbound_handoff_request_steers_active_turn() -> Result<()> {
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
data: realtime_pcm_test_tone_20ms_base64(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
|
||||
@@ -112,6 +112,7 @@ codex-windows-sandbox = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["time"] }
|
||||
|
||||
[target.'cfg(not(target_os = "linux"))'.dependencies]
|
||||
aec3 = { workspace = true }
|
||||
cpal = "0.15"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::*;
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
use crate::realtime_audio_processing::RealtimeAudioProcessor;
|
||||
use codex_protocol::protocol::ConversationStartParams;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeConversationClosedEvent;
|
||||
@@ -30,6 +32,8 @@ pub(super) struct RealtimeConversationUiState {
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
capture_stop_flag: Option<Arc<AtomicBool>>,
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
audio_processor: Option<RealtimeAudioProcessor>,
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
capture: Option<crate::voice::VoiceCapture>,
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
audio_player: Option<crate::voice::RealtimeAudioPlayer>,
|
||||
@@ -331,9 +335,26 @@ impl ChatWidget {
|
||||
fn enqueue_realtime_audio_out(&mut self, frame: &RealtimeAudioFrame) {
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
if !self.realtime_conversation.is_active() {
|
||||
return;
|
||||
}
|
||||
if self.realtime_conversation.audio_player.is_none() {
|
||||
self.realtime_conversation.audio_player =
|
||||
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
|
||||
let Some(audio_processor) = self.realtime_conversation.audio_processor.clone()
|
||||
else {
|
||||
self.fail_realtime_conversation(
|
||||
"Realtime audio processor was unavailable".to_string(),
|
||||
);
|
||||
return;
|
||||
};
|
||||
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor) {
|
||||
Ok(player) => self.realtime_conversation.audio_player = Some(player),
|
||||
Err(err) => {
|
||||
self.fail_realtime_conversation(format!(
|
||||
"Failed to start speaker output: {err}"
|
||||
));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(player) = &self.realtime_conversation.audio_player
|
||||
&& let Err(err) = player.enqueue_frame(frame)
|
||||
@@ -367,12 +388,42 @@ impl ChatWidget {
|
||||
self.realtime_conversation.meter_placeholder_id = Some(placeholder_id.clone());
|
||||
self.request_redraw();
|
||||
|
||||
let audio_processor = match RealtimeAudioProcessor::new() {
|
||||
Ok(audio_processor) => audio_processor,
|
||||
Err(err) => {
|
||||
self.realtime_conversation.meter_placeholder_id = None;
|
||||
self.remove_recording_meter_placeholder(&placeholder_id);
|
||||
self.fail_realtime_conversation(format!(
|
||||
"Failed to start realtime audio processor: {err}"
|
||||
));
|
||||
return;
|
||||
}
|
||||
};
|
||||
self.realtime_conversation.audio_processor = Some(audio_processor.clone());
|
||||
|
||||
let audio_player =
|
||||
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor.clone()) {
|
||||
Ok(player) => player,
|
||||
Err(err) => {
|
||||
self.realtime_conversation.audio_processor = None;
|
||||
self.realtime_conversation.meter_placeholder_id = None;
|
||||
self.remove_recording_meter_placeholder(&placeholder_id);
|
||||
self.fail_realtime_conversation(format!(
|
||||
"Failed to start speaker output: {err}"
|
||||
));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let capture = match crate::voice::VoiceCapture::start_realtime(
|
||||
&self.config,
|
||||
self.app_event_tx.clone(),
|
||||
audio_processor,
|
||||
) {
|
||||
Ok(capture) => capture,
|
||||
Err(err) => {
|
||||
drop(audio_player);
|
||||
self.realtime_conversation.audio_processor = None;
|
||||
self.realtime_conversation.meter_placeholder_id = None;
|
||||
self.remove_recording_meter_placeholder(&placeholder_id);
|
||||
self.fail_realtime_conversation(format!(
|
||||
@@ -389,10 +440,7 @@ impl ChatWidget {
|
||||
|
||||
self.realtime_conversation.capture_stop_flag = Some(stop_flag.clone());
|
||||
self.realtime_conversation.capture = Some(capture);
|
||||
if self.realtime_conversation.audio_player.is_none() {
|
||||
self.realtime_conversation.audio_player =
|
||||
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
|
||||
}
|
||||
self.realtime_conversation.audio_player = Some(audio_player);
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let mut meter = crate::voice::RecordingMeterState::new();
|
||||
@@ -423,23 +471,10 @@ impl ChatWidget {
|
||||
}
|
||||
|
||||
match kind {
|
||||
RealtimeAudioDeviceKind::Microphone => {
|
||||
self.stop_realtime_microphone();
|
||||
RealtimeAudioDeviceKind::Microphone | RealtimeAudioDeviceKind::Speaker => {
|
||||
self.stop_realtime_local_audio();
|
||||
self.start_realtime_local_audio();
|
||||
}
|
||||
RealtimeAudioDeviceKind::Speaker => {
|
||||
self.stop_realtime_speaker();
|
||||
match crate::voice::RealtimeAudioPlayer::start(&self.config) {
|
||||
Ok(player) => {
|
||||
self.realtime_conversation.audio_player = Some(player);
|
||||
}
|
||||
Err(err) => {
|
||||
self.fail_realtime_conversation(format!(
|
||||
"Failed to start speaker output: {err}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.request_redraw();
|
||||
}
|
||||
@@ -453,6 +488,7 @@ impl ChatWidget {
|
||||
fn stop_realtime_local_audio(&mut self) {
|
||||
self.stop_realtime_microphone();
|
||||
self.stop_realtime_speaker();
|
||||
self.realtime_conversation.audio_processor = None;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
||||
@@ -130,6 +130,8 @@ pub mod onboarding;
|
||||
mod oss_selection;
|
||||
mod pager_overlay;
|
||||
pub mod public_widgets;
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
mod realtime_audio_processing;
|
||||
mod render;
|
||||
mod resume_picker;
|
||||
mod selection_list;
|
||||
|
||||
298
codex-rs/tui/src/realtime_audio_processing.rs
Normal file
298
codex-rs/tui/src/realtime_audio_processing.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use aec3::voip::VoipAec3;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::mpsc::Sender;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) const AUDIO_PROCESSING_SAMPLE_RATE: u32 = 24_000;
|
||||
pub(crate) const AUDIO_PROCESSING_CHANNELS: u16 = 1;
|
||||
|
||||
enum AudioProcessorCommand {
|
||||
Capture(Vec<i16>),
|
||||
Render(Vec<i16>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RealtimeAudioProcessor {
|
||||
capture_rx: Arc<Mutex<Option<Receiver<Vec<i16>>>>>,
|
||||
command_tx: Sender<AudioProcessorCommand>,
|
||||
}
|
||||
|
||||
impl RealtimeAudioProcessor {
|
||||
pub(crate) fn new() -> Result<Self, String> {
|
||||
build_pipeline()?;
|
||||
|
||||
let (command_tx, command_rx) = mpsc::channel();
|
||||
let (capture_tx, capture_rx) = mpsc::channel();
|
||||
std::thread::Builder::new()
|
||||
.name("codex-realtime-aec3".to_string())
|
||||
.spawn(move || run_processor_thread(command_rx, capture_tx))
|
||||
.map_err(|err| format!("failed to spawn realtime audio processor: {err}"))?;
|
||||
|
||||
Ok(Self {
|
||||
capture_rx: Arc::new(Mutex::new(Some(capture_rx))),
|
||||
command_tx,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn capture_stage(
|
||||
&self,
|
||||
input_sample_rate: u32,
|
||||
input_channels: u16,
|
||||
) -> Result<RealtimeCaptureAudioProcessor, String> {
|
||||
let capture_rx = self
|
||||
.capture_rx
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|mut capture_rx| capture_rx.take())
|
||||
.ok_or_else(|| "realtime capture stage was already created".to_string())?;
|
||||
|
||||
Ok(RealtimeCaptureAudioProcessor {
|
||||
capture_rx,
|
||||
command_tx: self.command_tx.clone(),
|
||||
input_sample_rate,
|
||||
input_channels,
|
||||
processed_samples: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn render_stage(
|
||||
&self,
|
||||
output_sample_rate: u32,
|
||||
output_channels: u16,
|
||||
) -> RealtimeRenderAudioProcessor {
|
||||
RealtimeRenderAudioProcessor {
|
||||
command_tx: self.command_tx.clone(),
|
||||
output_sample_rate,
|
||||
output_channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RealtimeCaptureAudioProcessor {
|
||||
capture_rx: Receiver<Vec<i16>>,
|
||||
command_tx: Sender<AudioProcessorCommand>,
|
||||
input_sample_rate: u32,
|
||||
input_channels: u16,
|
||||
processed_samples: VecDeque<i16>,
|
||||
}
|
||||
|
||||
impl RealtimeCaptureAudioProcessor {
|
||||
pub(crate) fn process_samples(&mut self, samples: &[i16]) -> Vec<i16> {
|
||||
let converted = convert_pcm16(
|
||||
samples,
|
||||
self.input_sample_rate,
|
||||
self.input_channels,
|
||||
AUDIO_PROCESSING_SAMPLE_RATE,
|
||||
AUDIO_PROCESSING_CHANNELS,
|
||||
);
|
||||
if !converted.is_empty()
|
||||
&& let Err(err) = self
|
||||
.command_tx
|
||||
.send(AudioProcessorCommand::Capture(converted))
|
||||
{
|
||||
warn!("failed to queue realtime capture audio: {err}");
|
||||
}
|
||||
|
||||
loop {
|
||||
match self.capture_rx.try_recv() {
|
||||
Ok(processed) => self.processed_samples.extend(processed),
|
||||
Err(mpsc::TryRecvError::Empty) => break,
|
||||
Err(mpsc::TryRecvError::Disconnected) => {
|
||||
warn!("realtime capture audio processor disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.processed_samples.drain(..).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RealtimeRenderAudioProcessor {
|
||||
command_tx: Sender<AudioProcessorCommand>,
|
||||
output_sample_rate: u32,
|
||||
output_channels: u16,
|
||||
}
|
||||
|
||||
impl RealtimeRenderAudioProcessor {
|
||||
pub(crate) fn process_samples(&mut self, samples: &[i16]) {
|
||||
let converted = convert_pcm16(
|
||||
samples,
|
||||
self.output_sample_rate,
|
||||
self.output_channels,
|
||||
AUDIO_PROCESSING_SAMPLE_RATE,
|
||||
AUDIO_PROCESSING_CHANNELS,
|
||||
);
|
||||
if !converted.is_empty()
|
||||
&& let Err(err) = self
|
||||
.command_tx
|
||||
.send(AudioProcessorCommand::Render(converted))
|
||||
{
|
||||
warn!("failed to queue realtime render audio: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_pipeline() -> Result<VoipAec3, String> {
|
||||
VoipAec3::builder(
|
||||
AUDIO_PROCESSING_SAMPLE_RATE as usize,
|
||||
usize::from(AUDIO_PROCESSING_CHANNELS),
|
||||
usize::from(AUDIO_PROCESSING_CHANNELS),
|
||||
)
|
||||
.enable_high_pass(false)
|
||||
.enable_noise_suppression(false)
|
||||
.build()
|
||||
.map_err(|err| format!("failed to initialize realtime audio processor: {err}"))
|
||||
}
|
||||
|
||||
fn run_processor_thread(command_rx: Receiver<AudioProcessorCommand>, capture_tx: Sender<Vec<i16>>) {
|
||||
let mut pipeline = match build_pipeline() {
|
||||
Ok(pipeline) => pipeline,
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let capture_frame_len =
|
||||
pipeline.capture_frame_samples() * usize::from(AUDIO_PROCESSING_CHANNELS);
|
||||
let render_frame_len = pipeline.render_frame_samples() * usize::from(AUDIO_PROCESSING_CHANNELS);
|
||||
let mut pending_capture = VecDeque::new();
|
||||
let mut pending_render = VecDeque::new();
|
||||
|
||||
while let Ok(command) = command_rx.recv() {
|
||||
match command {
|
||||
AudioProcessorCommand::Capture(samples) => {
|
||||
pending_capture.extend(samples);
|
||||
while pending_capture.len() >= capture_frame_len {
|
||||
let capture_frame =
|
||||
drain_pending_frame(&mut pending_capture, capture_frame_len);
|
||||
let capture_frame = capture_frame
|
||||
.iter()
|
||||
.copied()
|
||||
.map(i16_to_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let mut output = vec![0.0; capture_frame.len()];
|
||||
if let Err(err) = pipeline.process_capture_frame(
|
||||
&capture_frame,
|
||||
/*level_change*/ false,
|
||||
&mut output,
|
||||
) {
|
||||
warn!("failed to process realtime capture audio: {err}");
|
||||
continue;
|
||||
}
|
||||
|
||||
let processed = output.into_iter().map(f32_to_i16).collect::<Vec<_>>();
|
||||
if let Err(err) = capture_tx.send(processed) {
|
||||
warn!("failed to deliver realtime capture audio: {err}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
AudioProcessorCommand::Render(samples) => {
|
||||
pending_render.extend(samples);
|
||||
while pending_render.len() >= render_frame_len {
|
||||
let render_frame = drain_pending_frame(&mut pending_render, render_frame_len);
|
||||
let render_frame = render_frame
|
||||
.iter()
|
||||
.copied()
|
||||
.map(i16_to_f32)
|
||||
.collect::<Vec<_>>();
|
||||
if let Err(err) = pipeline.handle_render_frame(&render_frame) {
|
||||
warn!("failed to process realtime render audio: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn drain_pending_frame(pending_samples: &mut VecDeque<i16>, frame_len: usize) -> Vec<i16> {
|
||||
pending_samples.drain(..frame_len).collect()
|
||||
}
|
||||
|
||||
pub(crate) fn convert_pcm16(
|
||||
input: &[i16],
|
||||
input_sample_rate: u32,
|
||||
input_channels: u16,
|
||||
output_sample_rate: u32,
|
||||
output_channels: u16,
|
||||
) -> Vec<i16> {
|
||||
if input.is_empty() || input_channels == 0 || output_channels == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let in_channels = input_channels as usize;
|
||||
let out_channels = output_channels as usize;
|
||||
let in_frames = input.len() / in_channels;
|
||||
if in_frames == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let out_frames = if input_sample_rate == output_sample_rate {
|
||||
in_frames
|
||||
} else {
|
||||
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
|
||||
as usize
|
||||
};
|
||||
|
||||
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
|
||||
for out_frame_idx in 0..out_frames {
|
||||
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
|
||||
0
|
||||
} else {
|
||||
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
|
||||
};
|
||||
let src_start = src_frame_idx.saturating_mul(in_channels);
|
||||
let src = &input[src_start..src_start + in_channels];
|
||||
match (in_channels, out_channels) {
|
||||
(1, 1) => out.push(src[0]),
|
||||
(1, n) => {
|
||||
for _ in 0..n {
|
||||
out.push(src[0]);
|
||||
}
|
||||
}
|
||||
(n, 1) if n >= 2 => {
|
||||
let sum: i32 = src.iter().map(|s| *s as i32).sum();
|
||||
out.push((sum / (n as i32)) as i16);
|
||||
}
|
||||
(n, m) if n == m => out.extend_from_slice(src),
|
||||
(n, m) if n > m => out.extend_from_slice(&src[..m]),
|
||||
(n, m) => {
|
||||
out.extend_from_slice(src);
|
||||
let last = *src.last().unwrap_or(&0);
|
||||
for _ in n..m {
|
||||
out.push(last);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn i16_to_f32(sample: i16) -> f32 {
|
||||
(sample as f32) / (i16::MAX as f32)
|
||||
}
|
||||
|
||||
fn f32_to_i16(sample: f32) -> i16 {
|
||||
(sample.clamp(-1.0, 1.0) * i16::MAX as f32) as i16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::convert_pcm16;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
|
||||
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
|
||||
let converted = convert_pcm16(
|
||||
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
|
||||
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
|
||||
);
|
||||
assert_eq!(converted, vec![200, 700]);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,10 @@
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use crate::realtime_audio_processing::AUDIO_PROCESSING_CHANNELS;
|
||||
use crate::realtime_audio_processing::AUDIO_PROCESSING_SAMPLE_RATE;
|
||||
use crate::realtime_audio_processing::RealtimeAudioProcessor;
|
||||
use crate::realtime_audio_processing::RealtimeCaptureAudioProcessor;
|
||||
use crate::realtime_audio_processing::RealtimeRenderAudioProcessor;
|
||||
use crate::realtime_audio_processing::convert_pcm16;
|
||||
use base64::Engine;
|
||||
use codex_core::config::Config;
|
||||
use codex_protocol::protocol::ConversationAudioParams;
|
||||
@@ -23,7 +29,11 @@ pub struct VoiceCapture {
|
||||
}
|
||||
|
||||
impl VoiceCapture {
|
||||
pub fn start_realtime(config: &Config, tx: AppEventSender) -> Result<Self, String> {
|
||||
pub fn start_realtime(
|
||||
config: &Config,
|
||||
tx: AppEventSender,
|
||||
audio_processor: RealtimeAudioProcessor,
|
||||
) -> Result<Self, String> {
|
||||
let (device, config) = select_realtime_input_device_and_config(config)?;
|
||||
|
||||
let sample_rate = config.sample_rate().0;
|
||||
@@ -34,9 +44,8 @@ impl VoiceCapture {
|
||||
let stream = build_realtime_input_stream(
|
||||
&device,
|
||||
&config,
|
||||
sample_rate,
|
||||
channels,
|
||||
tx,
|
||||
audio_processor.capture_stage(sample_rate, channels)?,
|
||||
last_peak.clone(),
|
||||
)?;
|
||||
stream
|
||||
@@ -138,50 +147,76 @@ fn select_realtime_input_device_and_config(
|
||||
fn build_realtime_input_stream(
|
||||
device: &cpal::Device,
|
||||
config: &cpal::SupportedStreamConfig,
|
||||
sample_rate: u32,
|
||||
channels: u16,
|
||||
tx: AppEventSender,
|
||||
capture_processor: RealtimeCaptureAudioProcessor,
|
||||
last_peak: Arc<AtomicU16>,
|
||||
) -> Result<cpal::Stream, String> {
|
||||
match config.sample_format() {
|
||||
cpal::SampleFormat::F32 => device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[f32], _| {
|
||||
let peak = peak_f32(input);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
|
||||
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}")),
|
||||
cpal::SampleFormat::I16 => device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[i16], _| {
|
||||
let peak = peak_i16(input);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
send_realtime_audio_chunk(&tx, input.to_vec(), sample_rate, channels);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}")),
|
||||
cpal::SampleFormat::U16 => device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[u16], _| {
|
||||
let mut samples = Vec::with_capacity(input.len());
|
||||
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}")),
|
||||
cpal::SampleFormat::F32 => {
|
||||
let mut capture_processor = capture_processor;
|
||||
device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[f32], _| {
|
||||
let peak = peak_f32(input);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
|
||||
let samples = capture_processor.process_samples(&samples);
|
||||
send_realtime_audio_chunk(
|
||||
&tx,
|
||||
samples,
|
||||
AUDIO_PROCESSING_SAMPLE_RATE,
|
||||
AUDIO_PROCESSING_CHANNELS,
|
||||
);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}"))
|
||||
}
|
||||
cpal::SampleFormat::I16 => {
|
||||
let mut capture_processor = capture_processor;
|
||||
device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[i16], _| {
|
||||
let peak = peak_i16(input);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
let samples = capture_processor.process_samples(input);
|
||||
send_realtime_audio_chunk(
|
||||
&tx,
|
||||
samples,
|
||||
AUDIO_PROCESSING_SAMPLE_RATE,
|
||||
AUDIO_PROCESSING_CHANNELS,
|
||||
);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}"))
|
||||
}
|
||||
cpal::SampleFormat::U16 => {
|
||||
let mut capture_processor = capture_processor;
|
||||
device
|
||||
.build_input_stream(
|
||||
&config.clone().into(),
|
||||
move |input: &[u16], _| {
|
||||
let mut samples = Vec::with_capacity(input.len());
|
||||
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
|
||||
last_peak.store(peak, Ordering::Relaxed);
|
||||
let samples = capture_processor.process_samples(&samples);
|
||||
send_realtime_audio_chunk(
|
||||
&tx,
|
||||
samples,
|
||||
AUDIO_PROCESSING_SAMPLE_RATE,
|
||||
AUDIO_PROCESSING_CHANNELS,
|
||||
);
|
||||
},
|
||||
move |err| error!("audio input error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build input stream: {e}"))
|
||||
}
|
||||
_ => Err("unsupported input sample format".to_string()),
|
||||
}
|
||||
}
|
||||
@@ -288,13 +323,21 @@ pub(crate) struct RealtimeAudioPlayer {
|
||||
}
|
||||
|
||||
impl RealtimeAudioPlayer {
|
||||
pub(crate) fn start(config: &Config) -> Result<Self, String> {
|
||||
pub(crate) fn start(
|
||||
config: &Config,
|
||||
audio_processor: RealtimeAudioProcessor,
|
||||
) -> Result<Self, String> {
|
||||
let (device, config) =
|
||||
crate::audio_device::select_configured_output_device_and_config(config)?;
|
||||
let output_sample_rate = config.sample_rate().0;
|
||||
let output_channels = config.channels();
|
||||
let queue = Arc::new(Mutex::new(VecDeque::new()));
|
||||
let stream = build_output_stream(&device, &config, Arc::clone(&queue))?;
|
||||
let stream = build_output_stream(
|
||||
&device,
|
||||
&config,
|
||||
Arc::clone(&queue),
|
||||
audio_processor.render_stage(output_sample_rate, output_channels),
|
||||
)?;
|
||||
stream
|
||||
.play()
|
||||
.map_err(|e| format!("failed to start output stream: {e}"))?;
|
||||
@@ -350,140 +393,94 @@ fn build_output_stream(
|
||||
device: &cpal::Device,
|
||||
config: &cpal::SupportedStreamConfig,
|
||||
queue: Arc<Mutex<VecDeque<i16>>>,
|
||||
render_processor: RealtimeRenderAudioProcessor,
|
||||
) -> Result<cpal::Stream, String> {
|
||||
let config_any: cpal::StreamConfig = config.clone().into();
|
||||
match config.sample_format() {
|
||||
cpal::SampleFormat::F32 => device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [f32], _| fill_output_f32(output, &queue),
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build f32 output stream: {e}")),
|
||||
cpal::SampleFormat::I16 => device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [i16], _| fill_output_i16(output, &queue),
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build i16 output stream: {e}")),
|
||||
cpal::SampleFormat::U16 => device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [u16], _| fill_output_u16(output, &queue),
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build u16 output stream: {e}")),
|
||||
cpal::SampleFormat::F32 => {
|
||||
let mut render_processor = render_processor;
|
||||
device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [f32], _| {
|
||||
fill_output_f32(output, &queue, &mut render_processor)
|
||||
},
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build f32 output stream: {e}"))
|
||||
}
|
||||
cpal::SampleFormat::I16 => {
|
||||
let mut render_processor = render_processor;
|
||||
device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [i16], _| {
|
||||
fill_output_i16(output, &queue, &mut render_processor)
|
||||
},
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build i16 output stream: {e}"))
|
||||
}
|
||||
cpal::SampleFormat::U16 => {
|
||||
let mut render_processor = render_processor;
|
||||
device
|
||||
.build_output_stream(
|
||||
&config_any,
|
||||
move |output: &mut [u16], _| {
|
||||
fill_output_u16(output, &queue, &mut render_processor)
|
||||
},
|
||||
move |err| error!("audio output error: {err}"),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| format!("failed to build u16 output stream: {e}"))
|
||||
}
|
||||
other => Err(format!("unsupported output sample format: {other:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_output_i16(output: &mut [i16], queue: &Arc<Mutex<VecDeque<i16>>>) {
|
||||
if let Ok(mut guard) = queue.lock() {
|
||||
for sample in output {
|
||||
*sample = guard.pop_front().unwrap_or(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output.fill(0);
|
||||
fn fill_output_i16(
|
||||
output: &mut [i16],
|
||||
queue: &Arc<Mutex<VecDeque<i16>>>,
|
||||
render_processor: &mut RealtimeRenderAudioProcessor,
|
||||
) {
|
||||
let samples = drain_output_samples(output.len(), queue);
|
||||
output.copy_from_slice(&samples);
|
||||
render_processor.process_samples(output);
|
||||
}
|
||||
|
||||
fn fill_output_f32(output: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
|
||||
if let Ok(mut guard) = queue.lock() {
|
||||
for sample in output {
|
||||
let v = guard.pop_front().unwrap_or(0);
|
||||
*sample = (v as f32) / (i16::MAX as f32);
|
||||
}
|
||||
return;
|
||||
fn fill_output_f32(
|
||||
output: &mut [f32],
|
||||
queue: &Arc<Mutex<VecDeque<i16>>>,
|
||||
render_processor: &mut RealtimeRenderAudioProcessor,
|
||||
) {
|
||||
let samples = drain_output_samples(output.len(), queue);
|
||||
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
|
||||
*output_sample = (*sample as f32) / (i16::MAX as f32);
|
||||
}
|
||||
output.fill(0.0);
|
||||
render_processor.process_samples(&samples);
|
||||
}
|
||||
|
||||
fn fill_output_u16(output: &mut [u16], queue: &Arc<Mutex<VecDeque<i16>>>) {
|
||||
if let Ok(mut guard) = queue.lock() {
|
||||
for sample in output {
|
||||
let v = guard.pop_front().unwrap_or(0);
|
||||
*sample = (v as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
|
||||
}
|
||||
return;
|
||||
fn fill_output_u16(
|
||||
output: &mut [u16],
|
||||
queue: &Arc<Mutex<VecDeque<i16>>>,
|
||||
render_processor: &mut RealtimeRenderAudioProcessor,
|
||||
) {
|
||||
let samples = drain_output_samples(output.len(), queue);
|
||||
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
|
||||
*output_sample = (*sample as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
|
||||
}
|
||||
output.fill(32768);
|
||||
render_processor.process_samples(&samples);
|
||||
}
|
||||
|
||||
fn convert_pcm16(
|
||||
input: &[i16],
|
||||
input_sample_rate: u32,
|
||||
input_channels: u16,
|
||||
output_sample_rate: u32,
|
||||
output_channels: u16,
|
||||
) -> Vec<i16> {
|
||||
if input.is_empty() || input_channels == 0 || output_channels == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let in_channels = input_channels as usize;
|
||||
let out_channels = output_channels as usize;
|
||||
let in_frames = input.len() / in_channels;
|
||||
if in_frames == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let out_frames = if input_sample_rate == output_sample_rate {
|
||||
in_frames
|
||||
} else {
|
||||
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
|
||||
as usize
|
||||
fn drain_output_samples(output_len: usize, queue: &Arc<Mutex<VecDeque<i16>>>) -> Vec<i16> {
|
||||
let mut samples = vec![0; output_len];
|
||||
let Ok(mut guard) = queue.lock() else {
|
||||
return samples;
|
||||
};
|
||||
|
||||
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
|
||||
for out_frame_idx in 0..out_frames {
|
||||
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
|
||||
0
|
||||
} else {
|
||||
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
|
||||
};
|
||||
let src_start = src_frame_idx.saturating_mul(in_channels);
|
||||
let src = &input[src_start..src_start + in_channels];
|
||||
match (in_channels, out_channels) {
|
||||
(1, 1) => out.push(src[0]),
|
||||
(1, n) => {
|
||||
for _ in 0..n {
|
||||
out.push(src[0]);
|
||||
}
|
||||
}
|
||||
(n, 1) if n >= 2 => {
|
||||
let sum: i32 = src.iter().map(|s| *s as i32).sum();
|
||||
out.push((sum / (n as i32)) as i16);
|
||||
}
|
||||
(n, m) if n == m => out.extend_from_slice(src),
|
||||
(n, m) if n > m => out.extend_from_slice(&src[..m]),
|
||||
(n, m) => {
|
||||
out.extend_from_slice(src);
|
||||
let last = *src.last().unwrap_or(&0);
|
||||
for _ in n..m {
|
||||
out.push(last);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::convert_pcm16;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
|
||||
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
|
||||
let converted = convert_pcm16(
|
||||
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
|
||||
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
|
||||
);
|
||||
assert_eq!(converted, vec![200, 700]);
|
||||
for sample in &mut samples {
|
||||
*sample = guard.pop_front().unwrap_or(0);
|
||||
}
|
||||
samples
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user