Compare commits

...

4 Commits

Author SHA1 Message Date
celia-oai
09637409b0 changes 2026-04-30 20:25:16 -07:00
celia-oai
79908b64a1 changes 2026-04-30 20:08:54 -07:00
Owen Lin
6014b6679f fix flaky test falls_back_to_registered_fallback_port_when_default_po… (#20504)
…rt_is_in_use
2026-04-30 22:06:04 +00:00
Akshay Nathan
8426edf71e Stateful streaming apply_patch parser 2026-04-30 21:41:15 +00:00
18 changed files with 1077 additions and 381 deletions

View File

@@ -2,6 +2,7 @@ mod invocation;
mod parser;
mod seek_sequence;
mod standalone_executable;
mod streaming_parser;
use std::collections::HashMap;
use std::io;
@@ -20,8 +21,8 @@ pub use parser::ParseError;
use parser::ParseError::*;
pub use parser::UpdateFileChunk;
pub use parser::parse_patch;
pub use parser::parse_patch_streaming;
use similar::TextDiff;
pub use streaming_parser::StreamingPatchParser;
use thiserror::Error;
pub use invocation::maybe_parse_apply_patch_verified;

View File

@@ -31,15 +31,15 @@ use std::path::PathBuf;
use thiserror::Error;
const BEGIN_PATCH_MARKER: &str = "*** Begin Patch";
const END_PATCH_MARKER: &str = "*** End Patch";
const ADD_FILE_MARKER: &str = "*** Add File: ";
const DELETE_FILE_MARKER: &str = "*** Delete File: ";
const UPDATE_FILE_MARKER: &str = "*** Update File: ";
const MOVE_TO_MARKER: &str = "*** Move to: ";
const EOF_MARKER: &str = "*** End of File";
const CHANGE_CONTEXT_MARKER: &str = "@@ ";
const EMPTY_CHANGE_CONTEXT_MARKER: &str = "@@";
pub(crate) const BEGIN_PATCH_MARKER: &str = "*** Begin Patch";
pub(crate) const END_PATCH_MARKER: &str = "*** End Patch";
pub(crate) const ADD_FILE_MARKER: &str = "*** Add File: ";
pub(crate) const DELETE_FILE_MARKER: &str = "*** Delete File: ";
pub(crate) const UPDATE_FILE_MARKER: &str = "*** Update File: ";
pub(crate) const MOVE_TO_MARKER: &str = "*** Move to: ";
pub(crate) const EOF_MARKER: &str = "*** End of File";
pub(crate) const CHANGE_CONTEXT_MARKER: &str = "@@ ";
pub(crate) const EMPTY_CHANGE_CONTEXT_MARKER: &str = "@@";
/// Currently, the only OpenAI model that knowingly requires lenient parsing is
/// gpt-4.1. While we could try to require everyone to pass in a strictness
@@ -132,14 +132,6 @@ pub fn parse_patch(patch: &str) -> Result<ApplyPatchArgs, ParseError> {
parse_patch_text(patch, mode)
}
/// Parses streamed patch text that may not have reached `*** End Patch` yet.
///
/// This entry point is for progress reporting only; callers must not use its
/// output to apply a patch.
pub fn parse_patch_streaming(patch: &str) -> Result<ApplyPatchArgs, ParseError> {
parse_patch_text(patch, ParseMode::Streaming)
}
enum ParseMode {
/// Parse the patch text argument as is.
Strict,
@@ -177,12 +169,6 @@ enum ParseMode {
/// `<<'EOF'` and ends with `EOF\n`. If so, we strip off these markers,
/// trim() the result, and treat what is left as the patch text.
Lenient,
/// Parse partial patch text for progress reporting while the model is
/// still streaming tool input. This mode requires a begin marker but does
/// not require an end marker, and its output must not be used to apply a
/// patch.
Streaming,
}
fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, ParseError> {
@@ -190,15 +176,13 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
let (patch_lines, hunk_lines) = match mode {
ParseMode::Strict => check_patch_boundaries_strict(&lines)?,
ParseMode::Lenient => check_patch_boundaries_lenient(&lines)?,
ParseMode::Streaming => check_patch_boundaries_streaming(&lines)?,
};
let mut hunks: Vec<Hunk> = Vec::new();
let mut remaining_lines = hunk_lines;
let mut line_number = 2;
let allow_incomplete = matches!(mode, ParseMode::Streaming);
while !remaining_lines.is_empty() {
let (hunk, hunk_lines) = parse_one_hunk(remaining_lines, line_number, allow_incomplete)?;
let (hunk, hunk_lines) = parse_one_hunk(remaining_lines, line_number)?;
hunks.push(hunk);
line_number += hunk_lines;
remaining_lines = &remaining_lines[hunk_lines..]
@@ -211,25 +195,6 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
})
}
fn check_patch_boundaries_streaming<'a>(
original_lines: &'a [&'a str],
) -> Result<(&'a [&'a str], &'a [&'a str]), ParseError> {
match original_lines {
[first, ..] if first.trim() == BEGIN_PATCH_MARKER => {
let body_lines = if original_lines
.last()
.is_some_and(|line| line.trim() == END_PATCH_MARKER)
{
&original_lines[1..original_lines.len() - 1]
} else {
&original_lines[1..]
};
Ok((original_lines, body_lines))
}
_ => check_patch_boundaries_strict(original_lines),
}
}
/// Checks the start and end lines of the patch text for `apply_patch`,
/// returning an error if they do not match the expected markers.
fn check_patch_boundaries_strict<'a>(
@@ -297,15 +262,9 @@ fn check_start_and_end_lines_strict(
/// Attempts to parse a single hunk from the start of lines.
/// Returns the parsed hunk and the number of lines parsed (or a ParseError).
fn parse_one_hunk(
lines: &[&str],
line_number: usize,
allow_incomplete: bool,
) -> Result<(Hunk, usize), ParseError> {
// Be tolerant of case mismatches and extra padding around marker strings.
fn parse_one_hunk(lines: &[&str], line_number: usize) -> Result<(Hunk, usize), ParseError> {
let first_line = lines[0].trim();
if let Some(path) = first_line.strip_prefix(ADD_FILE_MARKER) {
// Add File
let mut contents = String::new();
let mut parsed_lines = 1;
for add_line in &lines[1..] {
@@ -325,7 +284,6 @@ fn parse_one_hunk(
parsed_lines,
));
} else if let Some(path) = first_line.strip_prefix(DELETE_FILE_MARKER) {
// Delete File
return Ok((
DeleteFile {
path: PathBuf::from(path),
@@ -333,11 +291,8 @@ fn parse_one_hunk(
1,
));
} else if let Some(path) = first_line.strip_prefix(UPDATE_FILE_MARKER) {
// Update File
let mut remaining_lines = &lines[1..];
let mut parsed_lines = 1;
// Optional: move file line
let move_path = remaining_lines
.first()
.and_then(|x| x.strip_prefix(MOVE_TO_MARKER));
@@ -348,9 +303,7 @@ fn parse_one_hunk(
}
let mut chunks = Vec::new();
// NOTE: we need to know to stop once we reach the next special marker header.
while !remaining_lines.is_empty() {
// Skip over any completely blank lines that may separate chunks.
if remaining_lines[0].trim().is_empty() {
parsed_lines += 1;
remaining_lines = &remaining_lines[1..];
@@ -361,22 +314,11 @@ fn parse_one_hunk(
break;
}
if allow_incomplete && remaining_lines[0] == "@" {
break;
}
let parsed_chunk = parse_update_file_chunk(
let (chunk, chunk_lines) = parse_update_file_chunk(
remaining_lines,
line_number + parsed_lines,
chunks.is_empty(),
);
let (chunk, chunk_lines) = match parsed_chunk {
Ok(parsed) => parsed,
Err(InvalidHunkError { .. }) if allow_incomplete && !chunks.is_empty() => {
break;
}
Err(err) => return Err(err),
};
)?;
chunks.push(chunk);
parsed_lines += chunk_lines;
remaining_lines = &remaining_lines[chunk_lines..]
@@ -384,7 +326,10 @@ fn parse_one_hunk(
if chunks.is_empty() {
return Err(InvalidHunkError {
message: format!("Update file hunk for path '{path}' is empty"),
message: format!(
"Update file hunk for path '{}' is empty",
Path::new(path).display()
),
line_number,
});
}
@@ -418,8 +363,6 @@ fn parse_update_file_chunk(
line_number,
});
}
// If we see an explicit context marker @@ or @@ <context>, consume it; otherwise, optionally
// allow treating the chunk as starting directly with diff lines.
let (change_context, start_index) = if lines[0] == EMPTY_CHANGE_CONTEXT_MARKER {
(None, 1)
} else if let Some(context) = lines[0].strip_prefix(CHANGE_CONTEXT_MARKER) {
@@ -501,162 +444,113 @@ fn parse_update_file_chunk(
}
#[test]
fn test_parse_patch_streaming() {
fn test_parse_one_hunk() {
assert_eq!(
parse_patch_streaming("*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor"),
Ok(ApplyPatchArgs {
hunks: vec![AddFile {
path: PathBuf::from("src/hello.txt"),
contents: "hello\nwor\n".to_string(),
}],
patch: "*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor".to_string(),
workdir: None,
})
);
assert_eq!(
parse_patch_streaming(
"*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new",
),
Ok(ApplyPatchArgs {
hunks: vec![UpdateFile {
path: PathBuf::from("src/old.rs"),
move_path: Some(PathBuf::from("src/new.rs")),
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}],
patch: "*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new".to_string(),
workdir: None,
})
);
assert!(
parse_patch_text(
"*** Begin Patch\n*** Delete File: gone.txt",
ParseMode::Streaming
)
.is_ok()
);
assert!(
parse_patch_text(
"*** Begin Patch\n*** Delete File: gone.txt",
ParseMode::Strict
)
.is_err()
);
assert_eq!(
parse_patch_streaming(
"*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt\n",
),
Ok(ApplyPatchArgs {
hunks: vec![
AddFile {
path: PathBuf::from("src/one.txt"),
contents: "one\n".to_string(),
},
DeleteFile {
path: PathBuf::from("src/two.txt"),
},
],
patch: "*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt"
.to_string(),
workdir: None,
parse_one_hunk(&["bad"], /*line_number*/ 234),
Err(InvalidHunkError {
message: "'bad' is not a valid hunk header. \
Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'".to_string(),
line_number: 234
})
);
}
#[test]
fn test_parse_patch_streaming_large_patch_by_character() {
let patch = "\
*** Begin Patch
*** Add File: docs/release-notes.md
+# Release notes
+
+## CLI
+- Surface apply_patch progress while arguments stream.
+- Keep final patch application gated on the completed tool call.
+- Include file summaries in the progress event payload.
*** Update File: src/config.rs
@@ impl Config
- pub apply_patch_progress: bool,
+ pub stream_apply_patch_progress: bool,
pub include_diagnostics: bool,
@@ fn default_progress_interval()
- Duration::from_millis(500)
+ Duration::from_millis(250)
*** Delete File: src/legacy_patch_progress.rs
*** Update File: crates/cli/src/main.rs
*** Move to: crates/cli/src/bin/codex.rs
@@ fn run()
- let args = Args::parse();
- dispatch(args)
+ let cli = Cli::parse();
+ dispatch(cli)
*** Add File: tests/fixtures/apply_patch_progress.json
+{
+ \"type\": \"apply_patch_progress\",
+ \"hunks\": [
+ { \"operation\": \"add\", \"path\": \"docs/release-notes.md\" },
+ { \"operation\": \"update\", \"path\": \"src/config.rs\" }
+ ]
+}
*** Update File: README.md
@@ Development workflow
Build the Rust workspace before opening a pull request.
+When touching streamed tool calls, include parser coverage for partial input.
+Prefer tests that exercise the exact event payload shape.
*** Delete File: docs/old-apply-patch-progress.md
*** End Patch";
let mut max_hunk_count = 0;
let mut saw_hunk_counts = Vec::new();
for i in 1..=patch.len() {
let partial = &patch[..i];
if let Ok(parsed) = parse_patch_streaming(partial) {
let hunk_count = parsed.hunks.len();
assert!(
hunk_count >= max_hunk_count,
"hunk count should never decrease while streaming: {hunk_count} < {max_hunk_count} for {partial:?}",
);
if hunk_count > max_hunk_count {
saw_hunk_counts.push(hunk_count);
max_hunk_count = hunk_count;
}
}
}
assert_eq!(saw_hunk_counts, vec![1, 2, 3, 4, 5, 6, 7]);
let parsed = parse_patch_streaming(patch).unwrap();
assert_eq!(parsed.hunks.len(), 7);
fn test_update_file_chunk() {
assert_eq!(
parsed
.hunks
.iter()
.map(|hunk| match hunk {
AddFile { .. } => "add",
DeleteFile { .. } => "delete",
UpdateFile {
move_path: Some(_), ..
} => "move-update",
UpdateFile {
move_path: None, ..
} => "update",
})
.collect::<Vec<_>>(),
vec![
"add",
"update",
"delete",
"move-update",
"add",
"update",
"delete"
]
parse_update_file_chunk(
&["bad"],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Err(InvalidHunkError {
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
.to_string(),
line_number: 123
})
);
assert_eq!(
parse_update_file_chunk(
&["@@"],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(
&["@@", "bad"],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Err(InvalidHunkError {
message: "Unexpected line found in update hunk: 'bad'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(
&["@@", "*** End of File"],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(
&[
"@@ change_context",
"",
" context",
"-remove",
"+add",
" context2",
"*** End Patch",
],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Ok((
UpdateFileChunk {
change_context: Some("change_context".to_string()),
old_lines: vec![
String::new(),
"context".to_string(),
"remove".to_string(),
"context2".to_string(),
],
new_lines: vec![
String::new(),
"context".to_string(),
"add".to_string(),
"context2".to_string(),
],
is_end_of_file: false,
},
6,
))
);
assert_eq!(
parse_update_file_chunk(
&["@@", "+line", "*** End of File"],
/*line_number*/ 123,
/*allow_missing_context*/ false,
),
Ok((
UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: vec!["line".to_string()],
is_end_of_file: true,
},
3,
))
);
}
@@ -997,112 +891,3 @@ fn test_parse_patch_lenient() {
))
);
}
#[test]
fn test_parse_one_hunk() {
assert_eq!(
parse_one_hunk(&["bad"], /*line_number*/ 234, /*allow_incomplete*/ false),
Err(InvalidHunkError {
message: "'bad' is not a valid hunk header. \
Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'".to_string(),
line_number: 234
})
);
// Other edge cases are already covered by tests above/below.
}
#[test]
fn test_update_file_chunk() {
assert_eq!(
parse_update_file_chunk(
&["bad"],
/*line_number*/ 123,
/*allow_missing_context*/ false
),
Err(InvalidHunkError {
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
.to_string(),
line_number: 123
})
);
assert_eq!(
parse_update_file_chunk(
&["@@"],
/*line_number*/ 123,
/*allow_missing_context*/ false
),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(&["@@", "bad"], /*line_number*/ 123, /*allow_missing_context*/ false),
Err(InvalidHunkError {
message: "Unexpected line found in update hunk: 'bad'. \
Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(
&["@@", "*** End of File"],
/*line_number*/ 123,
/*allow_missing_context*/ false
),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 124
})
);
assert_eq!(
parse_update_file_chunk(
&[
"@@ change_context",
"",
" context",
"-remove",
"+add",
" context2",
"*** End Patch",
],
/*line_number*/ 123,
/*allow_missing_context*/ false
),
Ok((
(UpdateFileChunk {
change_context: Some("change_context".to_string()),
old_lines: vec![
"".to_string(),
"context".to_string(),
"remove".to_string(),
"context2".to_string()
],
new_lines: vec![
"".to_string(),
"context".to_string(),
"add".to_string(),
"context2".to_string()
],
is_end_of_file: false
}),
6
))
);
assert_eq!(
parse_update_file_chunk(
&["@@", "+line", "*** End of File"],
/*line_number*/ 123,
/*allow_missing_context*/ false
),
Ok((
(UpdateFileChunk {
change_context: None,
old_lines: vec![],
new_lines: vec!["line".to_string()],
is_end_of_file: true
}),
3
))
);
}

View File

@@ -0,0 +1,813 @@
use std::path::PathBuf;
use crate::parser::ADD_FILE_MARKER;
use crate::parser::BEGIN_PATCH_MARKER;
use crate::parser::CHANGE_CONTEXT_MARKER;
use crate::parser::DELETE_FILE_MARKER;
use crate::parser::EMPTY_CHANGE_CONTEXT_MARKER;
use crate::parser::END_PATCH_MARKER;
use crate::parser::EOF_MARKER;
use crate::parser::Hunk;
use crate::parser::MOVE_TO_MARKER;
use crate::parser::ParseError;
use crate::parser::UPDATE_FILE_MARKER;
use crate::parser::UpdateFileChunk;
use Hunk::*;
use ParseError::*;
#[derive(Debug, Default, Clone)]
pub struct StreamingPatchParser {
line_buffer: String,
state: StreamingParserState,
line_number: usize,
}
#[derive(Debug, Default, Clone)]
struct StreamingParserState {
mode: StreamingParserMode,
hunks: Vec<Hunk>,
}
#[derive(Debug, Default, Clone)]
enum StreamingParserMode {
#[default]
NotStarted,
StartedPatch,
AddFile,
DeleteFile,
UpdateFile {
hunk_line_number: usize,
},
EndedPatch,
}
impl StreamingPatchParser {
fn ensure_update_hunk_is_not_empty(&self, line: &str) -> Result<(), ParseError> {
if let Some(UpdateFile { path, chunks, .. }) = self.state.hunks.last() {
if chunks.is_empty()
&& let StreamingParserMode::UpdateFile { hunk_line_number } = self.state.mode
{
return Err(InvalidHunkError {
message: format!("Update file hunk for path '{}' is empty", path.display()),
line_number: hunk_line_number,
});
}
if chunks
.last()
.is_some_and(|chunk| chunk.old_lines.is_empty() && chunk.new_lines.is_empty())
{
if line == END_PATCH_MARKER {
return Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: self.line_number,
});
}
return Err(InvalidHunkError {
message: format!(
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
),
line_number: self.line_number,
});
}
}
Ok(())
}
fn handle_hunk_headers_and_end_patch(&mut self, trimmed: &str) -> Result<bool, ParseError> {
if trimmed == END_PATCH_MARKER {
self.ensure_update_hunk_is_not_empty(trimmed)?;
self.state.mode = StreamingParserMode::EndedPatch;
return Ok(true);
}
if let Some(path) = trimmed.strip_prefix(ADD_FILE_MARKER) {
self.ensure_update_hunk_is_not_empty(trimmed)?;
self.state.hunks.push(AddFile {
path: PathBuf::from(path),
contents: String::new(),
});
self.state.mode = StreamingParserMode::AddFile;
return Ok(true);
}
if let Some(path) = trimmed.strip_prefix(DELETE_FILE_MARKER) {
self.ensure_update_hunk_is_not_empty(trimmed)?;
self.state.hunks.push(DeleteFile {
path: PathBuf::from(path),
});
self.state.mode = StreamingParserMode::DeleteFile;
return Ok(true);
}
if let Some(path) = trimmed.strip_prefix(UPDATE_FILE_MARKER) {
self.ensure_update_hunk_is_not_empty(trimmed)?;
self.state.hunks.push(UpdateFile {
path: PathBuf::from(path),
move_path: None,
chunks: Vec::new(),
});
self.state.mode = StreamingParserMode::UpdateFile {
hunk_line_number: self.line_number,
};
return Ok(true);
}
Ok(false)
}
pub fn push_delta(&mut self, delta: &str) -> Result<Vec<Hunk>, ParseError> {
for ch in delta.chars() {
if ch == '\n' {
let mut line = std::mem::take(&mut self.line_buffer);
line.truncate(line.strip_suffix('\r').map_or(line.len(), str::len));
self.line_number += 1;
self.process_line(&line)?;
} else {
self.line_buffer.push(ch);
}
}
Ok(self.state.hunks.clone())
}
pub fn finish(&mut self) -> Result<Vec<Hunk>, ParseError> {
if !self.line_buffer.is_empty() {
let line = std::mem::take(&mut self.line_buffer);
self.line_number += 1;
if line.trim() == END_PATCH_MARKER {
self.ensure_update_hunk_is_not_empty(line.trim())?;
self.state.mode = StreamingParserMode::EndedPatch;
} else {
self.process_line(&line)?;
}
}
if !matches!(self.state.mode, StreamingParserMode::EndedPatch) {
return Err(InvalidPatchError(
"The last line of the patch must be '*** End Patch'".to_string(),
));
}
Ok(self.state.hunks.clone())
}
fn process_line(&mut self, line: &str) -> Result<(), ParseError> {
let trimmed = line.trim();
match self.state.mode.clone() {
StreamingParserMode::NotStarted => {
if trimmed == BEGIN_PATCH_MARKER {
self.state.mode = StreamingParserMode::StartedPatch;
return Ok(());
}
Err(InvalidPatchError(
"The first line of the patch must be '*** Begin Patch'".to_string(),
))
}
StreamingParserMode::StartedPatch => {
if self.handle_hunk_headers_and_end_patch(trimmed)? {
return Ok(());
}
Err(InvalidHunkError {
message: format!(
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
),
line_number: self.line_number,
})
}
StreamingParserMode::AddFile => {
if self.handle_hunk_headers_and_end_patch(trimmed)? {
return Ok(());
}
if let Some(line_to_add) = line.strip_prefix('+')
&& let Some(AddFile { contents, .. }) = self.state.hunks.last_mut()
{
contents.push_str(line_to_add);
contents.push('\n');
return Ok(());
}
Err(InvalidHunkError {
message: format!(
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
),
line_number: self.line_number,
})
}
StreamingParserMode::DeleteFile => {
if self.handle_hunk_headers_and_end_patch(trimmed)? {
return Ok(());
}
Err(InvalidHunkError {
message: format!(
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
),
line_number: self.line_number,
})
}
StreamingParserMode::UpdateFile { hunk_line_number } => {
let update_line = line.trim_end();
if self.handle_hunk_headers_and_end_patch(update_line)? {
return Ok(());
}
if let Some(UpdateFile {
move_path, chunks, ..
}) = self.state.hunks.last_mut()
{
if chunks.is_empty()
&& move_path.is_none()
&& let Some(move_to_path) = update_line.strip_prefix(MOVE_TO_MARKER)
{
*move_path = Some(PathBuf::from(move_to_path));
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if (update_line == EMPTY_CHANGE_CONTEXT_MARKER
|| update_line.starts_with(CHANGE_CONTEXT_MARKER))
&& chunks.last().is_some_and(|chunk| {
chunk.old_lines.is_empty() && chunk.new_lines.is_empty()
})
{
return Err(InvalidHunkError {
message: format!(
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
),
line_number: self.line_number,
});
}
if update_line == EMPTY_CHANGE_CONTEXT_MARKER {
chunks.push(UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if let Some(change_context) = update_line.strip_prefix(CHANGE_CONTEXT_MARKER) {
chunks.push(UpdateFileChunk {
change_context: Some(change_context.to_string()),
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if update_line == EOF_MARKER {
if chunks.last().is_some_and(|chunk| {
chunk.old_lines.is_empty() && chunk.new_lines.is_empty()
}) {
return Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: self.line_number,
});
}
if let Some(chunk) = chunks.last_mut() {
chunk.is_end_of_file = true;
}
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if line.is_empty() {
if chunks.is_empty() {
chunks.push(UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
}
if let Some(chunk) = chunks.last_mut() {
chunk.old_lines.push(String::new());
chunk.new_lines.push(String::new());
}
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if let Some(line_to_add) = line.strip_prefix(' ') {
if chunks.is_empty() {
chunks.push(UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
}
if let Some(chunk) = chunks.last_mut() {
chunk.old_lines.push(line_to_add.to_string());
chunk.new_lines.push(line_to_add.to_string());
}
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if let Some(line_to_add) = line.strip_prefix('+') {
if chunks.is_empty() {
chunks.push(UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
}
if let Some(chunk) = chunks.last_mut() {
chunk.new_lines.push(line_to_add.to_string());
}
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if let Some(line_to_remove) = line.strip_prefix('-') {
if chunks.is_empty() {
chunks.push(UpdateFileChunk {
change_context: None,
old_lines: Vec::new(),
new_lines: Vec::new(),
is_end_of_file: false,
});
}
if let Some(chunk) = chunks.last_mut() {
chunk.old_lines.push(line_to_remove.to_string());
}
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
return Ok(());
}
if chunks.last().is_some_and(|chunk| {
!chunk.old_lines.is_empty() || !chunk.new_lines.is_empty()
}) {
return Err(InvalidHunkError {
message: format!(
"Expected update hunk to start with a @@ context marker, got: '{line}'"
),
line_number: self.line_number,
});
}
}
Err(InvalidHunkError {
message: format!(
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
),
line_number: self.line_number,
})
}
StreamingParserMode::EndedPatch => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::path::PathBuf;
use super::*;
#[test]
fn test_streaming_patch_parser_streams_complete_lines_before_end_patch() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor"),
Ok(vec![AddFile {
path: PathBuf::from("src/hello.txt"),
contents: "hello\n".to_string(),
}])
);
assert_eq!(
parser.push_delta("ld\n"),
Ok(vec![AddFile {
path: PathBuf::from("src/hello.txt"),
contents: "hello\nworld\n".to_string(),
}])
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new\n",
),
Ok(vec![UpdateFile {
path: PathBuf::from("src/old.rs"),
move_path: Some(PathBuf::from("src/new.rs")),
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}])
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Delete File: gone.txt"),
Ok(Vec::new())
);
assert_eq!(
parser.push_delta("\n"),
Ok(vec![DeleteFile {
path: PathBuf::from("gone.txt"),
}])
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt\n",
),
Ok(vec![
AddFile {
path: PathBuf::from("src/one.txt"),
contents: "one\n".to_string(),
},
DeleteFile {
path: PathBuf::from("src/two.txt"),
},
])
);
}
#[test]
fn test_streaming_patch_parser_large_patch_split_by_character() {
let patch = "\
*** Begin Patch
*** Add File: docs/release-notes.md
+# Release notes
+
+## CLI
+- Surface apply_patch progress while arguments stream.
+- Keep final patch application gated on the completed tool call.
+- Include file summaries in the progress event payload.
*** Update File: src/config.rs
@@ impl Config
- pub apply_patch_progress: bool,
+ pub stream_apply_patch_progress: bool,
pub include_diagnostics: bool,
@@ fn default_progress_interval()
- Duration::from_millis(500)
+ Duration::from_millis(250)
*** Delete File: src/legacy_patch_progress.rs
*** Update File: crates/cli/src/main.rs
*** Move to: crates/cli/src/bin/codex.rs
@@ fn run()
- let args = Args::parse();
- dispatch(args)
+ let cli = Cli::parse();
+ dispatch(cli)
*** Add File: tests/fixtures/apply_patch_progress.json
+{
+ \"type\": \"apply_patch_progress\",
+ \"hunks\": [
+ { \"operation\": \"add\", \"path\": \"docs/release-notes.md\" },
+ { \"operation\": \"update\", \"path\": \"src/config.rs\" }
+ ]
+}
*** Update File: README.md
@@ Development workflow
Build the Rust workspace before opening a pull request.
+When touching streamed tool calls, include parser coverage for partial input.
+Prefer tests that exercise the exact event payload shape.
*** Delete File: docs/old-apply-patch-progress.md
*** End Patch";
let mut parser = StreamingPatchParser::default();
let mut max_hunk_count = 0;
let mut saw_hunk_counts = Vec::new();
let mut hunks = Vec::new();
for ch in patch.chars() {
let updated_hunks = parser.push_delta(&ch.to_string()).unwrap();
if !updated_hunks.is_empty() {
let hunk_count = updated_hunks.len();
assert!(
hunk_count >= max_hunk_count,
"hunk count should never decrease while streaming: {hunk_count} < {max_hunk_count}",
);
if hunk_count > max_hunk_count {
saw_hunk_counts.push(hunk_count);
max_hunk_count = hunk_count;
}
hunks = updated_hunks;
}
}
assert_eq!(saw_hunk_counts, vec![1, 2, 3, 4, 5, 6, 7]);
assert_eq!(hunks.len(), 7);
assert_eq!(
hunks
.iter()
.map(|hunk| match hunk {
AddFile { .. } => "add",
DeleteFile { .. } => "delete",
UpdateFile {
move_path: Some(_), ..
} => "move-update",
UpdateFile {
move_path: None, ..
} => "update",
})
.collect::<Vec<_>>(),
vec![
"add",
"update",
"delete",
"move-update",
"add",
"update",
"delete"
]
);
}
#[test]
fn test_streaming_patch_parser_keeps_indented_update_markers_as_context_lines() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"\
*** Begin Patch
*** Update File: a.txt
@@
-old a
+new a
*** Update File: b.txt
@@
-old b
+new b
*** End Patch
",
),
Ok(vec![UpdateFile {
path: PathBuf::from("a.txt"),
move_path: None,
chunks: vec![
UpdateFileChunk {
change_context: None,
old_lines: vec!["old a".to_string(), "*** Update File: b.txt".to_string()],
new_lines: vec!["new a".to_string(), "*** Update File: b.txt".to_string()],
is_end_of_file: false,
},
UpdateFileChunk {
change_context: None,
old_lines: vec!["old b".to_string()],
new_lines: vec!["new b".to_string()],
is_end_of_file: false,
},
],
}])
);
}
#[test]
fn test_streaming_patch_parser_preserves_bare_empty_update_lines() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"\
*** Begin Patch
*** Update File: file.txt
@@
context before
context after
*** End Patch
",
),
Ok(vec![UpdateFile {
path: PathBuf::from("file.txt"),
move_path: None,
chunks: vec![UpdateFileChunk {
change_context: None,
// The normal parser treats a bare empty line in an update hunk as an
// empty context line. Preserve that leniency in the streaming parser.
old_lines: vec![
"context before".to_string(),
String::new(),
"context after".to_string(),
],
new_lines: vec![
"context before".to_string(),
String::new(),
"context after".to_string(),
],
is_end_of_file: false,
}],
}])
);
}
#[test]
fn test_streaming_patch_parser_matches_line_ending_behavior() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\r\n*** Update File: file.txt\r\n@@\r\n-old\r\n+new\r\n*** End Patch\r\n"),
Ok(vec![UpdateFile {
path: PathBuf::from("file.txt"),
move_path: None,
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}])
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\r\n*** Update File: file.txt\r\n@@\r\n-old\r\r\n+new\r\n*** End Patch\r\n"),
Ok(vec![UpdateFile {
path: PathBuf::from("file.txt"),
move_path: None,
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old\r".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}])
);
}
#[test]
fn test_streaming_patch_parser_finish_processes_final_line_without_newline() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\n+hello\n*** End Patch"),
Ok(vec![AddFile {
path: PathBuf::from("file.txt"),
contents: "hello\n".to_string(),
}])
);
assert_eq!(
parser.finish(),
Ok(vec![AddFile {
path: PathBuf::from("file.txt"),
contents: "hello\n".to_string(),
}])
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"*** Begin Patch\n*** Update File: file.txt\n@@\n-old\n+new\n *** End Patch",
),
Ok(vec![UpdateFile {
path: PathBuf::from("file.txt"),
move_path: None,
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}])
);
assert_eq!(
parser.finish(),
Ok(vec![UpdateFile {
path: PathBuf::from("file.txt"),
move_path: None,
chunks: vec![UpdateFileChunk {
change_context: None,
old_lines: vec!["old".to_string()],
new_lines: vec!["new".to_string()],
is_end_of_file: false,
}],
}])
);
}
#[test]
fn test_streaming_patch_parser_finish_requires_end_patch() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\n+hello\n"),
Ok(vec![AddFile {
path: PathBuf::from("file.txt"),
contents: "hello\n".to_string(),
}])
);
assert_eq!(
parser.finish(),
Err(InvalidPatchError(
"The last line of the patch must be '*** End Patch'".to_string(),
))
);
}
#[test]
fn test_streaming_patch_parser_returns_errors() {
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("bad\n"),
Err(InvalidPatchError(
"The first line of the patch must be '*** Begin Patch'".to_string(),
))
);
let mut parser = StreamingPatchParser::default();
assert_eq!(parser.push_delta("*** Begin Patch\n"), Ok(Vec::new()));
assert_eq!(
parser.push_delta("bad\n"),
Err(InvalidHunkError {
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
.to_string(),
line_number: 2,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\nbad\n"),
Err(InvalidHunkError {
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
.to_string(),
line_number: 3,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Delete File: file.txt\nbad\n"),
Err(InvalidHunkError {
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
.to_string(),
line_number: 3,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n*** End Patch\n"),
Err(InvalidHunkError {
message: "Update file hunk for path 'file.txt' is empty".to_string(),
line_number: 2,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"*** Begin Patch\n*** Update File: old.txt\n*** Move to: new.txt\n*** Delete File: other.txt\n",
),
Err(InvalidHunkError {
message: "Update file hunk for path 'old.txt' is empty".to_string(),
line_number: 2,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n*** End Patch\n"),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 4,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n*** End of File\n"),
Err(InvalidHunkError {
message: "Update hunk does not contain any lines".to_string(),
line_number: 4,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n@@\n"),
Err(InvalidHunkError {
message: "Unexpected line found in update hunk: '@@'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
.to_string(),
line_number: 4,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n-old\nbad\n"),
Err(InvalidHunkError {
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
.to_string(),
line_number: 5,
})
);
let mut parser = StreamingPatchParser::default();
assert_eq!(
parser.push_delta(
"*** Begin Patch\n*** Update File: file.txt\n@@\n*** Update File: other.txt\n",
),
Err(InvalidHunkError {
message: "Unexpected line found in update hunk: '*** Update File: other.txt'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
.to_string(),
line_number: 4,
})
);
}
}

View File

@@ -97,6 +97,11 @@ impl AwsAuthContext {
&self.service
}
pub async fn preload_credentials(&self) -> Result<(), AwsAuthError> {
let _ = self.credentials_provider.provide_credentials().await?;
Ok(())
}
pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
self.sign_at(request, SystemTime::now()).await
}
@@ -202,6 +207,14 @@ mod tests {
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
}
#[tokio::test]
async fn preload_credentials_resolves_provider() {
test_context(/*session_token*/ None)
.preload_credentials()
.await
.expect("static credentials should resolve");
}
#[test]
fn credentials_provider_failures_are_retryable() {
assert!(

View File

@@ -673,6 +673,15 @@ impl ModelClient {
true
}
/// Resolves provider credentials during session startup when the provider requests it.
pub(crate) async fn prewarm_provider_auth(&self) -> Result<()> {
if !self.state.provider.prewarms_auth_on_startup() {
return Ok(());
}
self.state.provider.prewarm_auth().await
}
/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in

View File

@@ -82,6 +82,7 @@ pub(crate) mod mentions {
mod sandbox_tags;
pub mod sandboxing;
mod session_prefix;
mod session_startup_auth_prewarm;
mod session_startup_prewarm;
mod shell_detect;
pub mod skills;

View File

@@ -980,6 +980,7 @@ impl Session {
anyhow::bail!("required MCP servers failed to initialize: {details}");
}
}
sess.schedule_startup_auth_prewarm().await;
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
let session_start_source = match &initial_history {

View File

@@ -1902,7 +1902,7 @@ async fn try_run_sampling_request(
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
if let Some((_, mut consumer)) = active_tool_argument_diff_consumer.take()
&& let Some(event) = consumer.flush_on_complete()
&& let Ok(Some(event)) = consumer.finish()
{
sess.send_event(&turn_context, event).await;
}

View File

@@ -0,0 +1,16 @@
use std::sync::Arc;
use tracing::debug;
use crate::session::session::Session;
impl Session {
pub(crate) async fn schedule_startup_auth_prewarm(self: &Arc<Self>) {
let model_client = self.services.model_client.clone();
tokio::spawn(async move {
if let Err(err) = model_client.prewarm_provider_auth().await {
debug!("startup provider auth prewarm failed: {err:#}");
}
});
}
}

View File

@@ -33,10 +33,9 @@ use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
use crate::tools::sandboxing::ToolCtx;
use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::ApplyPatchArgs;
use codex_apply_patch::ApplyPatchFileChange;
use codex_apply_patch::Hunk;
use codex_apply_patch::parse_patch_streaming;
use codex_apply_patch::StreamingPatchParser;
use codex_exec_server::ExecutorFileSystem;
use codex_features::Feature;
use codex_protocol::models::AdditionalPermissionProfile;
@@ -56,8 +55,7 @@ pub struct ApplyPatchHandler;
#[derive(Default)]
struct ApplyPatchArgumentDiffConsumer {
input: String,
last_progress: Option<Vec<Hunk>>,
parser: StreamingPatchParser,
last_sent_at: Option<Instant>,
pending: Option<PatchApplyUpdatedEvent>,
}
@@ -77,26 +75,19 @@ impl ToolArgumentDiffConsumer for ApplyPatchArgumentDiffConsumer {
.map(EventMsg::PatchApplyUpdated)
}
fn flush_on_complete(&mut self) -> Option<EventMsg> {
self.flush_update_on_complete()
.map(EventMsg::PatchApplyUpdated)
fn finish(&mut self) -> Result<Option<EventMsg>, FunctionCallError> {
self.finish_update_on_complete()
.map(|event| event.map(EventMsg::PatchApplyUpdated))
}
}
impl ApplyPatchArgumentDiffConsumer {
fn push_delta(&mut self, call_id: String, delta: &str) -> Option<PatchApplyUpdatedEvent> {
self.input.push_str(delta);
let ApplyPatchArgs { hunks, .. } = parse_patch_streaming(&self.input).ok()?;
let hunks = self.parser.push_delta(delta).ok()?;
if hunks.is_empty() {
return None;
}
if self.last_progress.as_ref() == Some(&hunks) {
return None;
}
let changes = convert_apply_patch_hunks_to_protocol(&hunks);
self.last_progress = Some(hunks);
let event = PatchApplyUpdatedEvent { call_id, changes };
let now = Instant::now();
match self.last_sent_at {
@@ -114,12 +105,18 @@ impl ApplyPatchArgumentDiffConsumer {
}
}
fn flush_update_on_complete(&mut self) -> Option<PatchApplyUpdatedEvent> {
fn finish_update_on_complete(
&mut self,
) -> Result<Option<PatchApplyUpdatedEvent>, FunctionCallError> {
self.parser.finish().map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to parse apply_patch: {err}"))
})?;
let event = self.pending.take();
if event.is_some() {
self.last_sent_at = Some(Instant::now());
}
event
Ok(event)
}
}

View File

@@ -136,7 +136,7 @@ fn diff_consumer_streams_apply_patch_changes() {
HashMap::from([(
PathBuf::from("hello.txt"),
FileChange::Add {
content: "hello\n".to_string(),
content: String::new(),
},
)]),
)
@@ -147,8 +147,16 @@ fn diff_consumer_streams_apply_patch_changes() {
.push_delta("call-1".to_string(), "\n+world")
.is_none()
);
assert!(
consumer
.push_delta("call-1".to_string(), "\n*** End Patch")
.is_none()
);
let event = consumer.flush_update_on_complete().expect("progress event");
let event = consumer
.finish_update_on_complete()
.expect("finish parser")
.expect("progress event");
assert_eq!(
(event.call_id, event.changes),
(
@@ -175,7 +183,7 @@ fn diff_consumer_sends_next_update_after_buffer_interval() {
HashMap::from([(
PathBuf::from("hello.txt"),
FileChange::Add {
content: "hello\n".to_string(),
content: String::new(),
},
)])
);
@@ -190,7 +198,7 @@ fn diff_consumer_sends_next_update_after_buffer_interval() {
HashMap::from([(
PathBuf::from("hello.txt"),
FileChange::Add {
content: "hello\nworld\n".to_string(),
content: "hello\n".to_string(),
},
)])
);

View File

@@ -98,9 +98,9 @@ pub(crate) trait ToolArgumentDiffConsumer: Send {
fn consume_diff(&mut self, turn: &TurnContext, call_id: String, diff: &str)
-> Option<EventMsg>;
/// Flush any buffered event before the tool call completes.
fn flush_on_complete(&mut self) -> Option<EventMsg> {
None
/// Finish consuming argument diffs before the tool call completes.
fn finish(&mut self) -> Result<Option<EventMsg>, FunctionCallError> {
Ok(None)
}
}

View File

@@ -1027,7 +1027,7 @@ async fn apply_patch_custom_tool_streaming_emits_updated_changes() -> Result<()>
.changes
.get(&std::path::PathBuf::from("streamed.txt")),
Some(&codex_protocol::protocol::FileChange::Add {
content: "hello\n".to_string(),
content: String::new(),
})
);
assert_eq!(

View File

@@ -129,7 +129,7 @@ pub struct ShutdownHandle {
impl ShutdownHandle {
/// Signals the login loop to terminate.
pub fn shutdown(&self) {
self.shutdown_notify.notify_waiters();
self.shutdown_notify.notify_one();
}
}

View File

@@ -22,6 +22,7 @@ use super::mantle::region_from_config;
const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK";
const LEGACY_SESSION_ID_HEADER: &str = "session_id";
#[derive(Clone, Debug)]
pub(super) enum BedrockAuthMethod {
EnvBearerToken { token: String, region: String },
AwsSdkAuth { context: AwsAuthContext },
@@ -42,17 +43,25 @@ pub(super) async fn resolve_auth_method(
Ok(BedrockAuthMethod::AwsSdkAuth { context })
}
pub(super) async fn resolve_provider_auth(
aws: &ModelProviderAwsAuthInfo,
) -> Result<SharedAuthProvider> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider {
pub(super) async fn prewarm_credentials(auth_method: &BedrockAuthMethod) -> Result<()> {
match auth_method {
BedrockAuthMethod::EnvBearerToken { .. } => Ok(()),
BedrockAuthMethod::AwsSdkAuth { context } => context
.preload_credentials()
.await
.map_err(aws_auth_error_to_codex_error),
}
}
pub(super) fn provider_auth_from_method(auth_method: BedrockAuthMethod) -> SharedAuthProvider {
match auth_method {
BedrockAuthMethod::EnvBearerToken { token, .. } => Arc::new(BearerAuthProvider {
token: Some(token),
account_id: None,
is_fedramp_account: false,
})),
}),
BedrockAuthMethod::AwsSdkAuth { context } => {
Ok(Arc::new(BedrockMantleSigV4AuthProvider::new(context)))
Arc::new(BedrockMantleSigV4AuthProvider::new(context))
}
}
}

View File

@@ -4,7 +4,6 @@ use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use super::auth::BedrockAuthMethod;
use super::auth::resolve_auth_method;
const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle";
const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [
@@ -48,16 +47,15 @@ pub(super) fn base_url(region: &str) -> Result<String> {
}
}
pub(super) async fn runtime_base_url(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
let region = resolve_region(aws).await?;
base_url(&region)
pub(super) fn region_from_auth_method(auth_method: &BedrockAuthMethod) -> String {
match auth_method {
BedrockAuthMethod::EnvBearerToken { region, .. } => region.clone(),
BedrockAuthMethod::AwsSdkAuth { context } => context.region().to_string(),
}
}
async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region),
BedrockAuthMethod::AwsSdkAuth { context } => Ok(context.region().to_string()),
}
pub(super) fn runtime_base_url_from_auth_method(auth_method: &BedrockAuthMethod) -> Result<String> {
base_url(&region_from_auth_method(auth_method))
}
#[cfg(test)]

View File

@@ -16,20 +16,26 @@ use codex_models_manager::manager::StaticModelsManager;
use codex_protocol::account::ProviderAccount;
use codex_protocol::error::Result;
use codex_protocol::openai_models::ModelsResponse;
use tokio::sync::OnceCell;
use crate::provider::ModelProvider;
use crate::provider::ProviderAccountResult;
use crate::provider::ProviderAccountState;
use crate::provider::ProviderCapabilities;
use auth::resolve_provider_auth;
use auth::BedrockAuthMethod;
use auth::prewarm_credentials;
use auth::provider_auth_from_method;
use auth::resolve_auth_method;
pub(crate) use catalog::static_model_catalog;
use mantle::runtime_base_url;
use mantle::runtime_base_url_from_auth_method;
/// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint.
#[derive(Clone, Debug)]
pub(crate) struct AmazonBedrockModelProvider {
pub(crate) info: ModelProviderInfo,
pub(crate) aws: ModelProviderAwsAuthInfo,
auth_method: Arc<OnceCell<BedrockAuthMethod>>,
credentials_prewarmed: Arc<OnceCell<()>>,
}
impl AmazonBedrockModelProvider {
@@ -44,8 +50,25 @@ impl AmazonBedrockModelProvider {
Self {
info: provider_info,
aws,
auth_method: Arc::new(OnceCell::new()),
credentials_prewarmed: Arc::new(OnceCell::new()),
}
}
async fn auth_method(&self) -> Result<BedrockAuthMethod> {
self.auth_method
.get_or_try_init(|| resolve_auth_method(&self.aws))
.await
.cloned()
}
async fn prewarm_bedrock_credentials(&self) -> Result<()> {
let auth_method = self.auth_method().await?;
self.credentials_prewarmed
.get_or_try_init(|| async move { prewarm_credentials(&auth_method).await })
.await?;
Ok(())
}
}
#[async_trait::async_trait]
@@ -70,6 +93,14 @@ impl ModelProvider for AmazonBedrockModelProvider {
None
}
fn prewarms_auth_on_startup(&self) -> bool {
true
}
async fn prewarm_auth(&self) -> Result<()> {
self.prewarm_bedrock_credentials().await
}
fn account_state(&self) -> ProviderAccountResult {
Ok(ProviderAccountState {
account: Some(ProviderAccount::AmazonBedrock),
@@ -79,16 +110,20 @@ impl ModelProvider for AmazonBedrockModelProvider {
async fn api_provider(&self) -> Result<Provider> {
let mut api_provider_info = self.info.clone();
api_provider_info.base_url = Some(runtime_base_url(&self.aws).await?);
api_provider_info.base_url = Some(runtime_base_url_from_auth_method(
&self.auth_method().await?,
)?);
api_provider_info.to_api_provider(/*auth_mode*/ None)
}
async fn runtime_base_url(&self) -> Result<Option<String>> {
Ok(Some(runtime_base_url(&self.aws).await?))
Ok(Some(runtime_base_url_from_auth_method(
&self.auth_method().await?,
)?))
}
async fn api_auth(&self) -> Result<SharedAuthProvider> {
resolve_provider_auth(&self.aws).await
Ok(provider_auth_from_method(self.auth_method().await?))
}
fn models_manager(

View File

@@ -96,6 +96,16 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
/// Returns the current provider-scoped auth value, if one is configured.
async fn auth(&self) -> Option<CodexAuth>;
/// Returns whether this provider should resolve request credentials during session startup.
fn prewarms_auth_on_startup(&self) -> bool {
false
}
/// Resolves provider credentials before the first model request when startup prewarm is enabled.
async fn prewarm_auth(&self) -> codex_protocol::error::Result<()> {
Ok(())
}
/// Returns the current app-visible account state for this provider.
fn account_state(&self) -> ProviderAccountResult;