This commit is contained in:
Ahmed Ibrahim
2025-12-18 22:20:35 -08:00
parent 6912ba9fda
commit 348d379509
7 changed files with 65 additions and 50 deletions

View File

@@ -790,7 +790,7 @@ impl Session {
}
})
{
let curr = turn_context.client.get_model();
let curr = turn_context.client.get_model().await;
if prev != curr {
warn!(
"resuming session with different model: previous={prev}, current={curr}"
@@ -1338,7 +1338,7 @@ impl Session {
if let Some(token_usage) = token_usage {
state.update_token_info_from_usage(
token_usage,
turn_context.client.get_model_context_window(),
turn_context.client.get_model_context_window().await,
);
}
}
@@ -1350,6 +1350,7 @@ impl Session {
.clone_history()
.await
.estimate_token_count(turn_context)
.await
else {
return;
};
@@ -1370,7 +1371,7 @@ impl Session {
};
if info.model_context_window.is_none() {
info.model_context_window = turn_context.client.get_model_context_window();
info.model_context_window = turn_context.client.get_model_context_window().await;
}
state.set_token_info(Some(info));
@@ -1400,7 +1401,7 @@ impl Session {
}
pub(crate) async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
let context_window = turn_context.client.get_model_context_window();
let context_window = turn_context.client.get_model_context_window().await;
if let Some(context_window) = context_window {
{
let mut state = self.state.lock().await;
@@ -2226,7 +2227,7 @@ fn errors_to_info(errors: &[SkillError]) -> Vec<SkillErrorInfo> {
///
pub(crate) async fn run_task(
sess: Arc<Session>,
mut turn_context: Arc<TurnContext>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
@@ -2237,6 +2238,7 @@ pub(crate) async fn run_task(
let auto_compact_limit = turn_context
.client
.get_model_family()
.await
.auto_compact_token_limit()
.unwrap_or(i64::MAX);
let total_usage_tokens = sess.get_total_token_usage().await;
@@ -2244,7 +2246,7 @@ pub(crate) async fn run_task(
run_auto_compact(&sess, &turn_context).await;
}
let event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
model_context_window: turn_context.client.get_model_context_window().await,
});
sess.send_event(&turn_context, event).await;
@@ -2309,7 +2311,7 @@ pub(crate) async fn run_task(
.collect::<Vec<String>>();
match run_turn(
Arc::clone(&sess),
&mut turn_context,
&turn_context,
Arc::clone(&turn_diff_tracker),
turn_input,
cancellation_token.child_token(),
@@ -2370,7 +2372,7 @@ pub(crate) async fn run_task(
async fn refresh_models_and_reset_turn_context(
sess: &Arc<Session>,
turn_context: &mut Arc<TurnContext>,
turn_context: &Arc<TurnContext>,
) {
let config = {
let state = sess.state.lock().await;
@@ -2387,15 +2389,15 @@ async fn refresh_models_and_reset_turn_context(
{
error!("failed to refresh models after outdated models error: {err}");
}
let session_configuration = sess.state.lock().await.session_configuration.clone();
*turn_context = sess
.new_turn_from_configuration(
turn_context.sub_id.clone(),
session_configuration,
Some(turn_context.final_output_json_schema.clone()),
false,
)
let model = turn_context.client.get_model().await;
let model_family = sess
.services
.models_manager
.construct_model_family(&model, &config)
.await;
let models_etag = sess.services.models_manager.get_models_etag().await;
turn_context.client.update_model_family(model_family).await;
turn_context.client.update_models_etag(models_etag).await;
}
async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>) {
@@ -2410,13 +2412,13 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.client.get_model(),
model = %turn_context.client.get_model().await,
cwd = %turn_context.cwd.display()
)
)]
async fn run_turn(
sess: Arc<Session>,
turn_context: &mut Arc<TurnContext>,
turn_context: &Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
input: Vec<ResponseItem>,
cancellation_token: CancellationToken,
@@ -2454,7 +2456,7 @@ async fn run_turn(
Arc::clone(&sess),
Arc::clone(turn_context),
Arc::clone(&turn_diff_tracker),
&prompt,
&prompt.await,
cancellation_token.child_token(),
)
.await
@@ -2490,7 +2492,7 @@ async fn run_turn(
retries += 1;
// Refresh models if we got an outdated models error
if matches!(e, CodexErr::OutdatedModels) {
refresh_models_and_reset_turn_context(&sess, turn_context).await;
refresh_models_and_reset_turn_context(&sess, &turn_context).await;
continue;
}
let delay = match e {
@@ -2550,7 +2552,7 @@ async fn drain_in_flight(
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.client.get_model()
model = %turn_context.client.get_model().await,
)
)]
async fn try_run_turn(
@@ -2565,7 +2567,7 @@ async fn try_run_turn(
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
model: turn_context.client.get_model().await,
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});