use crate::auth::AuthProvider; use crate::auth::add_auth_headers_to_header_map; use crate::common::ResponseEvent; use crate::common::ResponseStream; use crate::common::ResponsesWsRequest; use crate::error::ApiError; use crate::provider::Provider; use crate::rate_limits::parse_rate_limit_event; use crate::sse::responses::ResponsesStreamEvent; use crate::sse::responses::process_responses_event; use crate::telemetry::WebsocketTelemetry; use codex_client::TransportError; use codex_client::maybe_build_rustls_client_config_with_custom_ca; use codex_utils_rustls_provider::ensure_rustls_crypto_provider; use futures::SinkExt; use futures::StreamExt; use http::HeaderMap; use http::HeaderName; use http::HeaderValue; use http::StatusCode; use serde::Deserialize; use serde_json::Value; use serde_json::map::Map as JsonMap; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::connect_async_tls_with_config; use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tracing::Instrument; use tracing::Span; use tracing::debug; use tracing::error; use tracing::info; use tracing::instrument; use tracing::trace; use tungstenite::extensions::ExtensionsConfig; use tungstenite::extensions::compression::deflate::DeflateConfig; use tungstenite::protocol::WebSocketConfig; use url::Url; struct WsStream { tx_command: mpsc::Sender, rx_message: mpsc::UnboundedReceiver>, pump_task: tokio::task::JoinHandle<()>, } enum WsCommand { Send { message: Message, tx_result: oneshot::Sender>, }, } impl WsStream { fn new(inner: WebSocketStream>) -> Self { let (tx_command, mut rx_command) = mpsc::channel::(32); let (tx_message, rx_message) = mpsc::unbounded_channel::>(); let pump_task = tokio::spawn(async move { let mut inner = inner; loop { tokio::select! { command = rx_command.recv() => { let Some(command) = command else { break; }; match command { WsCommand::Send { message, tx_result } => { let result = inner.send(message).await; let should_break = result.is_err(); let _ = tx_result.send(result); if should_break { break; } } } } message = inner.next() => { let Some(message) = message else { break; }; match message { Ok(Message::Ping(payload)) => { if let Err(err) = inner.send(Message::Pong(payload)).await { let _ = tx_message.send(Err(err)); break; } } Ok(Message::Pong(_)) => {} Ok(message @ (Message::Text(_) | Message::Binary(_) | Message::Close(_) | Message::Frame(_))) => { let is_close = matches!(message, Message::Close(_)); if tx_message.send(Ok(message)).is_err() { break; } if is_close { break; } } Err(err) => { let _ = tx_message.send(Err(err)); break; } } } } } }); Self { tx_command, rx_message, pump_task, } } async fn request( &self, make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, ) -> Result<(), WsError> { let (tx_result, rx_result) = oneshot::channel(); if self.tx_command.send(make_command(tx_result)).await.is_err() { return Err(WsError::ConnectionClosed); } rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) } async fn send(&self, message: Message) -> Result<(), WsError> { self.request(|tx_result| WsCommand::Send { message, tx_result }) .await } async fn next(&mut self) -> Option> { self.rx_message.recv().await } } impl Drop for WsStream { fn drop(&mut self) { self.pump_task.abort(); } } const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; const X_MODELS_ETAG_HEADER: &str = "x-models-etag"; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; const OPENAI_MODEL_HEADER: &str = "openai-model"; const WEBSOCKET_CONNECTION_LIMIT_REACHED_CODE: &str = "websocket_connection_limit_reached"; const WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE: &str = "Responses websocket connection limit reached (60 minutes). Create a new websocket connection to continue."; pub struct ResponsesWebsocketConnection { stream: Arc>>, // TODO (pakrym): is this the right place for timeout? idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, server_model: Option, telemetry: Option>, } impl std::fmt::Debug for ResponsesWebsocketConnection { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ResponsesWebsocketConnection") .field("stream", &"") .field("idle_timeout", &self.idle_timeout) .field("server_reasoning_included", &self.server_reasoning_included) .field("models_etag", &self.models_etag) .field("server_model", &self.server_model) .field("telemetry", &self.telemetry.as_ref().map(|_| "")) .finish() } } impl ResponsesWebsocketConnection { fn new( stream: WsStream, idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, server_model: Option, telemetry: Option>, ) -> Self { Self { stream: Arc::new(Mutex::new(Some(stream))), idle_timeout, server_reasoning_included, models_etag, server_model, telemetry, } } pub async fn is_closed(&self) -> bool { self.stream.lock().await.is_none() } #[instrument( name = "responses_websocket.stream_request", level = "info", skip_all, fields(transport = "responses_websocket", api.path = "responses") )] pub async fn stream_request( &self, request: ResponsesWsRequest, connection_reused: bool, ) -> Result { let (tx_event, rx_event) = mpsc::channel::>(1600); let stream = Arc::clone(&self.stream); let idle_timeout = self.idle_timeout; let server_reasoning_included = self.server_reasoning_included; let models_etag = self.models_etag.clone(); let server_model = self.server_model.clone(); let telemetry = self.telemetry.clone(); let request_body = serde_json::to_value(&request).map_err(|err| { ApiError::Stream(format!("failed to encode websocket request: {err}")) })?; let current_span = Span::current(); tokio::spawn( async move { if let Some(model) = server_model { let _ = tx_event.send(Ok(ResponseEvent::ServerModel(model))).await; } if let Some(etag) = models_etag { let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await; } if server_reasoning_included { let _ = tx_event .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) .await; } let mut guard = stream.lock().await; let result = { let Some(ws_stream) = guard.as_mut() else { let _ = tx_event .send(Err(ApiError::Stream( "websocket connection is closed".to_string(), ))) .await; return; }; run_websocket_response_stream( ws_stream, tx_event.clone(), request_body, idle_timeout, telemetry, connection_reused, ) .await }; if let Err(err) = result { // A terminal stream error should reach the caller immediately. Waiting for a // graceful close handshake here can stall indefinitely and mask the error. let failed_stream = guard.take(); drop(guard); drop(failed_stream); let _ = tx_event.send(Err(err)).await; } } .instrument(current_span), ); Ok(ResponseStream { rx_event }) } } pub struct ResponsesWebsocketClient { provider: Provider, auth: A, } impl ResponsesWebsocketClient { pub fn new(provider: Provider, auth: A) -> Self { Self { provider, auth } } #[instrument( name = "responses_websocket.connect", level = "info", skip_all, fields(transport = "responses_websocket", api.path = "responses") )] pub async fn connect( &self, extra_headers: HeaderMap, default_headers: HeaderMap, turn_state: Option>>, telemetry: Option>, ) -> Result { let ws_url = self .provider .websocket_url_for_path("responses") .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; let mut headers = merge_request_headers(&self.provider.headers, extra_headers, default_headers); add_auth_headers_to_header_map(&self.auth, &mut headers); let (stream, server_reasoning_included, models_etag, server_model) = connect_websocket(ws_url, headers, turn_state.clone()).await?; Ok(ResponsesWebsocketConnection::new( stream, self.provider.stream_idle_timeout, server_reasoning_included, models_etag, server_model, telemetry, )) } } fn merge_request_headers( provider_headers: &HeaderMap, extra_headers: HeaderMap, default_headers: HeaderMap, ) -> HeaderMap { let mut headers = provider_headers.clone(); headers.extend(extra_headers); for (name, value) in &default_headers { if let http::header::Entry::Vacant(entry) = headers.entry(name) { entry.insert(value.clone()); } } headers } async fn connect_websocket( url: Url, headers: HeaderMap, turn_state: Option>>, ) -> Result<(WsStream, bool, Option, Option), ApiError> { ensure_rustls_crypto_provider(); info!("connecting to websocket: {url}"); let mut request = url .as_str() .into_client_request() .map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?; request.headers_mut().extend(headers); // Secure websocket traffic needs the same custom-CA policy as reqwest-based HTTPS traffic. // If a Codex-specific CA bundle is configured, build an explicit rustls connector so this // websocket path does not fall back to tungstenite's default native-roots-only behavior. let connector = maybe_build_rustls_client_config_with_custom_ca() .map_err(|err| ApiError::Stream(format!("failed to configure websocket TLS: {err}")))? .map(tokio_tungstenite::Connector::Rustls); let response = connect_async_tls_with_config( request, Some(websocket_config()), false, // `false` means "do not disable Nagle", which is tungstenite's recommended default. connector, ) .await; let (stream, response) = match response { Ok((stream, response)) => { info!( "successfully connected to websocket: {url}, headers: {:?}", response.headers() ); (stream, response) } Err(err) => { error!("failed to connect to websocket: {err}, url: {url}"); return Err(map_ws_error(err, &url)); } }; let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER); let models_etag = response .headers() .get(X_MODELS_ETAG_HEADER) .and_then(|value| value.to_str().ok()) .map(ToString::to_string); let server_model = response .headers() .get(OPENAI_MODEL_HEADER) .and_then(|value| value.to_str().ok()) .map(ToString::to_string); if let Some(turn_state) = turn_state && let Some(header_value) = response .headers() .get(X_CODEX_TURN_STATE_HEADER) .and_then(|value| value.to_str().ok()) { let _ = turn_state.set(header_value.to_string()); } Ok(( WsStream::new(stream), reasoning_included, models_etag, server_model, )) } fn websocket_config() -> WebSocketConfig { let mut extensions = ExtensionsConfig::default(); extensions.permessage_deflate = Some(DeflateConfig::default()); let mut config = WebSocketConfig::default(); config.extensions = extensions; config } fn map_ws_error(err: WsError, url: &Url) -> ApiError { match err { WsError::Http(response) => { let status = response.status(); let headers = response.headers().clone(); let body = response .body() .as_ref() .and_then(|bytes| String::from_utf8(bytes.clone()).ok()); ApiError::Transport(TransportError::Http { status, url: Some(url.to_string()), headers: Some(headers), body, }) } WsError::ConnectionClosed | WsError::AlreadyClosed => { ApiError::Stream("websocket closed".to_string()) } WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())), other => ApiError::Transport(TransportError::Network(other.to_string())), } } #[derive(Debug, Deserialize)] struct WrappedWebsocketError { code: Option, message: Option, } #[derive(Debug, Deserialize)] struct WrappedWebsocketErrorEvent { #[serde(rename = "type")] kind: String, #[serde(alias = "status_code")] status: Option, #[serde(default)] error: Option, #[serde(default)] headers: Option>, } fn parse_wrapped_websocket_error_event(payload: &str) -> Option { let event: WrappedWebsocketErrorEvent = serde_json::from_str(payload).ok()?; if event.kind != "error" { return None; } Some(event) } fn map_wrapped_websocket_error_event( event: WrappedWebsocketErrorEvent, original_payload: String, ) -> Option { let WrappedWebsocketErrorEvent { status, error, headers, .. } = event; if let Some(error) = error.as_ref() && let Some(code) = error.code.as_deref() && code == WEBSOCKET_CONNECTION_LIMIT_REACHED_CODE { return Some(ApiError::Retryable { message: error .message .clone() .unwrap_or_else(|| WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE.to_string()), delay: None, }); } let status = StatusCode::from_u16(status?).ok()?; if status.is_success() { return None; } Some(ApiError::Transport(TransportError::Http { status, url: None, headers: headers.map(json_headers_to_http_headers), body: Some(original_payload), })) } fn json_headers_to_http_headers(headers: JsonMap) -> HeaderMap { let mut mapped = HeaderMap::new(); for (name, value) in headers { let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else { continue; }; let Some(header_value) = json_header_value(value) else { continue; }; mapped.insert(header_name, header_value); } mapped } fn json_header_value(value: Value) -> Option { let value = match value { Value::String(value) => value, Value::Number(value) => value.to_string(), Value::Bool(value) => value.to_string(), _ => return None, }; HeaderValue::from_str(&value).ok() } async fn run_websocket_response_stream( ws_stream: &mut WsStream, tx_event: mpsc::Sender>, request_body: Value, idle_timeout: Duration, telemetry: Option>, connection_reused: bool, ) -> Result<(), ApiError> { let mut last_server_model: Option = None; let request_text = match serde_json::to_string(&request_body) { Ok(text) => text, Err(err) => { return Err(ApiError::Stream(format!( "failed to encode websocket request: {err}" ))); } }; trace!("websocket request: {request_text}"); let request_start = Instant::now(); let result = ws_stream .send(Message::Text(request_text.into())) .await .map_err(|err| ApiError::Stream(format!("failed to send websocket request: {err}"))); if let Some(t) = telemetry.as_ref() { t.on_ws_request( request_start.elapsed(), result.as_ref().err(), connection_reused, ); } result?; loop { let poll_start = Instant::now(); let response = tokio::time::timeout(idle_timeout, ws_stream.next()) .await .map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into())); if let Some(t) = telemetry.as_ref() { t.on_ws_event(&response, poll_start.elapsed()); } let message = match response { Ok(Some(Ok(msg))) => msg, Ok(Some(Err(err))) => { return Err(ApiError::Stream(err.to_string())); } Ok(None) => { return Err(ApiError::Stream( "stream closed before response.completed".into(), )); } Err(err) => { return Err(err); } }; match message { Message::Text(text) => { trace!("websocket event: {text}"); if let Some(wrapped_error) = parse_wrapped_websocket_error_event(&text) && let Some(error) = map_wrapped_websocket_error_event(wrapped_error, text.to_string()) { return Err(error); } let event = match serde_json::from_str::(&text) { Ok(event) => event, Err(err) => { debug!("failed to parse websocket event: {err}, data: {text}"); continue; } }; if event.kind() == "codex.rate_limits" { if let Some(snapshot) = parse_rate_limit_event(&text) { let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; } continue; } if let Some(model) = event.response_model() && last_server_model.as_deref() != Some(model.as_str()) { let _ = tx_event .send(Ok(ResponseEvent::ServerModel(model.clone()))) .await; last_server_model = Some(model); } match process_responses_event(event) { Ok(Some(event)) => { let is_completed = matches!(event, ResponseEvent::Completed { .. }); let _ = tx_event.send(Ok(event)).await; if is_completed { break; } } Ok(None) => {} Err(error) => { return Err(error.into_api_error()); } } } Message::Binary(_) => { return Err(ApiError::Stream("unexpected binary websocket event".into())); } Message::Close(_) => { return Err(ApiError::Stream( "websocket closed by server before response.completed".into(), )); } Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) => {} } } Ok(()) } #[cfg(test)] mod tests { use super::*; use pretty_assertions::assert_eq; use serde_json::json; #[test] fn websocket_config_enables_permessage_deflate() { let config = websocket_config(); assert!(config.extensions.permessage_deflate.is_some()); } #[test] fn parse_wrapped_websocket_error_event_maps_to_transport_http() { let payload = json!({ "type": "error", "status": 429, "error": { "type": "usage_limit_reached", "message": "The usage limit has been reached", "plan_type": "pro", "resets_at": 1738888888 }, "headers": { "x-codex-primary-used-percent": "100.0", "x-codex-primary-window-minutes": 15 } }) .to_string(); let wrapped_error = parse_wrapped_websocket_error_event(&payload) .expect("expected websocket error payload to be parsed"); let api_error = map_wrapped_websocket_error_event(wrapped_error, payload) .expect("expected websocket error payload to map to ApiError"); let ApiError::Transport(TransportError::Http { status, headers, body, .. }) = api_error else { panic!("expected ApiError::Transport(Http)"); }; assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); let headers = headers.expect("expected headers"); assert_eq!( headers .get("x-codex-primary-used-percent") .and_then(|value| value.to_str().ok()), Some("100.0") ); assert_eq!( headers .get("x-codex-primary-window-minutes") .and_then(|value| value.to_str().ok()), Some("15") ); let body = body.expect("expected body"); assert!(body.contains("usage_limit_reached")); assert!(body.contains("The usage limit has been reached")); } #[test] fn parse_wrapped_websocket_error_event_ignores_non_error_payloads() { let payload = json!({ "type": "response.created", "response": { "id": "resp-1" } }) .to_string(); let wrapped_error = parse_wrapped_websocket_error_event(&payload); assert!(wrapped_error.is_none()); } #[test] fn parse_wrapped_websocket_error_event_with_status_maps_invalid_request() { let payload = json!({ "type": "error", "status": 400, "error": { "type": "invalid_request_error", "message": "Model does not support image inputs" } }) .to_string(); let wrapped_error = parse_wrapped_websocket_error_event(&payload) .expect("expected websocket error payload to be parsed"); let api_error = map_wrapped_websocket_error_event(wrapped_error, payload) .expect("expected websocket error payload to map to ApiError"); let ApiError::Transport(TransportError::Http { status, body, .. }) = api_error else { panic!("expected ApiError::Transport(Http)"); }; assert_eq!(status, StatusCode::BAD_REQUEST); let body = body.expect("expected body"); assert!(body.contains("invalid_request_error")); assert!(body.contains("Model does not support image inputs")); } #[test] fn parse_wrapped_websocket_error_event_with_connection_limit_maps_retryable() { let payload = json!({ "type": "error", "status": 400, "error": { "type": "invalid_request_error", "code": "websocket_connection_limit_reached", "message": "Responses websocket connection limit reached (60 minutes). Create a new websocket connection to continue." } }) .to_string(); let wrapped_error = parse_wrapped_websocket_error_event(&payload) .expect("expected websocket error payload to be parsed"); let api_error = map_wrapped_websocket_error_event(wrapped_error, payload) .expect("expected websocket error payload to map to ApiError"); let ApiError::Retryable { message, delay } = api_error else { panic!("expected ApiError::Retryable"); }; assert_eq!(message, WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE); assert_eq!(delay, None); } #[test] fn parse_wrapped_websocket_error_event_without_status_is_not_mapped() { let payload = json!({ "type": "error", "error": { "type": "usage_limit_reached", "message": "The usage limit has been reached" }, "headers": { "x-codex-primary-used-percent": "100.0", "x-codex-primary-window-minutes": 15 } }) .to_string(); let wrapped_error = parse_wrapped_websocket_error_event(&payload) .expect("expected websocket error payload to be parsed"); let api_error = map_wrapped_websocket_error_event(wrapped_error, payload); assert!(api_error.is_none()); } #[test] fn merge_request_headers_matches_http_precedence() { let mut provider_headers = HeaderMap::new(); provider_headers.insert( "originator", HeaderValue::from_static("provider-originator"), ); provider_headers.insert("x-priority", HeaderValue::from_static("provider")); let mut extra_headers = HeaderMap::new(); extra_headers.insert("x-priority", HeaderValue::from_static("extra")); let mut default_headers = HeaderMap::new(); default_headers.insert("originator", HeaderValue::from_static("default-originator")); default_headers.insert("x-priority", HeaderValue::from_static("default")); default_headers.insert("x-default-only", HeaderValue::from_static("default-only")); let merged = merge_request_headers(&provider_headers, extra_headers, default_headers); assert_eq!( merged.get("originator"), Some(&HeaderValue::from_static("provider-originator")) ); assert_eq!( merged.get("x-priority"), Some(&HeaderValue::from_static("extra")) ); assert_eq!( merged.get("x-default-only"), Some(&HeaderValue::from_static("default-only")) ); } }