mirror of
https://github.com/openai/codex.git
synced 2026-02-01 14:44:17 +00:00
add a codex.metadata event for websockets
This commit is contained in:
@@ -4,6 +4,7 @@ use crate::common::ResponseStream;
|
||||
use crate::common::ResponsesWsRequest;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::rate_limits::parse_rate_limit;
|
||||
use crate::sse::responses::ResponsesStreamEvent;
|
||||
use crate::sse::responses::process_responses_event;
|
||||
use crate::telemetry::WebsocketTelemetry;
|
||||
@@ -11,6 +12,7 @@ use codex_client::TransportError;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderName;
|
||||
use http::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
@@ -41,6 +43,7 @@ pub struct ResponsesWebsocketConnection {
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
}
|
||||
|
||||
impl ResponsesWebsocketConnection {
|
||||
@@ -49,12 +52,14 @@ impl ResponsesWebsocketConnection {
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
server_reasoning_included,
|
||||
telemetry,
|
||||
turn_state,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +77,7 @@ impl ResponsesWebsocketConnection {
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let server_reasoning_included = self.server_reasoning_included;
|
||||
let telemetry = self.telemetry.clone();
|
||||
let turn_state = self.turn_state.clone();
|
||||
let request_body = serde_json::to_value(&request).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
||||
})?;
|
||||
@@ -98,6 +104,7 @@ impl ResponsesWebsocketConnection {
|
||||
request_body,
|
||||
idle_timeout,
|
||||
telemetry,
|
||||
turn_state,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -137,12 +144,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let (stream, server_reasoning_included) =
|
||||
connect_websocket(ws_url, headers, turn_state).await?;
|
||||
connect_websocket(ws_url, headers, turn_state.clone()).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
server_reasoning_included,
|
||||
telemetry,
|
||||
turn_state,
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -226,12 +234,31 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
fn headers_from_value(raw: &Value) -> Option<HeaderMap> {
|
||||
let obj = raw.as_object()?;
|
||||
let mut headers = HeaderMap::new();
|
||||
for (name, value) in obj {
|
||||
let Some(value_str) = value.as_str() else {
|
||||
continue;
|
||||
};
|
||||
let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
|
||||
continue;
|
||||
};
|
||||
let Ok(header_value) = HeaderValue::from_str(value_str) else {
|
||||
continue;
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
Some(headers)
|
||||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
ws_stream: &mut WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
request_body: Value,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<(), ApiError> {
|
||||
let request_text = match serde_json::to_string(&request_body) {
|
||||
Ok(text) => text,
|
||||
@@ -287,6 +314,36 @@ async fn run_websocket_response_stream(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if event.kind() == "codex.metadata" {
|
||||
if let Some(raw_headers) = event.headers()
|
||||
&& let Some(headers) = headers_from_value(raw_headers)
|
||||
{
|
||||
if let Some(turn_state) = turn_state.as_ref()
|
||||
&& let Some(header_value) = headers
|
||||
.get(X_CODEX_TURN_STATE_HEADER)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
if let Some(snapshot) = parse_rate_limit(&headers) {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
if let Some(etag) = headers
|
||||
.get("X-Models-Etag")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
{
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ModelsEtag(etag.to_string())))
|
||||
.await;
|
||||
}
|
||||
if headers.contains_key(X_REASONING_INCLUDED_HEADER) {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
||||
|
||||
@@ -163,6 +163,17 @@ pub struct ResponsesStreamEvent {
|
||||
delta: Option<String>,
|
||||
summary_index: Option<i64>,
|
||||
content_index: Option<i64>,
|
||||
headers: Option<Value>,
|
||||
}
|
||||
|
||||
impl ResponsesStreamEvent {
|
||||
pub fn kind(&self) -> &str {
|
||||
&self.kind
|
||||
}
|
||||
|
||||
pub fn headers(&self) -> Option<&Value> {
|
||||
self.headers.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -81,6 +81,7 @@ use mcp_types::ReadResourceResult;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -4356,12 +4357,26 @@ async fn try_run_sampling_request(
|
||||
.await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
response_id,
|
||||
token_usage,
|
||||
} => {
|
||||
if let Some(state) = plan_mode_state.as_mut() {
|
||||
flush_proposed_plan_segments_all(&sess, &turn_context, state).await;
|
||||
}
|
||||
if let Some(usage) = token_usage.as_ref() {
|
||||
info!(
|
||||
target: "codex_core::stream_events_utils",
|
||||
"TokenUsage: {}",
|
||||
json!({
|
||||
"response_id": response_id,
|
||||
"input_tokens": usage.input_tokens,
|
||||
"cached_input_tokens": usage.cached_input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"reasoning_output_tokens": usage.reasoning_output_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
})
|
||||
);
|
||||
}
|
||||
sess.update_token_usage_info(&turn_context, token_usage.as_ref())
|
||||
.await;
|
||||
should_emit_turn_diff = true;
|
||||
|
||||
@@ -29,6 +29,7 @@ use core_test_support::skip_if_no_network;
|
||||
use futures::StreamExt;
|
||||
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tempfile::TempDir;
|
||||
@@ -136,6 +137,72 @@ async fn responses_websocket_emits_reasoning_included_event() {
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_emits_metadata_events() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let metadata_headers = json!({
|
||||
"X-Codex-Primary-Used-Percent": "42",
|
||||
"X-Codex-Primary-Window-Minutes": "60",
|
||||
"X-Codex-Primary-Reset-At": "1700000000",
|
||||
"X-Codex-Credits-Has-Credits": "true",
|
||||
"X-Codex-Credits-Unlimited": "false",
|
||||
"X-Codex-Credits-Balance": "123",
|
||||
"X-Models-Etag": "etag-123",
|
||||
"X-Reasoning-Included": "true",
|
||||
});
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![
|
||||
json!({"type": "codex.metadata", "headers": metadata_headers}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_completed("resp-1"),
|
||||
]]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
|
||||
let mut stream = session
|
||||
.stream(&prompt)
|
||||
.await
|
||||
.expect("websocket stream failed");
|
||||
|
||||
let mut saw_rate_limits = None;
|
||||
let mut saw_models_etag = None;
|
||||
let mut saw_reasoning_included = false;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event.expect("event") {
|
||||
ResponseEvent::RateLimits(snapshot) => {
|
||||
saw_rate_limits = Some(snapshot);
|
||||
}
|
||||
ResponseEvent::ModelsEtag(etag) => {
|
||||
saw_models_etag = Some(etag);
|
||||
}
|
||||
ResponseEvent::ServerReasoningIncluded(true) => {
|
||||
saw_reasoning_included = true;
|
||||
}
|
||||
ResponseEvent::Completed { .. } => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let rate_limits = saw_rate_limits.expect("missing rate limits");
|
||||
let primary = rate_limits.primary.expect("missing primary window");
|
||||
assert_eq!(primary.used_percent, 42.0);
|
||||
assert_eq!(primary.window_minutes, Some(60));
|
||||
assert_eq!(primary.resets_at, Some(1_700_000_000));
|
||||
let credits = rate_limits.credits.expect("missing credits");
|
||||
assert!(credits.has_credits);
|
||||
assert!(!credits.unlimited);
|
||||
assert_eq!(credits.balance.as_deref(), Some("123"));
|
||||
assert_eq!(saw_models_etag.as_deref(), Some("etag-123"));
|
||||
assert!(saw_reasoning_included);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_appends_on_prefix() {
|
||||
skip_if_no_network!();
|
||||
|
||||
Reference in New Issue
Block a user