Add model client sessions (#9102)

Maintain a long-running session.
This commit is contained in:
pakrym-oai
2026-01-12 17:15:56 -08:00
committed by GitHub
parent 87f7226cca
commit 490c1c1fdd
22 changed files with 874 additions and 196 deletions

View File

@@ -1181,7 +1181,8 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Reasoning {

View File

@@ -71,3 +71,4 @@ mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod websocket;

View File

@@ -67,7 +67,7 @@ async fn retries_on_early_close() {
name: "openai".into(),
base_url: Some(format!("{}/v1", server.uri())),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// ModelClientSession will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,

View File

@@ -0,0 +1,112 @@
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use futures::StreamExt;
use std::sync::Arc;
use tempfile::TempDir;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_request() {
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let provider = ModelProviderInfo {
name: "mock-ws".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::ResponsesWebsocket,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
auth_manager.get_auth_mode(),
false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Arc::clone(&config),
None,
model_info,
otel_manager,
provider,
effort,
summary,
conversation_id,
SessionSource::Exec,
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client
.stream(&prompt)
.await
.expect("websocket stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let request = connection.first().cloned().unwrap();
let body = request.body_json();
assert_eq!(body["model"].as_str(), Some(model.as_str()));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}