mirror of
https://github.com/openai/codex.git
synced 2026-04-26 07:35:29 +00:00
Add a codex.rate_limits event for websockets (#10324)
When communicating over websockets, we can't rely on headers to deliver rate limit information. This PR adds a `codex.rate_limits` event that the server can pass to the client to inform them about rate limit usage. The client parses this data the same way we parse rate limit headers in HTTP mode. This PR also wires up the etag and reasoning headers for websockets
This commit is contained in:
@@ -5,6 +5,7 @@ use crate::common::ResponseStream;
|
||||
use crate::common::ResponsesWsRequest;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::rate_limits::parse_rate_limit_event;
|
||||
use crate::sse::responses::ResponsesStreamEvent;
|
||||
use crate::sse::responses::process_responses_event;
|
||||
use crate::telemetry::WebsocketTelemetry;
|
||||
@@ -33,6 +34,7 @@ use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
@@ -40,6 +42,7 @@ pub struct ResponsesWebsocketConnection {
|
||||
// TODO (pakrym): is this the right place for timeout?
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
models_etag: Option<String>,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
}
|
||||
|
||||
@@ -48,12 +51,14 @@ impl ResponsesWebsocketConnection {
|
||||
stream: WsStream,
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
models_etag: Option<String>,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
server_reasoning_included,
|
||||
models_etag,
|
||||
telemetry,
|
||||
}
|
||||
}
|
||||
@@ -71,12 +76,16 @@ impl ResponsesWebsocketConnection {
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let server_reasoning_included = self.server_reasoning_included;
|
||||
let models_etag = self.models_etag.clone();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let request_body = serde_json::to_value(&request).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(etag) = models_etag {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
|
||||
}
|
||||
if server_reasoning_included {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
||||
@@ -136,12 +145,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
headers.extend(extra_headers);
|
||||
add_auth_headers_to_header_map(&self.auth, &mut headers);
|
||||
|
||||
let (stream, server_reasoning_included) =
|
||||
connect_websocket(ws_url, headers, turn_state).await?;
|
||||
let (stream, server_reasoning_included, models_etag) =
|
||||
connect_websocket(ws_url, headers, turn_state.clone()).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
server_reasoning_included,
|
||||
models_etag,
|
||||
telemetry,
|
||||
))
|
||||
}
|
||||
@@ -151,7 +161,7 @@ async fn connect_websocket(
|
||||
url: Url,
|
||||
headers: HeaderMap,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<(WsStream, bool), ApiError> {
|
||||
) -> Result<(WsStream, bool, Option<String>), ApiError> {
|
||||
info!("connecting to websocket: {url}");
|
||||
|
||||
let mut request = url
|
||||
@@ -177,6 +187,11 @@ async fn connect_websocket(
|
||||
};
|
||||
|
||||
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
|
||||
let models_etag = response
|
||||
.headers()
|
||||
.get(X_MODELS_ETAG_HEADER)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
if let Some(turn_state) = turn_state
|
||||
&& let Some(header_value) = response
|
||||
.headers()
|
||||
@@ -185,7 +200,7 @@ async fn connect_websocket(
|
||||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
Ok((stream, reasoning_included))
|
||||
Ok((stream, reasoning_included, models_etag))
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
@@ -273,6 +288,12 @@ async fn run_websocket_response_stream(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if event.kind() == "codex.rate_limits" {
|
||||
if let Some(snapshot) = parse_rate_limit_event(&text) {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
||||
|
||||
Reference in New Issue
Block a user