mirror of
https://github.com/openai/codex.git
synced 2026-05-21 19:45:26 +00:00
Port realtime websocket handoff flow
This commit is contained in:
@@ -10,6 +10,12 @@ use crate::endpoint::realtime_websocket::protocol::SessionAudio;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionTool;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolParameters;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolProperties;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolProperty;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionTurnDetection;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
@@ -19,6 +25,7 @@ use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@@ -228,6 +235,20 @@ impl RealtimeWebsocketConnection {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_function_call_output(
|
||||
&self,
|
||||
call_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
self.writer
|
||||
.send_function_call_output(call_id, output_text)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_response_create(&self) -> Result<(), ApiError> {
|
||||
self.writer.send_response_create().await
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), ApiError> {
|
||||
self.writer.close().await
|
||||
}
|
||||
@@ -272,11 +293,10 @@ impl RealtimeWebsocketWriter {
|
||||
|
||||
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem {
|
||||
kind: "message".to_string(),
|
||||
item: ConversationItem::Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "text".to_string(),
|
||||
kind: "input_text".to_string(),
|
||||
text,
|
||||
}],
|
||||
},
|
||||
@@ -286,32 +306,85 @@ impl RealtimeWebsocketWriter {
|
||||
|
||||
pub async fn send_conversation_handoff_append(
|
||||
&self,
|
||||
handoff_id: String,
|
||||
_handoff_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "output_text".to_string(),
|
||||
text: output_text,
|
||||
}],
|
||||
},
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_function_call_output(
|
||||
&self,
|
||||
call_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
let output = json!({
|
||||
"content": output_text,
|
||||
})
|
||||
.to_string();
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::FunctionCallOutput { call_id, output },
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_response_create(&self) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ResponseCreate)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::SessionUpdate {
|
||||
session: SessionUpdateSession {
|
||||
kind: "quicksilver".to_string(),
|
||||
kind: "realtime".to_string(),
|
||||
instructions,
|
||||
output_modalities: vec!["audio".to_string()],
|
||||
audio: SessionAudio {
|
||||
input: SessionAudioInput {
|
||||
format: SessionAudioFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
},
|
||||
turn_detection: SessionTurnDetection {
|
||||
kind: "server_vad".to_string(),
|
||||
interrupt_response: true,
|
||||
create_response: true,
|
||||
},
|
||||
},
|
||||
output: SessionAudioOutput {
|
||||
voice: "fathom".to_string(),
|
||||
format: SessionAudioOutputFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
},
|
||||
voice: "marin".to_string(),
|
||||
},
|
||||
},
|
||||
tools: vec![SessionTool {
|
||||
kind: "function".to_string(),
|
||||
name: "codex".to_string(),
|
||||
description:
|
||||
"Delegate a request to Codex and return the final result to the user."
|
||||
.to_string(),
|
||||
parameters: SessionToolParameters {
|
||||
kind: "object".to_string(),
|
||||
properties: SessionToolProperties {
|
||||
prompt: SessionToolProperty {
|
||||
kind: "string".to_string(),
|
||||
description: "The user request to delegate to Codex.".to_string(),
|
||||
},
|
||||
},
|
||||
required: vec!["prompt".to_string()],
|
||||
},
|
||||
}],
|
||||
tool_choice: "auto".to_string(),
|
||||
},
|
||||
})
|
||||
.await
|
||||
@@ -558,15 +631,16 @@ fn websocket_url_from_api_url(
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let has_additional_query_params = query_params
|
||||
.is_some_and(|params| params.keys().any(|key| key != "model" || model.is_none()));
|
||||
if model.is_some() || has_additional_query_params {
|
||||
let mut query = url.query_pairs_mut();
|
||||
query.append_pair("intent", "quicksilver");
|
||||
if let Some(model) = model {
|
||||
query.append_pair("model", model);
|
||||
}
|
||||
if let Some(query_params) = query_params {
|
||||
for (key, value) in query_params {
|
||||
if key == "intent" || (key == "model" && model.is_some()) {
|
||||
if key == "model" && model.is_some() {
|
||||
continue;
|
||||
}
|
||||
query.append_pair(key, value);
|
||||
@@ -639,7 +713,7 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_audio_delta_event() {
|
||||
let payload = json!({
|
||||
"type": "conversation.output_audio.delta",
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AAA=",
|
||||
"sample_rate": 48000,
|
||||
"channels": 1,
|
||||
@@ -657,6 +731,24 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_audio_delta_event_defaults_audio_shape() {
|
||||
let payload = json!({
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AAA="
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data: "AAA=".to_string(),
|
||||
sample_rate: 24_000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: None,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_conversation_item_added_event() {
|
||||
let payload = json!({
|
||||
@@ -690,10 +782,18 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_handoff_requested_event() {
|
||||
let payload = json!({
|
||||
"type": "conversation.handoff.requested",
|
||||
"handoff_id": "handoff_123",
|
||||
"item_id": "item_123",
|
||||
"input_transcript": "delegate this"
|
||||
"type": "response.done",
|
||||
"response": {
|
||||
"output": [
|
||||
{
|
||||
"id": "item_123",
|
||||
"type": "function_call",
|
||||
"name": "codex",
|
||||
"call_id": "handoff_123",
|
||||
"arguments": "{\"prompt\":\"delegate this\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
@@ -708,6 +808,24 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_event_as_conversation_item_added() {
|
||||
let payload = json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": "hello",
|
||||
"response_id": "resp_1"
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
Some(RealtimeEvent::ConversationItemAdded(json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": "hello",
|
||||
"response_id": "resp_1"
|
||||
})))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_input_transcript_delta_event() {
|
||||
let payload = json!({
|
||||
@@ -781,10 +899,7 @@ mod tests {
|
||||
fn websocket_url_from_http_base_defaults_to_ws_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
|
||||
);
|
||||
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/v1/realtime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -794,7 +909,7 @@ mod tests {
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
|
||||
"wss://example.com/v1/realtime?model=realtime-test-model"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -804,7 +919,7 @@ mod tests {
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
"wss://api.openai.com/v1/realtime?model=snapshot"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -815,7 +930,7 @@ mod tests {
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
|
||||
"wss://example.com/openai/v1/realtime?model=snapshot"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -823,16 +938,13 @@ mod tests {
|
||||
fn websocket_url_preserves_existing_realtime_path_and_extra_query_params() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/v1/realtime?foo=bar",
|
||||
Some(&HashMap::from([
|
||||
("trace".to_string(), "1".to_string()),
|
||||
("intent".to_string(), "ignored".to_string()),
|
||||
])),
|
||||
Some(&HashMap::from([("trace".to_string(), "1".to_string())])),
|
||||
Some("snapshot"),
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?foo=bar&intent=quicksilver&model=snapshot&trace=1"
|
||||
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -856,12 +968,16 @@ mod tests {
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("quicksilver".to_string())
|
||||
Value::String("realtime".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
Value::String("backend prompt".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["output_modalities"][0],
|
||||
Value::String("audio".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["format"]["type"],
|
||||
Value::String("audio/pcm".to_string())
|
||||
@@ -870,9 +986,45 @@ mod tests {
|
||||
first_json["session"]["audio"]["input"]["format"]["rate"],
|
||||
Value::from(24_000)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["type"],
|
||||
Value::String("server_vad".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["create_response"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["format"]["type"],
|
||||
Value::String("audio/pcm".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["format"]["rate"],
|
||||
Value::from(24_000)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["voice"],
|
||||
Value::String("fathom".to_string())
|
||||
Value::String("marin".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tool_choice"],
|
||||
Value::String("auto".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["type"],
|
||||
Value::String("function".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["parameters"]["required"][0],
|
||||
Value::String("prompt".to_string())
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
@@ -915,13 +1067,43 @@ mod tests {
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let fourth_json: Value = serde_json::from_str(&fourth).expect("json");
|
||||
assert_eq!(fourth_json["type"], "conversation.handoff.append");
|
||||
assert_eq!(fourth_json["handoff_id"], "handoff_1");
|
||||
assert_eq!(fourth_json["output_text"], "hello from codex");
|
||||
assert_eq!(fourth_json["type"], "conversation.item.create");
|
||||
assert_eq!(fourth_json["item"]["type"], "message");
|
||||
assert_eq!(fourth_json["item"]["role"], "assistant");
|
||||
assert_eq!(
|
||||
fourth_json["item"]["content"][0]["type"],
|
||||
Value::String("output_text".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
fourth_json["item"]["content"][0]["text"],
|
||||
Value::String("hello from codex".to_string())
|
||||
);
|
||||
|
||||
let fifth = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("fifth msg")
|
||||
.expect("fifth msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let fifth_json: Value = serde_json::from_str(&fifth).expect("json");
|
||||
assert_eq!(fifth_json["type"], "conversation.item.create");
|
||||
assert_eq!(fifth_json["item"]["type"], "function_call_output");
|
||||
assert_eq!(fifth_json["item"]["call_id"], "handoff_1");
|
||||
|
||||
let sixth = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("sixth msg")
|
||||
.expect("sixth msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let sixth_json: Value = serde_json::from_str(&sixth).expect("json");
|
||||
assert_eq!(sixth_json["type"], "response.create");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "conversation.output_audio.delta",
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AQID",
|
||||
"sample_rate": 48000,
|
||||
"channels": 1
|
||||
@@ -967,10 +1149,18 @@ mod tests {
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "conversation.handoff.requested",
|
||||
"handoff_id": "handoff_1",
|
||||
"item_id": "item_2",
|
||||
"input_transcript": "delegate now"
|
||||
"type": "response.done",
|
||||
"response": {
|
||||
"output": [
|
||||
{
|
||||
"id": "item_2",
|
||||
"type": "function_call",
|
||||
"name": "codex",
|
||||
"call_id": "handoff_1",
|
||||
"arguments": "{\"prompt\":\"delegate now\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
@@ -1040,6 +1230,14 @@ mod tests {
|
||||
)
|
||||
.await
|
||||
.expect("send handoff");
|
||||
connection
|
||||
.send_function_call_output("handoff_1".to_string(), "final from codex".to_string())
|
||||
.await
|
||||
.expect("send function output");
|
||||
connection
|
||||
.send_response_create()
|
||||
.await
|
||||
.expect("send response.create");
|
||||
|
||||
let audio_event = connection
|
||||
.next_event()
|
||||
|
||||
@@ -3,8 +3,10 @@ pub use codex_protocol::protocol::RealtimeEvent;
|
||||
pub use codex_protocol::protocol::RealtimeHandoffRequested;
|
||||
pub use codex_protocol::protocol::RealtimeTranscriptDelta;
|
||||
pub use codex_protocol::protocol::RealtimeTranscriptEntry;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::string::ToString;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -19,11 +21,8 @@ pub struct RealtimeSessionConfig {
|
||||
pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "input_audio_buffer.append")]
|
||||
InputAudioBufferAppend { audio: String },
|
||||
#[serde(rename = "conversation.handoff.append")]
|
||||
ConversationHandoffAppend {
|
||||
handoff_id: String,
|
||||
output_text: String,
|
||||
},
|
||||
#[serde(rename = "response.create")]
|
||||
ResponseCreate,
|
||||
#[serde(rename = "session.update")]
|
||||
SessionUpdate { session: SessionUpdateSession },
|
||||
#[serde(rename = "conversation.item.create")]
|
||||
@@ -35,7 +34,10 @@ pub(super) struct SessionUpdateSession {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) instructions: String,
|
||||
pub(super) output_modalities: Vec<String>,
|
||||
pub(super) audio: SessionAudio,
|
||||
pub(super) tools: Vec<SessionTool>,
|
||||
pub(super) tool_choice: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -47,6 +49,7 @@ pub(super) struct SessionAudio {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioInput {
|
||||
pub(super) format: SessionAudioFormat,
|
||||
pub(super) turn_detection: SessionTurnDetection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -56,17 +59,66 @@ pub(super) struct SessionAudioFormat {
|
||||
pub(super) rate: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionTurnDetection {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) interrupt_response: bool,
|
||||
pub(super) create_response: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioOutput {
|
||||
pub(super) format: SessionAudioOutputFormat,
|
||||
pub(super) voice: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItem {
|
||||
pub(super) struct SessionAudioOutputFormat {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) role: String,
|
||||
pub(super) content: Vec<ConversationItemContent>,
|
||||
pub(super) rate: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) name: String,
|
||||
pub(super) description: String,
|
||||
pub(super) parameters: SessionToolParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) properties: SessionToolProperties,
|
||||
pub(super) required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionToolProperties {
|
||||
pub(super) prompt: SessionToolProperty,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionToolProperty {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub(super) enum ConversationItem {
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
role: String,
|
||||
content: Vec<ConversationItemContent>,
|
||||
},
|
||||
#[serde(rename = "function_call_output")]
|
||||
FunctionCallOutput { call_id: String, output: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -93,7 +145,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
}
|
||||
};
|
||||
match message_type {
|
||||
"session.updated" => {
|
||||
"session.created" | "session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
@@ -111,7 +163,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
"conversation.output_audio.delta" => {
|
||||
"conversation.output_audio.delta" | "response.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
@@ -120,12 +172,14 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())?;
|
||||
.and_then(|v| u32::try_from(v).ok())
|
||||
.unwrap_or(24_000);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u16::try_from(v).ok())?;
|
||||
.and_then(|v| u16::try_from(v).ok())
|
||||
.unwrap_or(1);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
@@ -177,6 +231,12 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
active_transcript: Vec::new(),
|
||||
}))
|
||||
}
|
||||
"response.done" => {
|
||||
if let Some(handoff) = parse_handoff_requested(&parsed) {
|
||||
return Some(RealtimeEvent::HandoffRequested(handoff));
|
||||
}
|
||||
Some(RealtimeEvent::ConversationItemAdded(parsed))
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
@@ -189,11 +249,86 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
_ => {
|
||||
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
_ => Some(RealtimeEvent::ConversationItemAdded(parsed)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_handoff_requested(parsed: &Value) -> Option<RealtimeHandoffRequested> {
|
||||
let outputs = parsed
|
||||
.get("response")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|response| response.get("output"))
|
||||
.and_then(Value::as_array)?;
|
||||
let function_call = outputs.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call")
|
||||
&& item.get("name").and_then(Value::as_str) == Some("codex")
|
||||
})?;
|
||||
let handoff_id = function_call
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = function_call
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| handoff_id.clone());
|
||||
let arguments = function_call
|
||||
.get("arguments")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
Some(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript: parse_handoff_arguments(arguments),
|
||||
active_transcript: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_handoff_arguments(arguments: &str) -> String {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct HandoffArguments {
|
||||
#[serde(default)]
|
||||
prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
#[serde(default)]
|
||||
input: Option<String>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
#[serde(default)]
|
||||
input_transcript: Option<String>,
|
||||
#[serde(default)]
|
||||
messages: Vec<RealtimeTranscriptEntry>,
|
||||
}
|
||||
|
||||
let Some(parsed) = serde_json::from_str::<HandoffArguments>(arguments).ok() else {
|
||||
return arguments.to_string();
|
||||
};
|
||||
|
||||
for value in [
|
||||
parsed.prompt,
|
||||
parsed.text,
|
||||
parsed.input,
|
||||
parsed.message,
|
||||
parsed.input_transcript,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
if !value.is_empty() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = parsed
|
||||
.messages
|
||||
.into_iter()
|
||||
.find(|message| message.role == "user" && !message.text.is_empty())
|
||||
{
|
||||
return message.text;
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
assert_eq!(first_json["type"], "session.update");
|
||||
assert_eq!(
|
||||
first_json["session"]["type"],
|
||||
Value::String("quicksilver".to_string())
|
||||
Value::String("realtime".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["instructions"],
|
||||
@@ -95,6 +95,42 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
first_json["session"]["audio"]["input"]["format"]["rate"],
|
||||
Value::from(24_000)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["type"],
|
||||
Value::String("server_vad".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["input"]["turn_detection"]["create_response"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["format"]["type"],
|
||||
Value::String("audio/pcm".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["format"]["rate"],
|
||||
Value::from(24_000)
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["audio"]["output"]["voice"],
|
||||
Value::String("marin".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tool_choice"],
|
||||
Value::String("auto".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["type"],
|
||||
Value::String("function".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
first_json["session"]["tools"][0]["name"],
|
||||
Value::String("codex".to_string())
|
||||
);
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
@@ -119,7 +155,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "conversation.output_audio.delta",
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AQID",
|
||||
"sample_rate": 48000,
|
||||
"channels": 1
|
||||
@@ -311,7 +347,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
async fn realtime_ws_e2e_forwards_unknown_text_events() {
|
||||
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
@@ -361,13 +397,26 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let event = connection
|
||||
let first_event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
event,
|
||||
first_event,
|
||||
RealtimeEvent::ConversationItemAdded(json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp_unknown"}
|
||||
}))
|
||||
);
|
||||
|
||||
let second_event = connection
|
||||
.next_event()
|
||||
.await
|
||||
.expect("next event")
|
||||
.expect("event");
|
||||
assert_eq!(
|
||||
second_event,
|
||||
RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_after_unknown".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
|
||||
Reference in New Issue
Block a user