diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 2a7d8726f1..168a6d3f17 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -4,6 +4,7 @@ use crate::common::ResponseStream; use crate::common::ResponsesWsRequest; use crate::error::ApiError; use crate::provider::Provider; +use crate::rate_limits::parse_rate_limit; use crate::sse::responses::ResponsesStreamEvent; use crate::sse::responses::process_responses_event; use crate::telemetry::WebsocketTelemetry; @@ -11,6 +12,7 @@ use codex_client::TransportError; use futures::SinkExt; use futures::StreamExt; use http::HeaderMap; +use http::HeaderName; use http::HeaderValue; use serde_json::Value; use std::sync::Arc; @@ -41,6 +43,7 @@ pub struct ResponsesWebsocketConnection { idle_timeout: Duration, server_reasoning_included: bool, telemetry: Option>, + turn_state: Option>>, } impl ResponsesWebsocketConnection { @@ -49,12 +52,14 @@ impl ResponsesWebsocketConnection { idle_timeout: Duration, server_reasoning_included: bool, telemetry: Option>, + turn_state: Option>>, ) -> Self { Self { stream: Arc::new(Mutex::new(Some(stream))), idle_timeout, server_reasoning_included, telemetry, + turn_state, } } @@ -72,6 +77,7 @@ impl ResponsesWebsocketConnection { let idle_timeout = self.idle_timeout; let server_reasoning_included = self.server_reasoning_included; let telemetry = self.telemetry.clone(); + let turn_state = self.turn_state.clone(); let request_body = serde_json::to_value(&request).map_err(|err| { ApiError::Stream(format!("failed to encode websocket request: {err}")) })?; @@ -98,6 +104,7 @@ impl ResponsesWebsocketConnection { request_body, idle_timeout, telemetry, + turn_state, ) .await { @@ -137,12 +144,13 @@ impl ResponsesWebsocketClient { apply_auth_headers(&mut headers, &self.auth); let (stream, server_reasoning_included) = - connect_websocket(ws_url, headers, turn_state).await?; + connect_websocket(ws_url, headers, turn_state.clone()).await?; Ok(ResponsesWebsocketConnection::new( stream, self.provider.stream_idle_timeout, server_reasoning_included, telemetry, + turn_state, )) } } @@ -226,12 +234,31 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError { } } +fn headers_from_value(raw: &Value) -> Option { + let obj = raw.as_object()?; + let mut headers = HeaderMap::new(); + for (name, value) in obj { + let Some(value_str) = value.as_str() else { + continue; + }; + let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else { + continue; + }; + let Ok(header_value) = HeaderValue::from_str(value_str) else { + continue; + }; + headers.insert(header_name, header_value); + } + Some(headers) +} + async fn run_websocket_response_stream( ws_stream: &mut WsStream, tx_event: mpsc::Sender>, request_body: Value, idle_timeout: Duration, telemetry: Option>, + turn_state: Option>>, ) -> Result<(), ApiError> { let request_text = match serde_json::to_string(&request_body) { Ok(text) => text, @@ -287,6 +314,36 @@ async fn run_websocket_response_stream( continue; } }; + if event.kind() == "codex.metadata" { + if let Some(raw_headers) = event.headers() + && let Some(headers) = headers_from_value(raw_headers) + { + if let Some(turn_state) = turn_state.as_ref() + && let Some(header_value) = headers + .get(X_CODEX_TURN_STATE_HEADER) + .and_then(|value| value.to_str().ok()) + { + let _ = turn_state.set(header_value.to_string()); + } + if let Some(snapshot) = parse_rate_limit(&headers) { + let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; + } + if let Some(etag) = headers + .get("X-Models-Etag") + .and_then(|value| value.to_str().ok()) + { + let _ = tx_event + .send(Ok(ResponseEvent::ModelsEtag(etag.to_string()))) + .await; + } + if headers.contains_key(X_REASONING_INCLUDED_HEADER) { + let _ = tx_event + .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) + .await; + } + } + continue; + } match process_responses_event(event) { Ok(Some(event)) => { let is_completed = matches!(event, ResponseEvent::Completed { .. }); diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index b363671a11..25fdc94c28 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -163,6 +163,17 @@ pub struct ResponsesStreamEvent { delta: Option, summary_index: Option, content_index: Option, + headers: Option, +} + +impl ResponsesStreamEvent { + pub fn kind(&self) -> &str { + &self.kind + } + + pub fn headers(&self) -> Option<&Value> { + self.headers.as_ref() + } } #[derive(Debug)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 29b02c006a..ac25be6457 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -81,6 +81,7 @@ use mcp_types::ReadResourceResult; use mcp_types::RequestId; use serde_json; use serde_json::Value; +use serde_json::json; use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::sync::oneshot; @@ -4356,12 +4357,26 @@ async fn try_run_sampling_request( .await; } ResponseEvent::Completed { - response_id: _, + response_id, token_usage, } => { if let Some(state) = plan_mode_state.as_mut() { flush_proposed_plan_segments_all(&sess, &turn_context, state).await; } + if let Some(usage) = token_usage.as_ref() { + info!( + target: "codex_core::stream_events_utils", + "TokenUsage: {}", + json!({ + "response_id": response_id, + "input_tokens": usage.input_tokens, + "cached_input_tokens": usage.cached_input_tokens, + "output_tokens": usage.output_tokens, + "reasoning_output_tokens": usage.reasoning_output_tokens, + "total_tokens": usage.total_tokens, + }) + ); + } sess.update_token_usage_info(&turn_context, token_usage.as_ref()) .await; should_emit_turn_diff = true; diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index fc0ff37cf4..48c2ed94ca 100644 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -29,6 +29,7 @@ use core_test_support::skip_if_no_network; use futures::StreamExt; use opentelemetry_sdk::metrics::InMemoryMetricExporter; use pretty_assertions::assert_eq; +use serde_json::json; use std::sync::Arc; use std::time::Duration; use tempfile::TempDir; @@ -136,6 +137,72 @@ async fn responses_websocket_emits_reasoning_included_event() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_emits_metadata_events() { + skip_if_no_network!(); + + let metadata_headers = json!({ + "X-Codex-Primary-Used-Percent": "42", + "X-Codex-Primary-Window-Minutes": "60", + "X-Codex-Primary-Reset-At": "1700000000", + "X-Codex-Credits-Has-Credits": "true", + "X-Codex-Credits-Unlimited": "false", + "X-Codex-Credits-Balance": "123", + "X-Models-Etag": "etag-123", + "X-Reasoning-Included": "true", + }); + + let server = start_websocket_server(vec![vec![vec![ + json!({"type": "codex.metadata", "headers": metadata_headers}), + ev_response_created("resp-1"), + ev_completed("resp-1"), + ]]]) + .await; + + let harness = websocket_harness(&server).await; + let mut session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); + + let mut stream = session + .stream(&prompt) + .await + .expect("websocket stream failed"); + + let mut saw_rate_limits = None; + let mut saw_models_etag = None; + let mut saw_reasoning_included = false; + + while let Some(event) = stream.next().await { + match event.expect("event") { + ResponseEvent::RateLimits(snapshot) => { + saw_rate_limits = Some(snapshot); + } + ResponseEvent::ModelsEtag(etag) => { + saw_models_etag = Some(etag); + } + ResponseEvent::ServerReasoningIncluded(true) => { + saw_reasoning_included = true; + } + ResponseEvent::Completed { .. } => break, + _ => {} + } + } + + let rate_limits = saw_rate_limits.expect("missing rate limits"); + let primary = rate_limits.primary.expect("missing primary window"); + assert_eq!(primary.used_percent, 42.0); + assert_eq!(primary.window_minutes, Some(60)); + assert_eq!(primary.resets_at, Some(1_700_000_000)); + let credits = rate_limits.credits.expect("missing credits"); + assert!(credits.has_credits); + assert!(!credits.unlimited); + assert_eq!(credits.balance.as_deref(), Some("123")); + assert_eq!(saw_models_etag.as_deref(), Some("etag-123")); + assert!(saw_reasoning_included); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_appends_on_prefix() { skip_if_no_network!();