Compare commits

...

15 Commits

Author SHA1 Message Date
Ahmed Ibrahim
124c7fc2be compact 2025-12-18 22:56:49 -08:00
Ahmed Ibrahim
0a1323747b rwlock 2025-12-18 22:38:40 -08:00
Ahmed Ibrahim
348d379509 rwlock 2025-12-18 22:20:35 -08:00
Ahmed Ibrahim
6912ba9fda final_output_json_schema 2025-12-18 22:06:04 -08:00
Ahmed Ibrahim
27cec53ddc error 2025-12-18 21:44:02 -08:00
Ahmed Ibrahim
42273d94e8 Merge branch 'main' into tag 2025-12-18 21:31:14 -08:00
Ahmed Ibrahim
1a5289a4ef test 2025-12-18 20:32:51 -08:00
Ahmed Ibrahim
359142f22f comments 2025-12-18 20:03:11 -08:00
Ahmed Ibrahim
ecff4d4f72 comments 2025-12-18 19:55:38 -08:00
Ahmed Ibrahim
985333feff comments 2025-12-18 19:44:55 -08:00
Ahmed Ibrahim
e01610f762 unit test 2025-12-18 19:37:24 -08:00
Ahmed Ibrahim
09693d259b rwlock 2025-12-18 19:02:58 -08:00
Ahmed Ibrahim
f8ba48d995 progress 2025-12-18 18:32:41 -08:00
Ahmed Ibrahim
677532f97b progress 2025-12-18 18:25:07 -08:00
Ahmed Ibrahim
beb83225e5 etag remove 2025-12-18 17:59:59 -08:00
19 changed files with 537 additions and 130 deletions

View File

@@ -5,6 +5,7 @@ use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
@@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
@@ -66,7 +67,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
@@ -74,9 +75,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
))
})?;
let etag = header_etag.unwrap_or(etag);
Ok(ModelsResponse { models, etag })
Ok((models, header_etag))
}
}
@@ -102,16 +101,15 @@ mod tests {
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
response_etag: Arc<Option<String>>,
}
impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
body: Arc::new(ModelsResponse { models: Vec::new() }),
response_etag: Arc::new(None),
}
}
}
@@ -122,8 +120,8 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
if let Some(etag) = self.response_etag.as_ref().as_deref() {
headers.insert(ETAG, etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
@@ -166,14 +164,12 @@ mod tests {
#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let response = ModelsResponse { models: Vec::new() };
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};
let client = ModelsClient::new(
@@ -182,12 +178,12 @@ mod tests {
DummyAuth,
);
let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 0);
assert_eq!(models.len(), 0);
let url = transport
.last_request
@@ -232,12 +228,12 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};
let client = ModelsClient::new(
@@ -246,27 +242,25 @@ mod tests {
DummyAuth,
);
let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
assert_eq!(models[0].supported_in_api, true);
assert_eq!(models[0].priority, 1);
}
#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let response = ModelsResponse { models: Vec::new() };
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(Some("\"abc\"".to_string())),
};
let client = ModelsClient::new(
@@ -275,12 +269,12 @@ mod tests {
DummyAuth,
);
let result = client
let (models, etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
assert_eq!(models.len(), 0);
assert_eq!(etag.as_deref(), Some("\"abc\""));
}
}

View File

