Compare commits

...

2 Commits

Author SHA1 Message Date
Brian Yu
a7c1a8cdd0 Reset websocket turn state on recon− 2026-02-07 10:58:02 -08:00
Brian Yu
a27322b039 Review websocket v2 state changes 2026-02-07 00:09:50 -08:00
2 changed files with 252 additions and 16 deletions

View File

@@ -142,6 +142,8 @@ struct ModelClientState {
/// This keeps startup preconnect task tracking and warmed-socket adoption in one lock so
/// turn-time websocket setup observes a single, coherent state.
preconnect: Mutex<PreconnectState>,
/// Session-scoped websocket v2 state shared across turns.
responses_websocket_v2: Mutex<ResponsesWebsocketV2State>,
}
impl std::fmt::Debug for ModelClientState {
@@ -167,6 +169,7 @@ impl std::fmt::Debug for ModelClientState {
&self.disable_websockets.load(Ordering::Relaxed),
)
.field("preconnect", &"<opaque>")
.field("responses_websocket_v2", &"<opaque>")
.finish()
}
}
@@ -186,7 +189,7 @@ struct CurrentClientSetup {
/// This bundles the socket with optional sticky-routing state captured during
/// handshake so they are taken and cleared atomically.
struct PreconnectedWebSocket {
connection: ApiWebSocketConnection,
connection: Arc<ApiWebSocketConnection>,
turn_state: Option<String>,
}
@@ -203,6 +206,13 @@ enum PreconnectState {
Ready(PreconnectedWebSocket),
}
#[derive(Default)]
struct ResponsesWebsocketV2State {
connection: Option<Arc<ApiWebSocketConnection>>,
websocket_last_items: Vec<ResponseItem>,
websocket_last_response_id: Option<String>,
}
/// A session-scoped client for model-provider API calls.
///
/// This holds configuration and state that should be shared across turns within a Codex session
@@ -240,7 +250,7 @@ pub struct ModelClient {
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
connection: Option<ApiWebSocketConnection>,
connection: Option<Arc<ApiWebSocketConnection>>,
websocket_last_items: Vec<ResponseItem>,
websocket_last_response_id: Option<String>,
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
@@ -289,6 +299,7 @@ impl ModelClient {
beta_features_header,
disable_websockets: AtomicBool::new(false),
preconnect: Mutex::new(PreconnectState::Idle),
responses_websocket_v2: Mutex::new(ResponsesWebsocketV2State::default()),
}),
}
}
@@ -493,6 +504,74 @@ impl ModelClient {
self.state.enable_responses_websockets_v2
}
fn with_shared_v2_state<R>(
&self,
action: impl FnOnce(&mut ResponsesWebsocketV2State) -> R,
) -> R {
let mut state = self
.state
.responses_websocket_v2
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
action(&mut state)
}
fn shared_v2_connection(&self) -> Option<Arc<ApiWebSocketConnection>> {
self.with_shared_v2_state(|state| state.connection.clone())
}
fn set_shared_v2_connection(&self, connection: Option<Arc<ApiWebSocketConnection>>) {
self.with_shared_v2_state(|state| {
let connection_changed = match (&state.connection, &connection) {
(Some(existing), Some(new_connection)) => !Arc::ptr_eq(existing, new_connection),
(None, Some(_)) | (Some(_), None) => true,
(None, None) => false,
};
if connection_changed {
// Response chaining state is scoped to a single backend websocket connection.
// When the connection changes, start the next request as a full create.
state.websocket_last_items.clear();
state.websocket_last_response_id = None;
}
state.connection = connection;
});
}
fn shared_v2_request_chain_state(&self) -> (Vec<ResponseItem>, Option<String>) {
self.with_shared_v2_state(|state| {
(
state.websocket_last_items.clone(),
state.websocket_last_response_id.clone(),
)
})
}
fn set_shared_v2_request_chain_state(
&self,
websocket_last_items: Vec<ResponseItem>,
response_id: String,
) {
self.with_shared_v2_state(|state| {
state.websocket_last_items = websocket_last_items;
state.websocket_last_response_id = Some(response_id);
});
}
fn clear_shared_v2_request_chain_state(&self) {
self.with_shared_v2_state(|state| {
state.websocket_last_items.clear();
state.websocket_last_response_id = None;
});
}
fn clear_shared_v2_state(&self) {
self.with_shared_v2_state(|state| {
state.connection = None;
state.websocket_last_items.clear();
state.websocket_last_response_id = None;
});
}
/// Returns whether websocket transport has been permanently disabled for this session.
///
/// Once set by fallback activation, subsequent turns must stay on HTTP transport.
@@ -612,7 +691,7 @@ impl ModelClient {
return;
}
*state = PreconnectState::Ready(PreconnectedWebSocket {
connection,
connection: Arc::new(connection),
turn_state,
});
}
@@ -875,6 +954,17 @@ impl ModelClientSession {
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
if responses_websockets_v2_enabled
&& self.websocket_last_items.is_empty()
&& self.websocket_last_response_id.is_none()
&& self.websocket_last_response_id_rx.is_none()
{
let (websocket_last_items, websocket_last_response_id) =
self.client.shared_v2_request_chain_state();
self.websocket_last_items = websocket_last_items;
self.websocket_last_response_id = websocket_last_response_id;
}
let incremental_items = self.get_incremental_items(&api_prompt.input);
if let Some(append_items) = incremental_items {
if responses_websockets_v2_enabled
@@ -920,6 +1010,11 @@ impl ModelClientSession {
turn_metadata_header: Option<&str>,
options: &ApiResponsesOptions,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
if responses_websockets_v2_enabled && self.connection.is_none() {
self.connection = self.client.shared_v2_connection();
}
// Prefer the session-level preconnect slot before creating a new websocket.
if self.connection.is_none() {
if let Some(preconnected) = self.try_use_preconnected_websocket() {
@@ -942,24 +1037,36 @@ impl ModelClientSession {
self.websocket_last_items.clear();
self.websocket_last_response_id = None;
self.websocket_last_response_id_rx = None;
if responses_websockets_v2_enabled {
self.client.clear_shared_v2_request_chain_state();
}
let turn_state = options
.turn_state
.clone()
.unwrap_or_else(|| Arc::clone(&self.turn_state));
let new_conn = self
.client
.connect_websocket(
otel_manager,
api_provider,
api_auth,
Some(turn_state),
turn_metadata_header,
)
.await?;
let new_conn = Arc::new(
self.client
.connect_websocket(
otel_manager,
api_provider,
api_auth,
Some(turn_state),
turn_metadata_header,
)
.await?,
);
if responses_websockets_v2_enabled {
self.client
.set_shared_v2_connection(Some(Arc::clone(&new_conn)));
}
self.connection = Some(new_conn);
} else if responses_websockets_v2_enabled && let Some(connection) = self.connection.as_ref()
{
self.client
.set_shared_v2_connection(Some(Arc::clone(connection)));
}
self.connection.as_ref().ok_or(ApiError::Stream(
self.connection.as_deref().ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
}
@@ -1136,16 +1243,33 @@ impl ModelClientSession {
.stream_request(request)
.await
.map_err(map_api_error)?;
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
if responses_websockets_v2_enabled {
// Keep chain state completion-driven: an interrupted request should not
// leave a mixed {new input, old response_id} pair for the next turn.
self.client.clear_shared_v2_request_chain_state();
}
self.websocket_last_items = api_prompt.input.clone();
let mut completed_input = Some(self.websocket_last_items.clone());
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
let mut last_response_id_sender = Some(last_response_id_sender);
let client = self.client.clone();
let stream_result = stream_result.inspect(move |event| {
if let Ok(ResponseEvent::Completed { response_id, .. }) = event
&& !response_id.is_empty()
&& let Some(sender) = last_response_id_sender.take()
{
let _ = sender.send(response_id.clone());
if let Some(sender) = last_response_id_sender.take() {
let _ = sender.send(response_id.clone());
}
if responses_websockets_v2_enabled
&& let Some(completed_input) = completed_input.take()
{
client.set_shared_v2_request_chain_state(
completed_input,
response_id.clone(),
);
}
}
});
@@ -1241,6 +1365,7 @@ impl ModelClientSession {
self.connection = None;
self.websocket_last_items.clear();
self.client.clear_shared_v2_state();
self.client.clear_preconnect();
}
activated

View File

@@ -1,6 +1,7 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use anyhow::Result;
use codex_core::features::Feature;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
@@ -123,3 +124,113 @@ async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_reuses_connection_across_turns() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-1"),
],
vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-2", "done"),
ev_completed("resp-2"),
],
],
response_headers: Vec::new(),
accept_delay: None,
}])
.await;
let builder = test_codex();
let test = builder
.with_config(|config| {
config.features.enable(Feature::ResponsesWebsocketsV2);
})
.build_with_websocket_server(&server)
.await?;
test.submit_turn("first turn").await?;
test.submit_turn("second turn").await?;
let handshakes = server.handshakes();
assert_eq!(handshakes.len(), 1);
let requests = server.single_connection();
assert_eq!(requests.len(), 2);
let second_request = requests[1].body_json();
assert_eq!(
second_request["previous_response_id"].as_str(),
Some("resp-1")
);
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_reconnect_after_turn_boundary_does_not_replay_turn_state() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server_with_headers(vec![
WebSocketConnectionConfig {
requests: vec![vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-1"),
]],
response_headers: vec![(TURN_STATE_HEADER.to_string(), "ts-1".to_string())],
accept_delay: None,
},
WebSocketConnectionConfig {
requests: vec![vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-2", "done"),
ev_completed("resp-2"),
]],
response_headers: Vec::new(),
accept_delay: None,
},
])
.await;
let builder = test_codex();
let test = builder
.with_config(|config| {
config.features.enable(Feature::ResponsesWebsocketsV2);
})
.build_with_websocket_server(&server)
.await?;
test.submit_turn("first turn").await?;
test.submit_turn("second turn").await?;
let handshakes = server.handshakes();
assert_eq!(handshakes.len(), 2);
assert_eq!(handshakes[0].header(TURN_STATE_HEADER), None);
assert_eq!(handshakes[1].header(TURN_STATE_HEADER), None);
let connections = server.connections();
assert_eq!(connections.len(), 2);
let second_connection = connections
.get(1)
.expect("second websocket connection should exist");
let second_request = second_connection
.first()
.expect("second websocket connection should have a request")
.body_json();
assert_eq!(second_request.get("previous_response_id"), None);
let second_input_len = second_request["input"]
.as_array()
.map(Vec::len)
.unwrap_or(0);
assert!(
second_input_len > 1,
"reconnect should send full input items, got {second_input_len}"
);
server.shutdown().await;
Ok(())
}