mirror of
https://github.com/openai/codex.git
synced 2026-04-26 15:45:02 +00:00
@@ -1,3 +1,4 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
@@ -5,7 +6,12 @@ use std::time::Duration;
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use serde_json::Value;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
@@ -199,6 +205,47 @@ impl ResponsesRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketRequest {
|
||||
body: Value,
|
||||
}
|
||||
|
||||
impl WebSocketRequest {
|
||||
pub fn body_json(&self) -> Value {
|
||||
self.body.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WebSocketTestServer {
|
||||
uri: String,
|
||||
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
shutdown: oneshot::Sender<()>,
|
||||
task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl WebSocketTestServer {
|
||||
pub fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
|
||||
self.connections.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
|
||||
let connections = self.connections.lock().unwrap();
|
||||
if connections.len() != 1 {
|
||||
panic!("expected 1 connection, got {}", connections.len());
|
||||
}
|
||||
connections.first().cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn shutdown(self) {
|
||||
let _ = self.shutdown.send(());
|
||||
let _ = self.task.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelsMock {
|
||||
requests: Arc<Mutex<Vec<wiremock::Request>>>,
|
||||
@@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer {
|
||||
server
|
||||
}
|
||||
|
||||
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
|
||||
///
|
||||
/// Each connection consumes a queue of request/event sequences. For each
|
||||
/// request message, the server records the payload and streams the matching
|
||||
/// events as WebSocket text frames before moving to the next request.
|
||||
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind websocket server");
|
||||
let addr = listener.local_addr().expect("websocket server address");
|
||||
let uri = format!("ws://{addr}");
|
||||
let connections_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let requests = Arc::clone(&connections_log);
|
||||
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
let accept_res = tokio::select! {
|
||||
_ = &mut shutdown_rx => return,
|
||||
accept_res = listener.accept() => accept_res,
|
||||
};
|
||||
let (stream, _) = match accept_res {
|
||||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_requests = {
|
||||
let mut pending = connections.lock().unwrap();
|
||||
pending.pop_front()
|
||||
};
|
||||
|
||||
let Some(connection_requests) = connection_requests else {
|
||||
let _ = ws_stream.close(None).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut connection_log = Vec::new();
|
||||
for request_events in connection_requests {
|
||||
let Some(Ok(message)) = ws_stream.next().await else {
|
||||
break;
|
||||
};
|
||||
if let Some(body) = parse_ws_request_body(message) {
|
||||
connection_log.push(WebSocketRequest { body });
|
||||
}
|
||||
|
||||
for event in &request_events {
|
||||
let Ok(payload) = serde_json::to_string(event) else {
|
||||
continue;
|
||||
};
|
||||
if ws_stream.send(Message::Text(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
requests.lock().unwrap().push(connection_log);
|
||||
let _ = ws_stream.close(None).await;
|
||||
|
||||
if connections.lock().unwrap().is_empty() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
WebSocketTestServer {
|
||||
uri,
|
||||
connections: connections_log,
|
||||
shutdown: shutdown_tx,
|
||||
task,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ws_request_body(message: Message) -> Option<Value> {
|
||||
match message {
|
||||
Message::Text(text) => serde_json::from_str(&text).ok(),
|
||||
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FunctionCallResponseMocks {
|
||||
pub function_call: ResponseMock,
|
||||
|
||||
Reference in New Issue
Block a user