Use queued turn context updates in app server

This commit is contained in:
Eric Traut
2026-05-15 18:05:01 -07:00
parent ba863d1fc8
commit 8c4e73d322
4 changed files with 247 additions and 84 deletions

View File

@@ -383,6 +383,8 @@ use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionConfiguredEvent;
#[cfg(test)]
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::Submission;
use codex_protocol::protocol::TurnContextOverrides;
use codex_protocol::protocol::TurnEnvironmentSelection;
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
use codex_protocol::protocol::W3cTraceContext;

View File

@@ -47,7 +47,25 @@ impl TurnContextOverrideRequest {
}
}
const TURN_STARTED_ACK_TIMEOUT: Duration = Duration::from_secs(5);
fn op_turn_context_overrides(overrides: CodexThreadTurnContextOverrides) -> TurnContextOverrides {
TurnContextOverrides {
cwd: overrides.cwd,
approval_policy: overrides.approval_policy,
approvals_reviewer: overrides.approvals_reviewer,
sandbox_policy: overrides.sandbox_policy,
permission_profile: overrides.permission_profile,
active_permission_profile: overrides.active_permission_profile,
windows_sandbox_level: overrides.windows_sandbox_level,
model: overrides.model,
effort: overrides.effort,
summary: overrides.summary,
service_tier: overrides.service_tier,
collaboration_mode: overrides.collaboration_mode,
personality: overrides.personality,
}
}
const TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT: Duration = Duration::from_secs(5);
impl TurnRequestProcessor {
#[allow(clippy::too_many_arguments)]
@@ -538,19 +556,7 @@ impl TurnRequestProcessor {
environments: environment_selections,
final_output_json_schema: params.output_schema,
responsesapi_client_metadata: params.responsesapi_client_metadata,
cwd: overrides.cwd,
approval_policy: overrides.approval_policy,
approvals_reviewer: overrides.approvals_reviewer,
sandbox_policy: overrides.sandbox_policy,
permission_profile: overrides.permission_profile,
active_permission_profile: overrides.active_permission_profile,
windows_sandbox_level: overrides.windows_sandbox_level,
model: overrides.model,
effort: overrides.effort,
summary: overrides.summary,
service_tier: overrides.service_tier,
collaboration_mode: overrides.collaboration_mode,
personality: overrides.personality,
turn_context: op_turn_context_overrides(overrides),
}
} else {
Op::UserInput {
@@ -560,46 +566,69 @@ impl TurnRequestProcessor {
responsesapi_client_metadata: params.responsesapi_client_metadata,
}
};
let turn_id = self
.submit_core_op(&request_id, thread.as_ref(), turn_op)
.await
.map_err(|err| {
let error = internal_error(format!("failed to start turn: {err}"));
self.track_error_response(&request_id, &error, /*error_type*/ None);
error
})?;
if has_turn_context_overrides {
// The queued UserInputWithTurnContext owns the sticky context
// mutation. Wait for core to start processing that turn before
// reporting the effective state, otherwise a later direct update
// can appear to win and then be overwritten by this turn.
let turn_id = Uuid::now_v7().to_string();
let turn_context_applied = if has_turn_context_overrides {
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let turn_started = {
Some({
let mut thread_state = thread_state.lock().await;
thread_state.turn_started_receiver(&turn_id)
};
if let Some(turn_started) = turn_started {
// Bound how long the RPC waits for the core turn-start acknowledgement.
tokio::time::timeout(TURN_STARTED_ACK_TIMEOUT, turn_started)
.await
.map_err(|_| {
internal_error(
"timed out waiting for turn context overrides to apply".to_string(),
)
})?
.map_err(|_| {
internal_error("turn context override waiter was cancelled".to_string())
})?;
thread_state.track_pending_turn_context(turn_id.clone())
})
} else {
None
};
if let Err(err) = thread
.submit_with_id(Submission {
id: turn_id.clone(),
op: turn_op,
trace: self.request_trace_context(&request_id).await,
})
.await
{
if has_turn_context_overrides {
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let mut thread_state = thread_state.lock().await;
thread_state.cancel_pending_turn_context(&turn_id);
}
let after_turn_context =
thread_turn_context_from_snapshot(&thread.config_snapshot().await);
self.maybe_emit_turn_context_updated(
&params.thread_id,
&before_turn_context,
after_turn_context,
)
.await;
let error = internal_error(format!("failed to start turn: {err}"));
self.track_error_response(&request_id, &error, /*error_type*/ None);
return Err(error);
}
if let Some(turn_context_applied) = turn_context_applied {
let processor = self.clone();
let api_thread_id = params.thread_id.clone();
let tracked_turn_id = turn_id.clone();
tokio::spawn(async move {
match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied)
.await
{
Ok(Ok(Ok(payload))) => {
let after_turn_context = thread_turn_context_from_applied_event(&payload);
processor
.maybe_emit_turn_context_updated(
&api_thread_id,
&before_turn_context,
after_turn_context,
)
.await;
}
Ok(Ok(Err(err))) => {
tracing::warn!(
"failed to apply turn context overrides for turn {tracked_turn_id}: {err}"
);
}
Ok(Err(_)) => {
tracing::warn!(
"turn context override acknowledgement was cancelled for turn {tracked_turn_id}"
);
}
Err(_) => {
tracing::warn!(
"timed out waiting for turn context overrides to apply for turn {tracked_turn_id}"
);
}
}
});
}
if turn_has_input {
@@ -636,12 +665,12 @@ impl TurnRequestProcessor {
request_id: &ConnectionRequestId,
params: ThreadTurnContextUpdateParams,
) -> Result<ThreadTurnContextUpdateResponse, JSONRPCErrorError> {
let (_, thread) = self
.load_thread(&params.thread_id)
.await
.inspect_err(|error| {
self.track_error_response(request_id, error, /*error_type*/ None);
})?;
let (thread_id, thread) =
self.load_thread(&params.thread_id)
.await
.inspect_err(|error| {
self.track_error_response(request_id, error, /*error_type*/ None);
})?;
let before_snapshot = thread.config_snapshot().await;
let before_turn_context = thread_turn_context_from_snapshot(&before_snapshot);
let resolved_overrides = self
@@ -663,17 +692,53 @@ impl TurnRequestProcessor {
)
.await?;
let after_snapshot = if let Some(overrides) = resolved_overrides {
// There is no queued turn to order against here, so applying
// directly gives the caller a synchronized response snapshot.
let after_turn_context = if let Some(overrides) = resolved_overrides {
thread
.update_turn_context_overrides(overrides)
.preview_turn_context_overrides(overrides.clone())
.await
.map_err(|err| invalid_request(format!("invalid turn context override: {err}")))?
.map_err(|err| invalid_request(format!("invalid turn context override: {err}")))?;
let update_id = Uuid::now_v7().to_string();
let turn_context_applied = {
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let mut thread_state = thread_state.lock().await;
thread_state.track_pending_turn_context(update_id.clone())
};
if let Err(err) = thread
.submit_with_id(Submission {
id: update_id.clone(),
op: Op::TurnContext {
turn_context: op_turn_context_overrides(overrides),
},
trace: self.request_trace_context(request_id).await,
})
.await
{
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let mut thread_state = thread_state.lock().await;
thread_state.cancel_pending_turn_context(&update_id);
return Err(internal_error(format!(
"failed to update turn context: {err}"
)));
}
match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied)
.await
{
Ok(Ok(Ok(payload))) => thread_turn_context_from_applied_event(&payload),
Ok(Ok(Err(err))) => return Err(invalid_request(err)),
Ok(Err(_)) => {
return Err(internal_error(
"turn context override waiter was cancelled".to_string(),
));
}
Err(_) => {
return Err(internal_error(
"timed out waiting for turn context overrides to apply".to_string(),
));
}
}
} else {
before_snapshot
before_turn_context.clone()
};
let after_turn_context = thread_turn_context_from_snapshot(&after_snapshot);
self.maybe_emit_turn_context_updated(
&params.thread_id,
&before_turn_context,
@@ -1303,6 +1368,32 @@ fn thread_turn_context_from_snapshot(config_snapshot: &ThreadConfigSnapshot) ->
}
}
fn thread_turn_context_from_applied_event(
event: &codex_protocol::protocol::TurnContextAppliedEvent,
) -> ThreadTurnContext {
let turn_context = &event.turn_context;
ThreadTurnContext {
model: turn_context.model.clone(),
model_provider: turn_context.model_provider_id.clone(),
service_tier: turn_context.service_tier.clone(),
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy.into(),
approvals_reviewer: turn_context.approvals_reviewer.into(),
sandbox_policy: thread_response_sandbox_policy(
&turn_context.permission_profile,
turn_context.cwd.as_path(),
),
permission_profile: turn_context.permission_profile.clone().into(),
active_permission_profile: thread_response_active_permission_profile(
turn_context.active_permission_profile.clone(),
),
effort: turn_context.reasoning_effort,
summary: turn_context.reasoning_summary,
personality: turn_context.personality,
collaboration_mode: turn_context.collaboration_mode.clone(),
}
}
fn xcode_26_4_mcp_elicitations_auto_deny(
client_name: Option<&str>,
client_version: Option<&str>,

View File

@@ -11,6 +11,7 @@ use codex_file_watcher::WatchRegistration;
use codex_protocol::ThreadId;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TurnContextAppliedEvent;
use codex_rollout::state_db::StateDbHandle;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::collections::HashMap;
@@ -24,6 +25,7 @@ use tokio::sync::watch;
use tracing::error;
type PendingInterruptQueue = Vec<ConnectionRequestId>;
type TurnContextAck = Result<TurnContextAppliedEvent, String>;
pub(crate) struct PendingThreadResumeRequest {
pub(crate) request_id: ConnectionRequestId,
@@ -78,7 +80,7 @@ pub(crate) struct ThreadState {
pub(crate) listener_generation: u64,
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
pending_turn_started_waiters: HashMap<String, Vec<oneshot::Sender<()>>>,
pending_turn_context_waiters: HashMap<String, Vec<oneshot::Sender<TurnContextAck>>>,
listener_thread: Option<Weak<CodexThread>>,
watch_registration: WatchRegistration,
}
@@ -113,7 +115,7 @@ impl ThreadState {
let _ = cancel_tx.send(());
}
self.listener_command_tx = None;
self.pending_turn_started_waiters.clear();
self.pending_turn_context_waiters.clear();
self.current_turn_history.reset();
self.listener_thread = None;
self.watch_registration = WatchRegistration::default();
@@ -133,21 +135,20 @@ impl ThreadState {
self.current_turn_history.active_turn_snapshot()
}
pub(crate) fn turn_started_receiver(&mut self, turn_id: &str) -> Option<oneshot::Receiver<()>> {
if self
.active_turn_snapshot()
.is_some_and(|turn| turn.id == turn_id)
|| self.last_terminal_turn_id.as_deref() == Some(turn_id)
{
return None;
}
pub(crate) fn track_pending_turn_context(
&mut self,
submission_id: String,
) -> oneshot::Receiver<TurnContextAck> {
let (tx, rx) = oneshot::channel();
self.pending_turn_started_waiters
.entry(turn_id.to_string())
self.pending_turn_context_waiters
.entry(submission_id)
.or_default()
.push(tx);
Some(rx)
rx
}
pub(crate) fn cancel_pending_turn_context(&mut self, submission_id: &str) {
self.pending_turn_context_waiters.remove(submission_id);
}
pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) {
@@ -155,22 +156,24 @@ impl ThreadState {
self.turn_summary.started_at = payload.started_at;
}
self.current_turn_history.handle_event(event);
if let EventMsg::TurnStarted(payload) = event {
self.notify_turn_started(&payload.turn_id);
if let EventMsg::TurnContextApplied(payload) = event {
self.notify_turn_context_applied(event_turn_id, Ok(payload.clone()));
}
if let EventMsg::Error(error) = event {
self.notify_turn_context_applied(event_turn_id, Err(error.message.clone()));
}
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
&& !self.current_turn_history.has_active_turn()
{
self.last_terminal_turn_id = Some(event_turn_id.to_string());
self.current_turn_history.reset();
self.notify_turn_started(event_turn_id);
}
}
fn notify_turn_started(&mut self, turn_id: &str) {
if let Some(waiters) = self.pending_turn_started_waiters.remove(turn_id) {
fn notify_turn_context_applied(&mut self, submission_id: &str, result: TurnContextAck) {
if let Some(waiters) = self.pending_turn_context_waiters.remove(submission_id) {
for waiter in waiters {
let _ = waiter.send(());
let _ = waiter.send(result.clone());
}
}
}

View File

@@ -340,6 +340,73 @@ async fn thread_turn_context_update_after_turn_start_preserves_newer_update() ->
.await
}
#[tokio::test]
async fn queued_updates_keep_each_turn_context_notification_snapshot() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(vec![
create_final_assistant_message_sse_response("Done")?,
])
.await;
let codex_home = TempDir::new()?;
write_config(&codex_home, &server.uri())?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let ThreadStartResponse { thread, .. } = start_thread(&mut mcp).await?;
let turn_request_id = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id.clone(),
input: vec![V2UserInput::Text {
text: "Hello".to_string(),
text_elements: Vec::new(),
}],
model: Some("gpt-5.2".to_string()),
effort: Some(ReasoningEffort::Low),
..Default::default()
})
.await?;
let update_request_id = mcp
.send_thread_turn_context_update_request(ThreadTurnContextUpdateParams {
thread_id: thread.id,
model: Some("gpt-5.4".to_string()),
effort: Some(Some(ReasoningEffort::High)),
..Default::default()
})
.await?;
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_request_id)),
)
.await??;
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(update_request_id)),
)
.await??;
let notifications = [
read_turn_context_updated(&mut mcp).await?,
read_turn_context_updated(&mut mcp).await?,
];
assert!(notifications.iter().any(|notification| {
notification.turn_context.model == "gpt-5.2"
&& notification.turn_context.effort == Some(ReasoningEffort::Low)
}));
assert!(notifications.iter().any(|notification| {
notification.turn_context.model == "gpt-5.4"
&& notification.turn_context.effort == Some(ReasoningEffort::High)
}));
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
Ok(())
}
#[tokio::test]
async fn thread_turn_context_update_after_no_op_turn_start_override_preserves_newer_update()
-> Result<()> {