Allow turn context refresh between sampling requests

Reload the active turn context before each sampling request so nudges and other mid-turn setting changes can take effect without mutating TurnContext in place. Move mid-turn compaction to the start of follow-up requests, skip separate diff injection when compaction re-establishes full context, and persist TurnContextItem baselines only when the active context actually changes.

Also refresh the active turn context after override_turn_context, reset the reused model client session on mid-turn model changes, and key model-switch diffs off the persisted reference context so same-turn switch-backs still re-inject model instructions.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Charles Cunningham
2026-03-01 10:19:00 -08:00
parent 9950b5e265
commit 2327519d34
3 changed files with 366 additions and 70 deletions

View File

@@ -2053,31 +2053,75 @@ impl Session {
&self,
updates: SessionSettingsUpdate,
) -> ConstraintResult<()> {
let mut state = self.state.lock().await;
match state.session_configuration.apply(&updates) {
Ok(updated) => {
let previous_cwd = state.session_configuration.cwd.clone();
let next_cwd = updated.cwd.clone();
let codex_home = updated.codex_home.clone();
let session_source = updated.session_source.clone();
state.session_configuration = updated;
drop(state);
self.maybe_refresh_shell_snapshot_for_cwd(
&previous_cwd,
&next_cwd,
&codex_home,
&session_source,
);
Ok(())
let (
session_configuration,
previous_cwd,
sandbox_policy_changed,
approval_policy,
codex_home,
session_source,
) = {
let mut state = self.state.lock().await;
match state.session_configuration.apply(&updates) {
Ok(updated) => {
let previous_cwd = state.session_configuration.cwd.clone();
let sandbox_policy_changed =
state.session_configuration.sandbox_policy != updated.sandbox_policy;
let approval_policy = updated.approval_policy.clone();
let codex_home = updated.codex_home.clone();
let session_source = updated.session_source.clone();
state.session_configuration = updated.clone();
(
updated,
previous_cwd,
sandbox_policy_changed,
approval_policy,
codex_home,
session_source,
)
}
Err(err) => {
warn!("rejected session settings update: {err}");
return Err(err);
}
}
Err(err) => {
warn!("rejected session settings update: {err}");
Err(err)
};
self.maybe_refresh_shell_snapshot_for_cwd(
&previous_cwd,
&session_configuration.cwd,
&codex_home,
&session_source,
);
self.services
.mcp_connection_manager
.read()
.await
.set_approval_policy(&approval_policy);
if sandbox_policy_changed {
let per_turn_config = Self::build_per_turn_config(&session_configuration);
let sandbox_state = SandboxState {
sandbox_policy: per_turn_config.permissions.sandbox_policy.get().clone(),
codex_linux_sandbox_exe: per_turn_config.codex_linux_sandbox_exe.clone(),
sandbox_cwd: per_turn_config.cwd.clone(),
use_linux_sandbox_bwrap: per_turn_config
.features
.enabled(Feature::UseLinuxSandboxBwrap),
};
if let Err(err) = self
.services
.mcp_connection_manager
.read()
.await
.notify_sandbox_state_change(&sandbox_state)
.await
{
warn!("Failed to notify sandbox state change to MCP servers: {err:#}");
}
}
Ok(())
}
pub(crate) async fn new_turn_with_sub_id(
@@ -2215,6 +2259,100 @@ impl Session {
turn_context
}
async fn build_updated_turn_context(
&self,
current_turn_context: &TurnContext,
session_configuration: &SessionConfiguration,
) -> Arc<TurnContext> {
let per_turn_config = Self::build_per_turn_config(session_configuration);
let model_info = self
.services
.models_manager
.get_model_info(
session_configuration.collaboration_mode.model(),
&per_turn_config,
)
.await;
let reasoning_effort = session_configuration.collaboration_mode.reasoning_effort();
let reasoning_summary = session_configuration
.model_reasoning_summary
.unwrap_or(model_info.default_reasoning_summary);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,
features: &per_turn_config.features,
web_search_mode: Some(per_turn_config.web_search_mode.value()),
session_source: current_turn_context.session_source.clone(),
})
.with_allow_login_shell(per_turn_config.permissions.allow_login_shell)
.with_agent_roles(per_turn_config.agent_roles.clone());
let turn_metadata_state = Arc::new(TurnMetadataState::new(
current_turn_context.sub_id.clone(),
session_configuration.cwd.clone(),
session_configuration.sandbox_policy.get(),
session_configuration.windows_sandbox_level,
per_turn_config
.features
.enabled(Feature::UseLinuxSandboxBwrap),
));
turn_metadata_state.spawn_git_enrichment_task();
Arc::new(TurnContext {
sub_id: current_turn_context.sub_id.clone(),
config: Arc::new(per_turn_config.clone()),
auth_manager: current_turn_context.auth_manager.clone(),
model_info: model_info.clone(),
otel_manager: self.services.otel_manager.clone().with_model(
session_configuration.collaboration_mode.model(),
model_info.slug.as_str(),
),
provider: session_configuration.provider.clone(),
reasoning_effort,
reasoning_summary,
session_source: current_turn_context.session_source.clone(),
cwd: session_configuration.cwd.clone(),
current_date: current_turn_context.current_date.clone(),
timezone: current_turn_context.timezone.clone(),
app_server_client_name: session_configuration.app_server_client_name.clone(),
developer_instructions: session_configuration.developer_instructions.clone(),
compact_prompt: session_configuration.compact_prompt.clone(),
user_instructions: session_configuration.user_instructions.clone(),
collaboration_mode: session_configuration.collaboration_mode.clone(),
personality: session_configuration.personality,
approval_policy: session_configuration.approval_policy.clone(),
sandbox_policy: session_configuration.sandbox_policy.clone(),
network: current_turn_context.network.clone(),
windows_sandbox_level: session_configuration.windows_sandbox_level,
shell_environment_policy: per_turn_config.permissions.shell_environment_policy.clone(),
tools_config,
features: per_turn_config.features.clone(),
ghost_snapshot: per_turn_config.ghost_snapshot.clone(),
final_output_json_schema: current_turn_context.final_output_json_schema.clone(),
codex_linux_sandbox_exe: per_turn_config.codex_linux_sandbox_exe.clone(),
tool_call_gate: Arc::clone(&current_turn_context.tool_call_gate),
truncation_policy: model_info.truncation_policy.into(),
js_repl: Arc::clone(&current_turn_context.js_repl),
dynamic_tools: session_configuration.dynamic_tools.clone(),
turn_metadata_state,
turn_skills: current_turn_context.turn_skills.clone(),
})
}
async fn refresh_current_active_turn_context_from_session_configuration(&self) {
let Some(current_turn_context) = self.current_active_turn_context().await else {
return;
};
let session_configuration = {
let state = self.state.lock().await;
state.session_configuration.clone()
};
let next_turn_context = self
.build_updated_turn_context(current_turn_context.as_ref(), &session_configuration)
.await;
let _ = self
.set_current_active_turn_context(next_turn_context)
.await;
}
pub(crate) async fn maybe_emit_unknown_model_warning_for_turn(&self, tc: &TurnContext) {
if tc.model_info.used_fallback_model_metadata {
self.send_event(
@@ -2382,9 +2520,21 @@ impl Session {
};
let shell = self.user_shell();
let exec_policy = self.services.exec_policy.current();
let effective_previous_turn_settings = match (
reference_context_item,
previous_turn_settings,
) {
(Some(item), previous_turn_settings) => Some(PreviousTurnSettings {
model: item.model.clone(),
realtime_active: item
.realtime_active
.or(previous_turn_settings.and_then(|settings| settings.realtime_active)),
}),
(None, previous_turn_settings) => previous_turn_settings,
};
crate::context_manager::updates::build_settings_update_items(
reference_context_item,
previous_turn_settings.as_ref(),
effective_previous_turn_settings.as_ref(),
current_context,
shell.as_ref(),
exec_policy.as_ref(),
@@ -2519,19 +2669,44 @@ impl Session {
pub(crate) async fn turn_context_for_sub_id(&self, sub_id: &str) -> Option<Arc<TurnContext>> {
let active = self.active_turn.lock().await;
active
.as_ref()
.and_then(|turn| turn.tasks.get(sub_id))
.map(|task| Arc::clone(&task.turn_context))
active.as_ref().and_then(|turn| {
turn.tasks.get(sub_id).map(|task| {
turn.current_turn_context
.clone()
.unwrap_or_else(|| Arc::clone(&task.turn_context))
})
})
}
async fn current_active_turn_context(&self) -> Option<Arc<TurnContext>> {
let active = self.active_turn.lock().await;
let turn = active.as_ref()?;
turn.current_turn_context.clone().or_else(|| {
turn.tasks
.first()
.map(|(_, task)| Arc::clone(&task.turn_context))
})
}
async fn set_current_active_turn_context(&self, turn_context: Arc<TurnContext>) -> bool {
let mut active = self.active_turn.lock().await;
let Some(turn) = active.as_mut() else {
return false;
};
turn.current_turn_context = Some(turn_context);
true
}
async fn active_turn_context_and_cancellation_token(
&self,
) -> Option<(Arc<TurnContext>, CancellationToken)> {
let active = self.active_turn.lock().await;
let (_, task) = active.as_ref()?.tasks.first()?;
let turn = active.as_ref()?;
let (_, task) = turn.tasks.first()?;
Some((
Arc::clone(&task.turn_context),
turn.current_turn_context
.clone()
.unwrap_or_else(|| Arc::clone(&task.turn_context)),
task.cancellation_token.child_token(),
))
}
@@ -3174,8 +3349,24 @@ impl Session {
state.reference_context_item()
}
/// Persist the latest turn context snapshot for the first real user turn and for
/// steady-state turns that emit model-visible context updates.
async fn maybe_record_context_updates_for_turn(
&self,
turn_context: &TurnContext,
) {
let current_context_item = turn_context.to_turn_context_item();
let reference_context_item = {
let state = self.state.lock().await;
state.reference_context_item()
};
if reference_context_item.as_ref() == Some(&current_context_item) {
return;
}
self.record_context_updates_and_set_reference_context_item(turn_context)
.await;
}
/// Persist the latest turn context snapshot whenever committed model-visible context changes.
///
/// When the reference snapshot is missing, this injects full initial context. Otherwise, it
/// emits only settings diff items.
@@ -3208,8 +3399,9 @@ impl Session {
self.record_conversation_items(turn_context, &context_items)
.await;
}
// Persist one `TurnContextItem` per real user turn so resume/lazy replay can recover the
// latest durable baseline even when this turn emitted no model-visible context diffs.
// Persist one `TurnContextItem` per committed model-visible context update so resume/lazy
// replay can recover the latest durable baseline even when this update emitted no
// model-visible context diffs.
self.persist_rollout_items(&[RolloutItem::TurnContext(turn_context_item.clone())])
.await;
@@ -3999,7 +4191,11 @@ mod handlers {
}),
})
.await;
return;
}
sess.refresh_current_active_turn_context_from_session_configuration()
.await;
}
pub async fn user_input_or_turn(sess: &Arc<Session>, sub_id: String, op: Op) {
@@ -4916,9 +5112,6 @@ pub(crate) async fn run_turn(
return None;
}
let model_info = turn_context.model_info.clone();
let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX);
let event = EventMsg::TurnStarted(TurnStartedEvent {
turn_id: turn_context.sub_id.clone(),
model_context_window: turn_context.model_context_window(),
@@ -4939,7 +5132,7 @@ pub(crate) async fn run_turn(
let skills_outcome = Some(turn_context.turn_skills.outcome.as_ref());
sess.record_context_updates_and_set_reference_context_item(turn_context.as_ref())
sess.maybe_record_context_updates_for_turn(turn_context.as_ref())
.await;
let available_connectors = if turn_context.config.features.enabled(Feature::Apps) {
@@ -5077,8 +5270,57 @@ pub(crate) async fn run_turn(
// one instance across retries within this turn.
let mut client_session =
prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session());
let mut client_session_model_slug = turn_context.model_info.slug.clone();
let initial_turn_context = Arc::clone(&turn_context);
let mut is_first_sampling_request = true;
loop {
let turn_context = sess
.current_active_turn_context()
.await
.unwrap_or_else(|| Arc::clone(&initial_turn_context));
let auto_compact_limit = turn_context
.model_info
.auto_compact_token_limit()
.unwrap_or(i64::MAX);
if !is_first_sampling_request {
let total_usage_tokens = sess.get_total_token_usage().await;
let token_limit_reached = total_usage_tokens >= auto_compact_limit;
let estimated_token_count = sess.get_estimated_token_count(turn_context.as_ref()).await;
trace!(
turn_id = %turn_context.sub_id,
total_usage_tokens,
estimated_token_count = ?estimated_token_count,
auto_compact_limit,
token_limit_reached,
"pre sampling token usage"
);
if token_limit_reached {
if run_auto_compact(
&sess,
&turn_context,
InitialContextInjection::BeforeLastUserMessage,
)
.await
.is_err()
{
return None;
}
continue;
}
}
if client_session_model_slug != turn_context.model_info.slug {
client_session = sess.services.model_client.new_session();
client_session_model_slug = turn_context.model_info.slug.clone();
server_model_warning_emitted_for_turn = false;
}
if !is_first_sampling_request {
sess.maybe_record_context_updates_for_turn(turn_context.as_ref())
.await;
}
is_first_sampling_request = false;
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
// may support this, the model might not.
@@ -5145,8 +5387,6 @@ pub(crate) async fn run_turn(
last_agent_message: sampling_request_last_agent_message,
} = sampling_request_output;
let total_usage_tokens = sess.get_total_token_usage().await;
let token_limit_reached = total_usage_tokens >= auto_compact_limit;
let estimated_token_count =
sess.get_estimated_token_count(turn_context.as_ref()).await;
@@ -5155,26 +5395,10 @@ pub(crate) async fn run_turn(
total_usage_tokens,
estimated_token_count = ?estimated_token_count,
auto_compact_limit,
token_limit_reached,
needs_follow_up,
"post sampling token usage"
);
// as long as compaction works well in getting us way below the token limit, we shouldn't worry about being in an infinite loop.
if token_limit_reached && needs_follow_up {
if run_auto_compact(
&sess,
&turn_context,
InitialContextInjection::BeforeLastUserMessage,
)
.await
.is_err()
{
return None;
}
continue;
}
if !needs_follow_up {
last_agent_message = sampling_request_last_agent_message;
let hook_outcomes = sess
@@ -9038,6 +9262,45 @@ mod tests {
assert_eq!(sess.previous_turn_settings().await, None);
}
#[tokio::test]
async fn override_turn_context_updates_active_turn_context() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
let mut active_turn = crate::state::ActiveTurn::default();
active_turn.current_turn_context = Some(Arc::clone(&tc));
*sess.active_turn.lock().await = Some(active_turn);
let next_model = if tc.model_info.slug == "gpt-5.1" {
"gpt-5"
} else {
"gpt-5.1"
};
let next_cwd = PathBuf::from("/tmp/updated-turn-context");
handlers::override_turn_context(
sess.as_ref(),
"override".to_string(),
SessionSettingsUpdate {
cwd: Some(next_cwd.clone()),
collaboration_mode: Some(tc.collaboration_mode.with_updates(
Some(next_model.to_string()),
None,
None,
)),
..Default::default()
},
)
.await;
let updated = sess
.current_active_turn_context()
.await
.expect("updated active turn context");
assert_eq!(updated.cwd, next_cwd);
assert_eq!(updated.model_info.slug, next_model);
assert_eq!(updated.collaboration_mode.model(), next_model);
assert!(Arc::ptr_eq(&updated.tool_call_gate, &tc.tool_call_gate));
}
#[tokio::test]
async fn build_settings_update_items_emits_environment_item_for_network_changes() {
let (session, previous_context) = make_session_and_context().await;
@@ -9350,14 +9613,6 @@ mod tests {
async fn record_context_updates_and_set_reference_context_item_persists_baseline_without_emitting_diffs()
{
let (session, previous_context) = make_session_and_context().await;
let next_model = if previous_context.model_info.slug == "gpt-5.1" {
"gpt-5"
} else {
"gpt-5.1"
};
let turn_context = previous_context
.with_model(next_model.to_string(), &session.services.models_manager)
.await;
let previous_context_item = previous_context.to_turn_context_item();
{
let mut state = session.state.lock().await;
@@ -9401,7 +9656,7 @@ mod tests {
assert_eq!(
serde_json::to_value(session.reference_context_item().await)
.expect("serialize current context item"),
serde_json::to_value(Some(turn_context.to_turn_context_item()))
serde_json::to_value(Some(previous_context.to_turn_context_item()))
.expect("serialize expected context item")
);
session.ensure_rollout_materialized().await;
@@ -9420,7 +9675,7 @@ mod tests {
assert_eq!(
serde_json::to_value(persisted_turn_context)
.expect("serialize persisted turn context item"),
serde_json::to_value(Some(turn_context.to_turn_context_item()))
serde_json::to_value(Some(previous_context.to_turn_context_item()))
.expect("serialize expected turn context item")
);
}
@@ -9448,6 +9703,45 @@ mod tests {
assert!(text.contains("<model_switch>"));
}
#[tokio::test]
async fn build_settings_update_items_uses_reference_context_model_for_switch_back() {
let (session, previous_context) = make_session_and_context().await;
let next_model = if previous_context.model_info.slug == "gpt-5.1" {
"gpt-5"
} else {
"gpt-5.1"
};
let switched_context = previous_context
.with_model(next_model.to_string(), &session.services.models_manager)
.await;
let switched_back_context = switched_context
.with_model(
previous_context.model_info.slug.clone(),
&session.services.models_manager,
)
.await;
let update_items = session.build_settings_update_items(
Some(&switched_context.to_turn_context_item()),
Some(previous_context.model_info.slug.as_str()),
&switched_back_context,
);
let developer_text = update_items
.iter()
.find_map(|item| match item {
ResponseItem::Message { role, content, .. } if role == "developer" => {
content.iter().find_map(|content| match content {
ContentItem::InputText { text } => Some(text.as_str()),
_ => None,
})
}
_ => None,
})
.expect("developer update item");
assert!(developer_text.contains("<model_switch>"));
}
#[tokio::test]
async fn record_context_updates_and_set_reference_context_item_persists_full_reinjection_to_rollout()
{