Turn-state sticky routing per turn (#9332)

- capture the header from SSE/WS handshakes, store it per
ModelClientSession using `Oncelock`, echo it on turn-scoped requests,
and add SSE+WS integration tests for within-turn persistence +
cross-turn reset.

- keep `x-codex-turn-state` sticky within a user turn to maintain
routing continuity for retries/tool follow-ups.
This commit is contained in:
Ahmed Ibrahim
2026-01-16 09:30:11 -08:00
committed by GitHub
parent 4125c825f9
commit ebdd8795e9
11 changed files with 343 additions and 24 deletions

View File

@@ -13,6 +13,7 @@ use http::HeaderMap;
use http::HeaderValue;
use serde_json::Value;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -27,6 +28,7 @@ use tracing::trace;
use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
@@ -100,6 +102,7 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
pub async fn connect(
&self,
extra_headers: HeaderMap,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<ResponsesWebsocketConnection, ApiError> {
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
@@ -108,7 +111,7 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let stream = connect_websocket(ws_url, headers).await?;
let stream = connect_websocket(ws_url, headers, turn_state).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
@@ -130,16 +133,28 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
}
}
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
async fn connect_websocket(
url: Url,
headers: HeaderMap,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<WsStream, ApiError> {
let mut request = url
.clone()
.into_client_request()
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
request.headers_mut().extend(headers);
let (stream, _) = tokio_tungstenite::connect_async(request)
let (stream, response) = tokio_tungstenite::connect_async(request)
.await
.map_err(|err| map_ws_error(err, &url))?;
if let Some(turn_state) = turn_state
&& let Some(header_value) = response
.headers()
.get(X_CODEX_TURN_STATE_HEADER)
.and_then(|value| value.to_str().ok())
{
let _ = turn_state.set(header_value.to_string());
}
Ok(stream)
}