mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
Support alternative websocket API (#10861)
**Test plan**
```
cargo build -p codex-cli && RUST_LOG='codex_api::endpoint::responses_websocket=trace,codex_core::client=debug,codex_core::codex=debug' \
./target/debug/codex \
--enable responses_websockets_v2 \
--profile byok \
--full-auto
```
This commit is contained in:
@@ -89,6 +89,8 @@ use reqwest::StatusCode;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::tungstenite::Error;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
@@ -117,6 +119,7 @@ pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
|
||||
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
|
||||
"x-responsesapi-include-timing-metrics";
|
||||
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
|
||||
|
||||
/// Session-scoped state shared by all [`ModelClient`] clones.
|
||||
///
|
||||
@@ -129,6 +132,7 @@ struct ModelClientState {
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_responses_websockets: bool,
|
||||
enable_responses_websockets_v2: bool,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
@@ -238,6 +242,8 @@ pub struct ModelClientSession {
|
||||
client: ModelClient,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
websocket_last_response_id: Option<String>,
|
||||
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
@@ -264,6 +270,7 @@ impl ModelClient {
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_responses_websockets: bool,
|
||||
enable_responses_websockets_v2: bool,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
@@ -276,6 +283,7 @@ impl ModelClient {
|
||||
session_source,
|
||||
model_verbosity,
|
||||
enable_responses_websockets,
|
||||
enable_responses_websockets_v2,
|
||||
enable_request_compression,
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
@@ -295,6 +303,8 @@ impl ModelClient {
|
||||
client: self.clone(),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
websocket_last_response_id: None,
|
||||
websocket_last_response_id_rx: None,
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
@@ -479,6 +489,10 @@ impl ModelClient {
|
||||
self.state.provider.supports_websockets && self.state.enable_responses_websockets
|
||||
}
|
||||
|
||||
fn responses_websockets_v2_enabled(&self) -> bool {
|
||||
self.state.enable_responses_websockets_v2
|
||||
}
|
||||
|
||||
/// Returns whether websocket transport has been permanently disabled for this session.
|
||||
///
|
||||
/// Once set by fallback activation, subsequent turns must stay on HTTP transport.
|
||||
@@ -544,9 +558,14 @@ impl ModelClient {
|
||||
headers.extend(build_conversation_headers(Some(
|
||||
self.state.conversation_id.to_string(),
|
||||
)));
|
||||
let responses_websockets_beta_header = if self.responses_websockets_v2_enabled() {
|
||||
RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE
|
||||
} else {
|
||||
OPENAI_BETA_RESPONSES_WEBSOCKETS
|
||||
};
|
||||
headers.insert(
|
||||
OPENAI_BETA_HEADER,
|
||||
HeaderValue::from_static(OPENAI_BETA_RESPONSES_WEBSOCKETS),
|
||||
HeaderValue::from_static(responses_websockets_beta_header),
|
||||
);
|
||||
if self.state.include_timing_metrics {
|
||||
headers.insert(
|
||||
@@ -789,18 +808,37 @@ impl ModelClientSession {
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_websocket_request(
|
||||
fn refresh_websocket_last_response_id(&mut self) {
|
||||
if let Some(mut receiver) = self.websocket_last_response_id_rx.take() {
|
||||
match receiver.try_recv() {
|
||||
Ok(response_id) if !response_id.is_empty() => {
|
||||
self.websocket_last_response_id = Some(response_id);
|
||||
}
|
||||
Ok(_) | Err(TryRecvError::Closed) => {
|
||||
self.websocket_last_response_id = None;
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
self.websocket_last_response_id_rx = Some(receiver);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn websocket_previous_response_id(&mut self) -> Option<String> {
|
||||
self.refresh_websocket_last_response_id();
|
||||
self.websocket_last_response_id
|
||||
.clone()
|
||||
.filter(|id| !id.is_empty())
|
||||
}
|
||||
|
||||
fn prepare_websocket_create_request(
|
||||
&self,
|
||||
model_slug: &str,
|
||||
api_prompt: &ApiPrompt,
|
||||
options: &ApiResponsesOptions,
|
||||
input: Vec<ResponseItem>,
|
||||
previous_response_id: Option<String>,
|
||||
) -> ResponsesWsRequest {
|
||||
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
|
||||
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
||||
input: append_items,
|
||||
});
|
||||
}
|
||||
|
||||
let ApiResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
@@ -814,7 +852,8 @@ impl ModelClientSession {
|
||||
let payload = ResponseCreateWsRequest {
|
||||
model: model_slug.to_string(),
|
||||
instructions: api_prompt.instructions.clone(),
|
||||
input: api_prompt.input.clone(),
|
||||
previous_response_id,
|
||||
input,
|
||||
tools: api_prompt.tools.clone(),
|
||||
tool_choice: "auto".to_string(),
|
||||
parallel_tool_calls: api_prompt.parallel_tool_calls,
|
||||
@@ -829,6 +868,43 @@ impl ModelClientSession {
|
||||
ResponsesWsRequest::ResponseCreate(payload)
|
||||
}
|
||||
|
||||
fn prepare_websocket_request(
|
||||
&mut self,
|
||||
model_slug: &str,
|
||||
api_prompt: &ApiPrompt,
|
||||
options: &ApiResponsesOptions,
|
||||
) -> ResponsesWsRequest {
|
||||
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
|
||||
let incremental_items = self.get_incremental_items(&api_prompt.input);
|
||||
if let Some(append_items) = incremental_items {
|
||||
if responses_websockets_v2_enabled
|
||||
&& let Some(previous_response_id) = self.websocket_previous_response_id()
|
||||
{
|
||||
return self.prepare_websocket_create_request(
|
||||
model_slug,
|
||||
api_prompt,
|
||||
options,
|
||||
append_items,
|
||||
Some(previous_response_id),
|
||||
);
|
||||
}
|
||||
|
||||
if !responses_websockets_v2_enabled {
|
||||
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
||||
input: append_items,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.prepare_websocket_create_request(
|
||||
model_slug,
|
||||
api_prompt,
|
||||
options,
|
||||
api_prompt.input.clone(),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a websocket connection for this turn, reusing preconnect when possible.
|
||||
///
|
||||
/// This method first tries to adopt the session-level preconnect slot, then falls back to a
|
||||
@@ -863,6 +939,9 @@ impl ModelClientSession {
|
||||
|
||||
if needs_new {
|
||||
self.client.clear_preconnect();
|
||||
self.websocket_last_items.clear();
|
||||
self.websocket_last_response_id = None;
|
||||
self.websocket_last_response_id_rx = None;
|
||||
let turn_state = options
|
||||
.turn_state
|
||||
.clone()
|
||||
@@ -1023,9 +1102,8 @@ impl ModelClientSession {
|
||||
turn_metadata_header,
|
||||
compression,
|
||||
);
|
||||
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);
|
||||
|
||||
let connection = match self
|
||||
match self
|
||||
.websocket_connection(
|
||||
otel_manager,
|
||||
client_setup.api_provider,
|
||||
@@ -1035,7 +1113,7 @@ impl ModelClientSession {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection,
|
||||
Ok(_) => {}
|
||||
Err(ApiError::Transport(
|
||||
unauthorized_transport @ TransportError::Http { status, .. },
|
||||
)) if status == StatusCode::UNAUTHORIZED => {
|
||||
@@ -1043,13 +1121,33 @@ impl ModelClientSession {
|
||||
continue;
|
||||
}
|
||||
Err(err) => return Err(map_api_error(err)),
|
||||
};
|
||||
}
|
||||
|
||||
let stream_result = connection
|
||||
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);
|
||||
|
||||
let stream_result = self
|
||||
.connection
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
map_api_error(ApiError::Stream(
|
||||
"websocket connection is unavailable".to_string(),
|
||||
))
|
||||
})?
|
||||
.stream_request(request)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
self.websocket_last_items = api_prompt.input.clone();
|
||||
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
|
||||
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
|
||||
let mut last_response_id_sender = Some(last_response_id_sender);
|
||||
let stream_result = stream_result.inspect(move |event| {
|
||||
if let Ok(ResponseEvent::Completed { response_id, .. }) = event
|
||||
&& !response_id.is_empty()
|
||||
&& let Some(sender) = last_response_id_sender.take()
|
||||
{
|
||||
let _ = sender.send(response_id.clone());
|
||||
}
|
||||
});
|
||||
|
||||
return Ok(map_response_stream(stream_result, otel_manager.clone()));
|
||||
}
|
||||
|
||||
@@ -1060,7 +1060,9 @@ impl Session {
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::ResponsesWebsockets),
|
||||
config.features.enabled(Feature::ResponsesWebsockets)
|
||||
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
config.features.enabled(Feature::RuntimeMetrics),
|
||||
Self::build_model_client_beta_features_header(config.as_ref()),
|
||||
@@ -5871,7 +5873,9 @@ mod tests {
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::ResponsesWebsockets),
|
||||
config.features.enabled(Feature::ResponsesWebsockets)
|
||||
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
config.features.enabled(Feature::RuntimeMetrics),
|
||||
Session::build_model_client_beta_features_header(config.as_ref()),
|
||||
@@ -6001,7 +6005,9 @@ mod tests {
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::ResponsesWebsockets),
|
||||
config.features.enabled(Feature::ResponsesWebsockets)
|
||||
|| config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::ResponsesWebsocketsV2),
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
config.features.enabled(Feature::RuntimeMetrics),
|
||||
Session::build_model_client_beta_features_header(config.as_ref()),
|
||||
|
||||
@@ -2685,25 +2685,27 @@ profile = "project"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_websockets_feature_does_not_change_wire_api() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut entries = BTreeMap::new();
|
||||
entries.insert("responses_websockets".to_string(), true);
|
||||
let cfg = ConfigToml {
|
||||
features: Some(crate::features::FeaturesToml { entries }),
|
||||
..Default::default()
|
||||
};
|
||||
fn responses_websocket_features_do_not_change_wire_api() -> std::io::Result<()> {
|
||||
for feature_key in ["responses_websockets", "responses_websockets_v2"] {
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut entries = BTreeMap::new();
|
||||
entries.insert(feature_key.to_string(), true);
|
||||
let cfg = ConfigToml {
|
||||
features: Some(crate::features::FeaturesToml { entries }),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.model_provider.wire_api,
|
||||
crate::model_provider_info::WireApi::Responses
|
||||
);
|
||||
assert_eq!(
|
||||
config.model_provider.wire_api,
|
||||
crate::model_provider_info::WireApi::Responses
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -127,6 +127,8 @@ pub enum Feature {
|
||||
Personality,
|
||||
/// Use the Responses API WebSocket transport for OpenAI by default.
|
||||
ResponsesWebsockets,
|
||||
/// Enable Responses API websocket v2 mode.
|
||||
ResponsesWebsocketsV2,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -569,6 +571,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ResponsesWebsocketsV2,
|
||||
key: "responses_websockets_v2",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
|
||||
/// Push a warning event if any under-development features are enabled.
|
||||
|
||||
Reference in New Issue
Block a user