Compare commits

...

1 Commits

Author SHA1 Message Date
Ahmed Ibrahim
e5c2ecdc92 Store previous turn context on session state 2026-02-10 19:11:15 -08:00
3 changed files with 55 additions and 27 deletions

View File

@@ -1324,6 +1324,16 @@ impl Session {
state.clear_mcp_tool_selection();
}
async fn previous_turn_context(&self) -> Option<Arc<TurnContext>> {
let state = self.state.lock().await;
state.previous_turn_context()
}
pub(crate) async fn set_previous_turn_context(&self, turn_context: Arc<TurnContext>) {
let mut state = self.state.lock().await;
state.set_previous_turn_context(turn_context);
}
async fn record_initial_history(&self, conversation_history: InitialHistory) {
let turn_context = self.new_default_turn().await;
match conversation_history {
@@ -2822,7 +2832,12 @@ impl Session {
async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiver<Submission>) {
// Seed with context in case there is an OverrideTurnContext first.
let mut previous_context: Option<Arc<TurnContext>> = Some(sess.new_default_turn().await);
if sess.previous_turn_context().await.is_none() {
let default_turn = sess.new_default_turn().await;
if sess.previous_turn_context().await.is_none() {
sess.set_previous_turn_context(default_turn).await;
}
}
// To break out of this loop, send Op::Shutdown.
while let Ok(sub) = rx_sub.recv().await {
@@ -2872,8 +2887,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
.await;
}
Op::UserInput { .. } | Op::UserTurn { .. } => {
handlers::user_input_or_turn(&sess, sub.id.clone(), sub.op, &mut previous_context)
.await;
handlers::user_input_or_turn(&sess, sub.id.clone(), sub.op).await;
}
Op::ExecApproval {
id: approval_id,
@@ -2939,13 +2953,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
handlers::set_thread_name(&sess, sub.id.clone(), name).await;
}
Op::RunUserShellCommand { command } => {
handlers::run_user_shell_command(
&sess,
sub.id.clone(),
command,
&mut previous_context,
)
.await;
handlers::run_user_shell_command(&sess, sub.id.clone(), command).await;
}
Op::ResolveElicitation {
server_name,
@@ -2973,7 +2981,6 @@ mod handlers {
use crate::codex::Session;
use crate::codex::SessionSettingsUpdate;
use crate::codex::SteerInputError;
use crate::codex::TurnContext;
use crate::codex::spawn_review_thread;
use crate::config::Config;
@@ -3048,12 +3055,7 @@ mod handlers {
}
}
pub async fn user_input_or_turn(
sess: &Arc<Session>,
sub_id: String,
op: Op,
previous_context: &mut Option<Arc<TurnContext>>,
) {
pub async fn user_input_or_turn(sess: &Arc<Session>, sub_id: String, op: Op) {
let (items, updates) = match op {
Op::UserTurn {
cwd,
@@ -3113,6 +3115,7 @@ mod handlers {
// Attempt to inject input into current task.
if let Err(SteerInputError::NoActiveTurn(items)) = sess.steer_input(items, None).await {
sess.seed_initial_context_if_needed(&current_context).await;
let previous_context = sess.previous_turn_context().await;
let resumed_model = sess.take_pending_resume_previous_model().await;
let update_items = sess.build_settings_update_items(
previous_context.as_ref(),
@@ -3129,16 +3132,10 @@ mod handlers {
let regular_task = sess.take_startup_regular_task().await.unwrap_or_default();
sess.spawn_task(Arc::clone(&current_context), items, regular_task)
.await;
*previous_context = Some(current_context);
}
}
pub async fn run_user_shell_command(
sess: &Arc<Session>,
sub_id: String,
command: String,
previous_context: &mut Option<Arc<TurnContext>>,
) {
pub async fn run_user_shell_command(sess: &Arc<Session>, sub_id: String, command: String) {
if let Some((turn_context, cancellation_token)) =
sess.active_turn_context_and_cancellation_token().await
{
@@ -3163,7 +3160,6 @@ mod handlers {
UserShellCommandTask::new(command),
)
.await;
*previous_context = Some(turn_context);
}
pub async fn resolve_elicitation(

View File

@@ -3,8 +3,10 @@
use codex_protocol::models::ResponseItem;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use crate::codex::SessionConfiguration;
use crate::codex::TurnContext;
use crate::context_manager::ContextManager;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
@@ -30,6 +32,7 @@ pub(crate) struct SessionState {
/// Startup regular task pre-created during session initialization.
pub(crate) startup_regular_task: Option<RegularTask>,
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
pub(crate) previous_turn_context: Option<Arc<TurnContext>>,
}
impl SessionState {
@@ -47,6 +50,7 @@ impl SessionState {
pending_resume_previous_model: None,
startup_regular_task: None,
active_mcp_tool_selection: None,
previous_turn_context: None,
}
}
@@ -168,6 +172,14 @@ impl SessionState {
pub(crate) fn clear_mcp_tool_selection(&mut self) {
self.active_mcp_tool_selection = None;
}
pub(crate) fn previous_turn_context(&self) -> Option<Arc<TurnContext>> {
self.previous_turn_context.clone()
}
pub(crate) fn set_previous_turn_context(&mut self, turn_context: Arc<TurnContext>) {
self.previous_turn_context = Some(turn_context);
}
}
// Sometimes new snapshots don't include credits or plan information.
@@ -187,6 +199,7 @@ fn merge_rate_limit_fields(
#[cfg(test)]
mod tests {
use super::*;
use crate::codex::make_session_and_context;
use crate::codex::make_session_configuration_for_tests;
use pretty_assertions::assert_eq;
@@ -258,4 +271,21 @@ mod tests {
assert_eq!(state.get_mcp_tool_selection(), None);
}
#[tokio::test]
async fn set_previous_turn_context_stores_context() {
let (_session, turn_context) = make_session_and_context().await;
let session_configuration = make_session_configuration_for_tests().await;
let mut state = SessionState::new(session_configuration);
state.set_previous_turn_context(Arc::new(turn_context));
assert_eq!(
state
.previous_turn_context()
.as_ref()
.map(|context| &context.sub_id),
Some(&"turn_id".to_string())
);
}
}

View File

@@ -148,10 +148,12 @@ impl Session {
task_cancellation_token.child_token(),
)
.await;
session_ctx.clone_session().flush_rollout().await;
let sess = session_ctx.clone_session();
sess.flush_rollout().await;
sess.set_previous_turn_context(Arc::clone(&ctx_for_finish))
.await;
if !task_cancellation_token.is_cancelled() {
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(ctx_for_finish, last_agent_message)
.await;
}