mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Require models refresh on cli version mismatch (#10414)
This commit is contained in:
@@ -75,9 +75,11 @@ pub fn write_models_cache_with_models(
|
||||
let cache_path = codex_home.join("models_cache.json");
|
||||
// DateTime<Utc> serializes to RFC3339 format by default with serde
|
||||
let fetched_at: DateTime<Utc> = Utc::now();
|
||||
let client_version = codex_core::models_manager::client_version_to_whole();
|
||||
let cache = json!({
|
||||
"fetched_at": fetched_at,
|
||||
"etag": null,
|
||||
"client_version": client_version,
|
||||
"models": models
|
||||
});
|
||||
std::fs::write(cache_path, serde_json::to_string_pretty(&cache)?)
|
||||
|
||||
@@ -27,7 +27,7 @@ impl ModelsCacheManager {
|
||||
}
|
||||
|
||||
/// Attempt to load a fresh cache entry. Returns `None` if the cache doesn't exist or is stale.
|
||||
pub(crate) async fn load_fresh(&self) -> Option<ModelsCache> {
|
||||
pub(crate) async fn load_fresh(&self, expected_version: &str) -> Option<ModelsCache> {
|
||||
let cache = match self.load().await {
|
||||
Ok(cache) => cache?,
|
||||
Err(err) => {
|
||||
@@ -35,6 +35,9 @@ impl ModelsCacheManager {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
if cache.client_version.as_deref() != Some(expected_version) {
|
||||
return None;
|
||||
}
|
||||
if !cache.is_fresh(self.cache_ttl) {
|
||||
return None;
|
||||
}
|
||||
@@ -42,10 +45,16 @@ impl ModelsCacheManager {
|
||||
}
|
||||
|
||||
/// Persist the cache to disk, creating parent directories as needed.
|
||||
pub(crate) async fn persist_cache(&self, models: &[ModelInfo], etag: Option<String>) {
|
||||
pub(crate) async fn persist_cache(
|
||||
&self,
|
||||
models: &[ModelInfo],
|
||||
etag: Option<String>,
|
||||
client_version: String,
|
||||
) {
|
||||
let cache = ModelsCache {
|
||||
fetched_at: Utc::now(),
|
||||
etag,
|
||||
client_version: Some(client_version),
|
||||
models: models.to_vec(),
|
||||
};
|
||||
if let Err(err) = self.save_internal(&cache).await {
|
||||
@@ -103,6 +112,20 @@ impl ModelsCacheManager {
|
||||
f(&mut cache.fetched_at);
|
||||
self.save_internal(&cache).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// Mutate the full cache contents for testing.
|
||||
pub(crate) async fn mutate_cache_for_test<F>(&self, f: F) -> io::Result<()>
|
||||
where
|
||||
F: FnOnce(&mut ModelsCache),
|
||||
{
|
||||
let mut cache = match self.load().await? {
|
||||
Some(cache) => cache,
|
||||
None => return Err(io::Error::new(ErrorKind::NotFound, "cache not found")),
|
||||
};
|
||||
f(&mut cache);
|
||||
self.save_internal(&cache).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialized snapshot of models and metadata cached on disk.
|
||||
@@ -111,6 +134,8 @@ pub(crate) struct ModelsCache {
|
||||
pub(crate) fetched_at: DateTime<Utc>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) etag: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) client_version: Option<String>,
|
||||
pub(crate) models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ impl ModelsManager {
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth);
|
||||
|
||||
let client_version = format_client_version_to_whole();
|
||||
let client_version = crate::models_manager::client_version_to_whole();
|
||||
let (models, etag) = timeout(
|
||||
MODELS_REFRESH_TIMEOUT,
|
||||
client.list_models(&client_version, HeaderMap::new()),
|
||||
@@ -221,7 +221,9 @@ impl ModelsManager {
|
||||
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
self.cache_manager.persist_cache(&models, etag).await;
|
||||
self.cache_manager
|
||||
.persist_cache(&models, etag, client_version)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -255,7 +257,8 @@ impl ModelsManager {
|
||||
async fn try_load_cache(&self) -> bool {
|
||||
let _timer =
|
||||
codex_otel::start_global_timer("codex.remote_models.load_cache.duration_ms", &[]);
|
||||
let cache = match self.cache_manager.load_fresh().await {
|
||||
let client_version = crate::models_manager::client_version_to_whole();
|
||||
let cache = match self.cache_manager.load_fresh(&client_version).await {
|
||||
Some(cache) => cache,
|
||||
None => return false,
|
||||
};
|
||||
@@ -350,16 +353,6 @@ impl ModelsManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a client version string to a whole version string (e.g. "1.2.3-alpha.4" -> "1.2.3")
|
||||
fn format_client_version_to_whole() -> String {
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
env!("CARGO_PKG_VERSION_MAJOR"),
|
||||
env!("CARGO_PKG_VERSION_MINOR"),
|
||||
env!("CARGO_PKG_VERSION_PATCH")
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -613,6 +606,75 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_refetches_when_version_mismatch() {
|
||||
let server = MockServer::start().await;
|
||||
let initial_models = vec![remote_model("old", "Old", 1)];
|
||||
let initial_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = tempdir().expect("temp dir");
|
||||
let mut config = ConfigBuilder::default()
|
||||
.codex_home(codex_home.path().to_path_buf())
|
||||
.build()
|
||||
.await
|
||||
.expect("load default test config");
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
let auth_manager = Arc::new(AuthManager::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
));
|
||||
let provider = provider_for(server.uri());
|
||||
let manager =
|
||||
ModelsManager::with_provider(codex_home.path().to_path_buf(), auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
manager
|
||||
.cache_manager
|
||||
.mutate_cache_for_test(|cache| {
|
||||
let client_version = crate::models_manager::client_version_to_whole();
|
||||
cache.client_version = Some(format!("{client_version}-mismatch"));
|
||||
})
|
||||
.await
|
||||
.expect("cache mutation succeeds");
|
||||
|
||||
let updated_models = vec![remote_model("new", "New", 2)];
|
||||
server.reset().await;
|
||||
let refreshed_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: updated_models.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config, RefreshStrategy::OnlineIfUncached)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
assert_models_contain(&manager.get_remote_models(&config).await, &updated_models);
|
||||
assert_eq!(
|
||||
initial_mock.requests().len(),
|
||||
1,
|
||||
"initial refresh should only hit /models once"
|
||||
);
|
||||
assert_eq!(
|
||||
refreshed_mock.requests().len(),
|
||||
1,
|
||||
"version mismatch should fetch /models once"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_drops_removed_remote_models() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
@@ -6,3 +6,13 @@ pub mod model_presets;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use collaboration_mode_presets::test_builtin_collaboration_mode_presets;
|
||||
|
||||
/// Convert the client version string to a whole version string (e.g. "1.2.3-alpha.4" -> "1.2.3").
|
||||
pub fn client_version_to_whole() -> String {
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
env!("CARGO_PKG_VERSION_MAJOR"),
|
||||
env!("CARGO_PKG_VERSION_MINOR"),
|
||||
env!("CARGO_PKG_VERSION_PATCH")
|
||||
)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,9 @@ use wiremock::MockServer;
|
||||
const ETAG: &str = "\"models-etag-ttl\"";
|
||||
const CACHE_FILE: &str = "models_cache.json";
|
||||
const REMOTE_MODEL: &str = "codex-test-ttl";
|
||||
const VERSIONED_MODEL: &str = "codex-test-versioned";
|
||||
const MISSING_VERSION_MODEL: &str = "codex-test-missing-version";
|
||||
const DIFFERENT_VERSION_MODEL: &str = "codex-test-different-version";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn renews_cache_ttl_on_matching_models_etag() -> Result<()> {
|
||||
@@ -131,11 +134,157 @@ async fn renews_cache_ttl_on_matching_models_etag() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn uses_cache_when_version_matches() -> Result<()> {
|
||||
let server = MockServer::start().await;
|
||||
let cached_model = test_remote_model(VERSIONED_MODEL, 1);
|
||||
let models_mock = responses::mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![test_remote_model("remote", 2)],
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing());
|
||||
builder = builder
|
||||
.with_pre_build_hook(move |home| {
|
||||
let cache = ModelsCache {
|
||||
fetched_at: Utc::now(),
|
||||
etag: None,
|
||||
client_version: Some(codex_core::models_manager::client_version_to_whole()),
|
||||
models: vec![cached_model],
|
||||
};
|
||||
let cache_path = home.join(CACHE_FILE);
|
||||
write_cache_sync(&cache_path, &cache).expect("write cache");
|
||||
})
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
});
|
||||
|
||||
let test = builder.build(&server).await?;
|
||||
let models_manager = test.thread_manager.get_models_manager();
|
||||
let models = models_manager
|
||||
.list_models(&test.config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
models.iter().any(|preset| preset.model == VERSIONED_MODEL),
|
||||
"expected cached model"
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
0,
|
||||
"/models should not be called when cache version matches"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn refreshes_when_cache_version_missing() -> Result<()> {
|
||||
let server = MockServer::start().await;
|
||||
let cached_model = test_remote_model(MISSING_VERSION_MODEL, 1);
|
||||
let models_mock = responses::mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![test_remote_model("remote-missing", 2)],
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing());
|
||||
builder = builder
|
||||
.with_pre_build_hook(move |home| {
|
||||
let cache = ModelsCache {
|
||||
fetched_at: Utc::now(),
|
||||
etag: None,
|
||||
client_version: None,
|
||||
models: vec![cached_model],
|
||||
};
|
||||
let cache_path = home.join(CACHE_FILE);
|
||||
write_cache_sync(&cache_path, &cache).expect("write cache");
|
||||
})
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
});
|
||||
|
||||
let test = builder.build(&server).await?;
|
||||
let models_manager = test.thread_manager.get_models_manager();
|
||||
let models = models_manager
|
||||
.list_models(&test.config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
models.iter().any(|preset| preset.model == "remote-missing"),
|
||||
"expected refreshed models"
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"/models should be called when cache version is missing"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn refreshes_when_cache_version_differs() -> Result<()> {
|
||||
let server = MockServer::start().await;
|
||||
let cached_model = test_remote_model(DIFFERENT_VERSION_MODEL, 1);
|
||||
let models_mock = responses::mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![test_remote_model("remote-different", 2)],
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing());
|
||||
builder = builder
|
||||
.with_pre_build_hook(move |home| {
|
||||
let client_version = codex_core::models_manager::client_version_to_whole();
|
||||
let cache = ModelsCache {
|
||||
fetched_at: Utc::now(),
|
||||
etag: None,
|
||||
client_version: Some(format!("{client_version}-diff")),
|
||||
models: vec![cached_model],
|
||||
};
|
||||
let cache_path = home.join(CACHE_FILE);
|
||||
write_cache_sync(&cache_path, &cache).expect("write cache");
|
||||
})
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
});
|
||||
|
||||
let test = builder.build(&server).await?;
|
||||
let models_manager = test.thread_manager.get_models_manager();
|
||||
let models = models_manager
|
||||
.list_models(&test.config, RefreshStrategy::OnlineIfUncached)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
models
|
||||
.iter()
|
||||
.any(|preset| preset.model == "remote-different"),
|
||||
"expected refreshed models"
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"/models should be called when cache version differs"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn rewrite_cache_timestamp(path: &Path, fetched_at: DateTime<Utc>) -> Result<()> {
|
||||
let mut cache = read_cache(path).await?;
|
||||
cache.fetched_at = fetched_at;
|
||||
let contents = serde_json::to_vec_pretty(&cache)?;
|
||||
tokio::fs::write(path, contents).await?;
|
||||
write_cache(path, &cache).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -145,11 +294,25 @@ async fn read_cache(path: &Path) -> Result<ModelsCache> {
|
||||
Ok(cache)
|
||||
}
|
||||
|
||||
async fn write_cache(path: &Path, cache: &ModelsCache) -> Result<()> {
|
||||
let contents = serde_json::to_vec_pretty(cache)?;
|
||||
tokio::fs::write(path, contents).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_cache_sync(path: &Path, cache: &ModelsCache) -> Result<()> {
|
||||
let contents = serde_json::to_vec_pretty(cache)?;
|
||||
std::fs::write(path, contents)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ModelsCache {
|
||||
fetched_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
etag: Option<String>,
|
||||
#[serde(default)]
|
||||
client_version: Option<String>,
|
||||
models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user