codex-api: share websocket pump and simplify realtime events

This commit is contained in:
Ahmed Ibrahim
2026-02-17 16:18:38 -08:00
parent f26c001568
commit 71b1d9ff0d
10 changed files with 259 additions and 407 deletions

View File

@@ -7,21 +7,17 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::protocol::SessionCreateSession;
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::endpoint::websocket_pump::WebsocketMessage;
use crate::endpoint::websocket_pump::WebsocketPump;
use crate::error::ApiError;
use crate::provider::Provider;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
@@ -30,123 +26,6 @@ use tracing::trace;
use tungstenite::protocol::WebSocketConfig;
use url::Url;
struct WsStream {
tx_command: mpsc::Sender<WsCommand>,
pump_task: tokio::task::JoinHandle<()>,
}
enum WsCommand {
Send {
message: Message,
tx_result: oneshot::Sender<Result<(), WsError>>,
},
Close {
tx_result: oneshot::Sender<Result<(), WsError>>,
},
}
impl WsStream {
fn new(
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> (Self, mpsc::UnboundedReceiver<Result<Message, WsError>>) {
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
let pump_task = tokio::spawn(async move {
let mut inner = inner;
loop {
tokio::select! {
command = rx_command.recv() => {
let Some(command) = command else {
break;
};
match command {
WsCommand::Send { message, tx_result } => {
let result = inner.send(message).await;
let should_break = result.is_err();
let _ = tx_result.send(result);
if should_break {
break;
}
}
WsCommand::Close { tx_result } => {
let result = inner.close(None).await;
let _ = tx_result.send(result);
break;
}
}
}
message = inner.next() => {
let Some(message) = message else {
break;
};
match message {
Ok(Message::Ping(payload)) => {
if let Err(err) = inner.send(Message::Pong(payload)).await {
let _ = tx_message.send(Err(err));
break;
}
}
Ok(Message::Pong(_)) => {}
Ok(message @ (Message::Text(_)
| Message::Binary(_)
| Message::Close(_)
| Message::Frame(_))) => {
let is_close = matches!(message, Message::Close(_));
if tx_message.send(Ok(message)).is_err() {
break;
}
if is_close {
break;
}
}
Err(err) => {
let _ = tx_message.send(Err(err));
break;
}
}
}
}
}
});
(
Self {
tx_command,
pump_task,
},
rx_message,
)
}
async fn request(
&self,
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
) -> Result<(), WsError> {
let (tx_result, rx_result) = oneshot::channel();
if self.tx_command.send(make_command(tx_result)).await.is_err() {
return Err(WsError::ConnectionClosed);
}
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
}
async fn send(&self, message: Message) -> Result<(), WsError> {
self.request(|tx_result| WsCommand::Send { message, tx_result })
.await
}
async fn close(&self) -> Result<(), WsError> {
self.request(|tx_result| WsCommand::Close { tx_result })
.await
}
}
impl Drop for WsStream {
fn drop(&mut self) {
self.pump_task.abort();
}
}
pub struct RealtimeWebsocketConnection {
writer: RealtimeWebsocketWriter,
events: RealtimeWebsocketEvents,
@@ -154,13 +33,13 @@ pub struct RealtimeWebsocketConnection {
#[derive(Clone)]
pub struct RealtimeWebsocketWriter {
stream: Arc<WsStream>,
stream: Arc<WebsocketPump>,
is_closed: Arc<AtomicBool>,
}
#[derive(Clone)]
pub struct RealtimeWebsocketEvents {
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<WebsocketMessage>>>,
is_closed: Arc<AtomicBool>,
}
@@ -209,10 +88,7 @@ impl RealtimeWebsocketConnection {
self.events.clone()
}
fn new(
stream: WsStream,
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
) -> Self {
fn new(stream: WebsocketPump, rx_message: mpsc::UnboundedReceiver<WebsocketMessage>) -> Self {
let stream = Arc::new(stream);
let is_closed = Arc::new(AtomicBool::new(false));
Self {
@@ -389,7 +265,7 @@ impl RealtimeWebsocketClient {
ApiError::Stream(format!("failed to connect realtime websocket: {err}"))
})?;
let (stream, rx_message) = WsStream::new(stream);
let (stream, rx_message) = WebsocketPump::new(stream);
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
connection
.send_session_create(config.prompt, config.session_id)
@@ -445,6 +321,8 @@ fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
#[cfg(test)]
mod tests {
use super::*;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use serde_json::Value;