use crate::auth::AuthProvider; use crate::auth::add_auth_headers_to_header_map; use crate::common::ResponseEvent; use crate::common::ResponseStream; use crate::common::ResponsesWsRequest; use crate::error::ApiError; use crate::provider::Provider; use crate::rate_limits::parse_rate_limit_event; use crate::sse::responses::ResponsesStreamEvent; use crate::sse::responses::process_responses_event; use crate::telemetry::WebsocketTelemetry; use codex_client::TransportError; use futures::SinkExt; use futures::StreamExt; use http::HeaderMap; use serde_json::Value; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tracing::debug; use tracing::error; use tracing::info; use tracing::trace; use tungstenite::extensions::ExtensionsConfig; use tungstenite::extensions::compression::deflate::DeflateConfig; use tungstenite::protocol::WebSocketConfig; use url::Url; type WsStream = WebSocketStream>; const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; const X_MODELS_ETAG_HEADER: &str = "x-models-etag"; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; pub struct ResponsesWebsocketConnection { stream: Arc>>, // TODO (pakrym): is this the right place for timeout? idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, telemetry: Option>, } impl ResponsesWebsocketConnection { fn new( stream: WsStream, idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, telemetry: Option>, ) -> Self { Self { stream: Arc::new(Mutex::new(Some(stream))), idle_timeout, server_reasoning_included, models_etag, telemetry, } } pub async fn is_closed(&self) -> bool { self.stream.lock().await.is_none() } pub async fn stream_request( &self, request: ResponsesWsRequest, ) -> Result { let (tx_event, rx_event) = mpsc::channel::>(1600); let stream = Arc::clone(&self.stream); let idle_timeout = self.idle_timeout; let server_reasoning_included = self.server_reasoning_included; let models_etag = self.models_etag.clone(); let telemetry = self.telemetry.clone(); let request_body = serde_json::to_value(&request).map_err(|err| { ApiError::Stream(format!("failed to encode websocket request: {err}")) })?; tokio::spawn(async move { if let Some(etag) = models_etag { let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await; } if server_reasoning_included { let _ = tx_event .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) .await; } let mut guard = stream.lock().await; let Some(ws_stream) = guard.as_mut() else { let _ = tx_event .send(Err(ApiError::Stream( "websocket connection is closed".to_string(), ))) .await; return; }; if let Err(err) = run_websocket_response_stream( ws_stream, tx_event.clone(), request_body, idle_timeout, telemetry, ) .await { let _ = ws_stream.close(None).await; *guard = None; let _ = tx_event.send(Err(err)).await; } }); Ok(ResponseStream { rx_event }) } } pub struct ResponsesWebsocketClient { provider: Provider, auth: A, } impl ResponsesWebsocketClient { pub fn new(provider: Provider, auth: A) -> Self { Self { provider, auth } } pub async fn connect( &self, extra_headers: HeaderMap, turn_state: Option>>, telemetry: Option>, ) -> Result { let ws_url = self .provider .websocket_url_for_path("responses") .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; let mut headers = self.provider.headers.clone(); headers.extend(extra_headers); add_auth_headers_to_header_map(&self.auth, &mut headers); let (stream, server_reasoning_included, models_etag) = connect_websocket(ws_url, headers, turn_state.clone()).await?; Ok(ResponsesWebsocketConnection::new( stream, self.provider.stream_idle_timeout, server_reasoning_included, models_etag, telemetry, )) } } async fn connect_websocket( url: Url, headers: HeaderMap, turn_state: Option>>, ) -> Result<(WsStream, bool, Option), ApiError> { info!("connecting to websocket: {url}"); let mut request = url .as_str() .into_client_request() .map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?; request.headers_mut().extend(headers); let response = tokio_tungstenite::connect_async_with_config(request, Some(websocket_config()), false) .await; let (stream, response) = match response { Ok((stream, response)) => { info!( "successfully connected to websocket: {url}, headers: {:?}", response.headers() ); (stream, response) } Err(err) => { error!("failed to connect to websocket: {err}, url: {url}"); return Err(map_ws_error(err, &url)); } }; let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER); let models_etag = response .headers() .get(X_MODELS_ETAG_HEADER) .and_then(|value| value.to_str().ok()) .map(ToString::to_string); if let Some(turn_state) = turn_state && let Some(header_value) = response .headers() .get(X_CODEX_TURN_STATE_HEADER) .and_then(|value| value.to_str().ok()) { let _ = turn_state.set(header_value.to_string()); } Ok((stream, reasoning_included, models_etag)) } fn websocket_config() -> WebSocketConfig { let mut extensions = ExtensionsConfig::default(); extensions.permessage_deflate = Some(DeflateConfig::default()); let mut config = WebSocketConfig::default(); config.extensions = extensions; config } #[cfg(test)] mod tests { use super::*; #[test] fn websocket_config_enables_permessage_deflate() { let config = websocket_config(); assert!(config.extensions.permessage_deflate.is_some()); } } fn map_ws_error(err: WsError, url: &Url) -> ApiError { match err { WsError::Http(response) => { let status = response.status(); let headers = response.headers().clone(); let body = response .body() .as_ref() .and_then(|bytes| String::from_utf8(bytes.clone()).ok()); ApiError::Transport(TransportError::Http { status, url: Some(url.to_string()), headers: Some(headers), body, }) } WsError::ConnectionClosed | WsError::AlreadyClosed => { ApiError::Stream("websocket closed".to_string()) } WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())), other => ApiError::Transport(TransportError::Network(other.to_string())), } } async fn run_websocket_response_stream( ws_stream: &mut WsStream, tx_event: mpsc::Sender>, request_body: Value, idle_timeout: Duration, telemetry: Option>, ) -> Result<(), ApiError> { let request_text = match serde_json::to_string(&request_body) { Ok(text) => text, Err(err) => { return Err(ApiError::Stream(format!( "failed to encode websocket request: {err}" ))); } }; let request_start = Instant::now(); let result = ws_stream .send(Message::Text(request_text.into())) .await .map_err(|err| ApiError::Stream(format!("failed to send websocket request: {err}"))); if let Some(t) = telemetry.as_ref() { t.on_ws_request(request_start.elapsed(), result.as_ref().err()); } result?; loop { let poll_start = Instant::now(); let response = tokio::time::timeout(idle_timeout, ws_stream.next()) .await .map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into())); if let Some(t) = telemetry.as_ref() { t.on_ws_event(&response, poll_start.elapsed()); } let message = match response { Ok(Some(Ok(msg))) => msg, Ok(Some(Err(err))) => { return Err(ApiError::Stream(err.to_string())); } Ok(None) => { return Err(ApiError::Stream( "stream closed before response.completed".into(), )); } Err(err) => { return Err(err); } }; match message { Message::Text(text) => { trace!("websocket event: {text}"); let event = match serde_json::from_str::(&text) { Ok(event) => event, Err(err) => { debug!("failed to parse websocket event: {err}, data: {text}"); continue; } }; if event.kind() == "codex.rate_limits" { if let Some(snapshot) = parse_rate_limit_event(&text) { let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; } continue; } match process_responses_event(event) { Ok(Some(event)) => { let is_completed = matches!(event, ResponseEvent::Completed { .. }); let _ = tx_event.send(Ok(event)).await; if is_completed { break; } } Ok(None) => {} Err(error) => { return Err(error.into_api_error()); } } } Message::Binary(_) => { return Err(ApiError::Stream("unexpected binary websocket event".into())); } Message::Ping(payload) => { if ws_stream.send(Message::Pong(payload)).await.is_err() { return Err(ApiError::Stream("websocket ping failed".into())); } } Message::Pong(_) => {} Message::Close(_) => { return Err(ApiError::Stream( "websocket closed by server before response.completed".into(), )); } _ => {} } } Ok(()) }