mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
codex-api: share websocket pump and simplify realtime events
This commit is contained in:
@@ -6,3 +6,4 @@ pub mod realtime_websocket;
|
||||
pub mod responses;
|
||||
pub mod responses_websocket;
|
||||
mod session;
|
||||
mod websocket_pump;
|
||||
|
||||
@@ -7,21 +7,17 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionCreateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::endpoint::websocket_pump::WebsocketMessage;
|
||||
use crate::endpoint::websocket_pump::WebsocketPump;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
@@ -30,123 +26,6 @@ use tracing::trace;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
enum WsCommand {
|
||||
Send {
|
||||
message: Message,
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
Close {
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
fn new(
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
) -> (Self, mpsc::UnboundedReceiver<Result<Message, WsError>>) {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
loop {
|
||||
tokio::select! {
|
||||
command = rx_command.recv() => {
|
||||
let Some(command) = command else {
|
||||
break;
|
||||
};
|
||||
match command {
|
||||
WsCommand::Send { message, tx_result } => {
|
||||
let result = inner.send(message).await;
|
||||
let should_break = result.is_err();
|
||||
let _ = tx_result.send(result);
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsCommand::Close { tx_result } => {
|
||||
let result = inner.close(None).await;
|
||||
let _ = tx_result.send(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = inner.next() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
Ok(Message::Ping(payload)) => {
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Message::Pong(_)) => {}
|
||||
Ok(message @ (Message::Text(_)
|
||||
| Message::Binary(_)
|
||||
| Message::Close(_)
|
||||
| Message::Frame(_))) => {
|
||||
let is_close = matches!(message, Message::Close(_));
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(
|
||||
Self {
|
||||
tx_command,
|
||||
pump_task,
|
||||
},
|
||||
rx_message,
|
||||
)
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
||||
) -> Result<(), WsError> {
|
||||
let (tx_result, rx_result) = oneshot::channel();
|
||||
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
||||
return Err(WsError::ConnectionClosed);
|
||||
}
|
||||
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
||||
}
|
||||
|
||||
async fn send(&self, message: Message) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WsStream {
|
||||
fn drop(&mut self) {
|
||||
self.pump_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RealtimeWebsocketConnection {
|
||||
writer: RealtimeWebsocketWriter,
|
||||
events: RealtimeWebsocketEvents,
|
||||
@@ -154,13 +33,13 @@ pub struct RealtimeWebsocketConnection {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
stream: Arc<WebsocketPump>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketEvents {
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<WebsocketMessage>>>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
@@ -209,10 +88,7 @@ impl RealtimeWebsocketConnection {
|
||||
self.events.clone()
|
||||
}
|
||||
|
||||
fn new(
|
||||
stream: WsStream,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
) -> Self {
|
||||
fn new(stream: WebsocketPump, rx_message: mpsc::UnboundedReceiver<WebsocketMessage>) -> Self {
|
||||
let stream = Arc::new(stream);
|
||||
let is_closed = Arc::new(AtomicBool::new(false));
|
||||
Self {
|
||||
@@ -389,7 +265,7 @@ impl RealtimeWebsocketClient {
|
||||
ApiError::Stream(format!("failed to connect realtime websocket: {err}"))
|
||||
})?;
|
||||
|
||||
let (stream, rx_message) = WsStream::new(stream);
|
||||
let (stream, rx_message) = WebsocketPump::new(stream);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
|
||||
connection
|
||||
.send_session_create(config.prompt, config.session_id)
|
||||
@@ -445,6 +321,8 @@ fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -6,6 +6,5 @@ pub use methods::RealtimeWebsocketConnection;
|
||||
pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use protocol::RealtimeAudioFrame;
|
||||
pub use protocol::RealtimeConnectionState;
|
||||
pub use protocol::RealtimeEvent;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
|
||||
@@ -19,15 +19,8 @@ pub struct RealtimeAudioFrame {
|
||||
pub samples_per_channel: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeConnectionState {
|
||||
Connected,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RealtimeEvent {
|
||||
State(RealtimeConnectionState),
|
||||
SessionCreated { session_id: String },
|
||||
SessionUpdated { backend_prompt: Option<String> },
|
||||
AudioOut(RealtimeAudioFrame),
|
||||
@@ -86,84 +79,68 @@ pub(super) struct ConversationItemContent {
|
||||
pub(super) text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum RealtimeInboundMessage {
|
||||
#[serde(rename = "session.created")]
|
||||
SessionCreated {
|
||||
session_id: Option<String>,
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "session.updated")]
|
||||
SessionUpdated {
|
||||
session: Option<RealtimeInboundSession>,
|
||||
},
|
||||
#[serde(rename = "response.output_audio.delta")]
|
||||
OutputAudioDelta {
|
||||
delta: Option<String>,
|
||||
data: Option<String>,
|
||||
sample_rate: Option<u32>,
|
||||
num_channels: Option<u16>,
|
||||
samples_per_channel: Option<u32>,
|
||||
},
|
||||
#[serde(rename = "conversation.item.added")]
|
||||
ConversationItemAdded { item: Option<Value> },
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
error: Option<Value>,
|
||||
message: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RealtimeInboundSession {
|
||||
id: Option<String>,
|
||||
backend_prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
let parsed: RealtimeInboundMessage = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
debug!("failed to parse realtime event: {err}, data: {payload}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match parsed {
|
||||
RealtimeInboundMessage::SessionCreated {
|
||||
session_id,
|
||||
session,
|
||||
} => {
|
||||
let session_id = session.and_then(|s| s.id).or(session_id);
|
||||
session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id })
|
||||
let event_type = parsed.get("type")?.as_str()?;
|
||||
match event_type {
|
||||
"session.created" => {
|
||||
let session_id = parsed
|
||||
.pointer("/session/id")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("session_id").and_then(Value::as_str))?;
|
||||
Some(RealtimeEvent::SessionCreated {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
}
|
||||
RealtimeInboundMessage::SessionUpdated { session } => Some(RealtimeEvent::SessionUpdated {
|
||||
backend_prompt: session.and_then(|s| s.backend_prompt),
|
||||
"session.updated" => Some(RealtimeEvent::SessionUpdated {
|
||||
backend_prompt: parsed
|
||||
.pointer("/session/backend_prompt")
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string),
|
||||
}),
|
||||
RealtimeInboundMessage::OutputAudioDelta {
|
||||
delta,
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel,
|
||||
} => {
|
||||
let data = delta.or(data)?;
|
||||
let sample_rate = sample_rate?;
|
||||
let num_channels = num_channels?;
|
||||
"response.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok())?;
|
||||
let num_channels = parsed
|
||||
.get("num_channels")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u16::try_from(value).ok())?;
|
||||
let samples_per_channel = parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|value| u32::try_from(value).ok());
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
data: data.to_string(),
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel,
|
||||
}))
|
||||
}
|
||||
RealtimeInboundMessage::ConversationItemAdded { item } => {
|
||||
item.map(RealtimeEvent::ConversationItemAdded)
|
||||
}
|
||||
RealtimeInboundMessage::Error { error, message } => {
|
||||
let message = message.or_else(|| error.map(|e| e.to_string()))?;
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"error" => {
|
||||
let message = parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))?;
|
||||
Some(RealtimeEvent::Error(message))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::auth::add_auth_headers_to_header_map;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::common::ResponsesWsRequest;
|
||||
use crate::endpoint::websocket_pump::WebsocketMessage;
|
||||
use crate::endpoint::websocket_pump::WebsocketPump;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::rate_limits::parse_rate_limit_event;
|
||||
@@ -11,8 +13,6 @@ use crate::sse::responses::process_responses_event;
|
||||
use crate::telemetry::WebsocketTelemetry;
|
||||
use codex_client::TransportError;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderName;
|
||||
use http::HeaderValue;
|
||||
@@ -23,10 +23,8 @@ use serde_json::map::Map as JsonMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
@@ -43,110 +41,22 @@ use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
enum WsCommand {
|
||||
Send {
|
||||
message: Message,
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
Close {
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
pump: WebsocketPump,
|
||||
rx_message: mpsc::UnboundedReceiver<WebsocketMessage>,
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
loop {
|
||||
tokio::select! {
|
||||
command = rx_command.recv() => {
|
||||
let Some(command) = command else {
|
||||
break;
|
||||
};
|
||||
match command {
|
||||
WsCommand::Send { message, tx_result } => {
|
||||
let result = inner.send(message).await;
|
||||
let should_break = result.is_err();
|
||||
let _ = tx_result.send(result);
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsCommand::Close { tx_result } => {
|
||||
let result = inner.close(None).await;
|
||||
let _ = tx_result.send(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = inner.next() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
Ok(Message::Ping(payload)) => {
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Message::Pong(_)) => {}
|
||||
Ok(message @ (Message::Text(_)
|
||||
| Message::Binary(_)
|
||||
| Message::Close(_)
|
||||
| Message::Frame(_))) => {
|
||||
let is_close = matches!(message, Message::Close(_));
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
tx_command,
|
||||
rx_message,
|
||||
pump_task,
|
||||
}
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
||||
) -> Result<(), WsError> {
|
||||
let (tx_result, rx_result) = oneshot::channel();
|
||||
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
||||
return Err(WsError::ConnectionClosed);
|
||||
}
|
||||
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
||||
fn new(inner: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>) -> Self {
|
||||
let (pump, rx_message) = WebsocketPump::new(inner);
|
||||
Self { pump, rx_message }
|
||||
}
|
||||
|
||||
async fn send(&self, message: Message) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
||||
.await
|
||||
self.pump.send(message).await
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
self.pump.close().await
|
||||
}
|
||||
|
||||
async fn next(&mut self) -> Option<Result<Message, WsError>> {
|
||||
@@ -154,12 +64,6 @@ impl WsStream {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WsStream {
|
||||
fn drop(&mut self) {
|
||||
self.pump_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
||||
128
codex-rs/codex-api/src/endpoint/websocket_pump.rs
Normal file
128
codex-rs/codex-api/src/endpoint/websocket_pump.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
pub(crate) type WebsocketMessage = Result<Message, WsError>;
|
||||
|
||||
pub(crate) struct WebsocketPump {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
enum WsCommand {
|
||||
Send {
|
||||
message: Message,
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
Close {
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WebsocketPump {
|
||||
pub(crate) fn new(
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
) -> (Self, mpsc::UnboundedReceiver<WebsocketMessage>) {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<WebsocketMessage>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
loop {
|
||||
tokio::select! {
|
||||
command = rx_command.recv() => {
|
||||
let Some(command) = command else {
|
||||
break;
|
||||
};
|
||||
match command {
|
||||
WsCommand::Send { message, tx_result } => {
|
||||
let result = inner.send(message).await;
|
||||
let should_break = result.is_err();
|
||||
let _ = tx_result.send(result);
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsCommand::Close { tx_result } => {
|
||||
let result = inner.close(None).await;
|
||||
let _ = tx_result.send(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = inner.next() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
Ok(Message::Ping(payload)) => {
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Message::Pong(_)) => {}
|
||||
Ok(message @ (Message::Text(_)
|
||||
| Message::Binary(_)
|
||||
| Message::Close(_)
|
||||
| Message::Frame(_))) => {
|
||||
let is_close = matches!(message, Message::Close(_));
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(
|
||||
Self {
|
||||
tx_command,
|
||||
pump_task,
|
||||
},
|
||||
rx_message,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) async fn send(&self, message: Message) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn close(&self) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
||||
) -> Result<(), WsError> {
|
||||
let (tx_result, rx_result) = oneshot::channel();
|
||||
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
||||
return Err(WsError::ConnectionClosed);
|
||||
}
|
||||
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WebsocketPump {
|
||||
fn drop(&mut self) {
|
||||
self.pump_task.abort();
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,6 @@ pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::memories::MemoriesClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeAudioFrame;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeConnectionState;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeEvent;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
|
||||
|
||||
1
codex-rs/codex-api/tests/common/mod.rs
Normal file
1
codex-rs/codex-api/tests/common/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ws_harness;
|
||||
43
codex-rs/codex-api/tests/common/ws_harness.rs
Normal file
43
codex-rs/codex-api/tests/common/ws_harness.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_api::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
use http::HeaderMap;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::accept_async;
|
||||
|
||||
pub(crate) async fn spawn_ws_server<F, Fut>(handler: F) -> (String, JoinHandle<()>)
|
||||
where
|
||||
F: FnOnce(WebSocketStream<tokio::net::TcpStream>) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let ws = accept_async(stream).await.expect("accept ws");
|
||||
handler(ws).await;
|
||||
});
|
||||
(format!("ws://{addr}"), server)
|
||||
}
|
||||
|
||||
pub(crate) fn test_provider() -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
@@ -1,30 +1,22 @@
|
||||
use std::collections::HashMap;
|
||||
mod common;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_api::RealtimeAudioFrame;
|
||||
use codex_api::RealtimeEvent;
|
||||
use codex_api::RealtimeSessionConfig;
|
||||
use codex_api::RealtimeWebsocketClient;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
use common::ws_harness;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
@@ -76,27 +68,14 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
))
|
||||
.await
|
||||
.expect("send audio out");
|
||||
});
|
||||
})
|
||||
.await;
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let client = RealtimeWebsocketClient::new(ws_harness::test_provider());
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
},
|
||||
@@ -149,13 +128,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
@@ -186,27 +159,14 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
))
|
||||
.await
|
||||
.expect("send session.created");
|
||||
});
|
||||
})
|
||||
.await;
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let client = RealtimeWebsocketClient::new(ws_harness::test_provider());
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
},
|
||||
@@ -249,13 +209,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
@@ -267,27 +221,14 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
assert_eq!(first_json["type"], "session.create");
|
||||
|
||||
ws.send(Message::Close(None)).await.expect("send close");
|
||||
});
|
||||
})
|
||||
.await;
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let client = RealtimeWebsocketClient::new(ws_harness::test_provider());
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
},
|
||||
@@ -308,13 +249,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
let addr = listener.local_addr().expect("local addr");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("accept");
|
||||
let mut ws = accept_async(stream).await.expect("accept ws");
|
||||
|
||||
let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move {
|
||||
let first = ws
|
||||
.next()
|
||||
.await
|
||||
@@ -346,27 +281,14 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
))
|
||||
.await
|
||||
.expect("send session.created");
|
||||
});
|
||||
})
|
||||
.await;
|
||||
|
||||
let provider = Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
query_params: Some(HashMap::new()),
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: false,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
};
|
||||
let client = RealtimeWebsocketClient::new(provider);
|
||||
let client = RealtimeWebsocketClient::new(ws_harness::test_provider());
|
||||
let connection = client
|
||||
.connect(
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user