Turn-state sticky routing per turn (#9332)

- capture the header from SSE/WS handshakes, store it per
ModelClientSession using `Oncelock`, echo it on turn-scoped requests,
and add SSE+WS integration tests for within-turn persistence +
cross-turn reset.

- keep `x-codex-turn-state` sticky within a user turn to maintain
routing continuity for retries/tool follow-ups.
This commit is contained in:
Ahmed Ibrahim
2026-01-16 09:30:11 -08:00
committed by GitHub
parent 4125c825f9
commit ebdd8795e9
11 changed files with 343 additions and 24 deletions

View File

@@ -12,6 +12,8 @@ use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::handshake::server::Request;
use tokio_tungstenite::tungstenite::handshake::server::Response;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
@@ -19,6 +21,8 @@ use wiremock::MockBuilder;
use wiremock::MockServer;
use wiremock::Respond;
use wiremock::ResponseTemplate;
use wiremock::http::HeaderName;
use wiremock::http::HeaderValue;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
@@ -216,9 +220,30 @@ impl WebSocketRequest {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketHandshake {
headers: Vec<(String, String)>,
}
impl WebSocketHandshake {
pub fn header(&self, name: &str) -> Option<String> {
self.headers
.iter()
.find(|(header, _)| header.eq_ignore_ascii_case(name))
.map(|(_, value)| value.clone())
}
}
#[derive(Debug, Clone)]
pub struct WebSocketConnectionConfig {
pub requests: Vec<Vec<Value>>,
pub response_headers: Vec<(String, String)>,
}
pub struct WebSocketTestServer {
uri: String,
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
@@ -240,6 +265,18 @@ impl WebSocketTestServer {
connections.first().cloned().unwrap_or_default()
}
pub fn handshakes(&self) -> Vec<WebSocketHandshake> {
self.handshakes.lock().unwrap().clone()
}
pub fn single_handshake(&self) -> WebSocketHandshake {
let handshakes = self.handshakes.lock().unwrap();
if handshakes.len() != 1 {
panic!("expected 1 handshake, got {}", handshakes.len());
}
handshakes.first().cloned().unwrap()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
@@ -786,13 +823,28 @@ pub async fn start_mock_server() -> MockServer {
/// request message, the server records the payload and streams the matching
/// events as WebSocket text frames before moving to the next request.
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
let connections = connections
.into_iter()
.map(|requests| WebSocketConnectionConfig {
requests,
response_headers: Vec::new(),
})
.collect();
start_websocket_server_with_headers(connections).await
}
pub async fn start_websocket_server_with_headers(
connections: Vec<WebSocketConnectionConfig>,
) -> WebSocketTestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind websocket server");
let addr = listener.local_addr().expect("websocket server address");
let uri = format!("ws://{addr}");
let connections_log = Arc::new(Mutex::new(Vec::new()));
let handshakes_log = Arc::new(Mutex::new(Vec::new()));
let requests = Arc::clone(&connections_log);
let handshakes = Arc::clone(&handshakes_log);
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
@@ -806,27 +858,57 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
Ok(value) => value,
Err(_) => return,
};
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_requests = {
let connection = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
};
let Some(connection_requests) = connection_requests else {
let _ = ws_stream.close(None).await;
let Some(connection) = connection else {
continue;
};
let response_headers = connection.response_headers.clone();
let handshake_log = Arc::clone(&handshakes);
let callback = move |req: &Request, mut response: Response| {
let headers = req
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|value| (name.as_str().to_string(), value.to_string()))
})
.collect();
handshake_log
.lock()
.unwrap()
.push(WebSocketHandshake { headers });
let headers_mut = response.headers_mut();
for (name, value) in &response_headers {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
headers_mut.insert(name, value);
}
}
Ok(response)
};
let mut ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_index = {
let mut log = requests.lock().unwrap();
log.push(Vec::new());
log.len() - 1
};
for request_events in connection_requests {
for request_events in connection.requests {
let Some(Ok(message)) = ws_stream.next().await else {
break;
};
@@ -858,6 +940,7 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
WebSocketTestServer {
uri,
connections: connections_log,
handshakes: handshakes_log,
shutdown: shutdown_tx,
task,
}
@@ -942,6 +1025,45 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> Res
response_mock
}
/// Mounts a sequence of responses for each POST to `/v1/responses`.
/// Panics if more requests are received than responses provided.
pub async fn mount_response_sequence(
server: &MockServer,
responses: Vec<ResponseTemplate>,
) -> ResponseMock {
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
struct SeqResponder {
num_calls: AtomicUsize,
responses: Vec<ResponseTemplate>,
}
impl Respond for SeqResponder {
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
self.responses
.get(call_num)
.unwrap_or_else(|| panic!("no response for {call_num}"))
.clone()
}
}
let num_calls = responses.len();
let responder = SeqResponder {
num_calls: AtomicUsize::new(0),
responses,
};
let (mock, response_mock) = base_mock();
mock.respond_with(responder)
.up_to_n_times(num_calls as u64)
.expect(num_calls as u64)
.mount(server)
.await;
response_mock
}
/// Validate invariants on the request body sent to `/v1/responses`.
///
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.