diff --git a/codex-rs/core/BUILD.bazel b/codex-rs/core/BUILD.bazel index 37a3173f9a..f1649b36c7 100644 --- a/codex-rs/core/BUILD.bazel +++ b/codex-rs/core/BUILD.bazel @@ -18,6 +18,7 @@ codex_rust_crate( ), integration_compile_data_extra = [ "//codex-rs/apply-patch:apply_patch_tool_instructions.md", + "models.json", "prompt.md", ], test_data_extra = [ diff --git a/codex-rs/core/src/models_manager/manager.rs b/codex-rs/core/src/models_manager/manager.rs index b7cc66159a..7210e1c62d 100644 --- a/codex-rs/core/src/models_manager/manager.rs +++ b/codex-rs/core/src/models_manager/manager.rs @@ -239,7 +239,18 @@ impl ModelsManager { /// Replace the cached remote models and rebuild the derived presets list. async fn apply_remote_models(&self, models: Vec) { - *self.remote_models.write().await = models; + let mut existing_models = Self::load_remote_models_from_file().unwrap_or_default(); + for model in models { + if let Some(existing_index) = existing_models + .iter() + .position(|existing| existing.slug == model.slug) + { + existing_models[existing_index] = model; + } else { + existing_models.push(model); + } + } + *self.remote_models.write().await = existing_models; } fn load_remote_models_from_file() -> Result, std::io::Error> { @@ -272,16 +283,16 @@ impl ModelsManager { let chatgpt_mode = self.auth_manager.get_auth_mode() == Some(AuthMode::ChatGPT); merged_presets = ModelPreset::filter_by_auth(merged_presets, chatgpt_mode); - let has_default = merged_presets.iter().any(|preset| preset.is_default); - if !has_default { - if let Some(default) = merged_presets - .iter_mut() - .find(|preset| preset.show_in_picker) - { - default.is_default = true; - } else if let Some(default) = merged_presets.first_mut() { - default.is_default = true; - } + for preset in &mut merged_presets { + preset.is_default = false; + } + if let Some(default) = merged_presets + .iter_mut() + .find(|preset| preset.show_in_picker) + { + default.is_default = true; + } else if let Some(default) = merged_presets.first_mut() { + default.is_default = true; } merged_presets @@ -396,6 +407,16 @@ mod tests { .expect("valid model") } + fn assert_models_contain(actual: &[ModelInfo], expected: &[ModelInfo]) { + for model in expected { + assert!( + actual.iter().any(|candidate| candidate.slug == model.slug), + "expected model {} in cached list", + model.slug + ); + } + } + fn provider_for(base_url: String) -> ModelProviderInfo { ModelProviderInfo { name: "mock".into(), @@ -415,7 +436,7 @@ mod tests { } #[tokio::test] - async fn refresh_available_models_sorts_and_marks_default() { + async fn refresh_available_models_sorts_by_priority() { let server = MockServer::start().await; let remote_models = vec![ remote_model("priority-low", "Low", 1), @@ -447,7 +468,7 @@ mod tests { .await .expect("refresh succeeds"); let cached_remote = manager.get_remote_models(&config).await; - assert_eq!(cached_remote, remote_models); + assert_models_contain(&cached_remote, &remote_models); let available = manager .list_models(&config, RefreshStrategy::OnlineIfUncached) @@ -464,11 +485,6 @@ mod tests { high_idx < low_idx, "higher priority should be listed before lower priority" ); - assert!( - available[high_idx].is_default, - "highest priority should be default" - ); - assert!(!available[low_idx].is_default); assert_eq!( models_mock.requests().len(), 1, @@ -508,22 +524,14 @@ mod tests { .refresh_available_models(&config, RefreshStrategy::OnlineIfUncached) .await .expect("first refresh succeeds"); - assert_eq!( - manager.get_remote_models(&config).await, - remote_models, - "remote cache should store fetched models" - ); + assert_models_contain(&manager.get_remote_models(&config).await, &remote_models); // Second call should read from cache and avoid the network. manager .refresh_available_models(&config, RefreshStrategy::OnlineIfUncached) .await .expect("cached refresh succeeds"); - assert_eq!( - manager.get_remote_models(&config).await, - remote_models, - "cache path should not mutate stored models" - ); + assert_models_contain(&manager.get_remote_models(&config).await, &remote_models); assert_eq!( models_mock.requests().len(), 1, @@ -587,11 +595,7 @@ mod tests { .refresh_available_models(&config, RefreshStrategy::OnlineIfUncached) .await .expect("second refresh succeeds"); - assert_eq!( - manager.get_remote_models(&config).await, - updated_models, - "stale cache should trigger refetch" - ); + assert_models_contain(&manager.get_remote_models(&config).await, &updated_models); assert_eq!( initial_mock.requests().len(), 1, diff --git a/codex-rs/core/tests/suite/remote_models.rs b/codex-rs/core/tests/suite/remote_models.rs index 8a6beb6e24..3b45989d8a 100644 --- a/codex-rs/core/tests/suite/remote_models.rs +++ b/codex-rs/core/tests/suite/remote_models.rs @@ -1,4 +1,5 @@ #![cfg(not(target_os = "windows"))] +#![allow(clippy::expect_used)] // unified exec is not supported on Windows OS use std::sync::Arc; @@ -434,8 +435,18 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> { .find(|model| model.model == "remote-alpha") .expect("remote model should be listed"); let mut expected_remote: ModelPreset = remote_model.into(); - expected_remote.is_default = true; + expected_remote.is_default = remote.is_default; assert_eq!(*remote, expected_remote); + let default_model = available + .iter() + .find(|model| model.show_in_picker) + .expect("default model should be set"); + assert!(default_model.is_default); + assert_eq!( + available.iter().filter(|model| model.is_default).count(), + 1, + "expected a single default model" + ); assert!( available .iter() @@ -451,6 +462,148 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_models_merge_adds_new_high_priority_first() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = MockServer::start().await; + let remote_model = test_remote_model("remote-top", ModelVisibility::List, -10_000); + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: vec![remote_model], + }, + ) + .await; + + let codex_home = TempDir::new()?; + let mut config = load_default_config_for_test(&codex_home).await; + config.features.enable(Feature::RemoteModels); + + let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + let provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + let manager = ModelsManager::with_provider( + codex_home.path().to_path_buf(), + codex_core::auth::AuthManager::from_auth_for_testing(auth), + provider, + ); + + let available = manager + .list_models(&config, RefreshStrategy::OnlineIfUncached) + .await; + assert_eq!( + available.first().map(|model| model.model.as_str()), + Some("remote-top") + ); + assert_eq!( + models_mock.requests().len(), + 1, + "expected a single /models request" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_models_merge_replaces_overlapping_model() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = MockServer::start().await; + let slug = bundled_model_slug(); + let mut remote_model = test_remote_model(&slug, ModelVisibility::List, 0); + remote_model.display_name = "Overridden".to_string(); + remote_model.description = Some("Overridden description".to_string()); + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: vec![remote_model.clone()], + }, + ) + .await; + + let codex_home = TempDir::new()?; + let mut config = load_default_config_for_test(&codex_home).await; + config.features.enable(Feature::RemoteModels); + + let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + let provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + let manager = ModelsManager::with_provider( + codex_home.path().to_path_buf(), + codex_core::auth::AuthManager::from_auth_for_testing(auth), + provider, + ); + + let available = manager + .list_models(&config, RefreshStrategy::OnlineIfUncached) + .await; + let overridden = available + .iter() + .find(|model| model.model == slug) + .expect("overlapping model should be listed"); + assert_eq!(overridden.display_name, remote_model.display_name); + assert_eq!( + overridden.description, + remote_model + .description + .expect("remote model should include description") + ); + assert_eq!( + models_mock.requests().len(), + 1, + "expected a single /models request" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_models_merge_preserves_bundled_models_on_empty_response() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = MockServer::start().await; + let models_mock = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await; + + let codex_home = TempDir::new()?; + let mut config = load_default_config_for_test(&codex_home).await; + config.features.enable(Feature::RemoteModels); + + let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + let provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + let manager = ModelsManager::with_provider( + codex_home.path().to_path_buf(), + codex_core::auth::AuthManager::from_auth_for_testing(auth), + provider, + ); + + let available = manager + .list_models(&config, RefreshStrategy::OnlineIfUncached) + .await; + let bundled_slug = bundled_model_slug(); + assert!( + available.iter().any(|model| model.model == bundled_slug), + "bundled models should remain available after empty remote response" + ); + assert_eq!( + models_mock.requests().len(), + 1, + "expected a single /models request" + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn remote_models_request_times_out_after_5s() -> Result<()> { skip_if_no_network!(Ok(())); @@ -588,6 +741,17 @@ async fn wait_for_model_available( } } +fn bundled_model_slug() -> String { + let response: ModelsResponse = serde_json::from_str(include_str!("../../models.json")) + .expect("bundled models.json should deserialize"); + response + .models + .first() + .expect("bundled models.json should include at least one model") + .slug + .clone() +} + fn test_remote_model(slug: &str, visibility: ModelVisibility, priority: i32) -> ModelInfo { test_remote_model_with_policy( slug,