mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Refresh on models etag mismatch (#8491)
- Send models etag - Refresh models on 412 - This wires `ModelsManager` to `ModelFamily` so we don't mutate it mid-turn
This commit is contained in:
@@ -59,6 +59,7 @@ pub enum ResponseEvent {
|
||||
summary_index: i64,
|
||||
},
|
||||
RateLimits(RateLimitSnapshot),
|
||||
ModelsEtag(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
|
||||
@@ -152,6 +152,9 @@ impl Stream for AggregatedStream {
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::provider::Provider;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
@@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
&self,
|
||||
client_version: &str,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ModelsResponse, ApiError> {
|
||||
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::GET, self.path());
|
||||
req.headers.extend(extra_headers.clone());
|
||||
@@ -66,7 +67,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
|
||||
let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
|
||||
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
|
||||
.map_err(|e| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to decode models response: {e}; body: {}",
|
||||
@@ -74,9 +75,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
))
|
||||
})?;
|
||||
|
||||
let etag = header_etag.unwrap_or(etag);
|
||||
|
||||
Ok(ModelsResponse { models, etag })
|
||||
Ok((models, header_etag))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,16 +101,15 @@ mod tests {
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
body: Arc<ModelsResponse>,
|
||||
etag: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for CapturingTransport {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
}),
|
||||
body: Arc::new(ModelsResponse { models: Vec::new() }),
|
||||
etag: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,8 +120,8 @@ mod tests {
|
||||
*self.last_request.lock().unwrap() = Some(req);
|
||||
let body = serde_json::to_vec(&*self.body).unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
if !self.body.etag.is_empty() {
|
||||
headers.insert(ETAG, self.body.etag.parse().unwrap());
|
||||
if let Some(etag) = &self.etag {
|
||||
headers.insert(ETAG, etag.parse().unwrap());
|
||||
}
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
@@ -166,14 +164,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn appends_client_version_query() {
|
||||
let response = ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
};
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
etag: None,
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -182,12 +178,12 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, _) = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 0);
|
||||
assert_eq!(models.len(), 0);
|
||||
|
||||
let url = transport
|
||||
.last_request
|
||||
@@ -231,12 +227,12 @@ mod tests {
|
||||
}))
|
||||
.unwrap(),
|
||||
],
|
||||
etag: String::new(),
|
||||
};
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
etag: None,
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -245,27 +241,25 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, _) = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
assert_eq!(result.models[0].supported_in_api, true);
|
||||
assert_eq!(result.models[0].priority, 1);
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].slug, "gpt-test");
|
||||
assert_eq!(models[0].supported_in_api, true);
|
||||
assert_eq!(models[0].priority, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_models_includes_etag() {
|
||||
let response = ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: "\"abc\"".to_string(),
|
||||
};
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
etag: Some("\"abc\"".to_string()),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -274,12 +268,12 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, etag) = client
|
||||
.list_models("0.1.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 0);
|
||||
assert_eq!(result.etag, "\"abc\"");
|
||||
assert_eq!(models.len(), 0);
|
||||
assert_eq!(etag, Some("\"abc\"".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,11 +51,19 @@ pub fn spawn_response_stream(
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> ResponseStream {
|
||||
let rate_limits = parse_rate_limit(&stream_response.headers);
|
||||
let models_etag = stream_response
|
||||
.headers
|
||||
.get("X-Models-Etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
if let Some(snapshot) = rate_limits {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
if let Some(etag) = models_etag {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
|
||||
}
|
||||
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
|
||||
|
||||
@@ -86,7 +86,6 @@ async fn models_client_hits_models_endpoint() {
|
||||
context_window: None,
|
||||
experimental_supported_tools: Vec::new(),
|
||||
}],
|
||||
etag: String::new(),
|
||||
};
|
||||
|
||||
Mock::given(method("GET"))
|
||||
@@ -102,13 +101,13 @@ async fn models_client_hits_models_endpoint() {
|
||||
let transport = ReqwestTransport::new(reqwest::Client::new());
|
||||
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
|
||||
|
||||
let result = client
|
||||
let (models, _) = client
|
||||
.list_models("0.1.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("models request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].slug, "gpt-test");
|
||||
|
||||
let received = server
|
||||
.received_requests()
|
||||
|
||||
@@ -246,7 +246,9 @@ impl Codex {
|
||||
|
||||
let config = Arc::new(config);
|
||||
if config.features.enabled(Feature::RemoteModels)
|
||||
&& let Err(err) = models_manager.refresh_available_models(&config).await
|
||||
&& let Err(err) = models_manager
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
{
|
||||
error!("failed to refresh available models: {err:?}");
|
||||
}
|
||||
@@ -2613,6 +2615,10 @@ async fn try_run_turn(
|
||||
// token usage is available to avoid duplicate TokenCount events.
|
||||
sess.update_rate_limits(&turn_context, snapshot).await;
|
||||
}
|
||||
ResponseEvent::ModelsEtag(etag) => {
|
||||
// Update internal state with latest models etag
|
||||
sess.services.models_manager.refresh_if_new_etag(etag).await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
@@ -3140,7 +3146,7 @@ mod tests {
|
||||
exec_policy,
|
||||
auth_manager: auth_manager.clone(),
|
||||
otel_manager: otel_manager.clone(),
|
||||
models_manager,
|
||||
models_manager: Arc::clone(&models_manager),
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
skills_manager,
|
||||
};
|
||||
@@ -3227,7 +3233,7 @@ mod tests {
|
||||
exec_policy,
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
otel_manager: otel_manager.clone(),
|
||||
models_manager,
|
||||
models_manager: Arc::clone(&models_manager),
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
skills_manager,
|
||||
};
|
||||
|
||||
@@ -77,7 +77,7 @@ impl ModelsManager {
|
||||
}
|
||||
|
||||
/// Fetch the latest remote models, using the on-disk cache when still fresh.
|
||||
pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> {
|
||||
pub async fn refresh_available_models_with_cache(&self, config: &Config) -> CoreResult<()> {
|
||||
if !config.features.enabled(Feature::RemoteModels)
|
||||
|| self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey)
|
||||
{
|
||||
@@ -86,7 +86,13 @@ impl ModelsManager {
|
||||
if self.try_load_cache().await {
|
||||
return Ok(());
|
||||
}
|
||||
self.refresh_available_models_no_cache().await
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_available_models_no_cache(&self) -> CoreResult<()> {
|
||||
if self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey) {
|
||||
return Ok(());
|
||||
}
|
||||
let auth = self.auth_manager.auth();
|
||||
let api_provider = self.provider.to_api_provider(Some(AuthMode::ChatGPT))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
@@ -94,13 +100,11 @@ impl ModelsManager {
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth);
|
||||
|
||||
let client_version = format_client_version_to_whole();
|
||||
let ModelsResponse { models, etag } = client
|
||||
let (models, etag) = client
|
||||
.list_models(&client_version, HeaderMap::new())
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
|
||||
let etag = (!etag.is_empty()).then_some(etag);
|
||||
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
self.persist_cache(&models, etag).await;
|
||||
@@ -108,7 +112,7 @@ impl ModelsManager {
|
||||
}
|
||||
|
||||
pub async fn list_models(&self, config: &Config) -> Vec<ModelPreset> {
|
||||
if let Err(err) = self.refresh_available_models(config).await {
|
||||
if let Err(err) = self.refresh_available_models_with_cache(config).await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
let remote_models = self.remote_models(config).await;
|
||||
@@ -135,7 +139,7 @@ impl ModelsManager {
|
||||
if let Some(model) = model.as_ref() {
|
||||
return model.to_string();
|
||||
}
|
||||
if let Err(err) = self.refresh_available_models(config).await {
|
||||
if let Err(err) = self.refresh_available_models_with_cache(config).await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
// if codex-auto-balanced exists & signed in with chatgpt mode, return it, otherwise return the default model
|
||||
@@ -153,6 +157,15 @@ impl ModelsManager {
|
||||
}
|
||||
OPENAI_DEFAULT_API_MODEL.to_string()
|
||||
}
|
||||
pub async fn refresh_if_new_etag(&self, etag: String) {
|
||||
let current_etag = self.get_etag().await;
|
||||
if current_etag.clone().is_some() && current_etag.as_deref() == Some(etag.as_str()) {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self.refresh_available_models_no_cache().await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn get_model_offline(model: Option<&str>) -> String {
|
||||
@@ -165,6 +178,10 @@ impl ModelsManager {
|
||||
Self::find_family_for_model(model).with_config_overrides(config)
|
||||
}
|
||||
|
||||
async fn get_etag(&self) -> Option<String> {
|
||||
self.etag.read().await.clone()
|
||||
}
|
||||
|
||||
/// Replace the cached remote models and rebuild the derived presets list.
|
||||
async fn apply_remote_models(&self, models: Vec<ModelInfo>) {
|
||||
*self.remote_models.write().await = models;
|
||||
@@ -288,26 +305,14 @@ impl ModelsManager {
|
||||
|
||||
/// Convert a client version string to a whole version string (e.g. "1.2.3-alpha.4" -> "1.2.3")
|
||||
fn format_client_version_to_whole() -> String {
|
||||
format_client_version_from_parts(
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
env!("CARGO_PKG_VERSION_MAJOR"),
|
||||
env!("CARGO_PKG_VERSION_MINOR"),
|
||||
env!("CARGO_PKG_VERSION_PATCH"),
|
||||
env!("CARGO_PKG_VERSION_PATCH")
|
||||
)
|
||||
}
|
||||
|
||||
fn format_client_version_from_parts(major: &str, minor: &str, patch: &str) -> String {
|
||||
const DEV_VERSION: &str = "0.0.0";
|
||||
const FALLBACK_VERSION: &str = "99.99.99";
|
||||
|
||||
let normalized = format!("{major}.{minor}.{patch}");
|
||||
|
||||
if normalized == DEV_VERSION {
|
||||
FALLBACK_VERSION.to_string()
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::cache::ModelsCache;
|
||||
@@ -388,7 +393,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: remote_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -406,7 +410,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
let cached_remote = manager.remote_models(&config).await;
|
||||
@@ -445,7 +449,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: remote_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -466,7 +469,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("first refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -477,7 +480,7 @@ mod tests {
|
||||
|
||||
// Second call should read from cache and avoid the network.
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("cached refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -500,7 +503,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -521,7 +523,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
@@ -541,13 +543,12 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: updated_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -575,7 +576,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -594,7 +594,7 @@ mod tests {
|
||||
manager.cache_ttl = Duration::ZERO;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
@@ -604,13 +604,12 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: refreshed_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
|
||||
|
||||
@@ -670,6 +670,25 @@ pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> Mod
|
||||
models_mock
|
||||
}
|
||||
|
||||
pub async fn mount_models_once_with_etag(
|
||||
server: &MockServer,
|
||||
body: ModelsResponse,
|
||||
etag: &str,
|
||||
) -> ModelsMock {
|
||||
let (mock, models_mock) = models_mock();
|
||||
mock.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "application/json")
|
||||
// ModelsClient reads the ETag header, not a JSON field.
|
||||
.insert_header("ETag", etag)
|
||||
.set_body_json(body.clone()),
|
||||
)
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
models_mock
|
||||
}
|
||||
|
||||
pub async fn start_mock_server() -> MockServer {
|
||||
let server = MockServer::builder()
|
||||
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
||||
@@ -677,14 +696,7 @@ pub async fn start_mock_server() -> MockServer {
|
||||
.await;
|
||||
|
||||
// Provide a default `/models` response so tests remain hermetic when the client queries it.
|
||||
let _ = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ mod list_models;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod model_tools;
|
||||
mod models_etag_responses;
|
||||
mod otel;
|
||||
mod prompt_caching;
|
||||
mod quota_exceeded;
|
||||
|
||||
139
codex-rs/core/tests/suite/models_etag_responses.rs
Normal file
139
codex-rs/core/tests/suite/models_etag_responses.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn refresh_models_on_models_etag_mismatch_and_avoid_duplicate_models_fetch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
const ETAG_1: &str = "\"models-etag-1\"";
|
||||
const ETAG_2: &str = "\"models-etag-2\"";
|
||||
const CALL_ID: &str = "local-shell-call-1";
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// 1) On spawn, Codex fetches /models and stores the ETag.
|
||||
let spawn_models_mock = responses::mount_models_once_with_etag(
|
||||
&server,
|
||||
ModelsResponse { models: Vec::new() },
|
||||
ETAG_1,
|
||||
)
|
||||
.await;
|
||||
|
||||
let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing();
|
||||
let mut builder = test_codex()
|
||||
.with_auth(auth)
|
||||
.with_model("gpt-5")
|
||||
.with_config(|config| {
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
// Keep this test deterministic: no request retries, and a small stream retry budget.
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
config.model_provider.stream_max_retries = Some(1);
|
||||
});
|
||||
|
||||
let test = builder.build(&server).await?;
|
||||
let codex = Arc::clone(&test.codex);
|
||||
let cwd = Arc::clone(&test.cwd);
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
assert_eq!(spawn_models_mock.requests().len(), 1);
|
||||
assert_eq!(spawn_models_mock.single_request_path(), "/v1/models");
|
||||
|
||||
// 2) If the server sends a different X-Models-Etag on /responses, Codex refreshes /models.
|
||||
let refresh_models_mock = responses::mount_models_once_with_etag(
|
||||
&server,
|
||||
ModelsResponse { models: Vec::new() },
|
||||
ETAG_2,
|
||||
)
|
||||
.await;
|
||||
|
||||
// First /responses request (user message) succeeds and returns a tool call.
|
||||
// It also includes a mismatched X-Models-Etag, which should trigger a /models refresh.
|
||||
let first_response_body = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_local_shell_call(CALL_ID, "completed", vec!["/bin/echo", "etag ok"]),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_response_once(
|
||||
&server,
|
||||
sse_response(first_response_body).insert_header("X-Models-Etag", ETAG_2),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Second /responses request (tool output) includes the same X-Models-Etag; Codex should not
|
||||
// refetch /models again after it has already refreshed the catalog.
|
||||
let completion_response_body = sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
let tool_output_mock = responses::mount_response_once(
|
||||
&server,
|
||||
sse_response(completion_response_body).insert_header("X-Models-Etag", ETAG_2),
|
||||
)
|
||||
.await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "please run a tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Assert /models was refreshed exactly once after the X-Models-Etag mismatch.
|
||||
assert_eq!(refresh_models_mock.requests().len(), 1);
|
||||
assert_eq!(refresh_models_mock.single_request_path(), "/v1/models");
|
||||
let refresh_req = refresh_models_mock
|
||||
.requests()
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("one request");
|
||||
// Ensure Codex includes client_version on refresh. (This is a stable signal that we're using the /models client.)
|
||||
assert!(
|
||||
refresh_req
|
||||
.url
|
||||
.query_pairs()
|
||||
.any(|(k, _)| k == "client_version"),
|
||||
"expected /models refresh to include client_version query param"
|
||||
);
|
||||
|
||||
// Assert the tool output /responses request succeeded and did not trigger another /models fetch.
|
||||
let tool_req = tool_output_mock.single_request();
|
||||
let _ = tool_req.function_call_output(CALL_ID);
|
||||
assert_eq!(refresh_models_mock.requests().len(), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -89,7 +89,6 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -226,7 +225,6 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -304,7 +302,6 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model.clone()],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -324,7 +321,7 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.refresh_available_models_with_cache(&config)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
|
||||
@@ -362,7 +359,6 @@ async fn remote_models_hide_picker_only_models() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -511,6 +511,7 @@ impl OtelManager {
|
||||
"reasoning_summary_part_added".into()
|
||||
}
|
||||
ResponseEvent::RateLimits(_) => "rate_limits".into(),
|
||||
ResponseEvent::ModelsEtag(_) => "models_etag".into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -187,8 +187,6 @@ pub struct ModelInfo {
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)]
|
||||
pub struct ModelsResponse {
|
||||
pub models: Vec<ModelInfo>,
|
||||
#[serde(default)]
|
||||
pub etag: String,
|
||||
}
|
||||
|
||||
// convert ModelInfo to ModelPreset
|
||||
|
||||
Reference in New Issue
Block a user