This commit is contained in:
jimmyfraiture
2025-09-09 16:17:27 -07:00
parent 82ce78b1ea
commit 03c287f921
7 changed files with 98 additions and 39 deletions

View File

@@ -772,7 +772,7 @@ impl Session {
sub_id: &str,
call_id: &str,
output: &ExecToolCallOutput,
is_apply_patch: bool,
apply_patch: Option<&ApplyPatchCommandContext>,
) {
let ExecToolCallOutput {
stdout,
@@ -787,12 +787,25 @@ impl Session {
let formatted_output = format_exec_output_str(output);
let aggregated_output: String = aggregated_output.text.clone();
let msg = if is_apply_patch {
let msg = if let Some(apply_patch) = apply_patch {
let diff = if *exit_code == 0 {
match turn_diff_tracker.get_unified_diff_for_changes(&apply_patch.changes) {
Ok(diff) => diff,
Err(err) => {
warn!("failed to compute patch diff: {err:#}");
None
}
}
} else {
None
};
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
call_id: call_id.to_string(),
stdout,
stderr,
success: *exit_code == 0,
diff,
})
} else {
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
@@ -814,7 +827,7 @@ impl Session {
// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
if apply_patch.is_some() {
let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
@@ -836,7 +849,7 @@ impl Session {
begin_ctx: ExecCommandContext,
exec_args: ExecInvokeArgs<'a>,
) -> crate::error::Result<ExecToolCallOutput> {
let is_apply_patch = begin_ctx.apply_patch.is_some();
let apply_patch_ctx = begin_ctx.apply_patch.clone();
let sub_id = begin_ctx.sub_id.clone();
let call_id = begin_ctx.call_id.clone();
@@ -871,7 +884,7 @@ impl Session {
&sub_id,
&call_id,
borrowed,
is_apply_patch,
apply_patch_ctx.as_ref(),
)
.await;

View File

@@ -249,6 +249,56 @@ impl TurnDiffTracker {
}
}
pub fn get_unified_diff_for_changes(
&mut self,
changes: &HashMap<PathBuf, FileChange>,
) -> Result<Option<String>> {
let mut internal_names: Vec<String> = changes
.iter()
.flat_map(|(path, change)| {
let mut names = Vec::new();
if let Some(internal) = self.external_to_temp_name.get(path) {
names.push(internal.clone());
}
if let FileChange::Update {
move_path: Some(dest),
..
} = change
&& let Some(internal) = self.external_to_temp_name.get(dest)
{
names.push(internal.clone());
}
names
})
.collect();
if internal_names.is_empty() {
return Ok(None);
}
internal_names.sort();
internal_names.dedup();
internal_names.sort_by_key(|internal| {
self.get_path_for_internal(internal)
.map(|p| self.relative_to_git_root_str(&p))
.unwrap_or_default()
});
let mut aggregated = String::new();
for internal in internal_names {
aggregated.push_str(self.get_file_diff(&internal).as_str());
if !aggregated.ends_with('\n') {
aggregated.push('\n');
}
}
if aggregated.trim().is_empty() {
Ok(None)
} else {
Ok(Some(aggregated))
}
}
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
let mut aggregated = String::new();
@@ -784,6 +834,9 @@ index {left_oid_b}..{ZERO_OID}
assert_eq!(combined, expected);
}
/// Confirms that updating a binary file (non-UTF8 content) produces a
/// unified diff that reports "Binary files differ" with correct blob OIDs
/// instead of a textual hunk.
#[test]
fn binary_files_differ_update() {
let dir = tempdir().unwrap();

View File

@@ -463,7 +463,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
stdout,
stderr,
success,
..
diff: _diff,
}) => {
let patch_begin = self.call_id_to_patch.remove(&call_id);

View File

@@ -913,6 +913,9 @@ pub struct PatchApplyEndEvent {
pub stderr: String,
/// Whether the patch was applied successfully.
pub success: bool,
/// Unified diff describing the patch the tool applied, if available.
#[serde(default)]
pub diff: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]

View File

@@ -37,7 +37,6 @@ use codex_core::protocol::TurnDiffEvent;
use codex_core::protocol::UserMessageEvent;
use codex_core::protocol::WebSearchBeginEvent;
use codex_core::protocol::WebSearchEndEvent;
use codex_core::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::parse_command::ParsedCommand;
use crossterm::event::KeyCode;
use crossterm::event::KeyEvent;
@@ -103,16 +102,11 @@ struct AppliedPatchDiff {
working_dir: PathBuf,
}
struct PatchInFlight {
tracker: TurnDiffTracker,
working_dir: PathBuf,
}
const MAX_COMPLETED_UNDO_TURNS: usize = 1;
#[derive(Default)]
struct PatchUndoHistory {
trackers: HashMap<String, PatchInFlight>,
working_dirs: HashMap<String, PathBuf>,
active_turn: Vec<AppliedPatchDiff>,
completed_turns: Vec<Arc<[AppliedPatchDiff]>>,
undo_in_progress: Option<Arc<[AppliedPatchDiff]>>,
@@ -124,6 +118,7 @@ impl PatchUndoHistory {
}
fn start_turn(&mut self) {
self.working_dirs.clear();
self.active_turn.clear();
}
@@ -143,28 +138,16 @@ impl PatchUndoHistory {
fn start_patch(
&mut self,
call_id: String,
changes: HashMap<PathBuf, FileChange>,
changes: &HashMap<PathBuf, FileChange>,
patch_cwd: PathBuf,
default_cwd: &Path,
) {
let mut tracker = TurnDiffTracker::new();
tracker.on_patch_begin(&changes);
let working_dir = working_dir_for_patch(&changes, &patch_cwd, default_cwd);
self.trackers.insert(
call_id,
PatchInFlight {
tracker,
working_dir,
},
);
let working_dir = working_dir_for_patch(changes, &patch_cwd, default_cwd);
self.working_dirs.insert(call_id, working_dir);
}
fn complete_patch(&mut self, call_id: &str, success: bool) {
let Some(PatchInFlight {
mut tracker,
working_dir,
}) = self.trackers.remove(call_id)
else {
fn complete_patch(&mut self, call_id: &str, success: bool, diff: Option<String>) {
let Some(working_dir) = self.working_dirs.remove(call_id) else {
return;
};
@@ -172,14 +155,16 @@ impl PatchUndoHistory {
return;
}
if let Ok(Some(diff)) = tracker.get_unified_diff() {
if diff.trim().is_empty() {
return;
}
let Some(diff) = diff else {
return;
};
self.active_turn
.push(AppliedPatchDiff { diff, working_dir });
if diff.trim().is_empty() {
return;
}
self.active_turn
.push(AppliedPatchDiff { diff, working_dir });
}
fn prune_completed_turns(&mut self) {
@@ -512,7 +497,7 @@ impl ChatWidget {
let history_changes = changes.clone();
self.patch_history
.start_patch(call_id, changes, cwd, &self.config.cwd);
.start_patch(call_id, &changes, cwd, &self.config.cwd);
self.add_to_history(history_cell::new_patch_event(
PatchEventType::ApplyBegin { auto_approved },
history_changes,
@@ -678,7 +663,7 @@ impl ChatWidget {
event: codex_core::protocol::PatchApplyEndEvent,
) {
self.patch_history
.complete_patch(&event.call_id, event.success);
.complete_patch(&event.call_id, event.success, event.diff.clone());
if !event.success {
self.add_to_history(history_cell::new_patch_apply_failure(event.stderr));
}

View File

@@ -1152,6 +1152,7 @@ fn apply_patch_events_emit_history_cells() {
stdout: "ok\n".into(),
stderr: String::new(),
success: true,
diff: Some("diff --git a/foo.txt b/foo.txt\n".to_string()),
};
chat.handle_codex_event(Event {
id: "s1".into(),
@@ -1376,6 +1377,7 @@ fn apply_patch_full_flow_integration_like() {
stdout: String::from("ok"),
stderr: String::new(),
success: true,
diff: Some("diff --git a/pkg.rs b/pkg.rs\n".to_string()),
}),
});
}

View File

@@ -35,7 +35,10 @@ async fn run_git_apply(diff: &str, cwd: &Path, args: &[&str]) -> io::Result<Undo
}
pub(crate) async fn undo_patch(diff: &str, cwd: &Path) -> io::Result<UndoPatchResult> {
const UNDO_ARGS: [&[&str]; 2] = [&["apply", "-R"], &["apply", "--3way", "-R"]];
const UNDO_ARGS: [&[&str]; 2] = [
&["apply", "--unsafe-paths", "-R"],
&["apply", "--unsafe-paths", "--3way", "-R"],
];
let mut last_result = UndoPatchResult::default();
for args in UNDO_ARGS {