Support alternative websocket API (#10861)

**Test plan**

```
cargo build -p codex-cli && RUST_LOG='codex_api::endpoint::responses_websocket=trace,codex_core::client=debug,codex_core::codex=debug' \
  ./target/debug/codex \
    --enable responses_websockets_v2 \
    --profile byok \
    --full-auto
```
This commit is contained in:
Brian Yu
2026-02-06 14:40:50 -08:00
committed by GitHub
parent ba8b5d9018
commit 1fbf5ed06f
10 changed files with 410 additions and 35 deletions

View File

@@ -172,6 +172,8 @@ pub struct ResponsesApiRequest<'a> {
pub struct ResponseCreateWsRequest {
pub model: String,
pub instructions: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
pub input: Vec<ResponseItem>,
pub tools: Vec<Value>,
pub tool_choice: String,

View File

@@ -239,6 +239,9 @@
"responses_websockets": {
"type": "boolean"
},
"responses_websockets_v2": {
"type": "boolean"
},
"runtime_metrics": {
"type": "boolean"
},
@@ -1279,6 +1282,9 @@
"responses_websockets": {
"type": "boolean"
},
"responses_websockets_v2": {
"type": "boolean"
},
"runtime_metrics": {
"type": "boolean"
},
@@ -1623,4 +1629,4 @@
},
"title": "ConfigToml",
"type": "object"
}
}

View File

