add a codex.metadata event for websockets

This commit is contained in:
Rasmus Rygaard
2026-01-31 20:49:44 -08:00
parent dfba95309f
commit 9ba03c0c8f
4 changed files with 152 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::rate_limits::parse_rate_limit;
use crate::sse::responses::ResponsesStreamEvent;
use crate::sse::responses::process_responses_event;
use crate::telemetry::WebsocketTelemetry;
@@ -11,6 +12,7 @@ use codex_client::TransportError;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderName;
use http::HeaderValue;
use serde_json::Value;
use std::sync::Arc;
@@ -41,6 +43,7 @@ pub struct ResponsesWebsocketConnection {
idle_timeout: Duration,
server_reasoning_included: bool,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
turn_state: Option<Arc<OnceLock<String>>>,
}
impl ResponsesWebsocketConnection {
@@ -49,12 +52,14 @@ impl ResponsesWebsocketConnection {
idle_timeout: Duration,
server_reasoning_included: bool,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
server_reasoning_included,
telemetry,
turn_state,
}
}
@@ -72,6 +77,7 @@ impl ResponsesWebsocketConnection {
let idle_timeout = self.idle_timeout;
let server_reasoning_included = self.server_reasoning_included;
let telemetry = self.telemetry.clone();
let turn_state = self.turn_state.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
@@ -98,6 +104,7 @@ impl ResponsesWebsocketConnection {
request_body,
idle_timeout,
telemetry,
turn_state,
)
.await
{
@@ -137,12 +144,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
apply_auth_headers(&mut headers, &self.auth);
let (stream, server_reasoning_included) =
connect_websocket(ws_url, headers, turn_state).await?;
connect_websocket(ws_url, headers, turn_state.clone()).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
server_reasoning_included,
telemetry,
turn_state,
))
}
}
@@ -226,12 +234,31 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
}
}
fn headers_from_value(raw: &Value) -> Option<HeaderMap> {
let obj = raw.as_object()?;
let mut headers = HeaderMap::new();
for (name, value) in obj {
let Some(value_str) = value.as_str() else {
continue;
};
let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
continue;
};
let Ok(header_value) = HeaderValue::from_str(value_str) else {
continue;
};
headers.insert(header_name, header_value);
}
Some(headers)
}
async fn run_websocket_response_stream(
ws_stream: &mut WsStream,
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
request_body: Value,
idle_timeout: Duration,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<(), ApiError> {
let request_text = match serde_json::to_string(&request_body) {
Ok(text) => text,
@@ -287,6 +314,36 @@ async fn run_websocket_response_stream(
continue;
}
};
if event.kind() == "codex.metadata" {
if let Some(raw_headers) = event.headers()
&& let Some(headers) = headers_from_value(raw_headers)
{
if let Some(turn_state) = turn_state.as_ref()
&& let Some(header_value) = headers
.get(X_CODEX_TURN_STATE_HEADER)
.and_then(|value| value.to_str().ok())
{
let _ = turn_state.set(header_value.to_string());
}
if let Some(snapshot) = parse_rate_limit(&headers) {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
if let Some(etag) = headers
.get("X-Models-Etag")
.and_then(|value| value.to_str().ok())
{
let _ = tx_event
.send(Ok(ResponseEvent::ModelsEtag(etag.to_string())))
.await;
}
if headers.contains_key(X_REASONING_INCLUDED_HEADER) {
let _ = tx_event
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
.await;
}
}
continue;
}
match process_responses_event(event) {
Ok(Some(event)) => {
let is_completed = matches!(event, ResponseEvent::Completed { .. });

View File

@@ -163,6 +163,17 @@ pub struct ResponsesStreamEvent {
delta: Option<String>,
summary_index: Option<i64>,
content_index: Option<i64>,
headers: Option<Value>,
}
impl ResponsesStreamEvent {
pub fn kind(&self) -> &str {
&self.kind
}
pub fn headers(&self) -> Option<&Value> {
self.headers.as_ref()
}
}
#[derive(Debug)]

View File

@@ -81,6 +81,7 @@ use mcp_types::ReadResourceResult;
use mcp_types::RequestId;
use serde_json;
use serde_json::Value;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
@@ -4356,12 +4357,26 @@ async fn try_run_sampling_request(
.await;
}
ResponseEvent::Completed {
response_id: _,
response_id,
token_usage,
} => {
if let Some(state) = plan_mode_state.as_mut() {
flush_proposed_plan_segments_all(&sess, &turn_context, state).await;
}
if let Some(usage) = token_usage.as_ref() {
info!(
target: "codex_core::stream_events_utils",
"TokenUsage: {}",
json!({
"response_id": response_id,
"input_tokens": usage.input_tokens,
"cached_input_tokens": usage.cached_input_tokens,
"output_tokens": usage.output_tokens,
"reasoning_output_tokens": usage.reasoning_output_tokens,
"total_tokens": usage.total_tokens,
})
);
}
sess.update_token_usage_info(&turn_context, token_usage.as_ref())
.await;
should_emit_turn_diff = true;

View File

@@ -29,6 +29,7 @@ use core_test_support::skip_if_no_network;
use futures::StreamExt;
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
@@ -136,6 +137,72 @@ async fn responses_websocket_emits_reasoning_included_event() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_emits_metadata_events() {
skip_if_no_network!();
let metadata_headers = json!({
"X-Codex-Primary-Used-Percent": "42",
"X-Codex-Primary-Window-Minutes": "60",
"X-Codex-Primary-Reset-At": "1700000000",
"X-Codex-Credits-Has-Credits": "true",
"X-Codex-Credits-Unlimited": "false",
"X-Codex-Credits-Balance": "123",
"X-Models-Etag": "etag-123",
"X-Reasoning-Included": "true",
});
let server = start_websocket_server(vec![vec![vec![
json!({"type": "codex.metadata", "headers": metadata_headers}),
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
let mut stream = session
.stream(&prompt)
.await
.expect("websocket stream failed");
let mut saw_rate_limits = None;
let mut saw_models_etag = None;
let mut saw_reasoning_included = false;
while let Some(event) = stream.next().await {
match event.expect("event") {
ResponseEvent::RateLimits(snapshot) => {
saw_rate_limits = Some(snapshot);
}
ResponseEvent::ModelsEtag(etag) => {
saw_models_etag = Some(etag);
}
ResponseEvent::ServerReasoningIncluded(true) => {
saw_reasoning_included = true;
}
ResponseEvent::Completed { .. } => break,
_ => {}
}
}
let rate_limits = saw_rate_limits.expect("missing rate limits");
let primary = rate_limits.primary.expect("missing primary window");
assert_eq!(primary.used_percent, 42.0);
assert_eq!(primary.window_minutes, Some(60));
assert_eq!(primary.resets_at, Some(1_700_000_000));
let credits = rate_limits.credits.expect("missing credits");
assert!(credits.has_credits);
assert!(!credits.unlimited);
assert_eq!(credits.balance.as_deref(), Some("123"));
assert_eq!(saw_models_etag.as_deref(), Some("etag-123"));
assert!(saw_reasoning_included);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_appends_on_prefix() {
skip_if_no_network!();