mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
merge remote models (#9547)
We have `models.json` and `/models` response Behavior: 1. New models from models endpoint gets added 2. Shared models get replaced by remote ones 3. Existing models in `models.json` but not `/models` are kept 4. Mark highest priority as default
This commit is contained in:
@@ -18,6 +18,7 @@ codex_rust_crate(
|
||||
),
|
||||
integration_compile_data_extra = [
|
||||
"//codex-rs/apply-patch:apply_patch_tool_instructions.md",
|
||||
"models.json",
|
||||
"prompt.md",
|
||||
],
|
||||
test_data_extra = [
|
||||
|
||||
@@ -239,7 +239,18 @@ impl ModelsManager {
|
||||
|
||||
/// Replace the cached remote models and rebuild the derived presets list.
|
||||
async fn apply_remote_models(&self, models: Vec<ModelInfo>) {
|
||||
*self.remote_models.write().await = models;
|
||||
let mut existing_models = Self::load_remote_models_from_file().unwrap_or_default();
|
||||
for model in models {
|
||||
if let Some(existing_index) = existing_models
|
||||
.iter()
|
||||
.position(|existing| existing.slug == model.slug)
|
||||
{
|
||||
existing_models[existing_index] = model;
|
||||
} else {
|
||||
existing_models.push(model);
|
||||
}
|
||||
}
|
||||
*self.remote_models.write().await = existing_models;
|
||||
}
|
||||
|
||||
fn load_remote_models_from_file() -> Result<Vec<ModelInfo>, std::io::Error> {
|
||||
@@ -272,16 +283,16 @@ impl ModelsManager {
|
||||
let chatgpt_mode = self.auth_manager.get_auth_mode() == Some(AuthMode::ChatGPT);
|
||||
merged_presets = ModelPreset::filter_by_auth(merged_presets, chatgpt_mode);
|
||||
|
||||
let has_default = merged_presets.iter().any(|preset| preset.is_default);
|
||||
if !has_default {
|
||||
if let Some(default) = merged_presets
|
||||
.iter_mut()
|
||||
.find(|preset| preset.show_in_picker)
|
||||
{
|
||||
default.is_default = true;
|
||||
} else if let Some(default) = merged_presets.first_mut() {
|
||||
default.is_default = true;
|
||||
}
|
||||
for preset in &mut merged_presets {
|
||||
preset.is_default = false;
|
||||
}
|
||||
if let Some(default) = merged_presets
|
||||
.iter_mut()
|
||||
.find(|preset| preset.show_in_picker)
|
||||
{
|
||||
default.is_default = true;
|
||||
} else if let Some(default) = merged_presets.first_mut() {
|
||||
default.is_default = true;
|
||||
}
|
||||
|
||||
merged_presets
|
||||
@@ -396,6 +407,16 @@ mod tests {
|
||||
.expect("valid model")
|
||||
}
|
||||
|
||||
fn assert_models_contain(actual: &[ModelInfo], expected: &[ModelInfo]) {
|
||||
for model in expected {
|
||||
assert!(
|
||||
actual.iter().any(|candidate| candidate.slug == model.slug),
|
||||
"expected model {} in cached list",
|
||||
model.slug
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_for(base_url: String) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
@@ -415,7 +436,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_sorts_and_marks_default() {
|
||||
async fn refresh_available_models_sorts_by_priority() {
|
||||
let server = MockServer::start().await;
|
||||
let remote_models = vec![
|
||||
remote_model("priority-low", "Low", 1),
|
||||
@@ -447,7 +468,7 @@ mod tests {
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
let cached_remote = manager.get_remote_models(&config).await;
|
||||
assert_eq!(cached_remote, remote_models);
|
||||
assert_models_contain(&cached_remote, &remote_models);
|
||||
|
||||
let available = manager
|
||||
.list_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
@@ -464,11 +485,6 @@ mod tests {
|
||||
high_idx < low_idx,
|
||||
"higher priority should be listed before lower priority"
|
||||
);
|
||||
assert!(
|
||||
available[high_idx].is_default,
|
||||
"highest priority should be default"
|
||||
);
|
||||
assert!(!available[low_idx].is_default);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
@@ -508,22 +524,14 @@ mod tests {
|
||||
.refresh_available_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await
|
||||
.expect("first refresh succeeds");
|
||||
assert_eq!(
|
||||
manager.get_remote_models(&config).await,
|
||||
remote_models,
|
||||
"remote cache should store fetched models"
|
||||
);
|
||||
assert_models_contain(&manager.get_remote_models(&config).await, &remote_models);
|
||||
|
||||
// Second call should read from cache and avoid the network.
|
||||
manager
|
||||
.refresh_available_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await
|
||||
.expect("cached refresh succeeds");
|
||||
assert_eq!(
|
||||
manager.get_remote_models(&config).await,
|
||||
remote_models,
|
||||
"cache path should not mutate stored models"
|
||||
);
|
||||
assert_models_contain(&manager.get_remote_models(&config).await, &remote_models);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
@@ -587,11 +595,7 @@ mod tests {
|
||||
.refresh_available_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
assert_eq!(
|
||||
manager.get_remote_models(&config).await,
|
||||
updated_models,
|
||||
"stale cache should trigger refetch"
|
||||
);
|
||||
assert_models_contain(&manager.get_remote_models(&config).await, &updated_models);
|
||||
assert_eq!(
|
||||
initial_mock.requests().len(),
|
||||
1,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used)]
|
||||
// unified exec is not supported on Windows OS
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -434,8 +435,18 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
.find(|model| model.model == "remote-alpha")
|
||||
.expect("remote model should be listed");
|
||||
let mut expected_remote: ModelPreset = remote_model.into();
|
||||
expected_remote.is_default = true;
|
||||
expected_remote.is_default = remote.is_default;
|
||||
assert_eq!(*remote, expected_remote);
|
||||
let default_model = available
|
||||
.iter()
|
||||
.find(|model| model.show_in_picker)
|
||||
.expect("default model should be set");
|
||||
assert!(default_model.is_default);
|
||||
assert_eq!(
|
||||
available.iter().filter(|model| model.is_default).count(),
|
||||
1,
|
||||
"expected a single default model"
|
||||
);
|
||||
assert!(
|
||||
available
|
||||
.iter()
|
||||
@@ -451,6 +462,148 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_merge_adds_new_high_priority_first() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let remote_model = test_remote_model("remote-top", ModelVisibility::List, -10_000);
|
||||
let models_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
|
||||
let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing();
|
||||
let provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let manager = ModelsManager::with_provider(
|
||||
codex_home.path().to_path_buf(),
|
||||
codex_core::auth::AuthManager::from_auth_for_testing(auth),
|
||||
provider,
|
||||
);
|
||||
|
||||
let available = manager
|
||||
.list_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
assert_eq!(
|
||||
available.first().map(|model| model.model.as_str()),
|
||||
Some("remote-top")
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"expected a single /models request"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_merge_replaces_overlapping_model() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let slug = bundled_model_slug();
|
||||
let mut remote_model = test_remote_model(&slug, ModelVisibility::List, 0);
|
||||
remote_model.display_name = "Overridden".to_string();
|
||||
remote_model.description = Some("Overridden description".to_string());
|
||||
let models_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model.clone()],
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
|
||||
let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing();
|
||||
let provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let manager = ModelsManager::with_provider(
|
||||
codex_home.path().to_path_buf(),
|
||||
codex_core::auth::AuthManager::from_auth_for_testing(auth),
|
||||
provider,
|
||||
);
|
||||
|
||||
let available = manager
|
||||
.list_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
let overridden = available
|
||||
.iter()
|
||||
.find(|model| model.model == slug)
|
||||
.expect("overlapping model should be listed");
|
||||
assert_eq!(overridden.display_name, remote_model.display_name);
|
||||
assert_eq!(
|
||||
overridden.description,
|
||||
remote_model
|
||||
.description
|
||||
.expect("remote model should include description")
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"expected a single /models request"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_merge_preserves_bundled_models_on_empty_response() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let models_mock = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
|
||||
let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing();
|
||||
let provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let manager = ModelsManager::with_provider(
|
||||
codex_home.path().to_path_buf(),
|
||||
codex_core::auth::AuthManager::from_auth_for_testing(auth),
|
||||
provider,
|
||||
);
|
||||
|
||||
let available = manager
|
||||
.list_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
let bundled_slug = bundled_model_slug();
|
||||
assert!(
|
||||
available.iter().any(|model| model.model == bundled_slug),
|
||||
"bundled models should remain available after empty remote response"
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"expected a single /models request"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_request_times_out_after_5s() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -588,6 +741,17 @@ async fn wait_for_model_available(
|
||||
}
|
||||
}
|
||||
|
||||
fn bundled_model_slug() -> String {
|
||||
let response: ModelsResponse = serde_json::from_str(include_str!("../../models.json"))
|
||||
.expect("bundled models.json should deserialize");
|
||||
response
|
||||
.models
|
||||
.first()
|
||||
.expect("bundled models.json should include at least one model")
|
||||
.slug
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn test_remote_model(slug: &str, visibility: ModelVisibility, priority: i32) -> ModelInfo {
|
||||
test_remote_model_with_policy(
|
||||
slug,
|
||||
|
||||
Reference in New Issue
Block a user