Compare commits

...

5 Commits

Author SHA1 Message Date
jimmyfraiture
ee6677d398 NIT 2 2025-09-08 11:57:06 -07:00
jimmyfraiture
927ccb3299 V4 2025-09-08 10:52:08 -07:00
jimmyfraiture
10537867ad V3 2025-09-08 10:42:56 -07:00
jimmyfraiture
fdf52e87c2 V2 2025-09-08 10:33:21 -07:00
jimmyfraiture
731a354f6c V1 2025-09-08 10:07:57 -07:00
7 changed files with 415 additions and 23 deletions

View File

@@ -267,6 +267,12 @@ struct State {
pending_input: Vec<ResponseInputItem>,
history: ConversationHistory,
token_info: Option<TokenUsageInfo>,
last_undo_patch: Option<StoredUndoPatch>,
}
#[derive(Clone)]
struct StoredUndoPatch {
patch: String,
}
/// Context for an initialized model agent
@@ -660,6 +666,19 @@ impl Session {
state.approved_commands.insert(cmd);
}
fn store_last_undo_patch(&self, patch: String) {
let mut state = self.state.lock_unchecked();
state.last_undo_patch = Some(StoredUndoPatch { patch });
}
fn last_undo_patch(&self) -> Option<StoredUndoPatch> {
self.state.lock_unchecked().last_undo_patch.clone()
}
fn clear_last_undo_patch(&self) {
self.state.lock_unchecked().last_undo_patch = None;
}
/// Records items to both the rollout and the chat completions/ZDR
/// transcript, if enabled.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
@@ -704,6 +723,7 @@ impl Session {
user_explicitly_approved_this_action,
changes,
}) => {
self.clear_last_undo_patch();
turn_diff_tracker.on_patch_begin(&changes);
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
@@ -732,8 +752,7 @@ impl Session {
async fn on_exec_command_end(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
call_id: &str,
context: &ExecCommandContext,
output: &ExecToolCallOutput,
is_apply_patch: bool,
) {
@@ -752,14 +771,14 @@ impl Session {
let msg = if is_apply_patch {
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
call_id: call_id.to_string(),
call_id: context.call_id.clone(),
stdout,
stderr,
success: *exit_code == 0,
})
} else {
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id: call_id.to_string(),
call_id: context.call_id.clone(),
stdout,
stderr,
aggregated_output,
@@ -770,7 +789,7 @@ impl Session {
};
let event = Event {
id: sub_id.to_string(),
id: context.sub_id.clone(),
msg,
};
let _ = self.tx_event.send(event).await;
@@ -778,14 +797,55 @@ impl Session {
// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
id: sub_id.into(),
msg,
};
let _ = self.tx_event.send(event).await;
match turn_diff_tracker.get_unified_diff() {
Ok(Some(unified_diff)) => {
let msg = EventMsg::TurnDiff(TurnDiffEvent {
unified_diff: unified_diff.clone(),
});
let event = Event {
id: context.sub_id.clone(),
msg,
};
let _ = self.tx_event.send(event).await;
if *exit_code == 0 {
match turn_diff_tracker.build_undo_patch() {
Ok(Some(patch)) => {
self.store_last_undo_patch(patch);
}
Ok(None) => {
self.clear_last_undo_patch();
}
Err(error) => {
warn!("failed to prepare undo patch: {error:#}");
self.clear_last_undo_patch();
self.notify_background_event(
&context.sub_id,
format!("Undo is unavailable for this turn: {error:#}"),
)
.await;
}
}
}
}
Ok(None) => {
if *exit_code == 0 {
self.clear_last_undo_patch();
}
}
Err(error) => {
warn!("failed to compute unified diff: {error:#}");
if *exit_code == 0 {
self.clear_last_undo_patch();
self
.notify_background_event(
&context.sub_id,
format!(
"Undo is unavailable for this turn: failed to compute diff: {error:#}"
),
)
.await;
}
}
}
}
}
@@ -800,8 +860,6 @@ impl Session {
exec_args: ExecInvokeArgs<'a>,
) -> crate::error::Result<ExecToolCallOutput> {
let is_apply_patch = begin_ctx.apply_patch.is_some();
let sub_id = begin_ctx.sub_id.clone();
let call_id = begin_ctx.call_id.clone();
self.on_exec_command_begin(turn_diff_tracker, begin_ctx.clone())
.await;
@@ -829,14 +887,8 @@ impl Session {
&output_stderr
}
};
self.on_exec_command_end(
turn_diff_tracker,
&sub_id,
&call_id,
borrowed,
is_apply_patch,
)
.await;
self.on_exec_command_end(turn_diff_tracker, &begin_ctx, borrowed, is_apply_patch)
.await;
result
}
@@ -864,6 +916,37 @@ impl Session {
let _ = self.tx_event.send(event).await;
}
async fn undo_last_turn_diff(&self, sub_id: &str) {
let Some(stored_patch) = self.last_undo_patch() else {
self.notify_background_event(sub_id, "No turn diff available to undo.")
.await;
return;
};
let mut stdout = Vec::new();
let mut stderr = Vec::new();
match codex_apply_patch::apply_patch(&stored_patch.patch, &mut stdout, &mut stderr) {
Ok(()) => {
self.clear_last_undo_patch();
if stdout.is_empty() {
self.notify_background_event(sub_id, "Reverted last turn diff.")
.await;
} else if let Ok(output) = String::from_utf8(stdout) {
self.notify_background_event(sub_id, output).await;
}
}
Err(error) => {
let mut message = format!("failed to undo turn diff: {error:#}");
if let Ok(stderr_text) = String::from_utf8(stderr)
&& !stderr_text.is_empty()
{
message = format!("{message}\n{stderr_text}");
}
self.notify_stream_error(sub_id, message).await;
}
}
}
/// Build the full turn input by concatenating the current conversation
/// history with additional items for this turn.
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
@@ -1141,6 +1224,9 @@ async fn submission_loop(
sess.set_task(task);
}
}
Op::UndoLastTurnDiff => {
sess.undo_last_turn_diff(&sub.id).await;
}
Op::UserTurn {
items,
cwd,

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
@@ -249,6 +250,64 @@ impl TurnDiffTracker {
}
}
pub fn build_undo_patch(&mut self) -> Result<Option<String>> {
let mut delete_paths: BTreeSet<PathBuf> = BTreeSet::new();
let mut add_entries: Vec<(PathBuf, String)> = Vec::new();
let mut baseline_file_names: Vec<String> =
self.baseline_file_info.keys().cloned().collect();
baseline_file_names.sort();
for internal in baseline_file_names {
let Some(info) = self.baseline_file_info.get(&internal) else {
continue;
};
let current_path = self
.get_path_for_internal(&internal)
.unwrap_or(info.path.clone());
if current_path.exists() {
delete_paths.insert(current_path);
}
if info.oid.as_str() != ZERO_OID {
let content = String::from_utf8(info.content.clone()).map_err(|_| {
anyhow!(
"undo is not supported for non-UTF8 baseline file {}",
info.path.display()
)
})?;
add_entries.push((info.path.clone(), content));
}
}
if delete_paths.is_empty() && add_entries.is_empty() {
return Ok(None);
}
add_entries.sort_by(|(left_path, _), (right_path, _)| left_path.cmp(right_path));
let mut patch = String::from("*** Begin Patch\n");
for path in delete_paths {
patch.push_str(&format!("*** Delete File: {}\n", path.display()));
}
for (path, content) in add_entries {
patch.push_str(&format!("*** Add File: {}\n", path.display()));
if !content.is_empty() {
for line in content.split_terminator('\n') {
patch.push('+');
patch.push_str(line);
patch.push('\n');
}
if !content.ends_with('\n') {
patch.push_str("+\n");
}
}
}
patch.push_str("*** End Patch\n");
Ok(Some(patch))
}
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
let mut aggregated = String::new();
@@ -503,6 +562,146 @@ mod tests {
out
}
fn normalize_patch_for_test(input: &str, root: &Path) -> String {
let root_str = root.display().to_string().replace('\\', "/");
let mut replaced = input.replace('\\', "/");
replaced = replaced.replace(&root_str, "<TMP>");
if let Some(root_name) = root.file_name().and_then(|name| name.to_str()) {
let marker = format!("/{root_name}");
let mut normalized = String::with_capacity(replaced.len());
let mut search_start = 0;
while let Some(relative_pos) = replaced[search_start..].find(&marker) {
let absolute_pos = search_start + relative_pos;
let path_start = replaced[..absolute_pos]
.rfind(['\n', ' '])
.map(|idx| idx + 1)
.unwrap_or(0);
let prefix_end = replaced[path_start..absolute_pos]
.find('/')
.map(|idx| path_start + idx + 1)
.unwrap_or(path_start);
normalized.push_str(&replaced[search_start..prefix_end]);
normalized.push_str("<TMP>");
let after_marker = absolute_pos + marker.len();
let mut rest_start = after_marker;
if after_marker < replaced.len() && replaced.as_bytes()[after_marker] == b'/' {
normalized.push('/');
rest_start += 1;
}
search_start = rest_start;
}
normalized.push_str(&replaced[search_start..]);
replaced = normalized;
}
if !replaced.ends_with('\n') {
replaced.push('\n');
}
replaced
}
#[test]
fn build_undo_patch_returns_none_without_baseline() {
let mut tracker = TurnDiffTracker::new();
assert_eq!(tracker.build_undo_patch().unwrap(), None);
}
#[test]
fn build_undo_patch_restores_updated_file() {
let dir = tempdir().unwrap();
let path = dir.path().join("undo.txt");
fs::write(&path, "before\n").unwrap();
let mut tracker = TurnDiffTracker::new();
let update_changes = HashMap::from([(
path.clone(),
FileChange::Update {
unified_diff: String::new(),
move_path: None,
},
)]);
tracker.on_patch_begin(&update_changes);
fs::write(&path, "after\n").unwrap();
let patch = tracker
.build_undo_patch()
.expect("undo patch")
.expect("some undo patch");
let normalized = normalize_patch_for_test(&patch, dir.path());
let expected = concat!(
"*** Begin Patch\n",
"*** Delete File: <TMP>/undo.txt\n",
"*** Add File: <TMP>/undo.txt\n",
"+before\n",
"*** End Patch\n",
);
assert_eq!(normalized, expected);
}
#[test]
fn build_undo_patch_restores_deleted_file() {
let dir = tempdir().unwrap();
let path = dir.path().join("gone.txt");
fs::write(&path, "gone\n").unwrap();
let mut tracker = TurnDiffTracker::new();
let delete_changes = HashMap::from([(
path.clone(),
FileChange::Delete {
content: "gone\n".to_string(),
},
)]);
tracker.on_patch_begin(&delete_changes);
fs::remove_file(&path).unwrap();
let patch = tracker
.build_undo_patch()
.expect("undo patch")
.expect("some undo patch");
let normalized = normalize_patch_for_test(&patch, dir.path());
let expected = concat!(
"*** Begin Patch\n",
"*** Add File: <TMP>/gone.txt\n",
"+gone\n",
"*** End Patch\n",
);
assert_eq!(normalized, expected);
}
#[test]
fn build_undo_patch_rejects_non_utf8_content() {
let dir = tempdir().unwrap();
let path = dir.path().join("binary.bin");
fs::write(&path, [0xff, 0xfe, 0x00]).unwrap();
let mut tracker = TurnDiffTracker::new();
let update_changes = HashMap::from([(
path.clone(),
FileChange::Update {
unified_diff: String::new(),
move_path: None,
},
)]);
tracker.on_patch_begin(&update_changes);
let err = tracker.build_undo_patch().unwrap_err();
let message = format!("{err:#}");
assert!(
message.contains("undo is not supported for non-UTF8 baseline file"),
"unexpected error message: {message}"
);
}
#[test]
fn accumulates_add_and_update() {
let mut acc = TurnDiffTracker::new();

View File

@@ -58,6 +58,9 @@ pub enum Op {
items: Vec<InputItem>,
},
/// Undo the most recently applied turn diff using the local git repo.
UndoLastTurnDiff,
/// Similar to [`Op::UserInput`], but contains additional context required
/// for a turn of a [`crate::codex_conversation::CodexConversation`].
UserTurn {
@@ -1008,6 +1011,7 @@ pub enum TurnAbortReason {
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
/// Serialize Event to verify that its JSON representation has the expected
/// amount of nesting.

View File

@@ -411,6 +411,8 @@ impl ChatWidget {
fn on_background_event(&mut self, message: String) {
debug!("BackgroundEvent: {message}");
self.add_to_history(history_cell::new_background_event(message));
self.request_redraw();
}
fn on_stream_error(&mut self, message: String) {
@@ -862,6 +864,9 @@ impl ChatWidget {
tx.send(AppEvent::DiffResult(text));
});
}
SlashCommand::Undo => {
self.open_undo_confirmation_popup();
}
SlashCommand::Mention => {
self.insert_str("@");
}
@@ -1253,6 +1258,43 @@ impl ChatWidget {
);
}
fn open_undo_confirmation_popup(&mut self) {
let confirm_message = "Undoing the last Codex turn diff.".to_string();
let undo_actions: Vec<SelectionAction> = vec![Box::new(move |tx| {
tx.send(AppEvent::InsertHistoryCell(Box::new(
history_cell::new_background_event(confirm_message.clone()),
)));
tx.send(AppEvent::CodexOp(Op::UndoLastTurnDiff));
})];
let items = vec![
SelectionItem {
name: "Undo last turn diff".to_string(),
description: Some(
"Revert files that Codex changed during the most recent turn.".to_string(),
),
is_current: false,
actions: undo_actions,
},
SelectionItem {
name: "Cancel".to_string(),
description: Some("Close without undoing any files.".to_string()),
is_current: false,
actions: Vec::new(),
},
];
self.bottom_pane.show_selection_view(
"Undo last Codex turn?".to_string(),
Some(
"Codex will apply a patch to restore files from before the previous turn."
.to_string(),
),
Some("Press Enter to confirm or Esc to cancel".to_string()),
items,
);
}
/// Set the approval policy in the widget's config copy.
pub(crate) fn set_approval_policy(&mut self, policy: AskForApproval) {
self.config.approval_policy = policy;

View File

@@ -13,6 +13,7 @@ use codex_core::protocol::AgentMessageEvent;
use codex_core::protocol::AgentReasoningDeltaEvent;
use codex_core::protocol::AgentReasoningEvent;
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
use codex_core::protocol::BackgroundEventEvent;
use codex_core::protocol::Event;
use codex_core::protocol::EventMsg;
use codex_core::protocol::ExecApprovalRequestEvent;
@@ -614,6 +615,58 @@ fn disabled_slash_command_while_task_running_snapshot() {
assert_snapshot!(blob);
}
#[test]
fn undo_command_requires_confirmation() {
let (mut chat, mut rx, _op_rx) = make_chatwidget_manual();
chat.dispatch_command(SlashCommand::Undo);
assert!(rx.try_recv().is_err(), "undo should require confirmation");
chat.handle_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE));
let mut undo_requested = false;
let mut history_lines = Vec::new();
while let Ok(event) = rx.try_recv() {
match event {
AppEvent::InsertHistoryCell(cell) => {
history_lines.push(cell.display_lines(80));
}
AppEvent::CodexOp(Op::UndoLastTurnDiff) => {
undo_requested = true;
}
_ => {}
}
}
assert!(undo_requested, "expected undo op after confirmation");
let combined = history_lines
.iter()
.map(|lines| lines_to_single_string(lines))
.collect::<String>();
assert!(combined.contains("Undoing the last Codex turn diff."));
}
#[test]
fn background_events_are_rendered_in_history() {
let (mut chat, mut rx, _op_rx) = make_chatwidget_manual();
chat.handle_codex_event(Event {
id: "undo".to_string(),
msg: EventMsg::BackgroundEvent(BackgroundEventEvent {
message: "Reverted last turn diff.".to_string(),
}),
});
let history = drain_insert_history(&mut rx);
let combined = history
.iter()
.map(|lines| lines_to_single_string(lines))
.collect::<String>();
assert!(combined.contains("Reverted last turn diff."));
}
#[tokio::test(flavor = "current_thread")]
async fn binary_size_transcript_snapshot() {
let (mut chat, mut rx, _op_rx) = make_chatwidget_manual();

View File

@@ -1064,6 +1064,11 @@ pub(crate) fn new_stream_error_event(message: String) -> PlainHistoryCell {
PlainHistoryCell { lines }
}
pub(crate) fn new_background_event(message: String) -> PlainHistoryCell {
let lines: Vec<Line<'static>> = vec![vec![padded_emoji("").into(), message.into()].into()];
PlainHistoryCell { lines }
}
/// Render a userfriendly plan update styled like a checkbox todo list.
pub(crate) fn new_plan_update(update: UpdatePlanArgs) -> PlanUpdateCell {
let UpdatePlanArgs { explanation, plan } = update;

View File

@@ -18,6 +18,7 @@ pub enum SlashCommand {
Init,
Compact,
Diff,
Undo,
Mention,
Status,
Mcp,
@@ -36,6 +37,7 @@ impl SlashCommand {
SlashCommand::Compact => "summarize conversation to prevent hitting the context limit",
SlashCommand::Quit => "exit Codex",
SlashCommand::Diff => "show git diff (including untracked files)",
SlashCommand::Undo => "undo the last turn diff applied by Codex",
SlashCommand::Mention => "mention a file",
SlashCommand::Status => "show current session configuration and token usage",
SlashCommand::Model => "choose what model and reasoning effort to use",
@@ -63,6 +65,7 @@ impl SlashCommand {
| SlashCommand::Approvals
| SlashCommand::Logout => false,
SlashCommand::Diff
| SlashCommand::Undo
| SlashCommand::Mention
| SlashCommand::Status
| SlashCommand::Mcp