@@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() {
reasoning_summary_format: ReasoningSummaryFormat::None,
experimental_supported_tools: Vec::new(),
}],
etag: String::new(),
};
Mock::given(method("GET"))
@@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() {
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
let result = client
let (models, _etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");
assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
let received = server
.received_requests()

View File

@@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
status,
request_id: extract_request_id(headers.as_ref()),
})
} else if status == http::StatusCode::PRECONDITION_FAILED
&& body_text
.contains("Models catalog has changed. Please refresh your models list.")
{
CodexErr::OutdatedModels
} else {
CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,

View File

@@ -33,6 +33,7 @@ use http::StatusCode as HttpStatusCode;
use reqwest::StatusCode;
use serde_json::Value;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::mpsc;
use tracing::warn;
@@ -53,11 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
model_family: RwLock<ModelFamily>,
models_etag: RwLock<Option<String>>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
conversation_id: ConversationId,
@@ -72,6 +74,7 @@ impl ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
effort: Option<ReasoningEffortConfig>,
@@ -82,7 +85,8 @@ impl ModelClient {
Self {
config,
auth_manager,
model_family,
model_family: RwLock::new(model_family),
models_etag: RwLock::new(models_etag),
otel_manager,
provider,
conversation_id,
@@ -92,8 +96,8 @@ impl ModelClient {
}
}
pub fn get_model_context_window(&self) -> Option<i64> {
let model_family = self.get_model_family();
pub async fn get_model_context_window(&self) -> Option<i64> {
let model_family = self.get_model_family().await;
let effective_context_window_percent = model_family.effective_context_window_percent;
model_family
.context_window
@@ -146,7 +150,7 @@ impl ModelClient {
}
let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let model_family = self.get_model_family().await;
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
@@ -167,7 +171,7 @@ impl ModelClient {
let stream_result = client
.stream_prompt(
&self.get_model(),
&self.get_model().await,
&api_prompt,
Some(conversation_id.clone()),
Some(session_source.clone()),
@@ -200,7 +204,7 @@ impl ModelClient {
}
let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let model_family = self.get_model_family().await;
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
@@ -262,11 +266,14 @@ impl ModelClient {
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
extra_headers: beta_feature_headers(
&self.config,
self.get_models_etag().await.clone(),
),
};
let stream_result = client
.stream_prompt(&self.get_model(), &api_prompt, options)
.stream_prompt(&self.get_model().await, &api_prompt, options)
.await;
match stream_result {
@@ -297,13 +304,25 @@ impl ModelClient {
}
/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.get_model_family().get_model_slug().to_string()
pub async fn get_model(&self) -> String {
self.get_model_family().await.get_model_slug().to_string()
}
/// Returns the currently configured model family.
pub fn get_model_family(&self) -> ModelFamily {
self.model_family.clone()
pub async fn get_model_family(&self) -> ModelFamily {
self.model_family.read().await.clone()
}
pub async fn get_models_etag(&self) -> Option<String> {
self.models_etag.read().await.clone()
}
pub async fn update_models_etag(&self, etag: Option<String>) {
*self.models_etag.write().await = etag;
}
pub async fn update_model_family(&self, model_family: ModelFamily) {
*self.model_family.write().await = model_family;
}
/// Returns the current reasoning effort setting.
@@ -340,10 +359,10 @@ impl ModelClient {
.with_telemetry(Some(request_telemetry));
let instructions = prompt
.get_full_instructions(&self.get_model_family())
.get_full_instructions(&self.get_model_family().await)
.into_owned();
let payload = ApiCompactionInput {
model: &self.get_model(),
model: &self.get_model().await,
input: &prompt.input,
instructions: &instructions,
};
@@ -398,7 +417,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
}
}
fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
fn beta_feature_headers(config: &Config, models_etag: Option<String>) -> ApiHeaderMap {
let enabled = FEATURES
.iter()
.filter_map(|spec| {
@@ -416,6 +435,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(etag) = models_etag
&& let Ok(header_value) = HeaderValue::from_str(&etag)
{
headers.insert("X-If-Models-Match", header_value);
}
headers
}

View File

@@ -1,6 +1,10 @@
use crate::client_common::tools::ToolSpec;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::Result;
use crate::features::Feature;
use crate::openai_models::model_family::ModelFamily;
use crate::tools::ToolRouter;
pub use codex_api::common::ResponseEvent;
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
use codex_protocol::models::ResponseItem;
@@ -44,6 +48,28 @@ pub struct Prompt {
}
impl Prompt {
pub(crate) async fn new(
sess: &Session,
turn_context: &TurnContext,
router: &ToolRouter,
input: &[ResponseItem],
) -> Prompt {
let model_supports_parallel = turn_context
.client
.get_model_family()
.await
.supports_parallel_tool_calls;
Prompt {
input: input.to_vec(),
tools: router.specs(),
parallel_tool_calls: model_supports_parallel
&& sess.enabled(Feature::ParallelToolCalls),
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
}
}
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
let base = self
.base_instructions_override

View File

@@ -249,7 +249,7 @@ impl Codex {
let config = Arc::new(config);
if config.features.enabled(Feature::RemoteModels)
&& let Err(err) = models_manager.refresh_available_models(&config).await
&& let Err(err) = models_manager.try_refresh_available_models(&config).await
{
error!("failed to refresh available models: {err:?}");
}
@@ -492,6 +492,7 @@ impl Session {
session_configuration: &SessionConfiguration,
per_turn_config: Config,
model_family: ModelFamily,
models_etag: Option<String>,
conversation_id: ConversationId,
sub_id: String,
) -> TurnContext {
@@ -505,6 +506,7 @@ impl Session {
per_turn_config.clone(),
auth_manager,
model_family.clone(),
models_etag,
otel_manager,
provider,
session_configuration.model_reasoning_effort,
@@ -788,7 +790,7 @@ impl Session {
}
})
{
let curr = turn_context.client.get_model();
let curr = turn_context.client.get_model().await;
if prev != curr {
warn!(
"resuming session with different model: previous={prev}, current={curr}"
@@ -919,6 +921,7 @@ impl Session {
.models_manager
.construct_model_family(session_configuration.model.as_str(), &per_turn_config)
.await;
let models_etag = self.services.models_manager.get_models_etag().await;
let mut turn_context: TurnContext = Self::make_turn_context(
Some(Arc::clone(&self.services.auth_manager)),
&self.services.otel_manager,
@@ -926,6 +929,7 @@ impl Session {
&session_configuration,
per_turn_config,
model_family,
models_etag,
self.conversation_id,
sub_id,
);
@@ -1334,7 +1338,7 @@ impl Session {
if let Some(token_usage) = token_usage {
state.update_token_info_from_usage(
token_usage,
turn_context.client.get_model_context_window(),
turn_context.client.get_model_context_window().await,
);
}
}
@@ -1346,6 +1350,7 @@ impl Session {
.clone_history()
.await
.estimate_token_count(turn_context)
.await
else {
return;
};
@@ -1366,7 +1371,7 @@ impl Session {
};
if info.model_context_window.is_none() {
info.model_context_window = turn_context.client.get_model_context_window();
info.model_context_window = turn_context.client.get_model_context_window().await;
}
state.set_token_info(Some(info));
@@ -1396,7 +1401,7 @@ impl Session {
}
pub(crate) async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
let context_window = turn_context.client.get_model_context_window();
let context_window = turn_context.client.get_model_context_window().await;
if let Some(context_window) = context_window {
{
let mut state = self.state.lock().await;
@@ -2105,6 +2110,7 @@ async fn spawn_review_thread(
.models_manager
.construct_model_family(&model, &config)
.await;
let models_etag = sess.services.models_manager.get_models_etag().await;
// For reviews, disable web_search and view_image regardless of global settings.
let mut review_features = sess.features.clone();
review_features
@@ -2137,6 +2143,7 @@ async fn spawn_review_thread(
per_turn_config.clone(),
auth_manager,
model_family.clone(),
models_etag,
otel_manager,
provider,
per_turn_config.model_reasoning_effort,
@@ -2231,6 +2238,7 @@ pub(crate) async fn run_task(
let auto_compact_limit = turn_context
.client
.get_model_family()
.await
.auto_compact_token_limit()
.unwrap_or(i64::MAX);
let total_usage_tokens = sess.get_total_token_usage().await;
@@ -2238,7 +2246,7 @@ pub(crate) async fn run_task(
run_auto_compact(&sess, &turn_context).await;
}
let event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
model_context_window: turn_context.client.get_model_context_window().await,
});
sess.send_event(&turn_context, event).await;
@@ -2303,7 +2311,7 @@ pub(crate) async fn run_task(
.collect::<Vec<String>>();
match run_turn(
Arc::clone(&sess),
Arc::clone(&turn_context),
&turn_context,
Arc::clone(&turn_diff_tracker),
turn_input,
cancellation_token.child_token(),
@@ -2362,6 +2370,36 @@ pub(crate) async fn run_task(
last_agent_message
}
pub(crate) async fn refresh_models_and_reset_turn_context(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
) {
let config = {
let state = sess.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.clone()
};
if let Err(err) = sess
.services
.models_manager
.refresh_available_models(&config)
.await
{
error!("failed to refresh models after outdated models error: {err}");
}
let model = turn_context.client.get_model().await;
let model_family = sess
.services
.models_manager
.construct_model_family(&model, &config)
.await;
let models_etag = sess.services.models_manager.get_models_etag().await;
turn_context.client.update_model_family(model_family).await;
turn_context.client.update_models_etag(models_etag).await;
}
async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>) {
if should_use_remote_compact_task(sess.as_ref(), &turn_context.client.get_provider()) {
run_inline_remote_auto_compact_task(Arc::clone(sess), Arc::clone(turn_context)).await;
@@ -2374,17 +2412,19 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.client.get_model(),
model = tracing::field::Empty,
cwd = %turn_context.cwd.display()
)
)]
async fn run_turn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_context: &Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
input: Vec<ResponseItem>,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let model = turn_context.client.get_model().await;
tracing::Span::current().record("model", field::display(&model));
let mcp_tools = sess
.services
.mcp_connection_manager
@@ -2393,37 +2433,32 @@ async fn run_turn(
.list_all_tools()
.or_cancel(&cancellation_token)
.await?;
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(
mcp_tools
.into_iter()
.map(|(name, tool)| (name, tool.tool))
.collect(),
),
));
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
let prompt = Prompt {
input,
tools: router.specs(),
parallel_tool_calls: model_supports_parallel && sess.enabled(Feature::ParallelToolCalls),
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
};
let mut retries = 0;
loop {
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(
mcp_tools
.clone()
.into_iter()
.map(|(name, tool)| (name, tool.tool))
.collect(),
),
));
let prompt = Prompt::new(
sess.as_ref(),
turn_context.as_ref(),
router.as_ref(),
&input,
);
match try_run_turn(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(turn_context),
Arc::clone(&turn_diff_tracker),
&prompt,
&prompt.await,
cancellation_token.child_token(),
)
.await
@@ -2437,13 +2472,13 @@ async fn run_turn(
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e),
Err(e @ CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&turn_context).await;
sess.set_total_tokens_full(turn_context).await;
return Err(e);
}
Err(CodexErr::UsageLimitReached(e)) => {
let rate_limits = e.rate_limits.clone();
if let Some(rate_limits) = rate_limits {
sess.update_rate_limits(&turn_context, rate_limits).await;
sess.update_rate_limits(turn_context, rate_limits).await;
}
return Err(CodexErr::UsageLimitReached(e));
}
@@ -2457,6 +2492,11 @@ async fn run_turn(
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries < max_retries {
retries += 1;
// Refresh models if we got an outdated models error
if matches!(e, CodexErr::OutdatedModels) {
refresh_models_and_reset_turn_context(&sess, turn_context).await;
continue;
}
let delay = match e {
CodexErr::Stream(_, Some(delay)) => delay,
_ => backoff(retries),
@@ -2469,7 +2509,7 @@ async fn run_turn(
// user understands what is happening instead of staring
// at a seemingly frozen screen.
sess.notify_stream_error(
&turn_context,
turn_context,
format!("Reconnecting... {retries}/{max_retries}"),
e,
)
@@ -2514,7 +2554,7 @@ async fn drain_in_flight(
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.client.get_model()
model = tracing::field::Empty,
)
)]
async fn try_run_turn(
@@ -2525,11 +2565,13 @@ async fn try_run_turn(
prompt: &Prompt,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let model = turn_context.client.get_model().await;
tracing::Span::current().record("model", field::display(&model));
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
model,
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});
@@ -2537,7 +2579,6 @@ async fn try_run_turn(
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context
.client
.clone()
.stream(prompt)
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)
@@ -3163,6 +3204,7 @@ mod tests {
&session_configuration,
per_turn_config,
model_family,
None,
conversation_id,
"turn_id".to_string(),
);
@@ -3249,6 +3291,7 @@ mod tests {
&session_configuration,
per_turn_config,
model_family,
None,
conversation_id,
"turn_id".to_string(),
));

View File

@@ -6,6 +6,7 @@ use crate::client_common::ResponseEvent;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::codex::get_last_assistant_message_from_turn;
use crate::codex::refresh_models_and_reset_turn_context;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::features::Feature;
@@ -55,7 +56,7 @@ pub(crate) async fn run_compact_task(
input: Vec<UserInput>,
) {
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
model_context_window: turn_context.client.get_model_context_window().await,
});
sess.send_event(&turn_context, start_event).await;
run_compact_task_inner(sess.clone(), turn_context, input).await;
@@ -83,7 +84,7 @@ async fn run_compact_task_inner(
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
model: turn_context.client.get_model().await,
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});
@@ -132,6 +133,10 @@ async fn run_compact_task_inner(
Err(e) => {
if retries < max_retries {
retries += 1;
if matches!(e, CodexErr::OutdatedModels) {
refresh_models_and_reset_turn_context(&sess, &turn_context).await;
continue;
}
let delay = backoff(retries);
sess.notify_stream_error(
turn_context.as_ref(),
@@ -290,7 +295,7 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
let mut stream = turn_context.client.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {

View File

@@ -20,7 +20,7 @@ pub(crate) async fn run_inline_remote_auto_compact_task(
pub(crate) async fn run_remote_compact_task(sess: Arc<Session>, turn_context: Arc<TurnContext>) {
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
model_context_window: turn_context.client.get_model_context_window().await,
});
sess.send_event(&turn_context, start_event).await;

View File

@@ -79,8 +79,8 @@ impl ContextManager {
// Estimate token usage using byte-based heuristics from the truncation helpers.
// This is a coarse lower bound, not a tokenizer-accurate count.
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let model_family = turn_context.client.get_model_family();
pub(crate) async fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let model_family = turn_context.client.get_model_family().await;
let base_tokens =
i64::try_from(approx_token_count(model_family.base_instructions.as_str()))
.unwrap_or(i64::MAX);

View File

@@ -90,6 +90,10 @@ pub enum CodexErr {
#[error("spawn failed: child stdout/stderr not captured")]
Spawn,
/// Returned when the models list is outdated and needs to be refreshed.
#[error("remote models list is outdated")]
OutdatedModels,
/// Returned by run_command_stream when the user pressed CtrlC (SIGINT). Session uses this to
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
#[error("interrupted (Ctrl-C). Something went wrong? Hit `/feedback` to report the issue.")]

View File

@@ -77,7 +77,7 @@ impl ModelsManager {
}
/// Fetch the latest remote models, using the on-disk cache when still fresh.
pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> {
pub async fn try_refresh_available_models(&self, config: &Config) -> CoreResult<()> {
if !config.features.enabled(Feature::RemoteModels)
|| self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey)
{
@@ -86,7 +86,15 @@ impl ModelsManager {
if self.try_load_cache().await {
return Ok(());
}
self.refresh_available_models(config).await
}
pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> {
if !config.features.enabled(Feature::RemoteModels)
|| self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey)
{
return Ok(());
}
let auth = self.auth_manager.auth();
let api_provider = self.provider.to_api_provider(Some(AuthMode::ChatGPT))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
@@ -94,12 +102,12 @@ impl ModelsManager {
let client = ModelsClient::new(transport, api_provider, api_auth);
let client_version = format_client_version_to_whole();
let ModelsResponse { models, etag } = client
let (models, etag) = client
.list_models(&client_version, HeaderMap::new())
.await
.map_err(map_api_error)?;
let etag = (!etag.is_empty()).then_some(etag);
let etag = etag.filter(|value| !value.is_empty());
self.apply_remote_models(models.clone()).await;
*self.etag.write().await = etag.clone();
@@ -108,7 +116,7 @@ impl ModelsManager {
}
pub async fn list_models(&self, config: &Config) -> Vec<ModelPreset> {
if let Err(err) = self.refresh_available_models(config).await {
if let Err(err) = self.try_refresh_available_models(config).await {
error!("failed to refresh available models: {err}");
}
let remote_models = self.remote_models(config).await;
@@ -131,11 +139,15 @@ impl ModelsManager {
.with_config_overrides(config)
}
pub async fn get_models_etag(&self) -> Option<String> {
self.etag.read().await.clone()
}
pub async fn get_model(&self, model: &Option<String>, config: &Config) -> String {
if let Some(model) = model.as_ref() {
return model.to_string();
}
if let Err(err) = self.refresh_available_models(config).await {
if let Err(err) = self.try_refresh_available_models(config).await {
error!("failed to refresh available models: {err}");
}
// if codex-auto-balanced exists & signed in with chatgpt mode, return it, otherwise return the default model
@@ -389,7 +401,6 @@ mod tests {
&server,
ModelsResponse {
models: remote_models.clone(),
etag: String::new(),
},
)
.await;
@@ -407,7 +418,7 @@ mod tests {
let manager = ModelsManager::with_provider(auth_manager, provider);
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("refresh succeeds");
let cached_remote = manager.remote_models(&config).await;
@@ -446,7 +457,6 @@ mod tests {
&server,
ModelsResponse {
models: remote_models.clone(),
etag: String::new(),
},
)
.await;
@@ -467,7 +477,7 @@ mod tests {
let manager = ModelsManager::with_provider(auth_manager, provider);
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("first refresh succeeds");
assert_eq!(
@@ -478,7 +488,7 @@ mod tests {
// Second call should read from cache and avoid the network.
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("cached refresh succeeds");
assert_eq!(
@@ -501,7 +511,6 @@ mod tests {
&server,
ModelsResponse {
models: initial_models.clone(),
etag: String::new(),
},
)
.await;
@@ -522,7 +531,7 @@ mod tests {
let manager = ModelsManager::with_provider(auth_manager, provider);
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("initial refresh succeeds");
@@ -542,13 +551,12 @@ mod tests {
&server,
ModelsResponse {
models: updated_models.clone(),
etag: String::new(),
},
)
.await;
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("second refresh succeeds");
assert_eq!(
@@ -576,7 +584,6 @@ mod tests {
&server,
ModelsResponse {
models: initial_models,
etag: String::new(),
},
)
.await;
@@ -595,7 +602,7 @@ mod tests {
manager.cache_ttl = Duration::ZERO;
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("initial refresh succeeds");
@@ -605,13 +612,12 @@ mod tests {
&server,
ModelsResponse {
models: refreshed_models,
etag: String::new(),
},
)
.await;
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("second refresh succeeds");

View File

@@ -59,7 +59,7 @@ impl SessionTask for UserShellCommandTask {
cancellation_token: CancellationToken,
) -> Option<String> {
let event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
model_context_window: turn_context.client.get_model_context_window().await,
});
let session = session.clone_session();
session.send_event(turn_context.as_ref(), event).await;

View File

@@ -92,6 +92,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,

View File

@@ -93,6 +93,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,

View File

@@ -670,6 +670,24 @@ pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> Mod
models_mock
}
pub async fn mount_models_once_with_etag(
server: &MockServer,
body: ModelsResponse,
etag: &str,
) -> ModelsMock {
let (mock, models_mock) = models_mock();
mock.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.insert_header("etag", etag)
.set_body_json(body.clone()),
)
.up_to_n_times(1)
.mount(server)
.await;
models_mock
}
pub async fn start_mock_server() -> MockServer {
let server = MockServer::builder()
.body_print_limit(BodyPrintLimit::Limited(80_000))
@@ -677,14 +695,7 @@ pub async fn start_mock_server() -> MockServer {
.await;
// Provide a default `/models` response so tests remain hermetic when the client queries it.
let _ = mount_models_once(
&server,
ModelsResponse {
models: Vec::new(),
etag: String::new(),
},
)
.await;
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
server
}

View File

@@ -86,6 +86,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,
@@ -181,6 +182,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,
@@ -275,6 +277,7 @@ async fn responses_respects_model_family_overrides_from_config() {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,

View File

@@ -1146,6 +1146,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
Arc::clone(&config),
None,
model_family,
None,
otel_manager,
provider,
effort,

View File

@@ -33,8 +33,12 @@ use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_shell_command_call;
use core_test_support::responses::mount_models_once;
use core_test_support::responses::mount_models_once_with_etag;
use core_test_support::responses::mount_response_once_match;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::skip_if_no_network;
@@ -42,6 +46,7 @@ use core_test_support::skip_if_sandbox;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::Duration;
@@ -49,9 +54,92 @@ use tokio::time::Instant;
use tokio::time::sleep;
use wiremock::BodyPrintLimit;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
const REMOTE_MODEL_SLUG: &str = "codex-test";
#[derive(Clone, Default)]
struct ResponsesMatch {
etag: Option<String>,
user_text: Option<String>,
call_id: Option<String>,
}
impl ResponsesMatch {
fn with_etag(mut self, etag: &str) -> Self {
self.etag = Some(etag.to_string());
self
}
fn with_user_text(mut self, text: &str) -> Self {
self.user_text = Some(text.to_string());
self
}
fn with_function_call_output(mut self, call_id: &str) -> Self {
self.call_id = Some(call_id.to_string());
self
}
}
impl wiremock::Match for ResponsesMatch {
fn matches(&self, request: &wiremock::Request) -> bool {
if let Some(expected_etag) = &self.etag {
let header = request
.headers
.get("X-If-Models-Match")
.and_then(|value| value.to_str().ok());
if header != Some(expected_etag.as_str()) {
return false;
}
}
let Ok(body): Result<Value, _> = request.body_json() else {
return false;
};
let Some(items) = body.get("input").and_then(Value::as_array) else {
return false;
};
if let Some(expected_text) = &self.user_text
&& !input_has_user_text(items, expected_text)
{
return false;
}
if let Some(expected_call_id) = &self.call_id
&& !input_has_function_call_output(items, expected_call_id)
{
return false;
}
true
}
}
fn input_has_user_text(items: &[Value], expected: &str) -> bool {
items.iter().any(|item| {
item.get("type").and_then(Value::as_str) == Some("message")
&& item.get("role").and_then(Value::as_str) == Some("user")
&& item
.get("content")
.and_then(Value::as_array)
.is_some_and(|content| {
content.iter().any(|span| {
span.get("type").and_then(Value::as_str) == Some("input_text")
&& span.get("text").and_then(Value::as_str) == Some(expected)
})
})
})
}
fn input_has_function_call_output(items: &[Value], call_id: &str) -> bool {
items.iter().any(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
})
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -93,7 +181,6 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
&server,
ModelsResponse {
models: vec![remote_model],
etag: String::new(),
},
)
.await;
@@ -232,7 +319,6 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
&server,
ModelsResponse {
models: vec![remote_model],
etag: String::new(),
},
)
.await;
@@ -299,6 +385,208 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
Ok(())
}
/// Exercises the remote-models retry flow:
/// 1) initial `/models` fetch stores an ETag,
/// 2) `/responses` uses that ETag for a tool call,
/// 3) the tool-output turn receives a 412 (stale models),
/// 4) Codex refreshes `/models` to get a new ETag and retries,
/// 5) subsequent user turns keep sending the refreshed ETag.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_models_refresh_etag_after_outdated_models() -> Result<()> {
skip_if_no_network!(Ok(()));
skip_if_sandbox!(Ok(()));
let server = MockServer::builder()
.body_print_limit(BodyPrintLimit::Limited(80_000))
.start()
.await;
let remote_model = test_remote_model("remote-etag", ModelVisibility::List, 1);
let initial_etag = "models-etag-initial";
let refreshed_etag = "models-etag-refreshed";
// Phase 1a: seed the initial `/models` response with an ETag.
let models_mock = mount_models_once_with_etag(
&server,
ModelsResponse {
models: vec![remote_model.clone()],
},
initial_etag,
)
.await;
// Phase 1b: boot a Codex session configured for remote models.
let harness = build_remote_models_harness(&server, |config| {
config.features.enable(Feature::RemoteModels);
config.model = Some("gpt-5.1".to_string());
})
.await?;
let RemoteModelsHarness {
codex,
cwd,
config,
conversation_manager,
..
} = harness;
let models_manager = conversation_manager.get_models_manager();
wait_for_model_available(&models_manager, "remote-etag", &config).await;
// Phase 1c: confirm the ETag is stored and `/models` was called.
assert_eq!(
models_manager.get_models_etag().await.as_deref(),
Some(initial_etag),
);
assert_eq!(
models_mock.requests().len(),
1,
"expected an initial /models request",
);
assert_eq!(models_mock.requests()[0].url.path(), "/v1/models");
// Phase 2a: reset mocks so the next `/models` call must be explicit.
server.reset().await;
// Phase 2b: mount a refreshed `/models` response with a new ETag.
let refreshed_models_mock = mount_models_once_with_etag(
&server,
ModelsResponse {
models: vec![remote_model],
},
refreshed_etag,
)
.await;
let call_id = "shell-command-call";
let first_prompt = "run a shell command";
let followup_prompt = "send another message";
// Phase 2c: first `/responses` turn uses the initial ETag and emits a tool call.
let first_response = mount_sse_once_match(
&server,
ResponsesMatch::default()
.with_etag(initial_etag)
.with_user_text(first_prompt),
sse(vec![
ev_response_created("resp-1"),
ev_shell_command_call(call_id, "echo refreshed"),
ev_completed("resp-1"),
]),
)
.await;
// Phase 2d: the tool-output follow-up returns 412 (stale models).
let stale_response = mount_response_once_match(
&server,
ResponsesMatch::default()
.with_etag(initial_etag)
.with_function_call_output(call_id),
ResponseTemplate::new(412)
.set_body_string("Models catalog has changed. Please refresh your models list."),
)
.await;
// Phase 2e: retry tool-output follow-up should use the refreshed ETag.
let refreshed_response = mount_sse_once_match(
&server,
ResponsesMatch::default()
.with_etag(refreshed_etag)
.with_function_call_output(call_id),
sse(vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
)
.await;
// Phase 3a: next user turn should also use the refreshed ETag.
let next_turn_response = mount_sse_once_match(
&server,
ResponsesMatch::default()
.with_etag(refreshed_etag)
.with_user_text(followup_prompt),
sse(vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-2", "ok"),
ev_completed("resp-3"),
]),
)
.await;
// Phase 3b: run the first user turn and let retries complete.
codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: first_prompt.into(),
}],
final_output_json_schema: None,
cwd: cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: "gpt-5.1".to_string(),
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
// Phase 3c: assert the refresh happened and the ETag was updated.
assert_eq!(
refreshed_models_mock.requests().len(),
1,
"expected a refreshed /models request",
);
assert_eq!(
models_manager.get_models_etag().await.as_deref(),
Some(refreshed_etag),
);
// Phase 3d: assert the ETag header progression across the retry sequence.
assert_eq!(
first_response.single_request().header("X-If-Models-Match"),
Some(initial_etag.to_string()),
);
assert_eq!(
stale_response.single_request().header("X-If-Models-Match"),
Some(initial_etag.to_string()),
);
assert_eq!(
refreshed_response
.single_request()
.header("X-If-Models-Match"),
Some(refreshed_etag.to_string()),
);
// Phase 3e: execute a new user turn and ensure the refreshed ETag persists.
codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: followup_prompt.into(),
}],
final_output_json_schema: None,
cwd: cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: "gpt-5.1".to_string(),
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
assert_eq!(
next_turn_response
.single_request()
.header("X-If-Models-Match"),
Some(refreshed_etag.to_string()),
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_models_preserve_builtin_presets() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -310,7 +598,6 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
&server,
ModelsResponse {
models: vec![remote_model.clone()],
etag: String::new(),
},
)
.await;
@@ -330,7 +617,7 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> {
);
manager
.refresh_available_models(&config)
.try_refresh_available_models(&config)
.await
.expect("refresh succeeds");
@@ -368,7 +655,6 @@ async fn remote_models_hide_picker_only_models() -> Result<()> {
&server,
ModelsResponse {
models: vec![remote_model],
etag: String::new(),
},
)
.await;

View File

@@ -197,8 +197,6 @@ pub struct ModelInfo {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)]
pub struct ModelsResponse {
pub models: Vec<ModelInfo>,
#[serde(default)]
pub etag: String,
}
// convert ModelInfo to ModelPreset