Add model client sessions (#9102)

Maintain a long-running session.
This commit is contained in:
pakrym-oai
2026-01-12 17:15:56 -08:00
committed by GitHub
parent 87f7226cca
commit 490c1c1fdd
22 changed files with 874 additions and 196 deletions

View File

@@ -98,7 +98,8 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = input;

View File

@@ -99,7 +99,8 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {

View File

@@ -15,11 +15,13 @@ codex-core = { workspace = true, features = ["test-support"] }
codex-protocol = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-cargo-bin = { workspace = true }
futures = { workspace = true }
notify = { workspace = true }
regex-lite = { workspace = true }
serde_json = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["time"] }
tokio = { workspace = true, features = ["net", "time"] }
tokio-tungstenite = { workspace = true }
walkdir = { workspace = true }
wiremock = { workspace = true }
shlex = { workspace = true }

View File

@@ -1,3 +1,4 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
@@ -5,7 +6,12 @@ use std::time::Duration;
use anyhow::Result;
use base64::Engine;
use codex_protocol::openai_models::ModelsResponse;
use futures::SinkExt;
use futures::StreamExt;
use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
@@ -199,6 +205,47 @@ impl ResponsesRequest {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketRequest {
body: Value,
}
impl WebSocketRequest {
pub fn body_json(&self) -> Value {
self.body.clone()
}
}
pub struct WebSocketTestServer {
uri: String,
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
impl WebSocketTestServer {
pub fn uri(&self) -> &str {
&self.uri
}
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
self.connections.lock().unwrap().clone()
}
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
let connections = self.connections.lock().unwrap();
if connections.len() != 1 {
panic!("expected 1 connection, got {}", connections.len());
}
connections.first().cloned().unwrap_or_default()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
}
}
#[derive(Debug, Clone)]
pub struct ModelsMock {
requests: Arc<Mutex<Vec<wiremock::Request>>>,
@@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer {
server
}
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
///
/// Each connection consumes a queue of request/event sequences. For each
/// request message, the server records the payload and streams the matching
/// events as WebSocket text frames before moving to the next request.
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind websocket server");
let addr = listener.local_addr().expect("websocket server address");
let uri = format!("ws://{addr}");
let connections_log = Arc::new(Mutex::new(Vec::new()));
let requests = Arc::clone(&connections_log);
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move {
loop {
let accept_res = tokio::select! {
_ = &mut shutdown_rx => return,
accept_res = listener.accept() => accept_res,
};
let (stream, _) = match accept_res {
Ok(value) => value,
Err(_) => return,
};
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_requests = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
};
let Some(connection_requests) = connection_requests else {
let _ = ws_stream.close(None).await;
continue;
};
let mut connection_log = Vec::new();
for request_events in connection_requests {
let Some(Ok(message)) = ws_stream.next().await else {
break;
};
if let Some(body) = parse_ws_request_body(message) {
connection_log.push(WebSocketRequest { body });
}
for event in &request_events {
let Ok(payload) = serde_json::to_string(event) else {
continue;
};
if ws_stream.send(Message::Text(payload)).await.is_err() {
break;
}
}
}
requests.lock().unwrap().push(connection_log);
let _ = ws_stream.close(None).await;
if connections.lock().unwrap().is_empty() {
return;
}
}
});
WebSocketTestServer {
uri,
connections: connections_log,
shutdown: shutdown_tx,
task,
}
}
fn parse_ws_request_body(message: Message) -> Option<Value> {
match message {
Message::Text(text) => serde_json::from_str(&text).ok(),
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
_ => None,
}
}
#[derive(Clone)]
pub struct FunctionCallResponseMocks {
pub function_call: ResponseMock,

View File

@@ -91,7 +91,8 @@ async fn responses_stream_includes_subagent_header_on_review() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
@@ -186,7 +187,8 @@ async fn responses_stream_includes_subagent_header_on_other() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
@@ -279,7 +281,8 @@ async fn responses_respects_model_info_overrides_from_config() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {

View File

@@ -1181,7 +1181,8 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Reasoning {

View File

@@ -71,3 +71,4 @@ mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod websocket;

View File

@@ -67,7 +67,7 @@ async fn retries_on_early_close() {
name: "openai".into(),
base_url: Some(format!("{}/v1", server.uri())),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// ModelClientSession will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,

View File

@@ -0,0 +1,112 @@
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use futures::StreamExt;
use std::sync::Arc;
use tempfile::TempDir;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_request() {
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let provider = ModelProviderInfo {
name: "mock-ws".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::ResponsesWebsocket,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
auth_manager.get_auth_mode(),
false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Arc::clone(&config),
None,
model_info,
otel_manager,
provider,
effort,
summary,
conversation_id,
SessionSource::Exec,
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client
.stream(&prompt)
.await
.expect("websocket stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let request = connection.first().cloned().unwrap();
let body = request.body_json();
assert_eq!(body["model"].as_str(), Some(model.as_str()));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}