This commit is contained in:
Ahmed Ibrahim
2025-12-18 19:02:58 -08:00
parent f8ba48d995
commit 09693d259b
3 changed files with 63 additions and 9 deletions

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::RwLock;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
@@ -53,12 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
model_family: RwLock<ModelFamily>,
models_etag: RwLock<Option<String>>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
conversation_id: ConversationId,
@@ -84,8 +85,8 @@ impl ModelClient {
Self {
config,
auth_manager,
model_family,
models_etag,
model_family: RwLock::new(model_family),
models_etag: RwLock::new(models_etag),
otel_manager,
provider,
conversation_id,
@@ -111,6 +112,22 @@ impl ModelClient {
&self.provider
}
pub fn update_models_etag(&self, models_etag: Option<String>) {
let mut guard = self
.models_etag
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = models_etag;
}
pub fn update_model_family(&self, model_family: ModelFamily) {
let mut guard = self
.model_family
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = model_family;
}
/// Streams a single model turn using either the Responses or Chat
/// Completions wire API, depending on the configured provider.
///
@@ -265,7 +282,7 @@ impl ModelClient {
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config, self.models_etag.clone()),
extra_headers: beta_feature_headers(&self.config, self.get_models_etag()),
};
let stream_result = client
@@ -306,7 +323,17 @@ impl ModelClient {
/// Returns the currently configured model family.
pub fn get_model_family(&self) -> ModelFamily {
self.model_family.clone()
self.model_family
.read()
.map(|model_family| model_family.clone())
.unwrap_or_else(|err| err.into_inner().clone())
}
fn get_models_etag(&self) -> Option<String> {
self.models_etag
.read()
.map(|models_etag| models_etag.clone())
.unwrap_or_else(|err| err.into_inner().clone())
}
/// Returns the current reasoning effort setting.

View File

@@ -2459,6 +2459,34 @@ async fn run_turn(
Err(e @ CodexErr::InvalidRequest(_)) => return Err(e),
Err(e @ CodexErr::RefreshTokenFailed(_)) => return Err(e),
Err(e) => {
// Refresh models if we got an outdated models error
if matches!(e, CodexErr::OutdatedModels) {
let config = {
let state = sess.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.clone()
};
if let Err(err) = sess
.services
.models_manager
.refresh_available_models(&config)
.await
{
error!("failed to refresh models after outdated models error: {err}");
}
let models_etag = sess.services.models_manager.get_models_etag().await;
let model = turn_context.client.get_model();
let model_family = sess
.services
.models_manager
.construct_model_family(&model, &config)
.await;
turn_context.client.update_model_family(model_family);
turn_context.client.update_models_etag(models_etag);
}
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries < max_retries {
@@ -2543,7 +2571,6 @@ async fn try_run_turn(
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context
.client
.clone()
.stream(prompt)
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)

View File

@@ -290,7 +290,7 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
let mut stream = turn_context.client.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {