mirror of
https://github.com/openai/codex.git
synced 2026-05-02 18:37:01 +00:00
refactor: split realtime_websocket endpoint into modules
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeConnectionState;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionCreateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -17,7 +26,6 @@ use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tracing::debug;
|
||||
use tracing::info;
|
||||
use tracing::trace;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
@@ -25,7 +33,6 @@ use url::Url;
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
@@ -40,7 +47,9 @@ enum WsCommand {
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
|
||||
fn new(
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
) -> (Self, mpsc::UnboundedReceiver<Result<Message, WsError>>) {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
|
||||
@@ -102,11 +111,13 @@ impl WsStream {
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
tx_command,
|
||||
(
|
||||
Self {
|
||||
tx_command,
|
||||
pump_task,
|
||||
},
|
||||
rx_message,
|
||||
pump_task,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
async fn request(
|
||||
@@ -129,10 +140,6 @@ impl WsStream {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn next(&mut self) -> Option<Result<Message, WsError>> {
|
||||
self.rx_message.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WsStream {
|
||||
@@ -141,152 +148,88 @@ impl Drop for WsStream {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub api_url: String,
|
||||
pub prompt: String,
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RealtimeAudioFrame {
|
||||
pub data: String,
|
||||
pub sample_rate: u32,
|
||||
pub num_channels: u16,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub samples_per_channel: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeConnectionState {
|
||||
Connected,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RealtimeEvent {
|
||||
State(RealtimeConnectionState),
|
||||
SessionCreated { session_id: String },
|
||||
SessionUpdated { backend_prompt: Option<String> },
|
||||
AudioOut(RealtimeAudioFrame),
|
||||
ConversationItemAdded(Value),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "response.input_audio.delta")]
|
||||
InputAudioDelta {
|
||||
delta: String,
|
||||
sample_rate: u32,
|
||||
num_channels: u16,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
samples_per_channel: Option<u32>,
|
||||
},
|
||||
#[serde(rename = "session.create")]
|
||||
SessionCreate {
|
||||
session: SessionCreateSession,
|
||||
},
|
||||
#[serde(rename = "session.update")]
|
||||
SessionUpdate {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session: Option<SessionUpdateSession>,
|
||||
},
|
||||
#[serde(rename = "conversation.item.create")]
|
||||
ConversationItemCreate { item: ConversationItem },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct SessionUpdateSession {
|
||||
backend_prompt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
conversation_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct SessionCreateSession {
|
||||
backend_prompt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
conversation_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct ConversationItem {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
role: String,
|
||||
content: Vec<ConversationItemContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct ConversationItemContent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum RealtimeInboundMessage {
|
||||
#[serde(rename = "session.created")]
|
||||
SessionCreated {
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
#[serde(default)]
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "session.updated")]
|
||||
SessionUpdated {
|
||||
#[serde(default)]
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "response.output_audio.delta")]
|
||||
OutputAudioDelta {
|
||||
#[serde(default)]
|
||||
delta: Option<String>,
|
||||
#[serde(default)]
|
||||
data: Option<String>,
|
||||
#[serde(default)]
|
||||
sample_rate: Option<u32>,
|
||||
#[serde(default)]
|
||||
num_channels: Option<u16>,
|
||||
#[serde(default)]
|
||||
samples_per_channel: Option<u32>,
|
||||
},
|
||||
#[serde(rename = "conversation.item.added")]
|
||||
ConversationItemAdded {
|
||||
#[serde(default)]
|
||||
item: Option<Value>,
|
||||
},
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
#[serde(default)]
|
||||
error: Option<Value>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RealtimeInboundSession {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
backend_prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub struct RealtimeWebsocketConnection {
|
||||
stream: Arc<Mutex<Option<WsStream>>>,
|
||||
writer: RealtimeWebsocketWriter,
|
||||
events: RealtimeWebsocketEvents,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketEvents {
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl RealtimeWebsocketConnection {
|
||||
fn new(stream: WsStream) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
}
|
||||
pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> {
|
||||
self.writer.send_audio_frame(frame).await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
|
||||
self.writer.send_conversation_item_create(text).await
|
||||
}
|
||||
|
||||
pub async fn send_session_update(
|
||||
&self,
|
||||
backend_prompt: String,
|
||||
conversation_id: Option<String>,
|
||||
) -> Result<(), ApiError> {
|
||||
self.writer
|
||||
.send_session_update(backend_prompt, conversation_id)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_session_create(
|
||||
&self,
|
||||
backend_prompt: String,
|
||||
conversation_id: Option<String>,
|
||||
) -> Result<(), ApiError> {
|
||||
self.writer
|
||||
.send_session_create(backend_prompt, conversation_id)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), ApiError> {
|
||||
self.writer.close().await
|
||||
}
|
||||
|
||||
pub async fn next_event(&self) -> Result<Option<RealtimeEvent>, ApiError> {
|
||||
self.events.next_event().await
|
||||
}
|
||||
|
||||
pub fn writer(&self) -> RealtimeWebsocketWriter {
|
||||
self.writer.clone()
|
||||
}
|
||||
|
||||
pub fn events(&self) -> RealtimeWebsocketEvents {
|
||||
self.events.clone()
|
||||
}
|
||||
|
||||
fn new(
|
||||
stream: WsStream,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
) -> Self {
|
||||
let stream = Arc::new(stream);
|
||||
let is_closed = Arc::new(AtomicBool::new(false));
|
||||
Self {
|
||||
writer: RealtimeWebsocketWriter {
|
||||
stream: Arc::clone(&stream),
|
||||
is_closed: Arc::clone(&is_closed),
|
||||
},
|
||||
events: RealtimeWebsocketEvents {
|
||||
rx_message: Arc::new(Mutex::new(rx_message)),
|
||||
is_closed,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RealtimeWebsocketWriter {
|
||||
pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::InputAudioDelta {
|
||||
delta: frame.data,
|
||||
@@ -340,36 +283,56 @@ impl RealtimeWebsocketConnection {
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), ApiError> {
|
||||
let mut guard = self.stream.lock().await;
|
||||
let Some(ws) = guard.as_ref() else {
|
||||
if self.is_closed.swap(true, Ordering::SeqCst) {
|
||||
return Ok(());
|
||||
};
|
||||
if let Err(err) = ws.close().await
|
||||
}
|
||||
if let Err(err) = self.stream.close().await
|
||||
&& !matches!(err, WsError::ConnectionClosed | WsError::AlreadyClosed)
|
||||
{
|
||||
return Err(ApiError::Stream(format!("failed to close websocket: {err}")));
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to close websocket: {err}"
|
||||
)));
|
||||
}
|
||||
*guard = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_json(&self, message: RealtimeOutboundMessage) -> Result<(), ApiError> {
|
||||
let payload = serde_json::to_string(&message)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to encode realtime request: {err}")))?;
|
||||
trace!("realtime websocket request: {payload}");
|
||||
|
||||
if self.is_closed.load(Ordering::SeqCst) {
|
||||
return Err(ApiError::Stream(
|
||||
"realtime websocket connection is closed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
self.stream
|
||||
.send(Message::Text(payload.into()))
|
||||
.await
|
||||
.map_err(|err| ApiError::Stream(format!("failed to send realtime request: {err}")))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl RealtimeWebsocketEvents {
|
||||
pub async fn next_event(&self) -> Result<Option<RealtimeEvent>, ApiError> {
|
||||
let mut guard = self.stream.lock().await;
|
||||
let Some(ws) = guard.as_mut() else {
|
||||
if self.is_closed.load(Ordering::SeqCst) {
|
||||
return Ok(Some(RealtimeEvent::State(
|
||||
RealtimeConnectionState::Disconnected,
|
||||
)));
|
||||
};
|
||||
}
|
||||
|
||||
let msg = match ws.next().await {
|
||||
let msg = match self.rx_message.lock().await.recv().await {
|
||||
Some(Ok(msg)) => msg,
|
||||
Some(Err(err)) => {
|
||||
self.is_closed.store(true, Ordering::SeqCst);
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to read websocket message: {err}"
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
*guard = None;
|
||||
self.is_closed.store(true, Ordering::SeqCst);
|
||||
return Ok(Some(RealtimeEvent::State(
|
||||
RealtimeConnectionState::Disconnected,
|
||||
)));
|
||||
@@ -379,7 +342,7 @@ impl RealtimeWebsocketConnection {
|
||||
match msg {
|
||||
Message::Text(text) => Ok(parse_realtime_event(&text)),
|
||||
Message::Close(_) => {
|
||||
*guard = None;
|
||||
self.is_closed.store(true, Ordering::SeqCst);
|
||||
Ok(Some(RealtimeEvent::State(
|
||||
RealtimeConnectionState::Disconnected,
|
||||
)))
|
||||
@@ -390,24 +353,6 @@ impl RealtimeWebsocketConnection {
|
||||
Message::Frame(_) | Message::Ping(_) | Message::Pong(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_json(&self, message: RealtimeOutboundMessage) -> Result<(), ApiError> {
|
||||
let payload = serde_json::to_string(&message)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to encode realtime request: {err}")))?;
|
||||
trace!("realtime websocket request: {payload}");
|
||||
|
||||
let guard = self.stream.lock().await;
|
||||
let Some(ws) = guard.as_ref() else {
|
||||
return Err(ApiError::Stream(
|
||||
"realtime websocket connection is closed".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
ws.send(Message::Text(payload.into()))
|
||||
.await
|
||||
.map_err(|err| ApiError::Stream(format!("failed to send realtime request: {err}")))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RealtimeWebsocketClient {
|
||||
@@ -432,20 +377,19 @@ impl RealtimeWebsocketClient {
|
||||
.as_str()
|
||||
.into_client_request()
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
|
||||
let headers =
|
||||
merge_request_headers(&self.provider.headers, extra_headers, default_headers);
|
||||
let headers = merge_request_headers(&self.provider.headers, extra_headers, default_headers);
|
||||
request.headers_mut().extend(headers);
|
||||
|
||||
info!("connecting realtime websocket: {ws_url}");
|
||||
let (stream, _) = tokio_tungstenite::connect_async_with_config(
|
||||
request,
|
||||
Some(websocket_config()),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ApiError::Stream(format!("failed to connect realtime websocket: {err}")))?;
|
||||
let (stream, _) =
|
||||
tokio_tungstenite::connect_async_with_config(request, Some(websocket_config()), false)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::Stream(format!("failed to connect realtime websocket: {err}"))
|
||||
})?;
|
||||
|
||||
let connection = RealtimeWebsocketConnection::new(WsStream::new(stream));
|
||||
let (stream, rx_message) = WsStream::new(stream);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
|
||||
connection
|
||||
.send_session_create(config.prompt, config.session_id)
|
||||
.await?;
|
||||
@@ -453,53 +397,6 @@ impl RealtimeWebsocketClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: RealtimeInboundMessage = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match parsed {
|
||||
RealtimeInboundMessage::SessionCreated {
|
||||
session_id,
|
||||
session,
|
||||
} => {
|
||||
let session_id = session.and_then(|s| s.id).or(session_id);
|
||||
session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id })
|
||||
}
|
||||
RealtimeInboundMessage::SessionUpdated { session } => Some(RealtimeEvent::SessionUpdated {
|
||||
backend_prompt: session.and_then(|s| s.backend_prompt),
|
||||
}),
|
||||
RealtimeInboundMessage::OutputAudioDelta {
|
||||
delta,
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel,
|
||||
} => {
|
||||
let data = delta.or(data).unwrap_or_default();
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate: sample_rate.unwrap_or(0),
|
||||
num_channels: num_channels.unwrap_or(1),
|
||||
samples_per_channel,
|
||||
}))
|
||||
}
|
||||
RealtimeInboundMessage::ConversationItemAdded { item } => {
|
||||
item.map(RealtimeEvent::ConversationItemAdded)
|
||||
}
|
||||
RealtimeInboundMessage::Error { error, message } => {
|
||||
let message = message
|
||||
.or_else(|| error.map(|e| e.to_string()))
|
||||
.unwrap_or_else(|| "unknown realtime error".to_string());
|
||||
Some(RealtimeEvent::Error(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_request_headers(
|
||||
provider_headers: &HeaderMap,
|
||||
extra_headers: HeaderMap,
|
||||
@@ -549,6 +446,7 @@ mod tests {
|
||||
use super::*;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
@@ -822,4 +720,104 @@ mod tests {
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_does_not_block_while_next_event_waits_for_inbound_data() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("first msg")
|
||||
.expect("first msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let first_json: Value = serde_json::from_str(&first).expect("json");
|
||||
assert_eq!(first_json["type"], "session.create");
|
||||
|
||||
let second = ws
|
||||
.next()
|
||||
.await
|
||||
.expect("second msg")
|
||||
.expect("second msg ok")
|
||||
.into_text()
|
||||
.expect("text");
|
||||
let second_json: Value = serde_json::from_str(&second).expect("json");
|
||||
assert_eq!(second_json["type"], "response.input_audio.delta");
|
||||
|
||||
ws.send(Message::Text(
|
||||
json!({
|
||||
"type": "session.created",
|
||||
"session": {"id": "sess_after_send"}
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await
|
||||
.expect("send session.created");
|
||||
});
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: crate::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
let (send_result, next_result) = tokio::join!(
|
||||
async {
|
||||
tokio::time::timeout(
|
||||
Duration::from_millis(200),
|
||||
connection.send_audio_frame(RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 48000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(960),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
},
|
||||
connection.next_event()
|
||||
);
|
||||
|
||||
send_result
|
||||
.expect("send should not block on next_event")
|
||||
.expect("send audio");
|
||||
let next_event = next_result.expect("next event").expect("event");
|
||||
assert_eq!(
|
||||
next_event,
|
||||
RealtimeEvent::SessionCreated {
|
||||
session_id: "sess_after_send".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
connection.close().await.expect("close");
|
||||
server.await.expect("server task");
|
||||
}
|
||||
}
|
||||
11
codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs
Normal file
11
codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
pub mod methods;
|
||||
pub mod protocol;
|
||||
|
||||
pub use methods::RealtimeWebsocketClient;
|
||||
pub use methods::RealtimeWebsocketConnection;
|
||||
pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use protocol::RealtimeAudioFrame;
|
||||
pub use protocol::RealtimeConnectionState;
|
||||
pub use protocol::RealtimeEvent;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
169
codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs
Normal file
169
codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub api_url: String,
|
||||
pub prompt: String,
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RealtimeAudioFrame {
|
||||
pub data: String,
|
||||
pub sample_rate: u32,
|
||||
pub num_channels: u16,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub samples_per_channel: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeConnectionState {
|
||||
Connected,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RealtimeEvent {
|
||||
State(RealtimeConnectionState),
|
||||
SessionCreated { session_id: String },
|
||||
SessionUpdated { backend_prompt: Option<String> },
|
||||
AudioOut(RealtimeAudioFrame),
|
||||
ConversationItemAdded(Value),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "response.input_audio.delta")]
|
||||
InputAudioDelta {
|
||||
delta: String,
|
||||
sample_rate: u32,
|
||||
num_channels: u16,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
samples_per_channel: Option<u32>,
|
||||
},
|
||||
#[serde(rename = "session.create")]
|
||||
SessionCreate { session: SessionCreateSession },
|
||||
#[serde(rename = "session.update")]
|
||||
SessionUpdate {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session: Option<SessionUpdateSession>,
|
||||
},
|
||||
#[serde(rename = "conversation.item.create")]
|
||||
ConversationItemCreate { item: ConversationItem },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSession {
|
||||
pub(super) backend_prompt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) conversation_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionCreateSession {
|
||||
pub(super) backend_prompt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) conversation_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItem {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) role: String,
|
||||
pub(super) content: Vec<ConversationItemContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct ConversationItemContent {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum RealtimeInboundMessage {
|
||||
#[serde(rename = "session.created")]
|
||||
SessionCreated {
|
||||
session_id: Option<String>,
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "session.updated")]
|
||||
SessionUpdated {
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "response.output_audio.delta")]
|
||||
OutputAudioDelta {
|
||||
delta: Option<String>,
|
||||
data: Option<String>,
|
||||
sample_rate: Option<u32>,
|
||||
num_channels: Option<u16>,
|
||||
samples_per_channel: Option<u32>,
|
||||
},
|
||||
#[serde(rename = "conversation.item.added")]
|
||||
ConversationItemAdded { item: Option<Value> },
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
error: Option<Value>,
|
||||
message: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RealtimeInboundSession {
|
||||
id: Option<String>,
|
||||
backend_prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: RealtimeInboundMessage = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match parsed {
|
||||
RealtimeInboundMessage::SessionCreated {
|
||||
session_id,
|
||||
session,
|
||||
} => {
|
||||
let session_id = session.and_then(|s| s.id).or(session_id);
|
||||
session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id })
|
||||
}
|
||||
RealtimeInboundMessage::SessionUpdated { session } => Some(RealtimeEvent::SessionUpdated {
|
||||
backend_prompt: session.and_then(|s| s.backend_prompt),
|
||||
}),
|
||||
RealtimeInboundMessage::OutputAudioDelta {
|
||||
delta,
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel,
|
||||
} => {
|
||||
let data = delta.or(data)?;
|
||||
let sample_rate = sample_rate?;
|
||||
let num_channels = num_channels?;
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel,
|
||||
}))
|
||||
}
|
||||
RealtimeInboundMessage::ConversationItemAdded { item } => {
|
||||
item.map(RealtimeEvent::ConversationItemAdded)
|
||||
}
|
||||
RealtimeInboundMessage::Error { error, message } => {
|
||||
let message = message.or_else(|| error.map(|e| e.to_string()))?;
|
||||
Some(RealtimeEvent::Error(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user