Add a codex.rate_limits event for websockets (#10324)

When communicating over websockets, we can't rely on headers to deliver
rate limit information. This PR adds a `codex.rate_limits` event that
the server can pass to the client to inform them about rate limit usage.
The client parses this data the same way we parse rate limit headers in
HTTP mode.

This PR also wires up the etag and reasoning headers for websockets
This commit is contained in:
Rasmus Rygaard
2026-02-04 06:01:47 -08:00
committed by GitHub
parent aab60a55f1
commit df000da917
4 changed files with 183 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::rate_limits::parse_rate_limit_event;
use crate::sse::responses::ResponsesStreamEvent;
use crate::sse::responses::process_responses_event;
use crate::telemetry::WebsocketTelemetry;
@@ -33,6 +34,7 @@ use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
pub struct ResponsesWebsocketConnection {
@@ -40,6 +42,7 @@ pub struct ResponsesWebsocketConnection {
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
server_reasoning_included: bool,
models_etag: Option<String>,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
}
@@ -48,12 +51,14 @@ impl ResponsesWebsocketConnection {
stream: WsStream,
idle_timeout: Duration,
server_reasoning_included: bool,
models_etag: Option<String>,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
server_reasoning_included,
models_etag,
telemetry,
}
}
@@ -71,12 +76,16 @@ impl ResponsesWebsocketConnection {
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let server_reasoning_included = self.server_reasoning_included;
let models_etag = self.models_etag.clone();
let telemetry = self.telemetry.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
tokio::spawn(async move {
if let Some(etag) = models_etag {
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
}
if server_reasoning_included {
let _ = tx_event
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
@@ -136,12 +145,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
headers.extend(extra_headers);
add_auth_headers_to_header_map(&self.auth, &mut headers);
let (stream, server_reasoning_included) =
connect_websocket(ws_url, headers, turn_state).await?;
let (stream, server_reasoning_included, models_etag) =
connect_websocket(ws_url, headers, turn_state.clone()).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
server_reasoning_included,
models_etag,
telemetry,
))
}
@@ -151,7 +161,7 @@ async fn connect_websocket(
url: Url,
headers: HeaderMap,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<(WsStream, bool), ApiError> {
) -> Result<(WsStream, bool, Option<String>), ApiError> {
info!("connecting to websocket: {url}");
let mut request = url
@@ -177,6 +187,11 @@ async fn connect_websocket(
};
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
let models_etag = response
.headers()
.get(X_MODELS_ETAG_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
if let Some(turn_state) = turn_state
&& let Some(header_value) = response
.headers()
@@ -185,7 +200,7 @@ async fn connect_websocket(
{
let _ = turn_state.set(header_value.to_string());
}
Ok((stream, reasoning_included))
Ok((stream, reasoning_included, models_etag))
}
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
@@ -273,6 +288,12 @@ async fn run_websocket_response_stream(
continue;
}
};
if event.kind() == "codex.rate_limits" {
if let Some(snapshot) = parse_rate_limit_event(&text) {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
continue;
}
match process_responses_event(event) {
Ok(Some(event)) => {
let is_completed = matches!(event, ResponseEvent::Completed { .. });