mirror of
https://github.com/openai/codex.git
synced 2026-04-26 23:55:25 +00:00
Turn-state sticky routing per turn (#9332)
- capture the header from SSE/WS handshakes, store it per ModelClientSession using `Oncelock`, echo it on turn-scoped requests, and add SSE+WS integration tests for within-turn persistence + cross-turn reset. - keep `x-codex-turn-state` sticky within a user turn to maintain routing continuity for retries/tool follow-ups.
This commit is contained in:
@@ -12,6 +12,8 @@ use serde_json::Value;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::handshake::server::Request;
|
||||
use tokio_tungstenite::tungstenite::handshake::server::Response;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
@@ -19,6 +21,8 @@ use wiremock::MockBuilder;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::http::HeaderName;
|
||||
use wiremock::http::HeaderValue;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
@@ -216,9 +220,30 @@ impl WebSocketRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketHandshake {
|
||||
headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl WebSocketHandshake {
|
||||
pub fn header(&self, name: &str) -> Option<String> {
|
||||
self.headers
|
||||
.iter()
|
||||
.find(|(header, _)| header.eq_ignore_ascii_case(name))
|
||||
.map(|(_, value)| value.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketConnectionConfig {
|
||||
pub requests: Vec<Vec<Value>>,
|
||||
pub response_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
pub struct WebSocketTestServer {
|
||||
uri: String,
|
||||
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
|
||||
shutdown: oneshot::Sender<()>,
|
||||
task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
@@ -240,6 +265,18 @@ impl WebSocketTestServer {
|
||||
connections.first().cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn handshakes(&self) -> Vec<WebSocketHandshake> {
|
||||
self.handshakes.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn single_handshake(&self) -> WebSocketHandshake {
|
||||
let handshakes = self.handshakes.lock().unwrap();
|
||||
if handshakes.len() != 1 {
|
||||
panic!("expected 1 handshake, got {}", handshakes.len());
|
||||
}
|
||||
handshakes.first().cloned().unwrap()
|
||||
}
|
||||
|
||||
pub async fn shutdown(self) {
|
||||
let _ = self.shutdown.send(());
|
||||
let _ = self.task.await;
|
||||
@@ -786,13 +823,28 @@ pub async fn start_mock_server() -> MockServer {
|
||||
/// request message, the server records the payload and streams the matching
|
||||
/// events as WebSocket text frames before moving to the next request.
|
||||
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
|
||||
let connections = connections
|
||||
.into_iter()
|
||||
.map(|requests| WebSocketConnectionConfig {
|
||||
requests,
|
||||
response_headers: Vec::new(),
|
||||
})
|
||||
.collect();
|
||||
start_websocket_server_with_headers(connections).await
|
||||
}
|
||||
|
||||
pub async fn start_websocket_server_with_headers(
|
||||
connections: Vec<WebSocketConnectionConfig>,
|
||||
) -> WebSocketTestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind websocket server");
|
||||
let addr = listener.local_addr().expect("websocket server address");
|
||||
let uri = format!("ws://{addr}");
|
||||
let connections_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let handshakes_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let requests = Arc::clone(&connections_log);
|
||||
let handshakes = Arc::clone(&handshakes_log);
|
||||
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
|
||||
@@ -806,27 +858,57 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
|
||||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_requests = {
|
||||
let connection = {
|
||||
let mut pending = connections.lock().unwrap();
|
||||
pending.pop_front()
|
||||
};
|
||||
|
||||
let Some(connection_requests) = connection_requests else {
|
||||
let _ = ws_stream.close(None).await;
|
||||
let Some(connection) = connection else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let response_headers = connection.response_headers.clone();
|
||||
let handshake_log = Arc::clone(&handshakes);
|
||||
let callback = move |req: &Request, mut response: Response| {
|
||||
let headers = req
|
||||
.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.map(|value| (name.as_str().to_string(), value.to_string()))
|
||||
})
|
||||
.collect();
|
||||
handshake_log
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(WebSocketHandshake { headers });
|
||||
|
||||
let headers_mut = response.headers_mut();
|
||||
for (name, value) in &response_headers {
|
||||
if let (Ok(name), Ok(value)) = (
|
||||
HeaderName::from_bytes(name.as_bytes()),
|
||||
HeaderValue::from_str(value),
|
||||
) {
|
||||
headers_mut.insert(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
};
|
||||
|
||||
let mut ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_index = {
|
||||
let mut log = requests.lock().unwrap();
|
||||
log.push(Vec::new());
|
||||
log.len() - 1
|
||||
};
|
||||
for request_events in connection_requests {
|
||||
for request_events in connection.requests {
|
||||
let Some(Ok(message)) = ws_stream.next().await else {
|
||||
break;
|
||||
};
|
||||
@@ -858,6 +940,7 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
|
||||
WebSocketTestServer {
|
||||
uri,
|
||||
connections: connections_log,
|
||||
handshakes: handshakes_log,
|
||||
shutdown: shutdown_tx,
|
||||
task,
|
||||
}
|
||||
@@ -942,6 +1025,45 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> Res
|
||||
response_mock
|
||||
}
|
||||
|
||||
/// Mounts a sequence of responses for each POST to `/v1/responses`.
|
||||
/// Panics if more requests are received than responses provided.
|
||||
pub async fn mount_response_sequence(
|
||||
server: &MockServer,
|
||||
responses: Vec<ResponseTemplate>,
|
||||
) -> ResponseMock {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<ResponseTemplate>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.responses
|
||||
.get(call_num)
|
||||
.unwrap_or_else(|| panic!("no response for {call_num}"))
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let num_calls = responses.len();
|
||||
let responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses,
|
||||
};
|
||||
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(responder)
|
||||
.up_to_n_times(num_calls as u64)
|
||||
.expect(num_calls as u64)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
/// Validate invariants on the request body sent to `/v1/responses`.
|
||||
///
|
||||
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.
|
||||
|
||||
Reference in New Issue
Block a user