mirror of
https://github.com/openai/codex.git
synced 2026-05-26 05:55:36 +00:00
Use queued turn context updates in app server
This commit is contained in:
@@ -383,6 +383,8 @@ use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionConfiguredEvent;
|
||||
#[cfg(test)]
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::Submission;
|
||||
use codex_protocol::protocol::TurnContextOverrides;
|
||||
use codex_protocol::protocol::TurnEnvironmentSelection;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
|
||||
@@ -47,7 +47,25 @@ impl TurnContextOverrideRequest {
|
||||
}
|
||||
}
|
||||
|
||||
const TURN_STARTED_ACK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
fn op_turn_context_overrides(overrides: CodexThreadTurnContextOverrides) -> TurnContextOverrides {
|
||||
TurnContextOverrides {
|
||||
cwd: overrides.cwd,
|
||||
approval_policy: overrides.approval_policy,
|
||||
approvals_reviewer: overrides.approvals_reviewer,
|
||||
sandbox_policy: overrides.sandbox_policy,
|
||||
permission_profile: overrides.permission_profile,
|
||||
active_permission_profile: overrides.active_permission_profile,
|
||||
windows_sandbox_level: overrides.windows_sandbox_level,
|
||||
model: overrides.model,
|
||||
effort: overrides.effort,
|
||||
summary: overrides.summary,
|
||||
service_tier: overrides.service_tier,
|
||||
collaboration_mode: overrides.collaboration_mode,
|
||||
personality: overrides.personality,
|
||||
}
|
||||
}
|
||||
|
||||
const TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
impl TurnRequestProcessor {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -538,19 +556,7 @@ impl TurnRequestProcessor {
|
||||
environments: environment_selections,
|
||||
final_output_json_schema: params.output_schema,
|
||||
responsesapi_client_metadata: params.responsesapi_client_metadata,
|
||||
cwd: overrides.cwd,
|
||||
approval_policy: overrides.approval_policy,
|
||||
approvals_reviewer: overrides.approvals_reviewer,
|
||||
sandbox_policy: overrides.sandbox_policy,
|
||||
permission_profile: overrides.permission_profile,
|
||||
active_permission_profile: overrides.active_permission_profile,
|
||||
windows_sandbox_level: overrides.windows_sandbox_level,
|
||||
model: overrides.model,
|
||||
effort: overrides.effort,
|
||||
summary: overrides.summary,
|
||||
service_tier: overrides.service_tier,
|
||||
collaboration_mode: overrides.collaboration_mode,
|
||||
personality: overrides.personality,
|
||||
turn_context: op_turn_context_overrides(overrides),
|
||||
}
|
||||
} else {
|
||||
Op::UserInput {
|
||||
@@ -560,46 +566,69 @@ impl TurnRequestProcessor {
|
||||
responsesapi_client_metadata: params.responsesapi_client_metadata,
|
||||
}
|
||||
};
|
||||
let turn_id = self
|
||||
.submit_core_op(&request_id, thread.as_ref(), turn_op)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
let error = internal_error(format!("failed to start turn: {err}"));
|
||||
self.track_error_response(&request_id, &error, /*error_type*/ None);
|
||||
error
|
||||
})?;
|
||||
|
||||
if has_turn_context_overrides {
|
||||
// The queued UserInputWithTurnContext owns the sticky context
|
||||
// mutation. Wait for core to start processing that turn before
|
||||
// reporting the effective state, otherwise a later direct update
|
||||
// can appear to win and then be overwritten by this turn.
|
||||
let turn_id = Uuid::now_v7().to_string();
|
||||
let turn_context_applied = if has_turn_context_overrides {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let turn_started = {
|
||||
Some({
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.turn_started_receiver(&turn_id)
|
||||
};
|
||||
if let Some(turn_started) = turn_started {
|
||||
// Bound how long the RPC waits for the core turn-start acknowledgement.
|
||||
tokio::time::timeout(TURN_STARTED_ACK_TIMEOUT, turn_started)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
internal_error(
|
||||
"timed out waiting for turn context overrides to apply".to_string(),
|
||||
)
|
||||
})?
|
||||
.map_err(|_| {
|
||||
internal_error("turn context override waiter was cancelled".to_string())
|
||||
})?;
|
||||
thread_state.track_pending_turn_context(turn_id.clone())
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Err(err) = thread
|
||||
.submit_with_id(Submission {
|
||||
id: turn_id.clone(),
|
||||
op: turn_op,
|
||||
trace: self.request_trace_context(&request_id).await,
|
||||
})
|
||||
.await
|
||||
{
|
||||
if has_turn_context_overrides {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.cancel_pending_turn_context(&turn_id);
|
||||
}
|
||||
let after_turn_context =
|
||||
thread_turn_context_from_snapshot(&thread.config_snapshot().await);
|
||||
self.maybe_emit_turn_context_updated(
|
||||
¶ms.thread_id,
|
||||
&before_turn_context,
|
||||
after_turn_context,
|
||||
)
|
||||
.await;
|
||||
let error = internal_error(format!("failed to start turn: {err}"));
|
||||
self.track_error_response(&request_id, &error, /*error_type*/ None);
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
if let Some(turn_context_applied) = turn_context_applied {
|
||||
let processor = self.clone();
|
||||
let api_thread_id = params.thread_id.clone();
|
||||
let tracked_turn_id = turn_id.clone();
|
||||
tokio::spawn(async move {
|
||||
match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(Ok(payload))) => {
|
||||
let after_turn_context = thread_turn_context_from_applied_event(&payload);
|
||||
processor
|
||||
.maybe_emit_turn_context_updated(
|
||||
&api_thread_id,
|
||||
&before_turn_context,
|
||||
after_turn_context,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Ok(Ok(Err(err))) => {
|
||||
tracing::warn!(
|
||||
"failed to apply turn context overrides for turn {tracked_turn_id}: {err}"
|
||||
);
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
tracing::warn!(
|
||||
"turn context override acknowledgement was cancelled for turn {tracked_turn_id}"
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
"timed out waiting for turn context overrides to apply for turn {tracked_turn_id}"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if turn_has_input {
|
||||
@@ -636,12 +665,12 @@ impl TurnRequestProcessor {
|
||||
request_id: &ConnectionRequestId,
|
||||
params: ThreadTurnContextUpdateParams,
|
||||
) -> Result<ThreadTurnContextUpdateResponse, JSONRPCErrorError> {
|
||||
let (_, thread) = self
|
||||
.load_thread(¶ms.thread_id)
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
self.track_error_response(request_id, error, /*error_type*/ None);
|
||||
})?;
|
||||
let (thread_id, thread) =
|
||||
self.load_thread(¶ms.thread_id)
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
self.track_error_response(request_id, error, /*error_type*/ None);
|
||||
})?;
|
||||
let before_snapshot = thread.config_snapshot().await;
|
||||
let before_turn_context = thread_turn_context_from_snapshot(&before_snapshot);
|
||||
let resolved_overrides = self
|
||||
@@ -663,17 +692,53 @@ impl TurnRequestProcessor {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let after_snapshot = if let Some(overrides) = resolved_overrides {
|
||||
// There is no queued turn to order against here, so applying
|
||||
// directly gives the caller a synchronized response snapshot.
|
||||
let after_turn_context = if let Some(overrides) = resolved_overrides {
|
||||
thread
|
||||
.update_turn_context_overrides(overrides)
|
||||
.preview_turn_context_overrides(overrides.clone())
|
||||
.await
|
||||
.map_err(|err| invalid_request(format!("invalid turn context override: {err}")))?
|
||||
.map_err(|err| invalid_request(format!("invalid turn context override: {err}")))?;
|
||||
let update_id = Uuid::now_v7().to_string();
|
||||
let turn_context_applied = {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.track_pending_turn_context(update_id.clone())
|
||||
};
|
||||
if let Err(err) = thread
|
||||
.submit_with_id(Submission {
|
||||
id: update_id.clone(),
|
||||
op: Op::TurnContext {
|
||||
turn_context: op_turn_context_overrides(overrides),
|
||||
},
|
||||
trace: self.request_trace_context(request_id).await,
|
||||
})
|
||||
.await
|
||||
{
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.cancel_pending_turn_context(&update_id);
|
||||
return Err(internal_error(format!(
|
||||
"failed to update turn context: {err}"
|
||||
)));
|
||||
}
|
||||
match tokio::time::timeout(TURN_CONTEXT_OVERRIDE_ACK_TIMEOUT, turn_context_applied)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(Ok(payload))) => thread_turn_context_from_applied_event(&payload),
|
||||
Ok(Ok(Err(err))) => return Err(invalid_request(err)),
|
||||
Ok(Err(_)) => {
|
||||
return Err(internal_error(
|
||||
"turn context override waiter was cancelled".to_string(),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(internal_error(
|
||||
"timed out waiting for turn context overrides to apply".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
before_snapshot
|
||||
before_turn_context.clone()
|
||||
};
|
||||
let after_turn_context = thread_turn_context_from_snapshot(&after_snapshot);
|
||||
self.maybe_emit_turn_context_updated(
|
||||
¶ms.thread_id,
|
||||
&before_turn_context,
|
||||
@@ -1303,6 +1368,32 @@ fn thread_turn_context_from_snapshot(config_snapshot: &ThreadConfigSnapshot) ->
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_turn_context_from_applied_event(
|
||||
event: &codex_protocol::protocol::TurnContextAppliedEvent,
|
||||
) -> ThreadTurnContext {
|
||||
let turn_context = &event.turn_context;
|
||||
ThreadTurnContext {
|
||||
model: turn_context.model.clone(),
|
||||
model_provider: turn_context.model_provider_id.clone(),
|
||||
service_tier: turn_context.service_tier.clone(),
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy.into(),
|
||||
approvals_reviewer: turn_context.approvals_reviewer.into(),
|
||||
sandbox_policy: thread_response_sandbox_policy(
|
||||
&turn_context.permission_profile,
|
||||
turn_context.cwd.as_path(),
|
||||
),
|
||||
permission_profile: turn_context.permission_profile.clone().into(),
|
||||
active_permission_profile: thread_response_active_permission_profile(
|
||||
turn_context.active_permission_profile.clone(),
|
||||
),
|
||||
effort: turn_context.reasoning_effort,
|
||||
summary: turn_context.reasoning_summary,
|
||||
personality: turn_context.personality,
|
||||
collaboration_mode: turn_context.collaboration_mode.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn xcode_26_4_mcp_elicitations_auto_deny(
|
||||
client_name: Option<&str>,
|
||||
client_version: Option<&str>,
|
||||
|
||||
@@ -11,6 +11,7 @@ use codex_file_watcher::WatchRegistration;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::TurnContextAppliedEvent;
|
||||
use codex_rollout::state_db::StateDbHandle;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::collections::HashMap;
|
||||
@@ -24,6 +25,7 @@ use tokio::sync::watch;
|
||||
use tracing::error;
|
||||
|
||||
type PendingInterruptQueue = Vec<ConnectionRequestId>;
|
||||
type TurnContextAck = Result<TurnContextAppliedEvent, String>;
|
||||
|
||||
pub(crate) struct PendingThreadResumeRequest {
|
||||
pub(crate) request_id: ConnectionRequestId,
|
||||
@@ -78,7 +80,7 @@ pub(crate) struct ThreadState {
|
||||
pub(crate) listener_generation: u64,
|
||||
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
|
||||
current_turn_history: ThreadHistoryBuilder,
|
||||
pending_turn_started_waiters: HashMap<String, Vec<oneshot::Sender<()>>>,
|
||||
pending_turn_context_waiters: HashMap<String, Vec<oneshot::Sender<TurnContextAck>>>,
|
||||
listener_thread: Option<Weak<CodexThread>>,
|
||||
watch_registration: WatchRegistration,
|
||||
}
|
||||
@@ -113,7 +115,7 @@ impl ThreadState {
|
||||
let _ = cancel_tx.send(());
|
||||
}
|
||||
self.listener_command_tx = None;
|
||||
self.pending_turn_started_waiters.clear();
|
||||
self.pending_turn_context_waiters.clear();
|
||||
self.current_turn_history.reset();
|
||||
self.listener_thread = None;
|
||||
self.watch_registration = WatchRegistration::default();
|
||||
@@ -133,21 +135,20 @@ impl ThreadState {
|
||||
self.current_turn_history.active_turn_snapshot()
|
||||
}
|
||||
|
||||
pub(crate) fn turn_started_receiver(&mut self, turn_id: &str) -> Option<oneshot::Receiver<()>> {
|
||||
if self
|
||||
.active_turn_snapshot()
|
||||
.is_some_and(|turn| turn.id == turn_id)
|
||||
|| self.last_terminal_turn_id.as_deref() == Some(turn_id)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
pub(crate) fn track_pending_turn_context(
|
||||
&mut self,
|
||||
submission_id: String,
|
||||
) -> oneshot::Receiver<TurnContextAck> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_turn_started_waiters
|
||||
.entry(turn_id.to_string())
|
||||
self.pending_turn_context_waiters
|
||||
.entry(submission_id)
|
||||
.or_default()
|
||||
.push(tx);
|
||||
Some(rx)
|
||||
rx
|
||||
}
|
||||
|
||||
pub(crate) fn cancel_pending_turn_context(&mut self, submission_id: &str) {
|
||||
self.pending_turn_context_waiters.remove(submission_id);
|
||||
}
|
||||
|
||||
pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) {
|
||||
@@ -155,22 +156,24 @@ impl ThreadState {
|
||||
self.turn_summary.started_at = payload.started_at;
|
||||
}
|
||||
self.current_turn_history.handle_event(event);
|
||||
if let EventMsg::TurnStarted(payload) = event {
|
||||
self.notify_turn_started(&payload.turn_id);
|
||||
if let EventMsg::TurnContextApplied(payload) = event {
|
||||
self.notify_turn_context_applied(event_turn_id, Ok(payload.clone()));
|
||||
}
|
||||
if let EventMsg::Error(error) = event {
|
||||
self.notify_turn_context_applied(event_turn_id, Err(error.message.clone()));
|
||||
}
|
||||
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
|
||||
&& !self.current_turn_history.has_active_turn()
|
||||
{
|
||||
self.last_terminal_turn_id = Some(event_turn_id.to_string());
|
||||
self.current_turn_history.reset();
|
||||
self.notify_turn_started(event_turn_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn notify_turn_started(&mut self, turn_id: &str) {
|
||||
if let Some(waiters) = self.pending_turn_started_waiters.remove(turn_id) {
|
||||
fn notify_turn_context_applied(&mut self, submission_id: &str, result: TurnContextAck) {
|
||||
if let Some(waiters) = self.pending_turn_context_waiters.remove(submission_id) {
|
||||
for waiter in waiters {
|
||||
let _ = waiter.send(());
|
||||
let _ = waiter.send(result.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,6 +340,73 @@ async fn thread_turn_context_update_after_turn_start_preserves_newer_update() ->
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn queued_updates_keep_each_turn_context_notification_snapshot() -> Result<()> {
|
||||
let server = create_mock_responses_server_sequence_unchecked(vec![
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
])
|
||||
.await;
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(&codex_home, &server.uri())?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
let ThreadStartResponse { thread, .. } = start_thread(&mut mcp).await?;
|
||||
|
||||
let turn_request_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: "Hello".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
model: Some("gpt-5.2".to_string()),
|
||||
effort: Some(ReasoningEffort::Low),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let update_request_id = mcp
|
||||
.send_thread_turn_context_update_request(ThreadTurnContextUpdateParams {
|
||||
thread_id: thread.id,
|
||||
model: Some("gpt-5.4".to_string()),
|
||||
effort: Some(Some(ReasoningEffort::High)),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_request_id)),
|
||||
)
|
||||
.await??;
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(update_request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let notifications = [
|
||||
read_turn_context_updated(&mut mcp).await?,
|
||||
read_turn_context_updated(&mut mcp).await?,
|
||||
];
|
||||
assert!(notifications.iter().any(|notification| {
|
||||
notification.turn_context.model == "gpt-5.2"
|
||||
&& notification.turn_context.effort == Some(ReasoningEffort::Low)
|
||||
}));
|
||||
assert!(notifications.iter().any(|notification| {
|
||||
notification.turn_context.model == "gpt-5.4"
|
||||
&& notification.turn_context.effort == Some(ReasoningEffort::High)
|
||||
}));
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_turn_context_update_after_no_op_turn_start_override_preserves_newer_update()
|
||||
-> Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user