diff --git a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs index d125784483..dafdd9eb88 100644 --- a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs +++ b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs @@ -51,7 +51,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { vec![], vec![ json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24_000, "channels": 1, diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs index 60cb5d2c31..84b0cd0983 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -10,6 +10,12 @@ use crate::endpoint::realtime_websocket::protocol::SessionAudio; use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat; use crate::endpoint::realtime_websocket::protocol::SessionAudioInput; use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput; +use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputFormat; +use crate::endpoint::realtime_websocket::protocol::SessionTool; +use crate::endpoint::realtime_websocket::protocol::SessionToolParameters; +use crate::endpoint::realtime_websocket::protocol::SessionToolProperties; +use crate::endpoint::realtime_websocket::protocol::SessionToolProperty; +use crate::endpoint::realtime_websocket::protocol::SessionTurnDetection; use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; use crate::endpoint::realtime_websocket::protocol::parse_realtime_event; use crate::error::ApiError; @@ -19,6 +25,7 @@ use futures::SinkExt; use futures::StreamExt; use http::HeaderMap; use http::HeaderValue; +use serde_json::json; use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::AtomicBool; @@ -228,6 +235,20 @@ impl RealtimeWebsocketConnection { .await } + pub async fn send_function_call_output( + &self, + call_id: String, + output_text: String, + ) -> Result<(), ApiError> { + self.writer + .send_function_call_output(call_id, output_text) + .await + } + + pub async fn send_response_create(&self) -> Result<(), ApiError> { + self.writer.send_response_create().await + } + pub async fn close(&self) -> Result<(), ApiError> { self.writer.close().await } @@ -272,11 +293,10 @@ impl RealtimeWebsocketWriter { pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> { self.send_json(RealtimeOutboundMessage::ConversationItemCreate { - item: ConversationItem { - kind: "message".to_string(), + item: ConversationItem::Message { role: "user".to_string(), content: vec![ConversationItemContent { - kind: "text".to_string(), + kind: "input_text".to_string(), text, }], }, @@ -286,32 +306,85 @@ impl RealtimeWebsocketWriter { pub async fn send_conversation_handoff_append( &self, - handoff_id: String, + _handoff_id: String, output_text: String, ) -> Result<(), ApiError> { - self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend { - handoff_id, - output_text, + self.send_json(RealtimeOutboundMessage::ConversationItemCreate { + item: ConversationItem::Message { + role: "assistant".to_string(), + content: vec![ConversationItemContent { + kind: "output_text".to_string(), + text: output_text, + }], + }, }) .await } + pub async fn send_function_call_output( + &self, + call_id: String, + output_text: String, + ) -> Result<(), ApiError> { + let output = json!({ + "content": output_text, + }) + .to_string(); + self.send_json(RealtimeOutboundMessage::ConversationItemCreate { + item: ConversationItem::FunctionCallOutput { call_id, output }, + }) + .await + } + + pub async fn send_response_create(&self) -> Result<(), ApiError> { + self.send_json(RealtimeOutboundMessage::ResponseCreate) + .await + } + pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> { self.send_json(RealtimeOutboundMessage::SessionUpdate { session: SessionUpdateSession { - kind: "quicksilver".to_string(), + kind: "realtime".to_string(), instructions, + output_modalities: vec!["audio".to_string()], audio: SessionAudio { input: SessionAudioInput { format: SessionAudioFormat { kind: "audio/pcm".to_string(), rate: 24_000, }, + turn_detection: SessionTurnDetection { + kind: "server_vad".to_string(), + interrupt_response: true, + create_response: true, + }, }, output: SessionAudioOutput { - voice: "fathom".to_string(), + format: SessionAudioOutputFormat { + kind: "audio/pcm".to_string(), + rate: 24_000, + }, + voice: "marin".to_string(), }, }, + tools: vec![SessionTool { + kind: "function".to_string(), + name: "codex".to_string(), + description: + "Delegate a request to Codex and return the final result to the user." + .to_string(), + parameters: SessionToolParameters { + kind: "object".to_string(), + properties: SessionToolProperties { + prompt: SessionToolProperty { + kind: "string".to_string(), + description: "The user request to delegate to Codex.".to_string(), + }, + }, + required: vec!["prompt".to_string()], + }, + }], + tool_choice: "auto".to_string(), }, }) .await @@ -558,15 +631,16 @@ fn websocket_url_from_api_url( } } - { + let has_additional_query_params = query_params + .is_some_and(|params| params.keys().any(|key| key != "model" || model.is_none())); + if model.is_some() || has_additional_query_params { let mut query = url.query_pairs_mut(); - query.append_pair("intent", "quicksilver"); if let Some(model) = model { query.append_pair("model", model); } if let Some(query_params) = query_params { for (key, value) in query_params { - if key == "intent" || (key == "model" && model.is_some()) { + if key == "model" && model.is_some() { continue; } query.append_pair(key, value); @@ -639,7 +713,7 @@ mod tests { #[test] fn parse_audio_delta_event() { let payload = json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AAA=", "sample_rate": 48000, "channels": 1, @@ -657,6 +731,24 @@ mod tests { ); } + #[test] + fn parse_audio_delta_event_defaults_audio_shape() { + let payload = json!({ + "type": "response.output_audio.delta", + "delta": "AAA=" + }) + .to_string(); + assert_eq!( + parse_realtime_event(payload.as_str()), + Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { + data: "AAA=".to_string(), + sample_rate: 24_000, + num_channels: 1, + samples_per_channel: None, + })) + ); + } + #[test] fn parse_conversation_item_added_event() { let payload = json!({ @@ -690,10 +782,18 @@ mod tests { #[test] fn parse_handoff_requested_event() { let payload = json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_123", - "item_id": "item_123", - "input_transcript": "delegate this" + "type": "response.done", + "response": { + "output": [ + { + "id": "item_123", + "type": "function_call", + "name": "codex", + "call_id": "handoff_123", + "arguments": "{\"prompt\":\"delegate this\"}" + } + ] + } }) .to_string(); @@ -708,6 +808,24 @@ mod tests { ); } + #[test] + fn parse_unknown_event_as_conversation_item_added() { + let payload = json!({ + "type": "response.output_text.delta", + "delta": "hello", + "response_id": "resp_1" + }) + .to_string(); + assert_eq!( + parse_realtime_event(payload.as_str()), + Some(RealtimeEvent::ConversationItemAdded(json!({ + "type": "response.output_text.delta", + "delta": "hello", + "response_id": "resp_1" + }))) + ); + } + #[test] fn parse_input_transcript_delta_event() { let payload = json!({ @@ -781,10 +899,7 @@ mod tests { fn websocket_url_from_http_base_defaults_to_ws_path() { let url = websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url"); - assert_eq!( - url.as_str(), - "ws://127.0.0.1:8011/v1/realtime?intent=quicksilver" - ); + assert_eq!(url.as_str(), "ws://127.0.0.1:8011/v1/realtime"); } #[test] @@ -794,7 +909,7 @@ mod tests { .expect("build ws url"); assert_eq!( url.as_str(), - "wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model" + "wss://example.com/v1/realtime?model=realtime-test-model" ); } @@ -804,7 +919,7 @@ mod tests { .expect("build ws url"); assert_eq!( url.as_str(), - "wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot" + "wss://api.openai.com/v1/realtime?model=snapshot" ); } @@ -815,7 +930,7 @@ mod tests { .expect("build ws url"); assert_eq!( url.as_str(), - "wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot" + "wss://example.com/openai/v1/realtime?model=snapshot" ); } @@ -823,16 +938,13 @@ mod tests { fn websocket_url_preserves_existing_realtime_path_and_extra_query_params() { let url = websocket_url_from_api_url( "https://example.com/v1/realtime?foo=bar", - Some(&HashMap::from([ - ("trace".to_string(), "1".to_string()), - ("intent".to_string(), "ignored".to_string()), - ])), + Some(&HashMap::from([("trace".to_string(), "1".to_string())])), Some("snapshot"), ) .expect("build ws url"); assert_eq!( url.as_str(), - "wss://example.com/v1/realtime?foo=bar&intent=quicksilver&model=snapshot&trace=1" + "wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1" ); } @@ -856,12 +968,16 @@ mod tests { assert_eq!(first_json["type"], "session.update"); assert_eq!( first_json["session"]["type"], - Value::String("quicksilver".to_string()) + Value::String("realtime".to_string()) ); assert_eq!( first_json["session"]["instructions"], Value::String("backend prompt".to_string()) ); + assert_eq!( + first_json["session"]["output_modalities"][0], + Value::String("audio".to_string()) + ); assert_eq!( first_json["session"]["audio"]["input"]["format"]["type"], Value::String("audio/pcm".to_string()) @@ -870,9 +986,45 @@ mod tests { first_json["session"]["audio"]["input"]["format"]["rate"], Value::from(24_000) ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["type"], + Value::String("server_vad".to_string()) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"], + Value::Bool(true) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["create_response"], + Value::Bool(true) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["format"]["type"], + Value::String("audio/pcm".to_string()) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["format"]["rate"], + Value::from(24_000) + ); assert_eq!( first_json["session"]["audio"]["output"]["voice"], - Value::String("fathom".to_string()) + Value::String("marin".to_string()) + ); + assert_eq!( + first_json["session"]["tool_choice"], + Value::String("auto".to_string()) + ); + assert_eq!( + first_json["session"]["tools"][0]["type"], + Value::String("function".to_string()) + ); + assert_eq!( + first_json["session"]["tools"][0]["name"], + Value::String("codex".to_string()) + ); + assert_eq!( + first_json["session"]["tools"][0]["parameters"]["required"][0], + Value::String("prompt".to_string()) ); ws.send(Message::Text( @@ -915,13 +1067,43 @@ mod tests { .into_text() .expect("text"); let fourth_json: Value = serde_json::from_str(&fourth).expect("json"); - assert_eq!(fourth_json["type"], "conversation.handoff.append"); - assert_eq!(fourth_json["handoff_id"], "handoff_1"); - assert_eq!(fourth_json["output_text"], "hello from codex"); + assert_eq!(fourth_json["type"], "conversation.item.create"); + assert_eq!(fourth_json["item"]["type"], "message"); + assert_eq!(fourth_json["item"]["role"], "assistant"); + assert_eq!( + fourth_json["item"]["content"][0]["type"], + Value::String("output_text".to_string()) + ); + assert_eq!( + fourth_json["item"]["content"][0]["text"], + Value::String("hello from codex".to_string()) + ); + + let fifth = ws + .next() + .await + .expect("fifth msg") + .expect("fifth msg ok") + .into_text() + .expect("text"); + let fifth_json: Value = serde_json::from_str(&fifth).expect("json"); + assert_eq!(fifth_json["type"], "conversation.item.create"); + assert_eq!(fifth_json["item"]["type"], "function_call_output"); + assert_eq!(fifth_json["item"]["call_id"], "handoff_1"); + + let sixth = ws + .next() + .await + .expect("sixth msg") + .expect("sixth msg ok") + .into_text() + .expect("text"); + let sixth_json: Value = serde_json::from_str(&sixth).expect("json"); + assert_eq!(sixth_json["type"], "response.create"); ws.send(Message::Text( json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 48000, "channels": 1 @@ -967,10 +1149,18 @@ mod tests { ws.send(Message::Text( json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_1", - "item_id": "item_2", - "input_transcript": "delegate now" + "type": "response.done", + "response": { + "output": [ + { + "id": "item_2", + "type": "function_call", + "name": "codex", + "call_id": "handoff_1", + "arguments": "{\"prompt\":\"delegate now\"}" + } + ] + } }) .to_string() .into(), @@ -1040,6 +1230,14 @@ mod tests { ) .await .expect("send handoff"); + connection + .send_function_call_output("handoff_1".to_string(), "final from codex".to_string()) + .await + .expect("send function output"); + connection + .send_response_create() + .await + .expect("send response.create"); let audio_event = connection .next_event() diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs index 7967d59991..daac248b9a 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs @@ -3,8 +3,10 @@ pub use codex_protocol::protocol::RealtimeEvent; pub use codex_protocol::protocol::RealtimeHandoffRequested; pub use codex_protocol::protocol::RealtimeTranscriptDelta; pub use codex_protocol::protocol::RealtimeTranscriptEntry; +use serde::Deserialize; use serde::Serialize; use serde_json::Value; +use std::string::ToString; use tracing::debug; #[derive(Debug, Clone, PartialEq, Eq)] @@ -19,11 +21,8 @@ pub struct RealtimeSessionConfig { pub(super) enum RealtimeOutboundMessage { #[serde(rename = "input_audio_buffer.append")] InputAudioBufferAppend { audio: String }, - #[serde(rename = "conversation.handoff.append")] - ConversationHandoffAppend { - handoff_id: String, - output_text: String, - }, + #[serde(rename = "response.create")] + ResponseCreate, #[serde(rename = "session.update")] SessionUpdate { session: SessionUpdateSession }, #[serde(rename = "conversation.item.create")] @@ -35,7 +34,10 @@ pub(super) struct SessionUpdateSession { #[serde(rename = "type")] pub(super) kind: String, pub(super) instructions: String, + pub(super) output_modalities: Vec, pub(super) audio: SessionAudio, + pub(super) tools: Vec, + pub(super) tool_choice: String, } #[derive(Debug, Clone, Serialize)] @@ -47,6 +49,7 @@ pub(super) struct SessionAudio { #[derive(Debug, Clone, Serialize)] pub(super) struct SessionAudioInput { pub(super) format: SessionAudioFormat, + pub(super) turn_detection: SessionTurnDetection, } #[derive(Debug, Clone, Serialize)] @@ -56,17 +59,66 @@ pub(super) struct SessionAudioFormat { pub(super) rate: u32, } +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionTurnDetection { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) interrupt_response: bool, + pub(super) create_response: bool, +} + #[derive(Debug, Clone, Serialize)] pub(super) struct SessionAudioOutput { + pub(super) format: SessionAudioOutputFormat, pub(super) voice: String, } #[derive(Debug, Clone, Serialize)] -pub(super) struct ConversationItem { +pub(super) struct SessionAudioOutputFormat { #[serde(rename = "type")] pub(super) kind: String, - pub(super) role: String, - pub(super) content: Vec, + pub(super) rate: u32, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionTool { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) name: String, + pub(super) description: String, + pub(super) parameters: SessionToolParameters, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionToolParameters { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) properties: SessionToolProperties, + pub(super) required: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionToolProperties { + pub(super) prompt: SessionToolProperty, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionToolProperty { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) description: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type")] +pub(super) enum ConversationItem { + #[serde(rename = "message")] + Message { + role: String, + content: Vec, + }, + #[serde(rename = "function_call_output")] + FunctionCallOutput { call_id: String, output: String }, } #[derive(Debug, Clone, Serialize)] @@ -93,7 +145,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option { } }; match message_type { - "session.updated" => { + "session.created" | "session.updated" => { let session_id = parsed .get("session") .and_then(Value::as_object) @@ -111,7 +163,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option { instructions, }) } - "conversation.output_audio.delta" => { + "conversation.output_audio.delta" | "response.output_audio.delta" => { let data = parsed .get("delta") .and_then(Value::as_str) @@ -120,12 +172,14 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option { let sample_rate = parsed .get("sample_rate") .and_then(Value::as_u64) - .and_then(|v| u32::try_from(v).ok())?; + .and_then(|v| u32::try_from(v).ok()) + .unwrap_or(24_000); let num_channels = parsed .get("channels") .or_else(|| parsed.get("num_channels")) .and_then(Value::as_u64) - .and_then(|v| u16::try_from(v).ok())?; + .and_then(|v| u16::try_from(v).ok()) + .unwrap_or(1); Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { data, sample_rate, @@ -177,6 +231,12 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option { active_transcript: Vec::new(), })) } + "response.done" => { + if let Some(handoff) = parse_handoff_requested(&parsed) { + return Some(RealtimeEvent::HandoffRequested(handoff)); + } + Some(RealtimeEvent::ConversationItemAdded(parsed)) + } "error" => parsed .get("message") .and_then(Value::as_str) @@ -189,11 +249,86 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option { .and_then(Value::as_str) .map(str::to_string) }) - .or_else(|| parsed.get("error").map(std::string::ToString::to_string)) + .or_else(|| parsed.get("error").map(ToString::to_string)) .map(RealtimeEvent::Error), - _ => { - debug!("received unsupported realtime event type: {message_type}, data: {payload}"); - None - } + _ => Some(RealtimeEvent::ConversationItemAdded(parsed)), } } + +fn parse_handoff_requested(parsed: &Value) -> Option { + let outputs = parsed + .get("response") + .and_then(Value::as_object) + .and_then(|response| response.get("output")) + .and_then(Value::as_array)?; + let function_call = outputs.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call") + && item.get("name").and_then(Value::as_str) == Some("codex") + })?; + let handoff_id = function_call + .get("call_id") + .and_then(Value::as_str) + .map(str::to_string)?; + let item_id = function_call + .get("id") + .and_then(Value::as_str) + .map(str::to_string) + .unwrap_or_else(|| handoff_id.clone()); + let arguments = function_call + .get("arguments") + .and_then(Value::as_str) + .unwrap_or_default(); + Some(RealtimeHandoffRequested { + handoff_id, + item_id, + input_transcript: parse_handoff_arguments(arguments), + active_transcript: Vec::new(), + }) +} + +fn parse_handoff_arguments(arguments: &str) -> String { + #[derive(Debug, Deserialize)] + struct HandoffArguments { + #[serde(default)] + prompt: Option, + #[serde(default)] + text: Option, + #[serde(default)] + input: Option, + #[serde(default)] + message: Option, + #[serde(default)] + input_transcript: Option, + #[serde(default)] + messages: Vec, + } + + let Some(parsed) = serde_json::from_str::(arguments).ok() else { + return arguments.to_string(); + }; + + for value in [ + parsed.prompt, + parsed.text, + parsed.input, + parsed.message, + parsed.input_transcript, + ] + .into_iter() + .flatten() + { + if !value.is_empty() { + return value; + } + } + + if let Some(message) = parsed + .messages + .into_iter() + .find(|message| message.role == "user" && !message.text.is_empty()) + { + return message.text; + } + + String::new() +} diff --git a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs index aa11b79584..e79f0aac07 100644 --- a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs +++ b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs @@ -81,7 +81,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { assert_eq!(first_json["type"], "session.update"); assert_eq!( first_json["session"]["type"], - Value::String("quicksilver".to_string()) + Value::String("realtime".to_string()) ); assert_eq!( first_json["session"]["instructions"], @@ -95,6 +95,42 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { first_json["session"]["audio"]["input"]["format"]["rate"], Value::from(24_000) ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["type"], + Value::String("server_vad".to_string()) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"], + Value::Bool(true) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"]["create_response"], + Value::Bool(true) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["format"]["type"], + Value::String("audio/pcm".to_string()) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["format"]["rate"], + Value::from(24_000) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["voice"], + Value::String("marin".to_string()) + ); + assert_eq!( + first_json["session"]["tool_choice"], + Value::String("auto".to_string()) + ); + assert_eq!( + first_json["session"]["tools"][0]["type"], + Value::String("function".to_string()) + ); + assert_eq!( + first_json["session"]["tools"][0]["name"], + Value::String("codex".to_string()) + ); ws.send(Message::Text( json!({ @@ -119,7 +155,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { ws.send(Message::Text( json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 48000, "channels": 1 @@ -311,7 +347,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() { } #[tokio::test] -async fn realtime_ws_e2e_ignores_unknown_text_events() { +async fn realtime_ws_e2e_forwards_unknown_text_events() { let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move { let first = ws .next() @@ -361,13 +397,26 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() { .await .expect("connect"); - let event = connection + let first_event = connection .next_event() .await .expect("next event") .expect("event"); assert_eq!( - event, + first_event, + RealtimeEvent::ConversationItemAdded(json!({ + "type": "response.created", + "response": {"id": "resp_unknown"} + })) + ); + + let second_event = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + second_event, RealtimeEvent::SessionUpdated { session_id: "sess_after_unknown".to_string(), instructions: Some("backend prompt".to_string()), diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 842b256fdf..2bc25b2997 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2544,6 +2544,9 @@ impl Session { if !matches!(msg, EventMsg::TurnComplete(_)) { return; } + if let Err(err) = self.conversation.handoff_complete().await { + debug!("failed to send final realtime handoff tool output: {err}"); + } self.conversation.clear_active_handoff().await; } diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 3baea265ad..fff7a495ab 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -45,6 +45,7 @@ const USER_TEXT_IN_QUEUE_CAPACITY: usize = 64; const HANDOFF_OUT_QUEUE_CAPACITY: usize = 64; const OUTPUT_EVENTS_QUEUE_CAPACITY: usize = 256; const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000; +const DEFAULT_REALTIME_MODEL: &str = "gpt-realtime-1.5"; pub(crate) struct RealtimeConversationManager { state: Mutex>, @@ -54,12 +55,19 @@ pub(crate) struct RealtimeConversationManager { struct RealtimeHandoffState { output_tx: Sender, active_handoff: Arc>>, + last_output_text: Arc>>, } #[derive(Debug, PartialEq, Eq)] -struct HandoffOutput { - handoff_id: String, - output_text: String, +enum HandoffOutput { + TextUpdate { + handoff_id: String, + output_text: String, + }, + FinalToolCall { + call_id: String, + output_text: String, + }, } impl RealtimeHandoffState { @@ -67,6 +75,7 @@ impl RealtimeHandoffState { Self { output_tx, active_handoff: Arc::new(Mutex::new(None)), + last_output_text: Arc::new(Mutex::new(None)), } } @@ -74,9 +83,10 @@ impl RealtimeHandoffState { let Some(handoff_id) = self.active_handoff.lock().await.clone() else { return Ok(()); }; + *self.last_output_text.lock().await = Some(output_text.clone()); self.output_tx - .send(HandoffOutput { + .send(HandoffOutput::TextUpdate { handoff_id, output_text, }) @@ -84,6 +94,23 @@ impl RealtimeHandoffState { .map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?; Ok(()) } + + async fn send_final_output(&self) -> CodexResult<()> { + let Some(call_id) = self.active_handoff.lock().await.clone() else { + return Ok(()); + }; + let Some(output_text) = self.last_output_text.lock().await.clone() else { + return Ok(()); + }; + self.output_tx + .send(HandoffOutput::FinalToolCall { + call_id, + output_text, + }) + .await + .map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?; + Ok(()) + } } #[allow(dead_code)] @@ -234,6 +261,17 @@ impl RealtimeConversationManager { handoff.send_output(output_text).await } + pub(crate) async fn handoff_complete(&self) -> CodexResult<()> { + let handoff = { + let guard = self.state.lock().await; + guard.as_ref().map(|state| state.handoff.clone()) + }; + let Some(handoff) = handoff else { + return Ok(()); + }; + handoff.send_final_output().await + } + pub(crate) async fn active_handoff_id(&self) -> Option { let handoff = { let guard = self.state.lock().await; @@ -249,6 +287,7 @@ impl RealtimeConversationManager { }; if let Some(handoff) = handoff { *handoff.active_handoff.lock().await = None; + *handoff.last_output_text.lock().await = None; } } @@ -297,7 +336,7 @@ pub(crate) async fn handle_start( } else { format!("{prompt}\n\n{startup_context}") }; - let model = config.experimental_realtime_ws_model.clone(); + let model = Some(DEFAULT_REALTIME_MODEL.to_string()); let requested_session_id = params .session_id @@ -512,17 +551,41 @@ fn spawn_realtime_input_task( } handoff_output = handoff_output_rx.recv() => { match handoff_output { - Ok(HandoffOutput { - handoff_id, - output_text, - }) => { - if let Err(err) = writer - .send_conversation_handoff_append(handoff_id, output_text) - .await - { - let mapped_error = map_api_error(err); - warn!("failed to send handoff output: {mapped_error}"); - break; + Ok(handoff_output) => { + match handoff_output { + HandoffOutput::TextUpdate { + handoff_id, + output_text, + } => { + if let Err(err) = writer + .send_conversation_handoff_append(handoff_id, output_text) + .await + { + let mapped_error = map_api_error(err); + warn!("failed to send handoff output: {mapped_error}"); + break; + } + } + HandoffOutput::FinalToolCall { + call_id, + output_text, + } => { + if let Err(err) = writer + .send_function_call_output(call_id, output_text) + .await + { + let mapped_error = map_api_error(err); + warn!("failed to send handoff tool output: {mapped_error}"); + break; + } + if let Err(err) = writer.send_response_create().await { + let mapped_error = map_api_error(err); + warn!( + "failed to send handoff response.create: {mapped_error}" + ); + break; + } + } } } Err(_) => break, @@ -534,6 +597,7 @@ fn spawn_realtime_input_task( if let RealtimeEvent::HandoffRequested(handoff) = &event { *handoff_state.active_handoff.lock().await = Some(handoff.handoff_id.clone()); + *handoff_state.last_output_text.lock().await = None; } let should_stop = matches!(&event, RealtimeEvent::Error(_)); if events_tx.send(event).await.is_err() { @@ -693,7 +757,7 @@ mod tests { let output_1 = rx.recv().await.expect("recv"); assert_eq!( output_1, - HandoffOutput { + HandoffOutput::TextUpdate { handoff_id: "handoff_1".to_string(), output_text: "result".to_string(), } @@ -702,7 +766,7 @@ mod tests { let output_2 = rx.recv().await.expect("recv"); assert_eq!( output_2, - HandoffOutput { + HandoffOutput::TextUpdate { handoff_id: "handoff_1".to_string(), output_text: "result 2".to_string(), } @@ -715,4 +779,27 @@ mod tests { .expect("send"); assert!(rx.is_empty()); } + + #[tokio::test] + async fn sends_final_tool_call_output_for_active_handoff() { + let (tx, rx) = bounded(4); + let state = RealtimeHandoffState::new(tx); + *state.active_handoff.lock().await = Some("handoff_2".to_string()); + + state + .send_output("final text".to_string()) + .await + .expect("send"); + let _ = rx.recv().await.expect("recv text update"); + + state.send_final_output().await.expect("send final output"); + let final_output = rx.recv().await.expect("recv final output"); + assert_eq!( + final_output, + HandoffOutput::FinalToolCall { + call_id: "handoff_2".to_string(), + output_text: "final text".to_string(), + } + ); + } } diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 0d49f8c8d5..3750a5fa25 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -122,7 +122,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { vec![], vec![ json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 @@ -220,7 +220,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { ); assert_eq!( server.handshakes()[1].uri(), - "/v1/realtime?intent=quicksilver&model=realtime-test-model" + "/v1/realtime?model=gpt-realtime-1.5" ); let mut request_types = [ connection[1].body_json()["type"] @@ -467,7 +467,7 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> { "session": { "id": "sess_new", "instructions": "new" } })], vec![json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 @@ -974,14 +974,11 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re "type": "conversation.input_transcript.delta", "delta": "delegate hello" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_1", - "item_id": "item_1", - "input_transcript": "delegate hello" - }), + realtime_handoff_requested_event("handoff_1", "item_1", "delegate hello"), ], vec![], + vec![], + vec![], ]]) .await; @@ -1025,7 +1022,7 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re let deadline = tokio::time::Instant::now() + Duration::from_secs(2); while tokio::time::Instant::now() < deadline { let connections = realtime_server.connections(); - if connections.len() == 1 && connections[0].len() >= 2 { + if connections.len() == 1 && connections[0].len() >= 4 { break; } tokio::time::sleep(Duration::from_millis(10)).await; @@ -1033,22 +1030,46 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re let realtime_connections = realtime_server.connections(); assert_eq!(realtime_connections.len(), 1); - assert_eq!(realtime_connections[0].len(), 2); + assert_eq!(realtime_connections[0].len(), 4); assert_eq!( realtime_connections[0][0].body_json()["type"].as_str(), Some("session.update") ); assert_eq!( realtime_connections[0][1].body_json()["type"].as_str(), - Some("conversation.handoff.append") + Some("conversation.item.create") ); assert_eq!( - realtime_connections[0][1].body_json()["handoff_id"].as_str(), + realtime_connections[0][1].body_json()["item"]["type"].as_str(), + Some("message") + ); + assert_eq!( + realtime_connections[0][1].body_json()["item"]["role"].as_str(), + Some("assistant") + ); + assert_eq!( + realtime_connections[0][1].body_json()["item"]["content"][0]["type"].as_str(), + Some("output_text") + ); + assert_eq!( + realtime_connections[0][1].body_json()["item"]["content"][0]["text"].as_str(), + Some("assistant says hi") + ); + assert_eq!( + realtime_connections[0][2].body_json()["type"].as_str(), + Some("conversation.item.create") + ); + assert_eq!( + realtime_connections[0][2].body_json()["item"]["type"].as_str(), + Some("function_call_output") + ); + assert_eq!( + realtime_connections[0][2].body_json()["item"]["call_id"].as_str(), Some("handoff_1") ); assert_eq!( - realtime_connections[0][1].body_json()["output_text"].as_str(), - Some("assistant says hi") + realtime_connections[0][3].body_json()["type"].as_str(), + Some("response.create") ); realtime_server.shutdown().await; @@ -1096,18 +1117,15 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() -> "type": "conversation.input_transcript.delta", "delta": "delegate now" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_item_done", - "item_id": "item_item_done", - "input_transcript": "delegate now" - }), + realtime_handoff_requested_event("handoff_item_done", "item_item_done", "delegate now"), ], vec![json!({ "type": "conversation.item.done", "item": { "id": "item_item_done" } })], vec![], + vec![], + vec![], ]]) .await; @@ -1145,14 +1163,22 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() -> let first_append = realtime_server.wait_for_request(0, 1).await; assert_eq!( first_append.body_json()["type"].as_str(), - Some("conversation.handoff.append") + Some("conversation.item.create") ); assert_eq!( - first_append.body_json()["handoff_id"].as_str(), - Some("handoff_item_done") + first_append.body_json()["item"]["type"].as_str(), + Some("message") ); assert_eq!( - first_append.body_json()["output_text"].as_str(), + first_append.body_json()["item"]["role"].as_str(), + Some("assistant") + ); + assert_eq!( + first_append.body_json()["item"]["content"][0]["type"].as_str(), + Some("output_text") + ); + assert_eq!( + first_append.body_json()["item"]["content"][0]["text"].as_str(), Some("assistant message 1") ); @@ -1169,14 +1195,22 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() -> let second_append = realtime_server.wait_for_request(0, 2).await; assert_eq!( second_append.body_json()["type"].as_str(), - Some("conversation.handoff.append") + Some("conversation.item.create") ); assert_eq!( - second_append.body_json()["handoff_id"].as_str(), - Some("handoff_item_done") + second_append.body_json()["item"]["type"].as_str(), + Some("message") ); assert_eq!( - second_append.body_json()["output_text"].as_str(), + second_append.body_json()["item"]["role"].as_str(), + Some("assistant") + ); + assert_eq!( + second_append.body_json()["item"]["content"][0]["type"].as_str(), + Some("output_text") + ); + assert_eq!( + second_append.body_json()["item"]["content"][0]["text"].as_str(), Some("assistant message 2") ); @@ -1192,6 +1226,30 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() -> }) .await; + let final_tool_call = realtime_server.wait_for_request(0, 3).await; + assert_eq!( + final_tool_call.body_json()["type"].as_str(), + Some("conversation.item.create") + ); + assert_eq!( + final_tool_call.body_json()["item"]["type"].as_str(), + Some("function_call_output") + ); + assert_eq!( + final_tool_call.body_json()["item"]["call_id"].as_str(), + Some("handoff_item_done") + ); + assert_eq!( + final_tool_call.body_json()["item"]["output"].as_str(), + Some("{\"content\":\"assistant message 2\"}") + ); + + let response_create = realtime_server.wait_for_request(0, 4).await; + assert_eq!( + response_create.body_json()["type"].as_str(), + Some("response.create") + ); + realtime_server.shutdown().await; api_server.shutdown().await; Ok(()) @@ -1201,6 +1259,23 @@ fn sse_event(event: Value) -> String { responses::sse(vec![event]) } +fn realtime_handoff_requested_event(handoff_id: &str, item_id: &str, prompt: &str) -> Value { + json!({ + "type": "response.done", + "response": { + "output": [ + { + "id": item_id, + "type": "function_call", + "name": "codex", + "call_id": handoff_id, + "arguments": json!({ "prompt": prompt }).to_string(), + } + ] + } + }) +} + fn message_input_texts(body: &Value, role: &str) -> Vec { body.get("input") .and_then(Value::as_array) @@ -1239,12 +1314,7 @@ async fn inbound_handoff_request_starts_turn() -> Result<()> { "type": "conversation.input_transcript.delta", "delta": "text from realtime" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_inbound", - "item_id": "item_inbound", - "input_transcript": "text from realtime" - }), + realtime_handoff_requested_event("handoff_inbound", "item_inbound", "text from realtime"), ]]]) .await; @@ -1333,12 +1403,7 @@ async fn inbound_handoff_request_uses_active_transcript() -> Result<()> { "type": "conversation.output_transcript.delta", "delta": "assist confirm" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_inbound_multi", - "item_id": "item_inbound_multi", - "input_transcript": "ignored" - }), + realtime_handoff_requested_event("handoff_inbound_multi", "item_inbound_multi", "ignored"), ]]]) .await; @@ -1411,12 +1476,11 @@ async fn inbound_handoff_request_clears_active_transcript_after_each_handoff() - "type": "conversation.input_transcript.delta", "delta": "first question" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_inbound_clear_1", - "item_id": "item_inbound_clear_1", - "input_transcript": "first question" - }), + realtime_handoff_requested_event( + "handoff_inbound_clear_1", + "item_inbound_clear_1", + "first question", + ), ], vec![], vec![ @@ -1424,12 +1488,11 @@ async fn inbound_handoff_request_clears_active_transcript_after_each_handoff() - "type": "conversation.input_transcript.delta", "delta": "second question" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_inbound_clear_2", - "item_id": "item_inbound_clear_2", - "input_transcript": "second question" - }), + realtime_handoff_requested_event( + "handoff_inbound_clear_2", + "item_inbound_clear_2", + "second question", + ), ], ]]) .await; @@ -1524,7 +1587,7 @@ async fn inbound_conversation_item_does_not_start_turn_and_still_forwards_audio( } }), json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 @@ -1618,24 +1681,23 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au "type": "conversation.input_transcript.delta", "delta": "delegate now" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_echo_guard", - "item_id": "item_echo_guard", - "input_transcript": "delegate now" - }), + realtime_handoff_requested_event( + "handoff_echo_guard", + "item_echo_guard", + "delegate now", + ), ], vec![ json!({ - "type": "conversation.item.added", - "item": { - "type": "message", + "type": "conversation.item.added", + "item": { + "type": "message", "role": "user", "content": [{"type": "text", "text": "assistant says hi"}] } }), json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 @@ -1683,22 +1745,30 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au let mirrored_request = realtime_server.wait_for_request(0, 1).await; let mirrored_request_body = mirrored_request.body_json(); eprintln!( - "[realtime test +{}ms] saw mirrored request type={:?} handoff_id={:?} text={:?}", + "[realtime test +{}ms] saw mirrored request type={:?} role={:?} text={:?}", start.elapsed().as_millis(), mirrored_request_body["type"].as_str(), - mirrored_request_body["handoff_id"].as_str(), - mirrored_request_body["output_text"].as_str(), + mirrored_request_body["item"]["role"].as_str(), + mirrored_request_body["item"]["content"][0]["text"].as_str(), ); assert_eq!( mirrored_request_body["type"].as_str(), - Some("conversation.handoff.append") + Some("conversation.item.create") ); assert_eq!( - mirrored_request_body["handoff_id"].as_str(), - Some("handoff_echo_guard") + mirrored_request_body["item"]["type"].as_str(), + Some("message") ); assert_eq!( - mirrored_request_body["output_text"].as_str(), + mirrored_request_body["item"]["role"].as_str(), + Some("assistant") + ); + assert_eq!( + mirrored_request_body["item"]["content"][0]["type"].as_str(), + Some("output_text") + ); + assert_eq!( + mirrored_request_body["item"]["content"][0]["text"].as_str(), Some("assistant says hi") ); @@ -1769,14 +1839,13 @@ async fn inbound_handoff_request_does_not_block_realtime_event_forwarding() -> R "type": "conversation.input_transcript.delta", "delta": "delegate now" }), + realtime_handoff_requested_event( + "handoff_non_blocking", + "item_non_blocking", + "delegate now", + ), json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_non_blocking", - "item_id": "item_non_blocking", - "input_transcript": "delegate now" - }), - json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 @@ -1900,12 +1969,7 @@ async fn inbound_handoff_request_steers_active_turn() -> Result<()> { "type": "conversation.input_transcript.delta", "delta": "steer via realtime" }), - json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_steer", - "item_id": "item_steer", - "input_transcript": "steer via realtime" - }), + realtime_handoff_requested_event("handoff_steer", "item_steer", "steer via realtime"), ], ]]) .await; @@ -2035,14 +2099,9 @@ async fn inbound_handoff_request_starts_turn_and_does_not_block_realtime_audio() "type": "conversation.input_transcript.delta", "delta": delegated_text }), + realtime_handoff_requested_event("handoff_audio", "item_audio", delegated_text), json!({ - "type": "conversation.handoff.requested", - "handoff_id": "handoff_audio", - "item_id": "item_audio", - "input_transcript": delegated_text - }), - json!({ - "type": "conversation.output_audio.delta", + "type": "response.output_audio.delta", "delta": "AQID", "sample_rate": 24000, "channels": 1 diff --git a/codex-rs/tui/src/chatwidget/realtime.rs b/codex-rs/tui/src/chatwidget/realtime.rs index 4e4f2f0e70..c38615eceb 100644 --- a/codex-rs/tui/src/chatwidget/realtime.rs +++ b/codex-rs/tui/src/chatwidget/realtime.rs @@ -306,7 +306,7 @@ impl ChatWidget { } } - #[cfg(not(target_os = "linux"))] + #[cfg(all(not(target_os = "linux"), feature = "voice-input"))] fn start_realtime_local_audio(&mut self) { if self.realtime_conversation.capture_stop_flag.is_some() { return; @@ -363,6 +363,9 @@ impl ChatWidget { #[cfg(target_os = "linux")] fn start_realtime_local_audio(&mut self) {} + #[cfg(all(not(target_os = "linux"), not(feature = "voice-input")))] + fn start_realtime_local_audio(&mut self) {} + #[cfg(all(not(target_os = "linux"), feature = "voice-input"))] pub(crate) fn restart_realtime_audio_device(&mut self, kind: RealtimeAudioDeviceKind) { if !self.realtime_conversation.is_active() { diff --git a/codex-rs/tui/src/voice.rs b/codex-rs/tui/src/voice.rs index 227d27c88f..7bd566bf60 100644 --- a/codex-rs/tui/src/voice.rs +++ b/codex-rs/tui/src/voice.rs @@ -484,7 +484,7 @@ fn convert_u16_to_i16_and_peak(input: &[u16], out: &mut Vec) -> u16 { pub(crate) struct RealtimeAudioPlayer { _stream: cpal::Stream, - queue: Arc>>, + queue: Arc>, output_sample_rate: u32, output_channels: u16, } @@ -495,8 +495,9 @@ impl RealtimeAudioPlayer { crate::audio_device::select_configured_output_device_and_config(config)?; let output_sample_rate = config.sample_rate().0; let output_channels = config.channels(); - let queue = Arc::new(Mutex::new(VecDeque::new())); - let stream = build_output_stream(&device, &config, Arc::clone(&queue))?; + let prebuffer_samples = output_prebuffer_samples(output_sample_rate, output_channels); + let queue = Arc::new(Mutex::new(OutputAudioQueue::default())); + let stream = build_output_stream(&device, &config, Arc::clone(&queue), prebuffer_samples)?; stream .play() .map_err(|e| format!("failed to start output stream: {e}"))?; @@ -537,13 +538,14 @@ impl RealtimeAudioPlayer { .lock() .map_err(|_| "failed to lock output audio queue".to_string())?; // TODO(aibrahim): Cap or trim this queue if we observe producer bursts outrunning playback. - guard.extend(converted); + guard.samples.extend(converted); Ok(()) } pub(crate) fn clear(&self) { if let Ok(mut guard) = self.queue.lock() { - guard.clear(); + guard.samples.clear(); + guard.primed = false; } } } @@ -551,14 +553,15 @@ impl RealtimeAudioPlayer { fn build_output_stream( device: &cpal::Device, config: &cpal::SupportedStreamConfig, - queue: Arc>>, + queue: Arc>, + prebuffer_samples: usize, ) -> Result { let config_any: cpal::StreamConfig = config.clone().into(); match config.sample_format() { cpal::SampleFormat::F32 => device .build_output_stream( &config_any, - move |output: &mut [f32], _| fill_output_f32(output, &queue), + move |output: &mut [f32], _| fill_output_f32(output, &queue, prebuffer_samples), move |err| error!("audio output error: {err}"), None, ) @@ -566,7 +569,7 @@ fn build_output_stream( cpal::SampleFormat::I16 => device .build_output_stream( &config_any, - move |output: &mut [i16], _| fill_output_i16(output, &queue), + move |output: &mut [i16], _| fill_output_i16(output, &queue, prebuffer_samples), move |err| error!("audio output error: {err}"), None, ) @@ -574,7 +577,7 @@ fn build_output_stream( cpal::SampleFormat::U16 => device .build_output_stream( &config_any, - move |output: &mut [u16], _| fill_output_u16(output, &queue), + move |output: &mut [u16], _| fill_output_u16(output, &queue, prebuffer_samples), move |err| error!("audio output error: {err}"), None, ) @@ -583,20 +586,63 @@ fn build_output_stream( } } -fn fill_output_i16(output: &mut [i16], queue: &Arc>>) { +#[derive(Default)] +struct OutputAudioQueue { + samples: VecDeque, + primed: bool, +} + +fn output_prebuffer_samples(sample_rate: u32, channels: u16) -> usize { + let samples_per_second = (sample_rate as usize).saturating_mul(channels as usize); + ((samples_per_second as u64) * 120 / 1_000) as usize +} + +fn should_output_silence(queue: &mut OutputAudioQueue, min_buffer_samples: usize) -> bool { + if !queue.primed { + if queue.samples.len() < min_buffer_samples { + return true; + } + queue.primed = true; + } + + if queue.samples.is_empty() { + queue.primed = false; + return true; + } + + false +} + +fn fill_output_i16( + output: &mut [i16], + queue: &Arc>, + prebuffer_samples: usize, +) { if let Ok(mut guard) = queue.lock() { + if should_output_silence(&mut guard, prebuffer_samples) { + output.fill(0); + return; + } for sample in output { - *sample = guard.pop_front().unwrap_or(0); + *sample = guard.samples.pop_front().unwrap_or(0); } return; } output.fill(0); } -fn fill_output_f32(output: &mut [f32], queue: &Arc>>) { +fn fill_output_f32( + output: &mut [f32], + queue: &Arc>, + prebuffer_samples: usize, +) { if let Ok(mut guard) = queue.lock() { + if should_output_silence(&mut guard, prebuffer_samples) { + output.fill(0.0); + return; + } for sample in output { - let v = guard.pop_front().unwrap_or(0); + let v = guard.samples.pop_front().unwrap_or(0); *sample = (v as f32) / (i16::MAX as f32); } return; @@ -604,10 +650,18 @@ fn fill_output_f32(output: &mut [f32], queue: &Arc>>) { output.fill(0.0); } -fn fill_output_u16(output: &mut [u16], queue: &Arc>>) { +fn fill_output_u16( + output: &mut [u16], + queue: &Arc>, + prebuffer_samples: usize, +) { if let Ok(mut guard) = queue.lock() { + if should_output_silence(&mut guard, prebuffer_samples) { + output.fill(32768); + return; + } for sample in output { - let v = guard.pop_front().unwrap_or(0); + let v = guard.samples.pop_front().unwrap_or(0); *sample = (v as i32 + 32768).clamp(0, u16::MAX as i32) as u16; } return;