core: emit interrupted abort after rollout flush

This commit is contained in:
Charles Cunningham
2026-02-20 00:11:48 -08:00
parent f06b6f238f
commit 3519b96a1e
9 changed files with 78 additions and 58 deletions

View File

@@ -269,6 +269,7 @@ use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use crate::tasks::TaskRunOutput;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolDispatchOutput;
@@ -4811,9 +4812,9 @@ pub(crate) async fn run_turn(
input: Vec<UserInput>,
prewarmed_client_session: Option<ModelClientSession>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
if input.is_empty() {
return None;
return TaskRunOutput::default();
}
let model_info = turn_context.model_info.clone();
@@ -4834,7 +4835,7 @@ pub(crate) async fn run_turn(
.is_err()
{
error!("Failed to run pre-sampling compact");
return None;
return TaskRunOutput::default();
}
let skills_outcome = Some(turn_context.turn_skills.outcome.as_ref());
@@ -4853,7 +4854,7 @@ pub(crate) async fn run_turn(
.await
{
Ok(mcp_tools) => mcp_tools,
Err(_) => return None,
Err(_) => return TaskRunOutput::default(),
};
connectors::with_app_enabled_state(
connectors::accessible_connectors_from_mcp_tools(&mcp_tools),
@@ -4967,6 +4968,7 @@ pub(crate) async fn run_turn(
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut server_model_warning_emitted_for_turn = false;
let mut abort_reason = None;
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
// one instance across retries within this turn.
@@ -5044,11 +5046,9 @@ pub(crate) async fn run_turn(
cancellation_token.cancel();
sess.finish_turn_without_completion_event(turn_context.as_ref())
.await;
sess.emit_turn_aborted_without_rollout_flush(
&turn_context,
TurnAbortReason::Interrupted,
)
.await;
// Defer TurnAborted emission until run_turn unwinds so the caller can
// flush the rollout marker without blocking the in-flight tool loop.
abort_reason = Some(TurnAbortReason::Interrupted);
break;
}
let total_usage_tokens = sess.get_total_token_usage().await;
@@ -5077,7 +5077,7 @@ pub(crate) async fn run_turn(
.await
.is_err()
{
return None;
return TaskRunOutput::default();
}
continue;
}
@@ -5140,7 +5140,7 @@ pub(crate) async fn run_turn(
}),
)
.await;
return None;
return TaskRunOutput::default();
}
break;
}
@@ -5176,7 +5176,10 @@ pub(crate) async fn run_turn(
}
}
last_agent_message
TaskRunOutput {
last_agent_message,
abort_reason,
}
}
async fn run_pre_sampling_compact(
@@ -9241,10 +9244,10 @@ mod tests {
_ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
if self.listen_to_cancellation_token {
cancellation_token.cancelled().await;
return None;
return TaskRunOutput::default();
}
loop {
sleep(Duration::from_secs(60)).await;

View File

@@ -56,8 +56,11 @@ impl ActiveTurn {
self.tasks.insert(sub_id, task);
}
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
self.tasks.swap_remove(sub_id);
pub(crate) fn remove_task(&mut self, sub_id: &str) -> Option<RunningTask> {
self.tasks.swap_remove(sub_id)
}
pub(crate) fn is_empty(&self) -> bool {
self.tasks.is_empty()
}

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use super::SessionTask;
use super::SessionTaskContext;
use super::TaskRunOutput;
use crate::codex::TurnContext;
use crate::state::TaskKind;
use async_trait::async_trait;
@@ -23,7 +24,7 @@ impl SessionTask for CompactTask {
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
_cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
let session = session.clone_session();
let _ = if crate::compact::should_use_remote_compact_task(&ctx.provider) {
let _ = session.services.otel_manager.counter(
@@ -39,7 +40,11 @@ impl SessionTask for CompactTask {
&[("type", "local")],
);
crate::compact::run_compact_task(session.clone(), ctx, input).await
}
TaskRunOutput::default()
};
None
TaskRunOutput::default()
}
}

View File

@@ -4,6 +4,7 @@ use crate::protocol::WarningEvent;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use crate::tasks::TaskRunOutput;
use async_trait::async_trait;
use codex_git::CreateGhostCommitOptions;
use codex_git::GhostSnapshotReport;
@@ -38,7 +39,7 @@ impl SessionTask for GhostSnapshotTask {
ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
tokio::task::spawn(async move {
let token = self.token;
let warnings_enabled = !ctx.ghost_snapshot.disable_warnings;
@@ -152,7 +153,7 @@ impl SessionTask for GhostSnapshotTask {
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
}
});
None
TaskRunOutput::default()
}
}

View File

@@ -72,6 +72,12 @@ impl SessionTaskContext {
}
}
#[derive(Default)]
pub(crate) struct TaskRunOutput {
pub(crate) last_agent_message: Option<String>,
pub(crate) abort_reason: Option<TurnAbortReason>,
}
/// Async task that drives a [`Session`] turn.
///
/// Implementations encapsulate a specific Codex workflow (regular chat,
@@ -100,7 +106,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String>;
) -> TaskRunOutput;
/// Gives the task a chance to perform cleanup after an abort.
///
@@ -138,7 +144,11 @@ impl Session {
tokio::spawn(
async move {
let ctx_for_finish = Arc::clone(&ctx);
let last_agent_message = task_for_run
let model_slug = ctx_for_finish.model_info.slug.clone();
let TaskRunOutput {
last_agent_message,
abort_reason,
} = task_for_run
.run(
Arc::clone(&session_ctx),
ctx,
@@ -148,7 +158,15 @@ impl Session {
.await;
let sess = session_ctx.clone_session();
sess.flush_rollout().await;
if !task_cancellation_token.is_cancelled() {
// Update previous model before TurnComplete is emitted so
// immediately following turns observe the correct switch state.
sess.set_previous_model(Some(model_slug)).await;
if let Some(reason) = abort_reason {
// Emit TurnAborted from the spawn site so the rollout flush above
// makes the interrupt marker durable before clients observe the event.
sess.emit_turn_aborted(ctx_for_finish.as_ref(), reason)
.await;
} else if !task_cancellation_token.is_cancelled() {
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
sess.on_task_finished(Arc::clone(&ctx_for_finish), last_agent_message)
.await;
@@ -231,18 +249,25 @@ impl Session {
pub(crate) async fn finish_turn_without_completion_event(&self, turn_context: &TurnContext) {
let mut active = self.active_turn.lock().await;
let mut pending_input = Vec::<ResponseInputItem>::new();
let mut removed_handle: Option<Arc<AbortOnDropHandle<()>>> = None;
let mut should_clear_active_turn = false;
if let Some(at) = active.as_mut()
&& at.remove_task(&turn_context.sub_id)
&& let Some(task) = at.remove_task(&turn_context.sub_id)
{
removed_handle = Some(task.handle);
let mut ts = at.turn_state.lock().await;
pending_input = ts.take_pending_input();
should_clear_active_turn = true;
should_clear_active_turn = at.is_empty();
}
if should_clear_active_turn {
*active = None;
}
drop(active);
if let Some(handle) = removed_handle
&& let Ok(handle) = Arc::try_unwrap(handle)
{
drop(handle.detach());
}
if !pending_input.is_empty() {
let pending_response_items = pending_input
.into_iter()
@@ -288,25 +313,6 @@ impl Session {
self: &Arc<Self>,
turn_context: &TurnContext,
reason: TurnAbortReason,
) {
self.emit_turn_aborted_inner(turn_context, reason, true)
.await;
}
pub(crate) async fn emit_turn_aborted_without_rollout_flush(
self: &Arc<Self>,
turn_context: &TurnContext,
reason: TurnAbortReason,
) {
self.emit_turn_aborted_inner(turn_context, reason, false)
.await;
}
async fn emit_turn_aborted_inner(
self: &Arc<Self>,
turn_context: &TurnContext,
reason: TurnAbortReason,
flush_rollout_before_event: bool,
) {
if reason == TurnAbortReason::Interrupted {
let marker = ResponseItem::Message {
@@ -324,11 +330,9 @@ impl Session {
.await;
self.persist_rollout_items(&[RolloutItem::ResponseItem(marker)])
.await;
if flush_rollout_before_event {
// Ensure the marker is durably visible before emitting TurnAborted: some clients
// synchronously re-read the rollout on receipt of the abort event.
self.flush_rollout().await;
}
// Ensure the marker is durably visible before emitting TurnAborted: some clients
// synchronously re-read the rollout on receipt of the abort event.
self.flush_rollout().await;
}
let event = EventMsg::TurnAborted(TurnAbortedEvent {

View File

@@ -16,6 +16,7 @@ use tracing::trace_span;
use super::SessionTask;
use super::SessionTaskContext;
use super::TaskRunOutput;
pub(crate) struct RegularTask {
prewarmed_session: Mutex<Option<ModelClientSession>>,
@@ -73,7 +74,7 @@ impl SessionTask for RegularTask {
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
let sess = session.clone_session();
let run_turn_span = trace_span!("run_turn");
sess.set_server_reasoning_included(false).await;

View File

@@ -27,6 +27,7 @@ use codex_protocol::user_input::UserInput;
use super::SessionTask;
use super::SessionTaskContext;
use super::TaskRunOutput;
#[derive(Clone, Copy)]
pub(crate) struct ReviewTask;
@@ -49,7 +50,7 @@ impl SessionTask for ReviewTask {
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
let _ = session
.session
.services
@@ -71,7 +72,7 @@ impl SessionTask for ReviewTask {
if !cancellation_token.is_cancelled() {
exit_review_mode(session.clone_session(), output.clone(), ctx.clone()).await;
}
None
TaskRunOutput::default()
}
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {

View File

@@ -7,6 +7,7 @@ use crate::protocol::UndoStartedEvent;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use crate::tasks::TaskRunOutput;
use async_trait::async_trait;
use codex_git::RestoreGhostCommitOptions;
use codex_git::restore_ghost_commit_with_options;
@@ -37,7 +38,7 @@ impl SessionTask for UndoTask {
ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
let _ = session
.session
.services
@@ -61,7 +62,7 @@ impl SessionTask for UndoTask {
}),
)
.await;
return None;
return TaskRunOutput::default();
}
let history = sess.clone_history().await;
@@ -86,7 +87,7 @@ impl SessionTask for UndoTask {
completed.message = Some("No ghost snapshot available to undo.".to_string());
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
.await;
return None;
return TaskRunOutput::default();
};
let commit_id = ghost_commit.id().to_string();
@@ -122,6 +123,6 @@ impl SessionTask for UndoTask {
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
.await;
None
TaskRunOutput::default()
}
}

View File

@@ -33,6 +33,7 @@ use crate::user_shell_command::user_shell_command_record_item;
use super::SessionTask;
use super::SessionTaskContext;
use super::TaskRunOutput;
use crate::codex::Session;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
@@ -72,7 +73,7 @@ impl SessionTask for UserShellCommandTask {
turn_context: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
) -> TaskRunOutput {
execute_user_shell_command(
session.clone_session(),
turn_context,
@@ -81,7 +82,7 @@ impl SessionTask for UserShellCommandTask {
UserShellCommandMode::StandaloneTurn,
)
.await;
None
TaskRunOutput::default()
}
}