mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
rwlock
This commit is contained in:
@@ -33,6 +33,7 @@ use http::StatusCode as HttpStatusCode;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -57,8 +58,8 @@ use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_family: ModelFamily,
|
||||
models_etag: Option<String>,
|
||||
model_family: RwLock<ModelFamily>,
|
||||
models_etag: RwLock<Option<String>>,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
conversation_id: ConversationId,
|
||||
@@ -84,8 +85,8 @@ impl ModelClient {
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
model_family,
|
||||
models_etag,
|
||||
model_family: RwLock::new(model_family),
|
||||
models_etag: RwLock::new(models_etag),
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
@@ -95,8 +96,8 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_model_context_window(&self) -> Option<i64> {
|
||||
let model_family = self.get_model_family();
|
||||
pub async fn get_model_context_window(&self) -> Option<i64> {
|
||||
let model_family = self.get_model_family().await;
|
||||
let effective_context_window_percent = model_family.effective_context_window_percent;
|
||||
model_family
|
||||
.context_window
|
||||
@@ -149,8 +150,8 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let instructions = prompt.get_full_instructions(model_family).into_owned();
|
||||
let model_family = self.get_model_family().await;
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
@@ -170,7 +171,7 @@ impl ModelClient {
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(
|
||||
&self.get_model(),
|
||||
&self.get_model().await,
|
||||
&api_prompt,
|
||||
Some(conversation_id.clone()),
|
||||
Some(session_source.clone()),
|
||||
@@ -203,8 +204,8 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let instructions = prompt.get_full_instructions(model_family).into_owned();
|
||||
let model_family = self.get_model_family().await;
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
||||
let reasoning = if model_family.supports_reasoning_summaries {
|
||||
@@ -265,11 +266,14 @@ impl ModelClient {
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config, self.get_models_etag().clone()),
|
||||
extra_headers: beta_feature_headers(
|
||||
&self.config,
|
||||
self.get_models_etag().await.clone(),
|
||||
),
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(&self.get_model(), &api_prompt, options)
|
||||
.stream_prompt(&self.get_model().await, &api_prompt, options)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
@@ -300,17 +304,25 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Returns the currently configured model slug.
|
||||
pub fn get_model(&self) -> String {
|
||||
self.get_model_family().get_model_slug().to_string()
|
||||
pub async fn get_model(&self) -> String {
|
||||
self.get_model_family().await.get_model_slug().to_string()
|
||||
}
|
||||
|
||||
/// Returns the currently configured model family.
|
||||
pub fn get_model_family(&self) -> &ModelFamily {
|
||||
&self.model_family
|
||||
pub async fn get_model_family(&self) -> ModelFamily {
|
||||
self.model_family.read().await.clone()
|
||||
}
|
||||
|
||||
fn get_models_etag(&self) -> &Option<String> {
|
||||
&self.models_etag
|
||||
pub async fn get_models_etag(&self) -> Option<String> {
|
||||
self.models_etag.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn update_models_etag(&self, etag: Option<String>) {
|
||||
*self.models_etag.write().await = etag;
|
||||
}
|
||||
|
||||
pub async fn update_model_family(&self, model_family: ModelFamily) {
|
||||
*self.model_family.write().await = model_family;
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
@@ -347,10 +359,10 @@ impl ModelClient {
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let instructions = prompt
|
||||
.get_full_instructions(self.get_model_family())
|
||||
.get_full_instructions(&self.get_model_family().await)
|
||||
.into_owned();
|
||||
let payload = ApiCompactionInput {
|
||||
model: &self.get_model(),
|
||||
model: &self.get_model().await,
|
||||
input: &prompt.input,
|
||||
instructions: &instructions,
|
||||
};
|
||||
|
||||
@@ -48,7 +48,7 @@ pub struct Prompt {
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn new(
|
||||
pub(crate) async fn new(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
router: &ToolRouter,
|
||||
@@ -57,6 +57,7 @@ impl Prompt {
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.await
|
||||
.supports_parallel_tool_calls;
|
||||
|
||||
Prompt {
|
||||
|
||||
@@ -790,7 +790,7 @@ impl Session {
|
||||
}
|
||||
})
|
||||
{
|
||||
let curr = turn_context.client.get_model();
|
||||
let curr = turn_context.client.get_model().await;
|
||||
if prev != curr {
|
||||
warn!(
|
||||
"resuming session with different model: previous={prev}, current={curr}"
|
||||
@@ -1338,7 +1338,7 @@ impl Session {
|
||||
if let Some(token_usage) = token_usage {
|
||||
state.update_token_info_from_usage(
|
||||
token_usage,
|
||||
turn_context.client.get_model_context_window(),
|
||||
turn_context.client.get_model_context_window().await,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1350,6 +1350,7 @@ impl Session {
|
||||
.clone_history()
|
||||
.await
|
||||
.estimate_token_count(turn_context)
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
@@ -1370,7 +1371,7 @@ impl Session {
|
||||
};
|
||||
|
||||
if info.model_context_window.is_none() {
|
||||
info.model_context_window = turn_context.client.get_model_context_window();
|
||||
info.model_context_window = turn_context.client.get_model_context_window().await;
|
||||
}
|
||||
|
||||
state.set_token_info(Some(info));
|
||||
@@ -1400,7 +1401,7 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
|
||||
let context_window = turn_context.client.get_model_context_window();
|
||||
let context_window = turn_context.client.get_model_context_window().await;
|
||||
if let Some(context_window) = context_window {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
@@ -2226,7 +2227,7 @@ fn errors_to_info(errors: &[SkillError]) -> Vec<SkillErrorInfo> {
|
||||
///
|
||||
pub(crate) async fn run_task(
|
||||
sess: Arc<Session>,
|
||||
mut turn_context: Arc<TurnContext>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
@@ -2237,6 +2238,7 @@ pub(crate) async fn run_task(
|
||||
let auto_compact_limit = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.await
|
||||
.auto_compact_token_limit()
|
||||
.unwrap_or(i64::MAX);
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
@@ -2244,7 +2246,7 @@ pub(crate) async fn run_task(
|
||||
run_auto_compact(&sess, &turn_context).await;
|
||||
}
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
|
||||
@@ -2309,7 +2311,7 @@ pub(crate) async fn run_task(
|
||||
.collect::<Vec<String>>();
|
||||
match run_turn(
|
||||
Arc::clone(&sess),
|
||||
&mut turn_context,
|
||||
&turn_context,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
turn_input,
|
||||
cancellation_token.child_token(),
|
||||
@@ -2370,7 +2372,7 @@ pub(crate) async fn run_task(
|
||||
|
||||
async fn refresh_models_and_reset_turn_context(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &mut Arc<TurnContext>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
) {
|
||||
let config = {
|
||||
let state = sess.state.lock().await;
|
||||
@@ -2387,15 +2389,15 @@ async fn refresh_models_and_reset_turn_context(
|
||||
{
|
||||
error!("failed to refresh models after outdated models error: {err}");
|
||||
}
|
||||
let session_configuration = sess.state.lock().await.session_configuration.clone();
|
||||
*turn_context = sess
|
||||
.new_turn_from_configuration(
|
||||
turn_context.sub_id.clone(),
|
||||
session_configuration,
|
||||
Some(turn_context.final_output_json_schema.clone()),
|
||||
false,
|
||||
)
|
||||
let model = turn_context.client.get_model().await;
|
||||
let model_family = sess
|
||||
.services
|
||||
.models_manager
|
||||
.construct_model_family(&model, &config)
|
||||
.await;
|
||||
let models_etag = sess.services.models_manager.get_models_etag().await;
|
||||
turn_context.client.update_model_family(model_family).await;
|
||||
turn_context.client.update_models_etag(models_etag).await;
|
||||
}
|
||||
|
||||
async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>) {
|
||||
@@ -2410,13 +2412,13 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.client.get_model(),
|
||||
model = %turn_context.client.get_model().await,
|
||||
cwd = %turn_context.cwd.display()
|
||||
)
|
||||
)]
|
||||
async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: &mut Arc<TurnContext>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -2454,7 +2456,7 @@ async fn run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
&prompt.await,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
@@ -2490,7 +2492,7 @@ async fn run_turn(
|
||||
retries += 1;
|
||||
// Refresh models if we got an outdated models error
|
||||
if matches!(e, CodexErr::OutdatedModels) {
|
||||
refresh_models_and_reset_turn_context(&sess, turn_context).await;
|
||||
refresh_models_and_reset_turn_context(&sess, &turn_context).await;
|
||||
continue;
|
||||
}
|
||||
let delay = match e {
|
||||
@@ -2550,7 +2552,7 @@ async fn drain_in_flight(
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.client.get_model()
|
||||
model = %turn_context.client.get_model().await,
|
||||
)
|
||||
)]
|
||||
async fn try_run_turn(
|
||||
@@ -2565,7 +2567,7 @@ async fn try_run_turn(
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
model: turn_context.client.get_model().await,
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
|
||||
@@ -55,7 +55,7 @@ pub(crate) async fn run_compact_task(
|
||||
input: Vec<UserInput>,
|
||||
) {
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, input).await;
|
||||
@@ -83,7 +83,7 @@ async fn run_compact_task_inner(
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
model: turn_context.client.get_model().await,
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
|
||||
@@ -20,7 +20,7 @@ pub(crate) async fn run_inline_remote_auto_compact_task(
|
||||
|
||||
pub(crate) async fn run_remote_compact_task(sess: Arc<Session>, turn_context: Arc<TurnContext>) {
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ impl ContextManager {
|
||||
|
||||
// Estimate token usage using byte-based heuristics from the truncation helpers.
|
||||
// This is a coarse lower bound, not a tokenizer-accurate count.
|
||||
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
|
||||
let model_family = turn_context.client.get_model_family();
|
||||
pub(crate) async fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
|
||||
let model_family = turn_context.client.get_model_family().await;
|
||||
let base_tokens =
|
||||
i64::try_from(approx_token_count(model_family.base_instructions.as_str()))
|
||||
.unwrap_or(i64::MAX);
|
||||
|
||||
@@ -59,7 +59,7 @@ impl SessionTask for UserShellCommandTask {
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
model_context_window: turn_context.client.get_model_context_window().await,
|
||||
});
|
||||
let session = session.clone_session();
|
||||
session.send_event(turn_context.as_ref(), event).await;
|
||||
|
||||
Reference in New Issue
Block a user