mirror of
https://github.com/openai/codex.git
synced 2026-04-27 08:05:51 +00:00
codex-api: share websocket pump and simplify realtime events
This commit is contained in:
@@ -7,21 +7,17 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionCreateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::endpoint::websocket_pump::WebsocketMessage;
|
||||
use crate::endpoint::websocket_pump::WebsocketPump;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
@@ -30,123 +26,6 @@ use tracing::trace;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
enum WsCommand {
|
||||
Send {
|
||||
message: Message,
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
Close {
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
fn new(
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
) -> (Self, mpsc::UnboundedReceiver<Result<Message, WsError>>) {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
loop {
|
||||
tokio::select! {
|
||||
command = rx_command.recv() => {
|
||||
let Some(command) = command else {
|
||||
break;
|
||||
};
|
||||
match command {
|
||||
WsCommand::Send { message, tx_result } => {
|
||||
let result = inner.send(message).await;
|
||||
let should_break = result.is_err();
|
||||
let _ = tx_result.send(result);
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsCommand::Close { tx_result } => {
|
||||
let result = inner.close(None).await;
|
||||
let _ = tx_result.send(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = inner.next() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
Ok(Message::Ping(payload)) => {
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Message::Pong(_)) => {}
|
||||
Ok(message @ (Message::Text(_)
|
||||
| Message::Binary(_)
|
||||
| Message::Close(_)
|
||||
| Message::Frame(_))) => {
|
||||
let is_close = matches!(message, Message::Close(_));
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(
|
||||
Self {
|
||||
tx_command,
|
||||
pump_task,
|
||||
},
|
||||
rx_message,
|
||||
)
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
||||
) -> Result<(), WsError> {
|
||||
let (tx_result, rx_result) = oneshot::channel();
|
||||
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
||||
return Err(WsError::ConnectionClosed);
|
||||
}
|
||||
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
||||
}
|
||||
|
||||
async fn send(&self, message: Message) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WsStream {
|
||||
fn drop(&mut self) {
|
||||
self.pump_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RealtimeWebsocketConnection {
|
||||
writer: RealtimeWebsocketWriter,
|
||||
events: RealtimeWebsocketEvents,
|
||||
@@ -154,13 +33,13 @@ pub struct RealtimeWebsocketConnection {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
stream: Arc<WebsocketPump>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketEvents {
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<WebsocketMessage>>>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
@@ -209,10 +88,7 @@ impl RealtimeWebsocketConnection {
|
||||
self.events.clone()
|
||||
}
|
||||
|
||||
fn new(
|
||||
stream: WsStream,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
) -> Self {
|
||||
fn new(stream: WebsocketPump, rx_message: mpsc::UnboundedReceiver<WebsocketMessage>) -> Self {
|
||||
let stream = Arc::new(stream);
|
||||
let is_closed = Arc::new(AtomicBool::new(false));
|
||||
Self {
|
||||
@@ -389,7 +265,7 @@ impl RealtimeWebsocketClient {
|
||||
ApiError::Stream(format!("failed to connect realtime websocket: {err}"))
|
||||
})?;
|
||||
|
||||
let (stream, rx_message) = WsStream::new(stream);
|
||||
let (stream, rx_message) = WebsocketPump::new(stream);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
|
||||
connection
|
||||
.send_session_create(config.prompt, config.session_id)
|
||||
@@ -445,6 +321,8 @@ fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
Reference in New Issue
Block a user