Add model client sessions (#9102)

Maintain a long-running session.
This commit is contained in:
pakrym-oai
2026-01-12 17:15:56 -08:00
committed by GitHub
parent 87f7226cca
commit 490c1c1fdd
22 changed files with 874 additions and 196 deletions

View File

@@ -1,3 +1,4 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
@@ -5,7 +6,12 @@ use std::time::Duration;
use anyhow::Result;
use base64::Engine;
use codex_protocol::openai_models::ModelsResponse;
use futures::SinkExt;
use futures::StreamExt;
use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
@@ -199,6 +205,47 @@ impl ResponsesRequest {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketRequest {
body: Value,
}
impl WebSocketRequest {
pub fn body_json(&self) -> Value {
self.body.clone()
}
}
pub struct WebSocketTestServer {
uri: String,
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
impl WebSocketTestServer {
pub fn uri(&self) -> &str {
&self.uri
}
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
self.connections.lock().unwrap().clone()
}
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
let connections = self.connections.lock().unwrap();
if connections.len() != 1 {
panic!("expected 1 connection, got {}", connections.len());
}
connections.first().cloned().unwrap_or_default()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
}
}
#[derive(Debug, Clone)]
pub struct ModelsMock {
requests: Arc<Mutex<Vec<wiremock::Request>>>,
@@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer {
server
}
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
///
/// Each connection consumes a queue of request/event sequences. For each
/// request message, the server records the payload and streams the matching
/// events as WebSocket text frames before moving to the next request.
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind websocket server");
let addr = listener.local_addr().expect("websocket server address");
let uri = format!("ws://{addr}");
let connections_log = Arc::new(Mutex::new(Vec::new()));
let requests = Arc::clone(&connections_log);
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move {
loop {
let accept_res = tokio::select! {
_ = &mut shutdown_rx => return,
accept_res = listener.accept() => accept_res,
};
let (stream, _) = match accept_res {
Ok(value) => value,
Err(_) => return,
};
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_requests = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
};
let Some(connection_requests) = connection_requests else {
let _ = ws_stream.close(None).await;
continue;
};
let mut connection_log = Vec::new();
for request_events in connection_requests {
let Some(Ok(message)) = ws_stream.next().await else {
break;
};
if let Some(body) = parse_ws_request_body(message) {
connection_log.push(WebSocketRequest { body });
}
for event in &request_events {
let Ok(payload) = serde_json::to_string(event) else {
continue;
};
if ws_stream.send(Message::Text(payload)).await.is_err() {
break;
}
}
}
requests.lock().unwrap().push(connection_log);
let _ = ws_stream.close(None).await;
if connections.lock().unwrap().is_empty() {
return;
}
}
});
WebSocketTestServer {
uri,
connections: connections_log,
shutdown: shutdown_tx,
task,
}
}
fn parse_ws_request_body(message: Message) -> Option<Value> {
match message {
Message::Text(text) => serde_json::from_str(&text).ok(),
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
_ => None,
}
}
#[derive(Clone)]
pub struct FunctionCallResponseMocks {
pub function_call: ResponseMock,