@@ -89,6 +89,8 @@ use reqwest::StatusCode;
use serde_json::Value;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
@@ -117,6 +119,7 @@ pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
"x-responsesapi-include-timing-metrics";
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
/// Session-scoped state shared by all [`ModelClient`] clones.
///
@@ -129,6 +132,7 @@ struct ModelClientState {
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_responses_websockets: bool,
enable_responses_websockets_v2: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
@@ -238,6 +242,8 @@ pub struct ModelClientSession {
client: ModelClient,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
websocket_last_response_id: Option<String>,
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
@@ -264,6 +270,7 @@ impl ModelClient {
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_responses_websockets: bool,
enable_responses_websockets_v2: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
@@ -276,6 +283,7 @@ impl ModelClient {
session_source,
model_verbosity,
enable_responses_websockets,
enable_responses_websockets_v2,
enable_request_compression,
include_timing_metrics,
beta_features_header,
@@ -295,6 +303,8 @@ impl ModelClient {
client: self.clone(),
connection: None,
websocket_last_items: Vec::new(),
websocket_last_response_id: None,
websocket_last_response_id_rx: None,
turn_state: Arc::new(OnceLock::new()),
}
}
@@ -479,6 +489,10 @@ impl ModelClient {
self.state.provider.supports_websockets && self.state.enable_responses_websockets
}
fn responses_websockets_v2_enabled(&self) -> bool {
self.state.enable_responses_websockets_v2
}
/// Returns whether websocket transport has been permanently disabled for this session.
///
/// Once set by fallback activation, subsequent turns must stay on HTTP transport.
@@ -544,9 +558,14 @@ impl ModelClient {
headers.extend(build_conversation_headers(Some(
self.state.conversation_id.to_string(),
)));
let responses_websockets_beta_header = if self.responses_websockets_v2_enabled() {
RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE
} else {
OPENAI_BETA_RESPONSES_WEBSOCKETS
};
headers.insert(
OPENAI_BETA_HEADER,
HeaderValue::from_static(OPENAI_BETA_RESPONSES_WEBSOCKETS),
HeaderValue::from_static(responses_websockets_beta_header),
);
if self.state.include_timing_metrics {
headers.insert(
@@ -789,18 +808,37 @@ impl ModelClientSession {
}
}
fn prepare_websocket_request(
fn refresh_websocket_last_response_id(&mut self) {
if let Some(mut receiver) = self.websocket_last_response_id_rx.take() {
match receiver.try_recv() {
Ok(response_id) if !response_id.is_empty() => {
self.websocket_last_response_id = Some(response_id);
}
Ok(_) | Err(TryRecvError::Closed) => {
self.websocket_last_response_id = None;
}
Err(TryRecvError::Empty) => {
self.websocket_last_response_id_rx = Some(receiver);
}
}
}
}
fn websocket_previous_response_id(&mut self) -> Option<String> {
self.refresh_websocket_last_response_id();
self.websocket_last_response_id
.clone()
.filter(|id| !id.is_empty())
}
fn prepare_websocket_create_request(
&self,
model_slug: &str,
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
input: Vec<ResponseItem>,
previous_response_id: Option<String>,
) -> ResponsesWsRequest {
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
let ApiResponsesOptions {
reasoning,
include,
@@ -814,7 +852,8 @@ impl ModelClientSession {
let payload = ResponseCreateWsRequest {
model: model_slug.to_string(),
instructions: api_prompt.instructions.clone(),
input: api_prompt.input.clone(),
previous_response_id,
input,
tools: api_prompt.tools.clone(),
tool_choice: "auto".to_string(),
parallel_tool_calls: api_prompt.parallel_tool_calls,
@@ -829,6 +868,43 @@ impl ModelClientSession {
ResponsesWsRequest::ResponseCreate(payload)
}
fn prepare_websocket_request(
&mut self,
model_slug: &str,
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
let incremental_items = self.get_incremental_items(&api_prompt.input);
if let Some(append_items) = incremental_items {
if responses_websockets_v2_enabled
&& let Some(previous_response_id) = self.websocket_previous_response_id()
{
return self.prepare_websocket_create_request(
model_slug,
api_prompt,
options,
append_items,
Some(previous_response_id),
);
}
if !responses_websockets_v2_enabled {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
}
self.prepare_websocket_create_request(
model_slug,
api_prompt,
options,
api_prompt.input.clone(),
None,
)
}
/// Returns a websocket connection for this turn, reusing preconnect when possible.
///
/// This method first tries to adopt the session-level preconnect slot, then falls back to a
@@ -863,6 +939,9 @@ impl ModelClientSession {
if needs_new {
self.client.clear_preconnect();
self.websocket_last_items.clear();
self.websocket_last_response_id = None;
self.websocket_last_response_id_rx = None;
let turn_state = options
.turn_state
.clone()
@@ -1023,9 +1102,8 @@ impl ModelClientSession {
turn_metadata_header,
compression,
);
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);
let connection = match self
match self
.websocket_connection(
otel_manager,
client_setup.api_provider,
@@ -1035,7 +1113,7 @@ impl ModelClientSession {
)
.await
{
Ok(connection) => connection,
Ok(_) => {}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
@@ -1043,13 +1121,33 @@ impl ModelClientSession {
continue;
}
Err(err) => return Err(map_api_error(err)),
};
}
let stream_result = connection
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);
let stream_result = self
.connection
.as_ref()
.ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?
.stream_request(request)
.await
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
let mut last_response_id_sender = Some(last_response_id_sender);
let stream_result = stream_result.inspect(move |event| {
if let Ok(ResponseEvent::Completed { response_id, .. }) = event
&& !response_id.is_empty()
&& let Some(sender) = last_response_id_sender.take()
{
let _ = sender.send(response_id.clone());
}
});
return Ok(map_response_stream(stream_result, otel_manager.clone()));
}

View File

@@ -1060,7 +1060,9 @@ impl Session {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
@@ -5871,7 +5873,9 @@ mod tests {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
@@ -6001,7 +6005,9 @@ mod tests {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::ResponsesWebsockets)
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::ResponsesWebsocketsV2),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),

View File

@@ -2685,25 +2685,27 @@ profile = "project"
}
#[test]
fn responses_websockets_feature_does_not_change_wire_api() -> std::io::Result<()> {
let codex_home = TempDir::new()?;
let mut entries = BTreeMap::new();
entries.insert("responses_websockets".to_string(), true);
let cfg = ConfigToml {
features: Some(crate::features::FeaturesToml { entries }),
..Default::default()
};
fn responses_websocket_features_do_not_change_wire_api() -> std::io::Result<()> {
for feature_key in ["responses_websockets", "responses_websockets_v2"] {
let codex_home = TempDir::new()?;
let mut entries = BTreeMap::new();
entries.insert(feature_key.to_string(), true);
let cfg = ConfigToml {
features: Some(crate::features::FeaturesToml { entries }),
..Default::default()
};
let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;
let config = Config::load_from_base_config_with_overrides(
cfg,
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)?;
assert_eq!(
config.model_provider.wire_api,
crate::model_provider_info::WireApi::Responses
);
assert_eq!(
config.model_provider.wire_api,
crate::model_provider_info::WireApi::Responses
);
}
Ok(())
}

View File

@@ -127,6 +127,8 @@ pub enum Feature {
Personality,
/// Use the Responses API WebSocket transport for OpenAI by default.
ResponsesWebsockets,
/// Enable Responses API websocket v2 mode.
ResponsesWebsocketsV2,
}
impl Feature {
@@ -569,6 +571,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::ResponsesWebsocketsV2,
key: "responses_websockets_v2",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
];
/// Push a warning event if any under-development features are enabled.

View File

@@ -94,6 +94,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
false,
false,
false,
false,
None,
);
let mut client_session = client.new_session();
@@ -196,6 +197,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
false,
false,
false,
false,
None,
);
let mut client_session = client.new_session();
@@ -297,6 +299,7 @@ async fn responses_respects_model_info_overrides_from_config() {
false,
false,
false,
false,
None,
);
let mut client_session = client.new_session();

View File

@@ -1,4 +1,5 @@
use anyhow::Result;
use codex_core::features::Feature;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
@@ -13,6 +14,8 @@ use pretty_assertions::assert_eq;
use serde_json::Value;
use std::time::Duration;
const WS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_test_codex_shell_chain() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -120,3 +123,70 @@ async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> {
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
skip_if_no_network!(Ok(()));
let call_id = "shell-command-call";
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-1"),
ev_shell_command_call(call_id, "echo websocket"),
ev_completed("resp-1"),
],
vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
],
]])
.await;
let mut builder = test_codex().with_config(|config| {
config.features.enable(Feature::ResponsesWebsocketsV2);
});
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("run the echo command").await?;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let first = connection
.first()
.expect("missing first request")
.body_json();
let second = connection
.get(1)
.expect("missing second request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
let create_items = second
.get("input")
.and_then(Value::as_array)
.expect("response.create input array");
assert!(!create_items.is_empty());
let output_item = create_items
.iter()
.find(|item| item.get("type").and_then(Value::as_str) == Some("function_call_output"))
.expect("function_call_output in create");
assert_eq!(
output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let handshake = server.single_handshake();
assert_eq!(
handshake.header("openai-beta"),
Some(WS_V2_BETA_HEADER_VALUE.to_string())
);
server.shutdown().await;
Ok(())
}

View File

@@ -1273,6 +1273,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
false,
false,
false,
false,
None,
);
let mut client_session = client.new_session();

View File

@@ -42,6 +42,7 @@ use tracing_test::traced_test;
const MODEL: &str = "gpt-5.2-codex";
const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
const OPENAI_BETA_RESPONSES_WEBSOCKETS: &str = "responses_websockets=2026-02-04";
const WS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
struct WebsocketTestHarness {
_codex_home: TempDir,
@@ -456,6 +457,165 @@ async fn responses_websocket_creates_on_non_prefix() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness_with_v2(&server, true).await;
let mut session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("hello"), message_item("second")]);
stream_until_complete(&mut session, &harness, &prompt_one).await;
stream_until_complete(&mut session, &harness, &prompt_two).await;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let first = connection.first().expect("missing request").body_json();
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
assert_eq!(
second["input"],
serde_json::to_value(&prompt_two.input[1..]).unwrap()
);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_after_error_uses_full_create_without_previous_response_id() {
skip_if_no_network!();
let server = start_websocket_server(vec![
vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![json!({
"type": "response.failed",
"response": {
"error": {
"code": "invalid_prompt",
"message": "synthetic websocket failure"
}
}
})],
],
vec![vec![ev_response_created("resp-3"), ev_completed("resp-3")]],
])
.await;
let harness = websocket_harness_with_v2(&server, true).await;
let mut session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("hello"), message_item("second")]);
let prompt_three = prompt_with_input(vec![
message_item("hello"),
message_item("second"),
message_item("third"),
]);
stream_until_complete(&mut session, &harness, &prompt_one).await;
let mut second_stream = session
.stream(
&prompt_two,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket stream failed");
let mut saw_error = false;
while let Some(event) = second_stream.next().await {
if event.is_err() {
saw_error = true;
break;
}
}
assert!(saw_error, "expected second websocket stream to error");
stream_until_complete(&mut session, &harness, &prompt_three).await;
assert_eq!(server.handshakes().len(), 2);
let connections = server.connections();
assert_eq!(connections.len(), 2);
let first_connection = connections.first().expect("missing first connection");
assert_eq!(first_connection.len(), 2);
let first = first_connection
.first()
.expect("missing first request")
.body_json();
let second = first_connection
.get(1)
.expect("missing second request")
.body_json();
let third = connections
.get(1)
.and_then(|connection| connection.first())
.expect("missing third request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
assert_eq!(third["type"].as_str(), Some("response.create"));
assert_eq!(third.get("previous_response_id"), None);
assert_eq!(
third["input"],
serde_json::to_value(&prompt_three.input).unwrap()
);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_sets_openai_beta_header() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let harness = websocket_harness_with_v2(&server, true).await;
let mut session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut session, &harness, &prompt).await;
let handshake = server.single_handshake();
let openai_beta_header = handshake
.header(OPENAI_BETA_HEADER)
.expect("missing OpenAI-Beta header");
assert!(
openai_beta_header
.split(',')
.map(str::trim)
.any(|value| value == WS_V2_BETA_HEADER_VALUE)
);
assert!(
!openai_beta_header
.split(',')
.map(str::trim)
.any(|value| value == OPENAI_BETA_RESPONSES_WEBSOCKETS)
);
server.shutdown().await;
}
fn message_item(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
@@ -498,6 +658,21 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness
async fn websocket_harness_with_runtime_metrics(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
) -> WebsocketTestHarness {
websocket_harness_with_options(server, runtime_metrics_enabled, false).await
}
async fn websocket_harness_with_v2(
server: &WebSocketTestServer,
websocket_v2_enabled: bool,
) -> WebsocketTestHarness {
websocket_harness_with_options(server, false, websocket_v2_enabled).await
}
async fn websocket_harness_with_options(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
websocket_v2_enabled: bool,
) -> WebsocketTestHarness {
let provider = websocket_provider(server);
let codex_home = TempDir::new().unwrap();
@@ -507,6 +682,9 @@ async fn websocket_harness_with_runtime_metrics(
if runtime_metrics_enabled {
config.features.enable(Feature::RuntimeMetrics);
}
if websocket_v2_enabled {
config.features.enable(Feature::ResponsesWebsocketsV2);
}
let config = Arc::new(config);
let model_info = ModelsManager::construct_model_info_offline(MODEL, &config);
let conversation_id = ThreadId::new();
@@ -538,6 +716,7 @@ async fn websocket_harness_with_runtime_metrics(
SessionSource::Exec,
config.model_verbosity,
true,
websocket_v2_enabled,
false,
runtime_metrics_enabled,
None,