mirror of
https://github.com/openai/codex.git
synced 2026-04-26 07:35:29 +00:00
Reuse websocket connection (#9127)
Reuses the connection but still sends full requests.
This commit is contained in:
@@ -16,8 +16,10 @@ use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
@@ -31,6 +33,69 @@ use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
stream: Arc<Mutex<Option<WsStream>>>,
|
||||
// TODO (pakrym): is this the right place for timeout?
|
||||
idle_timeout: Duration,
|
||||
}
|
||||
|
||||
impl ResponsesWebsocketConnection {
|
||||
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_closed(&self) -> bool {
|
||||
self.stream.lock().await.is_none()
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if request.compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let request_body = request.body;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut guard = stream.lock().await;
|
||||
let Some(ws_stream) = guard.as_mut() else {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream(
|
||||
"websocket connection is closed".to_string(),
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(err) = run_websocket_response_stream(
|
||||
ws_stream,
|
||||
tx_event.clone(),
|
||||
request_body,
|
||||
idle_timeout,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = ws_stream.close(None).await;
|
||||
*guard = None;
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponsesWebsocketClient<A: AuthProvider> {
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
@@ -41,12 +106,22 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
Self { provider, auth }
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
pub async fn connect(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request.compression)
|
||||
.await
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponsesWebsocketConnection, ApiError> {
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
|
||||
let mut headers = self.provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let stream = connect_websocket(ws_url, headers).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
@@ -82,7 +157,8 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
.compression(compression)
|
||||
.build(&self.provider)?;
|
||||
|
||||
self.stream_request(request).await
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
@@ -91,41 +167,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
let mut headers = self.provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let connection = connect_websocket(ws_url, headers).await?;
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let idle_timeout = self.provider.stream_idle_timeout;
|
||||
|
||||
// TODO (pakrym): surface rate limits
|
||||
// TODO (pakrym): check models etags
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_websocket_response_stream(
|
||||
connection.stream,
|
||||
tx_event.clone(),
|
||||
body,
|
||||
idle_timeout,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
let request = ResponsesRequest {
|
||||
body,
|
||||
headers: extra_headers,
|
||||
compression,
|
||||
};
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,11 +191,7 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSocketConnection {
|
||||
stream: WsStream,
|
||||
}
|
||||
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConnection, ApiError> {
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
|
||||
let mut request = url
|
||||
.clone()
|
||||
.into_client_request()
|
||||
@@ -157,7 +201,7 @@ async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConn
|
||||
let (stream, _) = tokio_tungstenite::connect_async(request)
|
||||
.await
|
||||
.map_err(|err| map_ws_error(err, &url))?;
|
||||
Ok(WebSocketConnection { stream })
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
@@ -185,7 +229,7 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
mut ws_stream: WsStream,
|
||||
ws_stream: &mut WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
request_body: Value,
|
||||
idle_timeout: Duration,
|
||||
@@ -193,7 +237,6 @@ async fn run_websocket_response_stream(
|
||||
let request_text = match serde_json::to_string(&request_body) {
|
||||
Ok(text) => text,
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to encode websocket request: {err}"
|
||||
)));
|
||||
@@ -201,7 +244,6 @@ async fn run_websocket_response_stream(
|
||||
};
|
||||
|
||||
if let Err(err) = ws_stream.send(Message::Text(request_text)).await {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to send websocket request: {err}"
|
||||
)));
|
||||
@@ -214,17 +256,14 @@ async fn run_websocket_response_stream(
|
||||
let message = match response {
|
||||
Ok(Some(Ok(msg))) => msg,
|
||||
Ok(Some(Err(err))) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(err.to_string()));
|
||||
}
|
||||
Ok(None) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
@@ -249,24 +288,20 @@ async fn run_websocket_response_stream(
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(error) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(error.into_api_error());
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("unexpected binary websocket event".into()));
|
||||
}
|
||||
Message::Ping(payload) => {
|
||||
if ws_stream.send(Message::Pong(payload)).await.is_err() {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("websocket ping failed".into()));
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"websocket closed before response.completed".into(),
|
||||
));
|
||||
@@ -275,6 +310,5 @@ async fn run_websocket_response_stream(
|
||||
}
|
||||
}
|
||||
|
||||
let _ = ws_stream.close(None).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user