mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
15 Commits
animations
...
tag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
124c7fc2be | ||
|
|
0a1323747b | ||
|
|
348d379509 | ||
|
|
6912ba9fda | ||
|
|
27cec53ddc | ||
|
|
42273d94e8 | ||
|
|
1a5289a4ef | ||
|
|
359142f22f | ||
|
|
ecff4d4f72 | ||
|
|
985333feff | ||
|
|
e01610f762 | ||
|
|
09693d259b | ||
|
|
f8ba48d995 | ||
|
|
677532f97b | ||
|
|
beb83225e5 |
@@ -5,6 +5,7 @@ use crate::provider::Provider;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
@@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
&self,
|
||||
client_version: &str,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ModelsResponse, ApiError> {
|
||||
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::GET, self.path());
|
||||
req.headers.extend(extra_headers.clone());
|
||||
@@ -66,7 +67,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
|
||||
let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
|
||||
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
|
||||
.map_err(|e| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to decode models response: {e}; body: {}",
|
||||
@@ -74,9 +75,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
))
|
||||
})?;
|
||||
|
||||
let etag = header_etag.unwrap_or(etag);
|
||||
|
||||
Ok(ModelsResponse { models, etag })
|
||||
Ok((models, header_etag))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,16 +101,15 @@ mod tests {
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
body: Arc<ModelsResponse>,
|
||||
response_etag: Arc<Option<String>>,
|
||||
}
|
||||
|
||||
impl Default for CapturingTransport {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
}),
|
||||
body: Arc::new(ModelsResponse { models: Vec::new() }),
|
||||
response_etag: Arc::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,8 +120,8 @@ mod tests {
|
||||
*self.last_request.lock().unwrap() = Some(req);
|
||||
let body = serde_json::to_vec(&*self.body).unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
if !self.body.etag.is_empty() {
|
||||
headers.insert(ETAG, self.body.etag.parse().unwrap());
|
||||
if let Some(etag) = self.response_etag.as_ref().as_deref() {
|
||||
headers.insert(ETAG, etag.parse().unwrap());
|
||||
}
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
@@ -166,14 +164,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn appends_client_version_query() {
|
||||
let response = ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
};
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
response_etag: Arc::new(None),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -182,12 +178,12 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, _etag) = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 0);
|
||||
assert_eq!(models.len(), 0);
|
||||
|
||||
let url = transport
|
||||
.last_request
|
||||
@@ -232,12 +228,12 @@ mod tests {
|
||||
}))
|
||||
.unwrap(),
|
||||
],
|
||||
etag: String::new(),
|
||||
};
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
response_etag: Arc::new(None),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -246,27 +242,25 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, _etag) = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
assert_eq!(result.models[0].supported_in_api, true);
|
||||
assert_eq!(result.models[0].priority, 1);
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].slug, "gpt-test");
|
||||
assert_eq!(models[0].supported_in_api, true);
|
||||
assert_eq!(models[0].priority, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_models_includes_etag() {
|
||||
let response = ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: "\"abc\"".to_string(),
|
||||
};
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
response_etag: Arc::new(Some("\"abc\"".to_string())),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
@@ -275,12 +269,12 @@ mod tests {
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
let (models, etag) = client
|
||||
.list_models("0.1.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 0);
|
||||
assert_eq!(result.etag, "\"abc\"");
|
||||
assert_eq!(models.len(), 0);
|
||||
assert_eq!(etag.as_deref(), Some("\"abc\""));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() {
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
experimental_supported_tools: Vec::new(),
|
||||
}],
|
||||
etag: String::new(),
|
||||
};
|
||||
|
||||
Mock::given(method("GET"))
|
||||
@@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() {
|
||||
let transport = ReqwestTransport::new(reqwest::Client::new());
|
||||
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
|
||||
|
||||
let result = client
|
||||
let (models, _etag) = client
|
||||
.list_models("0.1.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("models request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].slug, "gpt-test");
|
||||
|
||||
let received = server
|
||||
.received_requests()
|
||||
|
||||
@@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
|
||||
status,
|
||||
request_id: extract_request_id(headers.as_ref()),
|
||||
})
|
||||
} else if status == http::StatusCode::PRECONDITION_FAILED
|
||||
&& body_text
|
||||
.contains("Models catalog has changed. Please refresh your models list.")
|
||||
{
|
||||
CodexErr::OutdatedModels
|
||||
} else {
|
||||
CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
|
||||
@@ -33,6 +33,7 @@ use http::StatusCode as HttpStatusCode;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -53,11 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_family: ModelFamily,
|
||||
model_family: RwLock<ModelFamily>,
|
||||
models_etag: RwLock<Option<String>>,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
conversation_id: ConversationId,
|
||||
@@ -72,6 +74,7 @@ impl ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_family: ModelFamily,
|
||||
models_etag: Option<String>,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
@@ -82,7 +85,8 @@ impl ModelClient {
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
model_family,
|
||||
model_family: RwLock::new(model_family),
|
||||
models_etag: RwLock::new(models_etag),
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
@@ -92,8 +96,8 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_model_context_window(&self) -> Option<i64> {
|
||||
let model_family = self.get_model_family();
|
||||
pub async fn get_model_context_window(&self) -> Option<i64> {
|
||||
let model_family = self.get_model_family().await;
|
||||
let effective_context_window_percent = model_family.effective_context_window_percent;
|
||||
model_family
|
||||
.context_window
|
||||
@@ -146,7 +150,7 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let model_family = self.get_model_family().await;
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
|
||||
@@ -167,7 +171,7 @@ impl ModelClient {
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(
|
||||
&self.get_model(),
|
||||
&self.get_model().await,
|
||||
&api_prompt,
|
||||
Some(conversation_id.clone()),
|
||||
Some(session_source.clone()),
|
||||
@@ -200,7 +204,7 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let model_family = self.get_model_family().await;
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
||||
@@ -262,11 +266,14 @@ impl ModelClient {
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config),
|
||||
extra_headers: beta_feature_headers(
|
||||
&self.config,
|
||||
self.get_models_etag().await.clone(),
|
||||
),
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(&self.get_model(), &api_prompt, options)
|
||||
.stream_prompt(&self.get_model().await, &api_prompt, options)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
@@ -297,13 +304,25 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Returns the currently configured model slug.
|
||||
pub fn get_model(&self) -> String {
|
||||
self.get_model_family().get_model_slug().to_string()
|
||||
pub async fn get_model(&self) -> String {
|
||||
self.get_model_family().await.get_model_slug().to_string()
|
||||
}
|
||||
|
||||
/// Returns the currently configured model family.
|
||||
pub fn get_model_family(&self) -> ModelFamily {
|
||||
self.model_family.clone()
|
||||
pub async fn get_model_family(&self) -> ModelFamily {
|
||||
self.model_family.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn get_models_etag(&self) -> Option<String> {
|
||||
self.models_etag.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn update_models_etag(&self, etag: Option<String>) {
|
||||
*self.models_etag.write().await = etag;
|
||||
}
|
||||
|
||||
pub async fn update_model_family(&self, model_family: ModelFamily) {
|
||||
*self.model_family.write().await = model_family;
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
@@ -340,10 +359,10 @@ impl ModelClient {
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let instructions = prompt
|
||||
.get_full_instructions(&self.get_model_family())
|
||||
.get_full_instructions(&self.get_model_family().await)
|
||||
.into_owned();
|
||||
let payload = ApiCompactionInput {
|
||||
model: &self.get_model(),
|
||||
model: &self.get_model().await,
|
||||
input: &prompt.input,
|
||||
instructions: &instructions,
|
||||
};
|
||||
@@ -398,7 +417,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
|
||||
}
|
||||
}
|
||||
|
||||
fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
|
||||
fn beta_feature_headers(config: &Config, models_etag: Option<String>) -> ApiHeaderMap {
|
||||
let enabled = FEATURES
|
||||
.iter()
|
||||
.filter_map(|spec| {
|
||||
@@ -416,6 +435,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
|
||||
{
|
||||
headers.insert("x-codex-beta-features", header_value);
|
||||
}
|
||||
if let Some(etag) = models_etag
|
||||
&& let Ok(header_value) = HeaderValue::from_str(&etag)
|
||||
{
|
||||
headers.insert("X-If-Models-Match", header_value);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::Result;
|
||||
use crate::features::Feature;
|
||||
use crate::openai_models::model_family::ModelFamily;
|
||||
use crate::tools::ToolRouter;
|
||||
pub use codex_api::common::ResponseEvent;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -44,6 +48,28 @@ pub struct Prompt {
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) async fn new(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
router: &ToolRouter,
|
||||
input: &[ResponseItem],
|
||||
) -> Prompt {
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.await
|
||||
.supports_parallel_tool_calls;
|
||||
|
||||
Prompt {
|
||||
input: input.to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls: model_supports_parallel
|
||||
&& sess.enabled(Feature::ParallelToolCalls),
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
|
||||
@@ -249,7 +249,7 @@ impl Codex {
|
||||
|
||||
let config = Arc::new(config);
|
||||
if config.features.enabled(Feature::RemoteModels)
|
||||
&& let Err(err) = models_manager.refresh_available_models(&config).await
|
||||
&& let Err(err) = models_manager.try_refresh_available_models(&config).await
|
||||
{
|
||||
error!("failed to refresh available models: {err:?}");
|
||||
}
|
||||
@@ -492,6 +492,7 @@ impl Session {
|
||||
session_configuration: &SessionConfiguration,
|
||||
per_turn_config: Config,
|
||||
model_family: ModelFamily,
|
||||
models_etag: Option<String>,
|
||||
conversation_id: ConversationId,
|
||||
sub_id: String,
|
||||
) -> TurnContext {
|
||||
@@ -505,6 +506,7 @@ impl Session {
|
||||
per_turn_config.clone(),
|
||||
auth_manager,
|
||||
model_family.clone(),
|
||||
models_etag,
|
||||
otel_manager,
|
||||
provider,
|
||||
session_configuration.model_reasoning_effort,
|
||||
@@ -788,7 +790,7 @@ impl Session {
|
||||
}
|
||||
})
|
||||
{
|
||||
let curr = turn_context.client.get_model();
|
||||
let curr = turn_context.client.get_model().await;
|
||||
if prev != curr {
|
||||
warn!(
|
||||
"resuming session with different model: previous={prev}, current={curr}"
|
||||
@@ -919,6 +921,7 @@ impl Session {
|
||||
.models_manager
|
||||
.construct_model_family(session_configuration.model.as_str(), &per_turn_config)
|
||||
.await;
|
||||
let models_etag = self.services.models_manager.get_models_etag().await;
|
||||
let mut turn_context: TurnContext = Self::make_turn_context(
|
||||
Some(Arc::clone(&self.services.auth_manager)),
|
||||
&self.services.otel_manager,
|
||||
@@ -926,6 +929,7 @@ impl Session {
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_family,
|
||||
models_etag,
|
||||
self.conversation_id,
|
||||
sub_id,
|
||||
);
|
||||
@@ -1334,7 +1338,7 @@ impl Session {
|
||||
if let Some(token_usage) = token_usage {
|
||||
state.update_token_info_from_usage(
|
||||
token_usage,
|
||||
turn_context.client.get_model_context_window(),
|
||||
turn_context.client.get_model_context_window().await,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1346,6 +1350,7 @@ impl Session {
|
||||
.clone_history()
|
||||
.await
|
||||
.estimate_token_count(turn_context)
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
@@ -1366,7 +1371,7 @@ impl Session {
|
||||
};
|
||||
|
||||
if info.model_context_window.is_none() {
|
||||
info.model_context_window = turn_context.client.get_model_context_window();
|
||||
info.model_context_window = turn_context.client.get_model_context_window().await;
|
||||
}
|
||||
|
||||
state.set_token_info(Some(info));
|
||||
@@ -1396,7 +1401,7 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
|
||||
let context_window = turn_context.client.get_model_context_window();
|
||||
let context_window = turn_context.client.get_model_context_window().await;
|
||||
if let Some(context_window) = context_window {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
@@ -2105,6 +2110,7 @@ async fn spawn_review_thread(
|
||||
.models_manager
|
||||
.construct_model_family(&model, &config)
|
||||
.await;
|
||||
let models_etag = sess.services.models_manager.get_models_etag().await;
|
||||
// For reviews, disable web_search and view_image regardless of global settings.
|
||||
let mut review_features = sess.features.clone();
|
||||
review_features
|
||||
@@ -2137,6 +2143,7 @@ async fn spawn_review_thread(
|
||||
per_turn_config.clone(),
|
||||
auth_manager,
|
||||
model_family.clone(),
|
||||
models_etag,
|
||||
otel_manager,
|
||||
provider,
|
||||
per_turn_config.model_reasoning_effort,
|
||||
@@ -2231,6 +2238,7 @@ pub(crate) async fn run_task(
|
||||
let auto_compact_limit = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.await
|
||||
.auto_compact_token_limit()
|
||||
.unwrap_or(i64::MAX);
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
@@ -2238,7 +2246,7 @@ pub(crate) async fn run_task(
|
||||
run_auto_compact(&sess, &turn_context).await;
|
||||
}
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
|
||||
@@ -2303,7 +2311,7 @@ pub(crate) async fn run_task(
|
||||
.collect::<Vec<String>>();
|
||||
match run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
&turn_context,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
turn_input,
|
||||
cancellation_token.child_token(),
|
||||
@@ -2362,6 +2370,36 @@ pub(crate) async fn run_task(
|
||||
last_agent_message
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_models_and_reset_turn_context(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
) {
|
||||
let config = {
|
||||
let state = sess.state.lock().await;
|
||||
state
|
||||
.session_configuration
|
||||
.original_config_do_not_use
|
||||
.clone()
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.services
|
||||
.models_manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
{
|
||||
error!("failed to refresh models after outdated models error: {err}");
|
||||
}
|
||||
let model = turn_context.client.get_model().await;
|
||||
let model_family = sess
|
||||
.services
|
||||
.models_manager
|
||||
.construct_model_family(&model, &config)
|
||||
.await;
|
||||
let models_etag = sess.services.models_manager.get_models_etag().await;
|
||||
turn_context.client.update_model_family(model_family).await;
|
||||
turn_context.client.update_models_etag(models_etag).await;
|
||||
}
|
||||
|
||||
async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>) {
|
||||
if should_use_remote_compact_task(sess.as_ref(), &turn_context.client.get_provider()) {
|
||||
run_inline_remote_auto_compact_task(Arc::clone(sess), Arc::clone(turn_context)).await;
|
||||
@@ -2374,17 +2412,19 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.client.get_model(),
|
||||
model = tracing::field::Empty,
|
||||
cwd = %turn_context.cwd.display()
|
||||
)
|
||||
)]
|
||||
async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let model = turn_context.client.get_model().await;
|
||||
tracing::Span::current().record("model", field::display(&model));
|
||||
let mcp_tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
@@ -2393,37 +2433,32 @@ async fn run_turn(
|
||||
.list_all_tools()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await?;
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
));
|
||||
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls: model_supports_parallel && sess.enabled(Feature::ParallelToolCalls),
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
));
|
||||
let prompt = Prompt::new(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
router.as_ref(),
|
||||
&input,
|
||||
);
|
||||
|
||||
match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
&prompt.await,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
@@ -2437,13 +2472,13 @@ async fn run_turn(
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
sess.set_total_tokens_full(turn_context).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
sess.update_rate_limits(&turn_context, rate_limits).await;
|
||||
sess.update_rate_limits(turn_context, rate_limits).await;
|
||||
}
|
||||
return Err(CodexErr::UsageLimitReached(e));
|
||||
}
|
||||
@@ -2457,6 +2492,11 @@ async fn run_turn(
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
// Refresh models if we got an outdated models error
|
||||
if matches!(e, CodexErr::OutdatedModels) {
|
||||
refresh_models_and_reset_turn_context(&sess, turn_context).await;
|
||||
continue;
|
||||
}
|
||||
let delay = match e {
|
||||
CodexErr::Stream(_, Some(delay)) => delay,
|
||||
_ => backoff(retries),
|
||||
@@ -2469,7 +2509,7 @@ async fn run_turn(
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&turn_context,
|
||||
turn_context,
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
e,
|
||||
)
|
||||
@@ -2514,7 +2554,7 @@ async fn drain_in_flight(
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.client.get_model()
|
||||
model = tracing::field::Empty,
|
||||
)
|
||||
)]
|
||||
async fn try_run_turn(
|
||||
@@ -2525,11 +2565,13 @@ async fn try_run_turn(
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let model = turn_context.client.get_model().await;
|
||||
tracing::Span::current().record("model", field::display(&model));
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
model,
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
@@ -2537,7 +2579,6 @@ async fn try_run_turn(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream(prompt)
|
||||
.instrument(trace_span!("stream_request"))
|
||||
.or_cancel(&cancellation_token)
|
||||
@@ -3163,6 +3204,7 @@ mod tests {
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_family,
|
||||
None,
|
||||
conversation_id,
|
||||
"turn_id".to_string(),
|
||||
);
|
||||
@@ -3249,6 +3291,7 @@ mod tests {
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_family,
|
||||
None,
|
||||
conversation_id,
|
||||
"turn_id".to_string(),
|
||||
));
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::get_last_assistant_message_from_turn;
|
||||
use crate::codex::refresh_models_and_reset_turn_context;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::features::Feature;
|
||||
@@ -55,7 +56,7 @@ pub(crate) async fn run_compact_task(
|
||||
input: Vec<UserInput>,
|
||||
) {
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, input).await;
|
||||
@@ -83,7 +84,7 @@ async fn run_compact_task_inner(
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
model: turn_context.client.get_model().await,
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
@@ -132,6 +133,10 @@ async fn run_compact_task_inner(
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
if matches!(e, CodexErr::OutdatedModels) {
|
||||
refresh_models_and_reset_turn_context(&sess, &turn_context).await;
|
||||
continue;
|
||||
}
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
turn_context.as_ref(),
|
||||
@@ -290,7 +295,7 @@ async fn drain_to_completed(
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let mut stream = turn_context.client.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
@@ -20,7 +20,7 @@ pub(crate) async fn run_inline_remote_auto_compact_task(
|
||||
|
||||
pub(crate) async fn run_remote_compact_task(sess: Arc<Session>, turn_context: Arc<TurnContext>) {
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ impl ContextManager {
|
||||
|
||||
// Estimate token usage using byte-based heuristics from the truncation helpers.
|
||||
// This is a coarse lower bound, not a tokenizer-accurate count.
|
||||
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
|
||||
let model_family = turn_context.client.get_model_family();
|
||||
pub(crate) async fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
|
||||
let model_family = turn_context.client.get_model_family().await;
|
||||
let base_tokens =
|
||||
i64::try_from(approx_token_count(model_family.base_instructions.as_str()))
|
||||
.unwrap_or(i64::MAX);
|
||||
|
||||
@@ -90,6 +90,10 @@ pub enum CodexErr {
|
||||
#[error("spawn failed: child stdout/stderr not captured")]
|
||||
Spawn,
|
||||
|
||||
/// Returned when the models list is outdated and needs to be refreshed.
|
||||
#[error("remote models list is outdated")]
|
||||
OutdatedModels,
|
||||
|
||||
/// Returned by run_command_stream when the user pressed Ctrl‑C (SIGINT). Session uses this to
|
||||
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
|
||||
#[error("interrupted (Ctrl-C). Something went wrong? Hit `/feedback` to report the issue.")]
|
||||
|
||||
@@ -77,7 +77,7 @@ impl ModelsManager {
|
||||
}
|
||||
|
||||
/// Fetch the latest remote models, using the on-disk cache when still fresh.
|
||||
pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> {
|
||||
pub async fn try_refresh_available_models(&self, config: &Config) -> CoreResult<()> {
|
||||
if !config.features.enabled(Feature::RemoteModels)
|
||||
|| self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey)
|
||||
{
|
||||
@@ -86,7 +86,15 @@ impl ModelsManager {
|
||||
if self.try_load_cache().await {
|
||||
return Ok(());
|
||||
}
|
||||
self.refresh_available_models(config).await
|
||||
}
|
||||
|
||||
pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> {
|
||||
if !config.features.enabled(Feature::RemoteModels)
|
||||
|| self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
let auth = self.auth_manager.auth();
|
||||
let api_provider = self.provider.to_api_provider(Some(AuthMode::ChatGPT))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
@@ -94,12 +102,12 @@ impl ModelsManager {
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth);
|
||||
|
||||
let client_version = format_client_version_to_whole();
|
||||
let ModelsResponse { models, etag } = client
|
||||
let (models, etag) = client
|
||||
.list_models(&client_version, HeaderMap::new())
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
|
||||
let etag = (!etag.is_empty()).then_some(etag);
|
||||
let etag = etag.filter(|value| !value.is_empty());
|
||||
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
@@ -108,7 +116,7 @@ impl ModelsManager {
|
||||
}
|
||||
|
||||
pub async fn list_models(&self, config: &Config) -> Vec<ModelPreset> {
|
||||
if let Err(err) = self.refresh_available_models(config).await {
|
||||
if let Err(err) = self.try_refresh_available_models(config).await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
let remote_models = self.remote_models(config).await;
|
||||
@@ -131,11 +139,15 @@ impl ModelsManager {
|
||||
.with_config_overrides(config)
|
||||
}
|
||||
|
||||
pub async fn get_models_etag(&self) -> Option<String> {
|
||||
self.etag.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn get_model(&self, model: &Option<String>, config: &Config) -> String {
|
||||
if let Some(model) = model.as_ref() {
|
||||
return model.to_string();
|
||||
}
|
||||
if let Err(err) = self.refresh_available_models(config).await {
|
||||
if let Err(err) = self.try_refresh_available_models(config).await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
// if codex-auto-balanced exists & signed in with chatgpt mode, return it, otherwise return the default model
|
||||
@@ -389,7 +401,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: remote_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -407,7 +418,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
let cached_remote = manager.remote_models(&config).await;
|
||||
@@ -446,7 +457,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: remote_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -467,7 +477,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("first refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -478,7 +488,7 @@ mod tests {
|
||||
|
||||
// Second call should read from cache and avoid the network.
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("cached refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -501,7 +511,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -522,7 +531,7 @@ mod tests {
|
||||
let manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
@@ -542,13 +551,12 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: updated_models.clone(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
assert_eq!(
|
||||
@@ -576,7 +584,6 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -595,7 +602,7 @@ mod tests {
|
||||
manager.cache_ttl = Duration::ZERO;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
@@ -605,13 +612,12 @@ mod tests {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: refreshed_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ impl SessionTask for UserShellCommandTask {
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
let session = session.clone_session();
|
||||
session.send_event(turn_context.as_ref(), event).await;
|
||||
|
||||
@@ -92,6 +92,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
|
||||
@@ -93,6 +93,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
|
||||
@@ -670,6 +670,24 @@ pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> Mod
|
||||
models_mock
|
||||
}
|
||||
|
||||
pub async fn mount_models_once_with_etag(
|
||||
server: &MockServer,
|
||||
body: ModelsResponse,
|
||||
etag: &str,
|
||||
) -> ModelsMock {
|
||||
let (mock, models_mock) = models_mock();
|
||||
mock.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "application/json")
|
||||
.insert_header("etag", etag)
|
||||
.set_body_json(body.clone()),
|
||||
)
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
models_mock
|
||||
}
|
||||
|
||||
pub async fn start_mock_server() -> MockServer {
|
||||
let server = MockServer::builder()
|
||||
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
||||
@@ -677,14 +695,7 @@ pub async fn start_mock_server() -> MockServer {
|
||||
.await;
|
||||
|
||||
// Provide a default `/models` response so tests remain hermetic when the client queries it.
|
||||
let _ = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: Vec::new(),
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
@@ -86,6 +86,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
@@ -181,6 +182,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
@@ -275,6 +277,7 @@ async fn responses_respects_model_family_overrides_from_config() {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
|
||||
@@ -1146,6 +1146,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_family,
|
||||
None,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
|
||||
@@ -33,8 +33,12 @@ use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_shell_command_call;
|
||||
use core_test_support::responses::mount_models_once;
|
||||
use core_test_support::responses::mount_models_once_with_etag;
|
||||
use core_test_support::responses::mount_response_once_match;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::skip_if_no_network;
|
||||
@@ -42,6 +46,7 @@ use core_test_support::skip_if_sandbox;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Duration;
|
||||
@@ -49,9 +54,92 @@ use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
|
||||
const REMOTE_MODEL_SLUG: &str = "codex-test";
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct ResponsesMatch {
|
||||
etag: Option<String>,
|
||||
user_text: Option<String>,
|
||||
call_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ResponsesMatch {
|
||||
fn with_etag(mut self, etag: &str) -> Self {
|
||||
self.etag = Some(etag.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
fn with_user_text(mut self, text: &str) -> Self {
|
||||
self.user_text = Some(text.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
fn with_function_call_output(mut self, call_id: &str) -> Self {
|
||||
self.call_id = Some(call_id.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl wiremock::Match for ResponsesMatch {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
if let Some(expected_etag) = &self.etag {
|
||||
let header = request
|
||||
.headers
|
||||
.get("X-If-Models-Match")
|
||||
.and_then(|value| value.to_str().ok());
|
||||
if header != Some(expected_etag.as_str()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
let Ok(body): Result<Value, _> = request.body_json() else {
|
||||
return false;
|
||||
};
|
||||
let Some(items) = body.get("input").and_then(Value::as_array) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Some(expected_text) = &self.user_text
|
||||
&& !input_has_user_text(items, expected_text)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(expected_call_id) = &self.call_id
|
||||
&& !input_has_function_call_output(items, expected_call_id)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn input_has_user_text(items: &[Value], expected: &str) -> bool {
|
||||
items.iter().any(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("message")
|
||||
&& item.get("role").and_then(Value::as_str) == Some("user")
|
||||
&& item
|
||||
.get("content")
|
||||
.and_then(Value::as_array)
|
||||
.is_some_and(|content| {
|
||||
content.iter().any(|span| {
|
||||
span.get("type").and_then(Value::as_str) == Some("input_text")
|
||||
&& span.get("text").and_then(Value::as_str) == Some(expected)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn input_has_function_call_output(items: &[Value], call_id: &str) -> bool {
|
||||
items.iter().any(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -93,7 +181,6 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -232,7 +319,6 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -299,6 +385,208 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Exercises the remote-models retry flow:
|
||||
/// 1) initial `/models` fetch stores an ETag,
|
||||
/// 2) `/responses` uses that ETag for a tool call,
|
||||
/// 3) the tool-output turn receives a 412 (stale models),
|
||||
/// 4) Codex refreshes `/models` to get a new ETag and retries,
|
||||
/// 5) subsequent user turns keep sending the refreshed ETag.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_refresh_etag_after_outdated_models() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::builder()
|
||||
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
||||
.start()
|
||||
.await;
|
||||
|
||||
let remote_model = test_remote_model("remote-etag", ModelVisibility::List, 1);
|
||||
let initial_etag = "models-etag-initial";
|
||||
let refreshed_etag = "models-etag-refreshed";
|
||||
|
||||
// Phase 1a: seed the initial `/models` response with an ETag.
|
||||
let models_mock = mount_models_once_with_etag(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model.clone()],
|
||||
},
|
||||
initial_etag,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Phase 1b: boot a Codex session configured for remote models.
|
||||
let harness = build_remote_models_harness(&server, |config| {
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
config.model = Some("gpt-5.1".to_string());
|
||||
})
|
||||
.await?;
|
||||
|
||||
let RemoteModelsHarness {
|
||||
codex,
|
||||
cwd,
|
||||
config,
|
||||
conversation_manager,
|
||||
..
|
||||
} = harness;
|
||||
|
||||
let models_manager = conversation_manager.get_models_manager();
|
||||
wait_for_model_available(&models_manager, "remote-etag", &config).await;
|
||||
|
||||
// Phase 1c: confirm the ETag is stored and `/models` was called.
|
||||
assert_eq!(
|
||||
models_manager.get_models_etag().await.as_deref(),
|
||||
Some(initial_etag),
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"expected an initial /models request",
|
||||
);
|
||||
assert_eq!(models_mock.requests()[0].url.path(), "/v1/models");
|
||||
|
||||
// Phase 2a: reset mocks so the next `/models` call must be explicit.
|
||||
server.reset().await;
|
||||
// Phase 2b: mount a refreshed `/models` response with a new ETag.
|
||||
let refreshed_models_mock = mount_models_once_with_etag(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
},
|
||||
refreshed_etag,
|
||||
)
|
||||
.await;
|
||||
|
||||
let call_id = "shell-command-call";
|
||||
let first_prompt = "run a shell command";
|
||||
let followup_prompt = "send another message";
|
||||
|
||||
// Phase 2c: first `/responses` turn uses the initial ETag and emits a tool call.
|
||||
let first_response = mount_sse_once_match(
|
||||
&server,
|
||||
ResponsesMatch::default()
|
||||
.with_etag(initial_etag)
|
||||
.with_user_text(first_prompt),
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_shell_command_call(call_id, "echo refreshed"),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Phase 2d: the tool-output follow-up returns 412 (stale models).
|
||||
let stale_response = mount_response_once_match(
|
||||
&server,
|
||||
ResponsesMatch::default()
|
||||
.with_etag(initial_etag)
|
||||
.with_function_call_output(call_id),
|
||||
ResponseTemplate::new(412)
|
||||
.set_body_string("Models catalog has changed. Please refresh your models list."),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Phase 2e: retry tool-output follow-up should use the refreshed ETag.
|
||||
let refreshed_response = mount_sse_once_match(
|
||||
&server,
|
||||
ResponsesMatch::default()
|
||||
.with_etag(refreshed_etag)
|
||||
.with_function_call_output(call_id),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Phase 3a: next user turn should also use the refreshed ETag.
|
||||
let next_turn_response = mount_sse_once_match(
|
||||
&server,
|
||||
ResponsesMatch::default()
|
||||
.with_etag(refreshed_etag)
|
||||
.with_user_text(followup_prompt),
|
||||
sse(vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-2", "ok"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Phase 3b: run the first user turn and let retries complete.
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: first_prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: "gpt-5.1".to_string(),
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Phase 3c: assert the refresh happened and the ETag was updated.
|
||||
assert_eq!(
|
||||
refreshed_models_mock.requests().len(),
|
||||
1,
|
||||
"expected a refreshed /models request",
|
||||
);
|
||||
assert_eq!(
|
||||
models_manager.get_models_etag().await.as_deref(),
|
||||
Some(refreshed_etag),
|
||||
);
|
||||
|
||||
// Phase 3d: assert the ETag header progression across the retry sequence.
|
||||
assert_eq!(
|
||||
first_response.single_request().header("X-If-Models-Match"),
|
||||
Some(initial_etag.to_string()),
|
||||
);
|
||||
assert_eq!(
|
||||
stale_response.single_request().header("X-If-Models-Match"),
|
||||
Some(initial_etag.to_string()),
|
||||
);
|
||||
assert_eq!(
|
||||
refreshed_response
|
||||
.single_request()
|
||||
.header("X-If-Models-Match"),
|
||||
Some(refreshed_etag.to_string()),
|
||||
);
|
||||
|
||||
// Phase 3e: execute a new user turn and ensure the refreshed ETag persists.
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: followup_prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: "gpt-5.1".to_string(),
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
assert_eq!(
|
||||
next_turn_response
|
||||
.single_request()
|
||||
.header("X-If-Models-Match"),
|
||||
Some(refreshed_etag.to_string()),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -310,7 +598,6 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model.clone()],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -330,7 +617,7 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.try_refresh_available_models(&config)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
|
||||
@@ -368,7 +655,6 @@ async fn remote_models_hide_picker_only_models() -> Result<()> {
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -197,8 +197,6 @@ pub struct ModelInfo {
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)]
|
||||
pub struct ModelsResponse {
|
||||
pub models: Vec<ModelInfo>,
|
||||
#[serde(default)]
|
||||
pub etag: String,
|
||||
}
|
||||
|
||||
// convert ModelInfo to ModelPreset
|
||||
|
||||
Reference in New Issue
Block a user