mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
rwlock
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
@@ -53,12 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_family: ModelFamily,
|
||||
models_etag: Option<String>,
|
||||
model_family: RwLock<ModelFamily>,
|
||||
models_etag: RwLock<Option<String>>,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
conversation_id: ConversationId,
|
||||
@@ -84,8 +85,8 @@ impl ModelClient {
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
model_family,
|
||||
models_etag,
|
||||
model_family: RwLock::new(model_family),
|
||||
models_etag: RwLock::new(models_etag),
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
@@ -111,6 +112,22 @@ impl ModelClient {
|
||||
&self.provider
|
||||
}
|
||||
|
||||
pub fn update_models_etag(&self, models_etag: Option<String>) {
|
||||
let mut guard = self
|
||||
.models_etag
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = models_etag;
|
||||
}
|
||||
|
||||
pub fn update_model_family(&self, model_family: ModelFamily) {
|
||||
let mut guard = self
|
||||
.model_family
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = model_family;
|
||||
}
|
||||
|
||||
/// Streams a single model turn using either the Responses or Chat
|
||||
/// Completions wire API, depending on the configured provider.
|
||||
///
|
||||
@@ -265,7 +282,7 @@ impl ModelClient {
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config, self.models_etag.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config, self.get_models_etag()),
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
@@ -306,7 +323,17 @@ impl ModelClient {
|
||||
|
||||
/// Returns the currently configured model family.
|
||||
pub fn get_model_family(&self) -> ModelFamily {
|
||||
self.model_family.clone()
|
||||
self.model_family
|
||||
.read()
|
||||
.map(|model_family| model_family.clone())
|
||||
.unwrap_or_else(|err| err.into_inner().clone())
|
||||
}
|
||||
|
||||
fn get_models_etag(&self) -> Option<String> {
|
||||
self.models_etag
|
||||
.read()
|
||||
.map(|models_etag| models_etag.clone())
|
||||
.unwrap_or_else(|err| err.into_inner().clone())
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
|
||||
@@ -2459,6 +2459,34 @@ async fn run_turn(
|
||||
Err(e @ CodexErr::InvalidRequest(_)) => return Err(e),
|
||||
Err(e @ CodexErr::RefreshTokenFailed(_)) => return Err(e),
|
||||
Err(e) => {
|
||||
// Refresh models if we got an outdated models error
|
||||
if matches!(e, CodexErr::OutdatedModels) {
|
||||
let config = {
|
||||
let state = sess.state.lock().await;
|
||||
state
|
||||
.session_configuration
|
||||
.original_config_do_not_use
|
||||
.clone()
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.services
|
||||
.models_manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
{
|
||||
error!("failed to refresh models after outdated models error: {err}");
|
||||
}
|
||||
let models_etag = sess.services.models_manager.get_models_etag().await;
|
||||
let model = turn_context.client.get_model();
|
||||
let model_family = sess
|
||||
.services
|
||||
.models_manager
|
||||
.construct_model_family(&model, &config)
|
||||
.await;
|
||||
turn_context.client.update_model_family(model_family);
|
||||
turn_context.client.update_models_etag(models_etag);
|
||||
}
|
||||
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
@@ -2543,7 +2571,6 @@ async fn try_run_turn(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream(prompt)
|
||||
.instrument(trace_span!("stream_request"))
|
||||
.or_cancel(&cancellation_token)
|
||||
|
||||
@@ -290,7 +290,7 @@ async fn drain_to_completed(
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let mut stream = turn_context.client.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
Reference in New Issue
Block a user