From 8c4e73d322f489cf677198cd9714db1b27400039 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Fri, 15 May 2026 18:05:01 -0700 Subject: [PATCH] Use queued turn context updates in app server --- codex-rs/app-server/src/request_processors.rs | 2 + .../src/request_processors/turn_processor.rs | 219 +++++++++++++----- codex-rs/app-server/src/thread_state.rs | 43 ++-- .../tests/suite/v2/thread_turn_context.rs | 67 ++++++ 4 files changed, 247 insertions(+), 84 deletions(-) diff --git a/codex-rs/app-server/src/request_processors.rs b/codex-rs/app-server/src/request_processors.rs index 376dfd329b..5e6cc42e02 100644 --- a/codex-rs/app-server/src/request_processors.rs +++ b/codex-rs/app-server/src/request_processors.rs @@ -383,6 +383,8 @@ use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::SessionConfiguredEvent; #[cfg(test)] use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::Submission; +use codex_protocol::protocol::TurnContextOverrides; use codex_protocol::protocol::TurnEnvironmentSelection; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use codex_protocol::protocol::W3cTraceContext; diff --git a/codex-rs/app-server/src/request_processors/turn_processor.rs b/codex-rs/app-server/src/request_processors/turn_processor.rs index 2ae905cc39..19873b6c28 100644 --- a/codex-rs/app-server/src/request_processors/turn_processor.rs +++ b/codex-rs/app-server/src/request_processors/turn_processor.rs @@ -47,7 +47,25 @@ impl TurnContextOverrideRequest { } } -const TURN_STARTED_ACK_TIMEOUT: Duration = Duration::from_secs(5); +fn op_turn_context_overrides(overrides: CodexThreadTurnContextOverrides) -> TurnContextOverrides { + TurnContextOverrides { + cwd: overrides.cwd, + approval_policy: overrides.approval_policy, + approvals_reviewer: overrides.approvals_reviewer, + sandbox_policy: overrides.sandbox_policy, + permission_profile: overrides.permission_profile, + active_permission_profile: overrides.active_permission_profile, + windows_sandbox_level: overrides.windows_sandbox_level, + model: overrides.model, + effort: overrides.effort, + summary: overrides.summary, + service_tier: overrides.service_tier, + collaboration_mode: overrides.collaboration_mode, + personality: overrides.personality, + } +} + +const TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT: Duration = Duration::from_secs(5); impl TurnRequestProcessor { #[allow(clippy::too_many_arguments)] @@ -538,19 +556,7 @@ impl TurnRequestProcessor { environments: environment_selections, final_output_json_schema: params.output_schema, responsesapi_client_metadata: params.responsesapi_client_metadata, - cwd: overrides.cwd, - approval_policy: overrides.approval_policy, - approvals_reviewer: overrides.approvals_reviewer, - sandbox_policy: overrides.sandbox_policy, - permission_profile: overrides.permission_profile, - active_permission_profile: overrides.active_permission_profile, - windows_sandbox_level: overrides.windows_sandbox_level, - model: overrides.model, - effort: overrides.effort, - summary: overrides.summary, - service_tier: overrides.service_tier, - collaboration_mode: overrides.collaboration_mode, - personality: overrides.personality, + turn_context: op_turn_context_overrides(overrides), } } else { Op::UserInput { @@ -560,46 +566,69 @@ impl TurnRequestProcessor { responsesapi_client_metadata: params.responsesapi_client_metadata, } }; - let turn_id = self - .submit_core_op(&request_id, thread.as_ref(), turn_op) - .await - .map_err(|err| { - let error = internal_error(format!("failed to start turn: {err}")); - self.track_error_response(&request_id, &error, /*error_type*/ None); - error - })?; - - if has_turn_context_overrides { - // The queued UserInputWithTurnContext owns the sticky context - // mutation. Wait for core to start processing that turn before - // reporting the effective state, otherwise a later direct update - // can appear to win and then be overwritten by this turn. + let turn_id = Uuid::now_v7().to_string(); + let turn_context_applied = if has_turn_context_overrides { let thread_state = self.thread_state_manager.thread_state(thread_id).await; - let turn_started = { + Some({ let mut thread_state = thread_state.lock().await; - thread_state.turn_started_receiver(&turn_id) - }; - if let Some(turn_started) = turn_started { - // Bound how long the RPC waits for the core turn-start acknowledgement. - tokio::time::timeout(TURN_STARTED_ACK_TIMEOUT, turn_started) - .await - .map_err(|_| { - internal_error( - "timed out waiting for turn context overrides to apply".to_string(), - ) - })? - .map_err(|_| { - internal_error("turn context override waiter was cancelled".to_string()) - })?; + thread_state.track_pending_turn_context(turn_id.clone()) + }) + } else { + None + }; + if let Err(err) = thread + .submit_with_id(Submission { + id: turn_id.clone(), + op: turn_op, + trace: self.request_trace_context(&request_id).await, + }) + .await + { + if has_turn_context_overrides { + let thread_state = self.thread_state_manager.thread_state(thread_id).await; + let mut thread_state = thread_state.lock().await; + thread_state.cancel_pending_turn_context(&turn_id); } - let after_turn_context = - thread_turn_context_from_snapshot(&thread.config_snapshot().await); - self.maybe_emit_turn_context_updated( - ¶ms.thread_id, - &before_turn_context, - after_turn_context, - ) - .await; + let error = internal_error(format!("failed to start turn: {err}")); + self.track_error_response(&request_id, &error, /*error_type*/ None); + return Err(error); + } + + if let Some(turn_context_applied) = turn_context_applied { + let processor = self.clone(); + let api_thread_id = params.thread_id.clone(); + let tracked_turn_id = turn_id.clone(); + tokio::spawn(async move { + match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied) + .await + { + Ok(Ok(Ok(payload))) => { + let after_turn_context = thread_turn_context_from_applied_event(&payload); + processor + .maybe_emit_turn_context_updated( + &api_thread_id, + &before_turn_context, + after_turn_context, + ) + .await; + } + Ok(Ok(Err(err))) => { + tracing::warn!( + "failed to apply turn context overrides for turn {tracked_turn_id}: {err}" + ); + } + Ok(Err(_)) => { + tracing::warn!( + "turn context override acknowledgement was cancelled for turn {tracked_turn_id}" + ); + } + Err(_) => { + tracing::warn!( + "timed out waiting for turn context overrides to apply for turn {tracked_turn_id}" + ); + } + } + }); } if turn_has_input { @@ -636,12 +665,12 @@ impl TurnRequestProcessor { request_id: &ConnectionRequestId, params: ThreadTurnContextUpdateParams, ) -> Result { - let (_, thread) = self - .load_thread(¶ms.thread_id) - .await - .inspect_err(|error| { - self.track_error_response(request_id, error, /*error_type*/ None); - })?; + let (thread_id, thread) = + self.load_thread(¶ms.thread_id) + .await + .inspect_err(|error| { + self.track_error_response(request_id, error, /*error_type*/ None); + })?; let before_snapshot = thread.config_snapshot().await; let before_turn_context = thread_turn_context_from_snapshot(&before_snapshot); let resolved_overrides = self @@ -663,17 +692,53 @@ impl TurnRequestProcessor { ) .await?; - let after_snapshot = if let Some(overrides) = resolved_overrides { - // There is no queued turn to order against here, so applying - // directly gives the caller a synchronized response snapshot. + let after_turn_context = if let Some(overrides) = resolved_overrides { thread - .update_turn_context_overrides(overrides) + .preview_turn_context_overrides(overrides.clone()) .await - .map_err(|err| invalid_request(format!("invalid turn context override: {err}")))? + .map_err(|err| invalid_request(format!("invalid turn context override: {err}")))?; + let update_id = Uuid::now_v7().to_string(); + let turn_context_applied = { + let thread_state = self.thread_state_manager.thread_state(thread_id).await; + let mut thread_state = thread_state.lock().await; + thread_state.track_pending_turn_context(update_id.clone()) + }; + if let Err(err) = thread + .submit_with_id(Submission { + id: update_id.clone(), + op: Op::TurnContext { + turn_context: op_turn_context_overrides(overrides), + }, + trace: self.request_trace_context(request_id).await, + }) + .await + { + let thread_state = self.thread_state_manager.thread_state(thread_id).await; + let mut thread_state = thread_state.lock().await; + thread_state.cancel_pending_turn_context(&update_id); + return Err(internal_error(format!( + "failed to update turn context: {err}" + ))); + } + match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied) + .await + { + Ok(Ok(Ok(payload))) => thread_turn_context_from_applied_event(&payload), + Ok(Ok(Err(err))) => return Err(invalid_request(err)), + Ok(Err(_)) => { + return Err(internal_error( + "turn context override waiter was cancelled".to_string(), + )); + } + Err(_) => { + return Err(internal_error( + "timed out waiting for turn context overrides to apply".to_string(), + )); + } + } } else { - before_snapshot + before_turn_context.clone() }; - let after_turn_context = thread_turn_context_from_snapshot(&after_snapshot); self.maybe_emit_turn_context_updated( ¶ms.thread_id, &before_turn_context, @@ -1303,6 +1368,32 @@ fn thread_turn_context_from_snapshot(config_snapshot: &ThreadConfigSnapshot) -> } } +fn thread_turn_context_from_applied_event( + event: &codex_protocol::protocol::TurnContextAppliedEvent, +) -> ThreadTurnContext { + let turn_context = &event.turn_context; + ThreadTurnContext { + model: turn_context.model.clone(), + model_provider: turn_context.model_provider_id.clone(), + service_tier: turn_context.service_tier.clone(), + cwd: turn_context.cwd.clone(), + approval_policy: turn_context.approval_policy.into(), + approvals_reviewer: turn_context.approvals_reviewer.into(), + sandbox_policy: thread_response_sandbox_policy( + &turn_context.permission_profile, + turn_context.cwd.as_path(), + ), + permission_profile: turn_context.permission_profile.clone().into(), + active_permission_profile: thread_response_active_permission_profile( + turn_context.active_permission_profile.clone(), + ), + effort: turn_context.reasoning_effort, + summary: turn_context.reasoning_summary, + personality: turn_context.personality, + collaboration_mode: turn_context.collaboration_mode.clone(), + } +} + fn xcode_26_4_mcp_elicitations_auto_deny( client_name: Option<&str>, client_version: Option<&str>, diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index 179bf73466..39637dcf01 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -11,6 +11,7 @@ use codex_file_watcher::WatchRegistration; use codex_protocol::ThreadId; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::TurnContextAppliedEvent; use codex_rollout::state_db::StateDbHandle; use codex_utils_absolute_path::AbsolutePathBuf; use std::collections::HashMap; @@ -24,6 +25,7 @@ use tokio::sync::watch; use tracing::error; type PendingInterruptQueue = Vec; +type TurnContextAck = Result; pub(crate) struct PendingThreadResumeRequest { pub(crate) request_id: ConnectionRequestId, @@ -78,7 +80,7 @@ pub(crate) struct ThreadState { pub(crate) listener_generation: u64, listener_command_tx: Option>, current_turn_history: ThreadHistoryBuilder, - pending_turn_started_waiters: HashMap>>, + pending_turn_context_waiters: HashMap>>, listener_thread: Option>, watch_registration: WatchRegistration, } @@ -113,7 +115,7 @@ impl ThreadState { let _ = cancel_tx.send(()); } self.listener_command_tx = None; - self.pending_turn_started_waiters.clear(); + self.pending_turn_context_waiters.clear(); self.current_turn_history.reset(); self.listener_thread = None; self.watch_registration = WatchRegistration::default(); @@ -133,21 +135,20 @@ impl ThreadState { self.current_turn_history.active_turn_snapshot() } - pub(crate) fn turn_started_receiver(&mut self, turn_id: &str) -> Option> { - if self - .active_turn_snapshot() - .is_some_and(|turn| turn.id == turn_id) - || self.last_terminal_turn_id.as_deref() == Some(turn_id) - { - return None; - } - + pub(crate) fn track_pending_turn_context( + &mut self, + submission_id: String, + ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); - self.pending_turn_started_waiters - .entry(turn_id.to_string()) + self.pending_turn_context_waiters + .entry(submission_id) .or_default() .push(tx); - Some(rx) + rx + } + + pub(crate) fn cancel_pending_turn_context(&mut self, submission_id: &str) { + self.pending_turn_context_waiters.remove(submission_id); } pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) { @@ -155,22 +156,24 @@ impl ThreadState { self.turn_summary.started_at = payload.started_at; } self.current_turn_history.handle_event(event); - if let EventMsg::TurnStarted(payload) = event { - self.notify_turn_started(&payload.turn_id); + if let EventMsg::TurnContextApplied(payload) = event { + self.notify_turn_context_applied(event_turn_id, Ok(payload.clone())); + } + if let EventMsg::Error(error) = event { + self.notify_turn_context_applied(event_turn_id, Err(error.message.clone())); } if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_)) && !self.current_turn_history.has_active_turn() { self.last_terminal_turn_id = Some(event_turn_id.to_string()); self.current_turn_history.reset(); - self.notify_turn_started(event_turn_id); } } - fn notify_turn_started(&mut self, turn_id: &str) { - if let Some(waiters) = self.pending_turn_started_waiters.remove(turn_id) { + fn notify_turn_context_applied(&mut self, submission_id: &str, result: TurnContextAck) { + if let Some(waiters) = self.pending_turn_context_waiters.remove(submission_id) { for waiter in waiters { - let _ = waiter.send(()); + let _ = waiter.send(result.clone()); } } } diff --git a/codex-rs/app-server/tests/suite/v2/thread_turn_context.rs b/codex-rs/app-server/tests/suite/v2/thread_turn_context.rs index 8f1bb998b4..58a27983be 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_turn_context.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_turn_context.rs @@ -340,6 +340,73 @@ async fn thread_turn_context_update_after_turn_start_preserves_newer_update() -> .await } +#[tokio::test] +async fn queued_updates_keep_each_turn_context_notification_snapshot() -> Result<()> { + let server = create_mock_responses_server_sequence_unchecked(vec![ + create_final_assistant_message_sse_response("Done")?, + ]) + .await; + let codex_home = TempDir::new()?; + write_config(&codex_home, &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + let ThreadStartResponse { thread, .. } = start_thread(&mut mcp).await?; + + let turn_request_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "Hello".to_string(), + text_elements: Vec::new(), + }], + model: Some("gpt-5.2".to_string()), + effort: Some(ReasoningEffort::Low), + ..Default::default() + }) + .await?; + let update_request_id = mcp + .send_thread_turn_context_update_request(ThreadTurnContextUpdateParams { + thread_id: thread.id, + model: Some("gpt-5.4".to_string()), + effort: Some(Some(ReasoningEffort::High)), + ..Default::default() + }) + .await?; + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_request_id)), + ) + .await??; + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(update_request_id)), + ) + .await??; + + let notifications = [ + read_turn_context_updated(&mut mcp).await?, + read_turn_context_updated(&mut mcp).await?, + ]; + assert!(notifications.iter().any(|notification| { + notification.turn_context.model == "gpt-5.2" + && notification.turn_context.effort == Some(ReasoningEffort::Low) + })); + assert!(notifications.iter().any(|notification| { + notification.turn_context.model == "gpt-5.4" + && notification.turn_context.effort == Some(ReasoningEffort::High) + })); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + Ok(()) +} + #[tokio::test] async fn thread_turn_context_update_after_no_op_turn_start_override_preserves_newer_update() -> Result<()> {