mirror of
https://github.com/openai/codex.git
synced 2026-05-09 13:52:41 +00:00
Compare commits
4 Commits
xli-codex/
...
dev/cc/bro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09637409b0 | ||
|
|
79908b64a1 | ||
|
|
6014b6679f | ||
|
|
8426edf71e |
@@ -2,6 +2,7 @@ mod invocation;
|
||||
mod parser;
|
||||
mod seek_sequence;
|
||||
mod standalone_executable;
|
||||
mod streaming_parser;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
@@ -20,8 +21,8 @@ pub use parser::ParseError;
|
||||
use parser::ParseError::*;
|
||||
pub use parser::UpdateFileChunk;
|
||||
pub use parser::parse_patch;
|
||||
pub use parser::parse_patch_streaming;
|
||||
use similar::TextDiff;
|
||||
pub use streaming_parser::StreamingPatchParser;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use invocation::maybe_parse_apply_patch_verified;
|
||||
|
||||
@@ -31,15 +31,15 @@ use std::path::PathBuf;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
const BEGIN_PATCH_MARKER: &str = "*** Begin Patch";
|
||||
const END_PATCH_MARKER: &str = "*** End Patch";
|
||||
const ADD_FILE_MARKER: &str = "*** Add File: ";
|
||||
const DELETE_FILE_MARKER: &str = "*** Delete File: ";
|
||||
const UPDATE_FILE_MARKER: &str = "*** Update File: ";
|
||||
const MOVE_TO_MARKER: &str = "*** Move to: ";
|
||||
const EOF_MARKER: &str = "*** End of File";
|
||||
const CHANGE_CONTEXT_MARKER: &str = "@@ ";
|
||||
const EMPTY_CHANGE_CONTEXT_MARKER: &str = "@@";
|
||||
pub(crate) const BEGIN_PATCH_MARKER: &str = "*** Begin Patch";
|
||||
pub(crate) const END_PATCH_MARKER: &str = "*** End Patch";
|
||||
pub(crate) const ADD_FILE_MARKER: &str = "*** Add File: ";
|
||||
pub(crate) const DELETE_FILE_MARKER: &str = "*** Delete File: ";
|
||||
pub(crate) const UPDATE_FILE_MARKER: &str = "*** Update File: ";
|
||||
pub(crate) const MOVE_TO_MARKER: &str = "*** Move to: ";
|
||||
pub(crate) const EOF_MARKER: &str = "*** End of File";
|
||||
pub(crate) const CHANGE_CONTEXT_MARKER: &str = "@@ ";
|
||||
pub(crate) const EMPTY_CHANGE_CONTEXT_MARKER: &str = "@@";
|
||||
|
||||
/// Currently, the only OpenAI model that knowingly requires lenient parsing is
|
||||
/// gpt-4.1. While we could try to require everyone to pass in a strictness
|
||||
@@ -132,14 +132,6 @@ pub fn parse_patch(patch: &str) -> Result<ApplyPatchArgs, ParseError> {
|
||||
parse_patch_text(patch, mode)
|
||||
}
|
||||
|
||||
/// Parses streamed patch text that may not have reached `*** End Patch` yet.
|
||||
///
|
||||
/// This entry point is for progress reporting only; callers must not use its
|
||||
/// output to apply a patch.
|
||||
pub fn parse_patch_streaming(patch: &str) -> Result<ApplyPatchArgs, ParseError> {
|
||||
parse_patch_text(patch, ParseMode::Streaming)
|
||||
}
|
||||
|
||||
enum ParseMode {
|
||||
/// Parse the patch text argument as is.
|
||||
Strict,
|
||||
@@ -177,12 +169,6 @@ enum ParseMode {
|
||||
/// `<<'EOF'` and ends with `EOF\n`. If so, we strip off these markers,
|
||||
/// trim() the result, and treat what is left as the patch text.
|
||||
Lenient,
|
||||
|
||||
/// Parse partial patch text for progress reporting while the model is
|
||||
/// still streaming tool input. This mode requires a begin marker but does
|
||||
/// not require an end marker, and its output must not be used to apply a
|
||||
/// patch.
|
||||
Streaming,
|
||||
}
|
||||
|
||||
fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, ParseError> {
|
||||
@@ -190,15 +176,13 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
|
||||
let (patch_lines, hunk_lines) = match mode {
|
||||
ParseMode::Strict => check_patch_boundaries_strict(&lines)?,
|
||||
ParseMode::Lenient => check_patch_boundaries_lenient(&lines)?,
|
||||
ParseMode::Streaming => check_patch_boundaries_streaming(&lines)?,
|
||||
};
|
||||
|
||||
let mut hunks: Vec<Hunk> = Vec::new();
|
||||
let mut remaining_lines = hunk_lines;
|
||||
let mut line_number = 2;
|
||||
let allow_incomplete = matches!(mode, ParseMode::Streaming);
|
||||
while !remaining_lines.is_empty() {
|
||||
let (hunk, hunk_lines) = parse_one_hunk(remaining_lines, line_number, allow_incomplete)?;
|
||||
let (hunk, hunk_lines) = parse_one_hunk(remaining_lines, line_number)?;
|
||||
hunks.push(hunk);
|
||||
line_number += hunk_lines;
|
||||
remaining_lines = &remaining_lines[hunk_lines..]
|
||||
@@ -211,25 +195,6 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
|
||||
})
|
||||
}
|
||||
|
||||
fn check_patch_boundaries_streaming<'a>(
|
||||
original_lines: &'a [&'a str],
|
||||
) -> Result<(&'a [&'a str], &'a [&'a str]), ParseError> {
|
||||
match original_lines {
|
||||
[first, ..] if first.trim() == BEGIN_PATCH_MARKER => {
|
||||
let body_lines = if original_lines
|
||||
.last()
|
||||
.is_some_and(|line| line.trim() == END_PATCH_MARKER)
|
||||
{
|
||||
&original_lines[1..original_lines.len() - 1]
|
||||
} else {
|
||||
&original_lines[1..]
|
||||
};
|
||||
Ok((original_lines, body_lines))
|
||||
}
|
||||
_ => check_patch_boundaries_strict(original_lines),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks the start and end lines of the patch text for `apply_patch`,
|
||||
/// returning an error if they do not match the expected markers.
|
||||
fn check_patch_boundaries_strict<'a>(
|
||||
@@ -297,15 +262,9 @@ fn check_start_and_end_lines_strict(
|
||||
|
||||
/// Attempts to parse a single hunk from the start of lines.
|
||||
/// Returns the parsed hunk and the number of lines parsed (or a ParseError).
|
||||
fn parse_one_hunk(
|
||||
lines: &[&str],
|
||||
line_number: usize,
|
||||
allow_incomplete: bool,
|
||||
) -> Result<(Hunk, usize), ParseError> {
|
||||
// Be tolerant of case mismatches and extra padding around marker strings.
|
||||
fn parse_one_hunk(lines: &[&str], line_number: usize) -> Result<(Hunk, usize), ParseError> {
|
||||
let first_line = lines[0].trim();
|
||||
if let Some(path) = first_line.strip_prefix(ADD_FILE_MARKER) {
|
||||
// Add File
|
||||
let mut contents = String::new();
|
||||
let mut parsed_lines = 1;
|
||||
for add_line in &lines[1..] {
|
||||
@@ -325,7 +284,6 @@ fn parse_one_hunk(
|
||||
parsed_lines,
|
||||
));
|
||||
} else if let Some(path) = first_line.strip_prefix(DELETE_FILE_MARKER) {
|
||||
// Delete File
|
||||
return Ok((
|
||||
DeleteFile {
|
||||
path: PathBuf::from(path),
|
||||
@@ -333,11 +291,8 @@ fn parse_one_hunk(
|
||||
1,
|
||||
));
|
||||
} else if let Some(path) = first_line.strip_prefix(UPDATE_FILE_MARKER) {
|
||||
// Update File
|
||||
let mut remaining_lines = &lines[1..];
|
||||
let mut parsed_lines = 1;
|
||||
|
||||
// Optional: move file line
|
||||
let move_path = remaining_lines
|
||||
.first()
|
||||
.and_then(|x| x.strip_prefix(MOVE_TO_MARKER));
|
||||
@@ -348,9 +303,7 @@ fn parse_one_hunk(
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
// NOTE: we need to know to stop once we reach the next special marker header.
|
||||
while !remaining_lines.is_empty() {
|
||||
// Skip over any completely blank lines that may separate chunks.
|
||||
if remaining_lines[0].trim().is_empty() {
|
||||
parsed_lines += 1;
|
||||
remaining_lines = &remaining_lines[1..];
|
||||
@@ -361,22 +314,11 @@ fn parse_one_hunk(
|
||||
break;
|
||||
}
|
||||
|
||||
if allow_incomplete && remaining_lines[0] == "@" {
|
||||
break;
|
||||
}
|
||||
|
||||
let parsed_chunk = parse_update_file_chunk(
|
||||
let (chunk, chunk_lines) = parse_update_file_chunk(
|
||||
remaining_lines,
|
||||
line_number + parsed_lines,
|
||||
chunks.is_empty(),
|
||||
);
|
||||
let (chunk, chunk_lines) = match parsed_chunk {
|
||||
Ok(parsed) => parsed,
|
||||
Err(InvalidHunkError { .. }) if allow_incomplete && !chunks.is_empty() => {
|
||||
break;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
)?;
|
||||
chunks.push(chunk);
|
||||
parsed_lines += chunk_lines;
|
||||
remaining_lines = &remaining_lines[chunk_lines..]
|
||||
@@ -384,7 +326,10 @@ fn parse_one_hunk(
|
||||
|
||||
if chunks.is_empty() {
|
||||
return Err(InvalidHunkError {
|
||||
message: format!("Update file hunk for path '{path}' is empty"),
|
||||
message: format!(
|
||||
"Update file hunk for path '{}' is empty",
|
||||
Path::new(path).display()
|
||||
),
|
||||
line_number,
|
||||
});
|
||||
}
|
||||
@@ -418,8 +363,6 @@ fn parse_update_file_chunk(
|
||||
line_number,
|
||||
});
|
||||
}
|
||||
// If we see an explicit context marker @@ or @@ <context>, consume it; otherwise, optionally
|
||||
// allow treating the chunk as starting directly with diff lines.
|
||||
let (change_context, start_index) = if lines[0] == EMPTY_CHANGE_CONTEXT_MARKER {
|
||||
(None, 1)
|
||||
} else if let Some(context) = lines[0].strip_prefix(CHANGE_CONTEXT_MARKER) {
|
||||
@@ -501,162 +444,113 @@ fn parse_update_file_chunk(
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_patch_streaming() {
|
||||
fn test_parse_one_hunk() {
|
||||
assert_eq!(
|
||||
parse_patch_streaming("*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor"),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: vec![AddFile {
|
||||
path: PathBuf::from("src/hello.txt"),
|
||||
contents: "hello\nwor\n".to_string(),
|
||||
}],
|
||||
patch: "*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor".to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parse_patch_streaming(
|
||||
"*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new",
|
||||
),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: vec![UpdateFile {
|
||||
path: PathBuf::from("src/old.rs"),
|
||||
move_path: Some(PathBuf::from("src/new.rs")),
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}],
|
||||
patch: "*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new".to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
assert!(
|
||||
parse_patch_text(
|
||||
"*** Begin Patch\n*** Delete File: gone.txt",
|
||||
ParseMode::Streaming
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
assert!(
|
||||
parse_patch_text(
|
||||
"*** Begin Patch\n*** Delete File: gone.txt",
|
||||
ParseMode::Strict
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parse_patch_streaming(
|
||||
"*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt\n",
|
||||
),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: vec![
|
||||
AddFile {
|
||||
path: PathBuf::from("src/one.txt"),
|
||||
contents: "one\n".to_string(),
|
||||
},
|
||||
DeleteFile {
|
||||
path: PathBuf::from("src/two.txt"),
|
||||
},
|
||||
],
|
||||
patch: "*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt"
|
||||
.to_string(),
|
||||
workdir: None,
|
||||
parse_one_hunk(&["bad"], /*line_number*/ 234),
|
||||
Err(InvalidHunkError {
|
||||
message: "'bad' is not a valid hunk header. \
|
||||
Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'".to_string(),
|
||||
line_number: 234
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_patch_streaming_large_patch_by_character() {
|
||||
let patch = "\
|
||||
*** Begin Patch
|
||||
*** Add File: docs/release-notes.md
|
||||
+# Release notes
|
||||
+
|
||||
+## CLI
|
||||
+- Surface apply_patch progress while arguments stream.
|
||||
+- Keep final patch application gated on the completed tool call.
|
||||
+- Include file summaries in the progress event payload.
|
||||
*** Update File: src/config.rs
|
||||
@@ impl Config
|
||||
- pub apply_patch_progress: bool,
|
||||
+ pub stream_apply_patch_progress: bool,
|
||||
pub include_diagnostics: bool,
|
||||
@@ fn default_progress_interval()
|
||||
- Duration::from_millis(500)
|
||||
+ Duration::from_millis(250)
|
||||
*** Delete File: src/legacy_patch_progress.rs
|
||||
*** Update File: crates/cli/src/main.rs
|
||||
*** Move to: crates/cli/src/bin/codex.rs
|
||||
@@ fn run()
|
||||
- let args = Args::parse();
|
||||
- dispatch(args)
|
||||
+ let cli = Cli::parse();
|
||||
+ dispatch(cli)
|
||||
*** Add File: tests/fixtures/apply_patch_progress.json
|
||||
+{
|
||||
+ \"type\": \"apply_patch_progress\",
|
||||
+ \"hunks\": [
|
||||
+ { \"operation\": \"add\", \"path\": \"docs/release-notes.md\" },
|
||||
+ { \"operation\": \"update\", \"path\": \"src/config.rs\" }
|
||||
+ ]
|
||||
+}
|
||||
*** Update File: README.md
|
||||
@@ Development workflow
|
||||
Build the Rust workspace before opening a pull request.
|
||||
+When touching streamed tool calls, include parser coverage for partial input.
|
||||
+Prefer tests that exercise the exact event payload shape.
|
||||
*** Delete File: docs/old-apply-patch-progress.md
|
||||
*** End Patch";
|
||||
|
||||
let mut max_hunk_count = 0;
|
||||
let mut saw_hunk_counts = Vec::new();
|
||||
for i in 1..=patch.len() {
|
||||
let partial = &patch[..i];
|
||||
if let Ok(parsed) = parse_patch_streaming(partial) {
|
||||
let hunk_count = parsed.hunks.len();
|
||||
assert!(
|
||||
hunk_count >= max_hunk_count,
|
||||
"hunk count should never decrease while streaming: {hunk_count} < {max_hunk_count} for {partial:?}",
|
||||
);
|
||||
if hunk_count > max_hunk_count {
|
||||
saw_hunk_counts.push(hunk_count);
|
||||
max_hunk_count = hunk_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(saw_hunk_counts, vec![1, 2, 3, 4, 5, 6, 7]);
|
||||
let parsed = parse_patch_streaming(patch).unwrap();
|
||||
assert_eq!(parsed.hunks.len(), 7);
|
||||
fn test_update_file_chunk() {
|
||||
assert_eq!(
|
||||
parsed
|
||||
.hunks
|
||||
.iter()
|
||||
.map(|hunk| match hunk {
|
||||
AddFile { .. } => "add",
|
||||
DeleteFile { .. } => "delete",
|
||||
UpdateFile {
|
||||
move_path: Some(_), ..
|
||||
} => "move-update",
|
||||
UpdateFile {
|
||||
move_path: None, ..
|
||||
} => "update",
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
"add",
|
||||
"update",
|
||||
"delete",
|
||||
"move-update",
|
||||
"add",
|
||||
"update",
|
||||
"delete"
|
||||
]
|
||||
parse_update_file_chunk(
|
||||
&["bad"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
|
||||
.to_string(),
|
||||
line_number: 123
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@", "bad"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Unexpected line found in update hunk: 'bad'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@", "*** End of File"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&[
|
||||
"@@ change_context",
|
||||
"",
|
||||
" context",
|
||||
"-remove",
|
||||
"+add",
|
||||
" context2",
|
||||
"*** End Patch",
|
||||
],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Ok((
|
||||
UpdateFileChunk {
|
||||
change_context: Some("change_context".to_string()),
|
||||
old_lines: vec![
|
||||
String::new(),
|
||||
"context".to_string(),
|
||||
"remove".to_string(),
|
||||
"context2".to_string(),
|
||||
],
|
||||
new_lines: vec![
|
||||
String::new(),
|
||||
"context".to_string(),
|
||||
"add".to_string(),
|
||||
"context2".to_string(),
|
||||
],
|
||||
is_end_of_file: false,
|
||||
},
|
||||
6,
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@", "+line", "*** End of File"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false,
|
||||
),
|
||||
Ok((
|
||||
UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: vec!["line".to_string()],
|
||||
is_end_of_file: true,
|
||||
},
|
||||
3,
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -997,112 +891,3 @@ fn test_parse_patch_lenient() {
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_one_hunk() {
|
||||
assert_eq!(
|
||||
parse_one_hunk(&["bad"], /*line_number*/ 234, /*allow_incomplete*/ false),
|
||||
Err(InvalidHunkError {
|
||||
message: "'bad' is not a valid hunk header. \
|
||||
Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'".to_string(),
|
||||
line_number: 234
|
||||
})
|
||||
);
|
||||
// Other edge cases are already covered by tests above/below.
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_file_chunk() {
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["bad"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
|
||||
.to_string(),
|
||||
line_number: 123
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(&["@@", "bad"], /*line_number*/ 123, /*allow_missing_context*/ false),
|
||||
Err(InvalidHunkError {
|
||||
message: "Unexpected line found in update hunk: 'bad'. \
|
||||
Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@", "*** End of File"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 124
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&[
|
||||
"@@ change_context",
|
||||
"",
|
||||
" context",
|
||||
"-remove",
|
||||
"+add",
|
||||
" context2",
|
||||
"*** End Patch",
|
||||
],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false
|
||||
),
|
||||
Ok((
|
||||
(UpdateFileChunk {
|
||||
change_context: Some("change_context".to_string()),
|
||||
old_lines: vec![
|
||||
"".to_string(),
|
||||
"context".to_string(),
|
||||
"remove".to_string(),
|
||||
"context2".to_string()
|
||||
],
|
||||
new_lines: vec![
|
||||
"".to_string(),
|
||||
"context".to_string(),
|
||||
"add".to_string(),
|
||||
"context2".to_string()
|
||||
],
|
||||
is_end_of_file: false
|
||||
}),
|
||||
6
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
parse_update_file_chunk(
|
||||
&["@@", "+line", "*** End of File"],
|
||||
/*line_number*/ 123,
|
||||
/*allow_missing_context*/ false
|
||||
),
|
||||
Ok((
|
||||
(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec![],
|
||||
new_lines: vec!["line".to_string()],
|
||||
is_end_of_file: true
|
||||
}),
|
||||
3
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
813
codex-rs/apply-patch/src/streaming_parser.rs
Normal file
813
codex-rs/apply-patch/src/streaming_parser.rs
Normal file
@@ -0,0 +1,813 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::parser::ADD_FILE_MARKER;
|
||||
use crate::parser::BEGIN_PATCH_MARKER;
|
||||
use crate::parser::CHANGE_CONTEXT_MARKER;
|
||||
use crate::parser::DELETE_FILE_MARKER;
|
||||
use crate::parser::EMPTY_CHANGE_CONTEXT_MARKER;
|
||||
use crate::parser::END_PATCH_MARKER;
|
||||
use crate::parser::EOF_MARKER;
|
||||
use crate::parser::Hunk;
|
||||
use crate::parser::MOVE_TO_MARKER;
|
||||
use crate::parser::ParseError;
|
||||
use crate::parser::UPDATE_FILE_MARKER;
|
||||
use crate::parser::UpdateFileChunk;
|
||||
|
||||
use Hunk::*;
|
||||
use ParseError::*;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StreamingPatchParser {
|
||||
line_buffer: String,
|
||||
state: StreamingParserState,
|
||||
line_number: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct StreamingParserState {
|
||||
mode: StreamingParserMode,
|
||||
hunks: Vec<Hunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
enum StreamingParserMode {
|
||||
#[default]
|
||||
NotStarted,
|
||||
StartedPatch,
|
||||
AddFile,
|
||||
DeleteFile,
|
||||
UpdateFile {
|
||||
hunk_line_number: usize,
|
||||
},
|
||||
EndedPatch,
|
||||
}
|
||||
|
||||
impl StreamingPatchParser {
|
||||
fn ensure_update_hunk_is_not_empty(&self, line: &str) -> Result<(), ParseError> {
|
||||
if let Some(UpdateFile { path, chunks, .. }) = self.state.hunks.last() {
|
||||
if chunks.is_empty()
|
||||
&& let StreamingParserMode::UpdateFile { hunk_line_number } = self.state.mode
|
||||
{
|
||||
return Err(InvalidHunkError {
|
||||
message: format!("Update file hunk for path '{}' is empty", path.display()),
|
||||
line_number: hunk_line_number,
|
||||
});
|
||||
}
|
||||
if chunks
|
||||
.last()
|
||||
.is_some_and(|chunk| chunk.old_lines.is_empty() && chunk.new_lines.is_empty())
|
||||
{
|
||||
if line == END_PATCH_MARKER {
|
||||
return Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: self.line_number,
|
||||
});
|
||||
}
|
||||
return Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_hunk_headers_and_end_patch(&mut self, trimmed: &str) -> Result<bool, ParseError> {
|
||||
if trimmed == END_PATCH_MARKER {
|
||||
self.ensure_update_hunk_is_not_empty(trimmed)?;
|
||||
self.state.mode = StreamingParserMode::EndedPatch;
|
||||
return Ok(true);
|
||||
}
|
||||
if let Some(path) = trimmed.strip_prefix(ADD_FILE_MARKER) {
|
||||
self.ensure_update_hunk_is_not_empty(trimmed)?;
|
||||
self.state.hunks.push(AddFile {
|
||||
path: PathBuf::from(path),
|
||||
contents: String::new(),
|
||||
});
|
||||
self.state.mode = StreamingParserMode::AddFile;
|
||||
return Ok(true);
|
||||
}
|
||||
if let Some(path) = trimmed.strip_prefix(DELETE_FILE_MARKER) {
|
||||
self.ensure_update_hunk_is_not_empty(trimmed)?;
|
||||
self.state.hunks.push(DeleteFile {
|
||||
path: PathBuf::from(path),
|
||||
});
|
||||
self.state.mode = StreamingParserMode::DeleteFile;
|
||||
return Ok(true);
|
||||
}
|
||||
if let Some(path) = trimmed.strip_prefix(UPDATE_FILE_MARKER) {
|
||||
self.ensure_update_hunk_is_not_empty(trimmed)?;
|
||||
self.state.hunks.push(UpdateFile {
|
||||
path: PathBuf::from(path),
|
||||
move_path: None,
|
||||
chunks: Vec::new(),
|
||||
});
|
||||
self.state.mode = StreamingParserMode::UpdateFile {
|
||||
hunk_line_number: self.line_number,
|
||||
};
|
||||
return Ok(true);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub fn push_delta(&mut self, delta: &str) -> Result<Vec<Hunk>, ParseError> {
|
||||
for ch in delta.chars() {
|
||||
if ch == '\n' {
|
||||
let mut line = std::mem::take(&mut self.line_buffer);
|
||||
line.truncate(line.strip_suffix('\r').map_or(line.len(), str::len));
|
||||
self.line_number += 1;
|
||||
self.process_line(&line)?;
|
||||
} else {
|
||||
self.line_buffer.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.state.hunks.clone())
|
||||
}
|
||||
|
||||
pub fn finish(&mut self) -> Result<Vec<Hunk>, ParseError> {
|
||||
if !self.line_buffer.is_empty() {
|
||||
let line = std::mem::take(&mut self.line_buffer);
|
||||
self.line_number += 1;
|
||||
if line.trim() == END_PATCH_MARKER {
|
||||
self.ensure_update_hunk_is_not_empty(line.trim())?;
|
||||
self.state.mode = StreamingParserMode::EndedPatch;
|
||||
} else {
|
||||
self.process_line(&line)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(self.state.mode, StreamingParserMode::EndedPatch) {
|
||||
return Err(InvalidPatchError(
|
||||
"The last line of the patch must be '*** End Patch'".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(self.state.hunks.clone())
|
||||
}
|
||||
|
||||
fn process_line(&mut self, line: &str) -> Result<(), ParseError> {
|
||||
let trimmed = line.trim();
|
||||
match self.state.mode.clone() {
|
||||
StreamingParserMode::NotStarted => {
|
||||
if trimmed == BEGIN_PATCH_MARKER {
|
||||
self.state.mode = StreamingParserMode::StartedPatch;
|
||||
return Ok(());
|
||||
}
|
||||
Err(InvalidPatchError(
|
||||
"The first line of the patch must be '*** Begin Patch'".to_string(),
|
||||
))
|
||||
}
|
||||
StreamingParserMode::StartedPatch => {
|
||||
if self.handle_hunk_headers_and_end_patch(trimmed)? {
|
||||
return Ok(());
|
||||
}
|
||||
Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
})
|
||||
}
|
||||
StreamingParserMode::AddFile => {
|
||||
if self.handle_hunk_headers_and_end_patch(trimmed)? {
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(line_to_add) = line.strip_prefix('+')
|
||||
&& let Some(AddFile { contents, .. }) = self.state.hunks.last_mut()
|
||||
{
|
||||
contents.push_str(line_to_add);
|
||||
contents.push('\n');
|
||||
return Ok(());
|
||||
}
|
||||
Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
})
|
||||
}
|
||||
StreamingParserMode::DeleteFile => {
|
||||
if self.handle_hunk_headers_and_end_patch(trimmed)? {
|
||||
return Ok(());
|
||||
}
|
||||
Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"'{trimmed}' is not a valid hunk header. Valid hunk headers: '*** Add File: {{path}}', '*** Delete File: {{path}}', '*** Update File: {{path}}'"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
})
|
||||
}
|
||||
StreamingParserMode::UpdateFile { hunk_line_number } => {
|
||||
let update_line = line.trim_end();
|
||||
if self.handle_hunk_headers_and_end_patch(update_line)? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(UpdateFile {
|
||||
move_path, chunks, ..
|
||||
}) = self.state.hunks.last_mut()
|
||||
{
|
||||
if chunks.is_empty()
|
||||
&& move_path.is_none()
|
||||
&& let Some(move_to_path) = update_line.strip_prefix(MOVE_TO_MARKER)
|
||||
{
|
||||
*move_path = Some(PathBuf::from(move_to_path));
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if (update_line == EMPTY_CHANGE_CONTEXT_MARKER
|
||||
|| update_line.starts_with(CHANGE_CONTEXT_MARKER))
|
||||
&& chunks.last().is_some_and(|chunk| {
|
||||
chunk.old_lines.is_empty() && chunk.new_lines.is_empty()
|
||||
})
|
||||
{
|
||||
return Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
});
|
||||
}
|
||||
|
||||
if update_line == EMPTY_CHANGE_CONTEXT_MARKER {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(change_context) = update_line.strip_prefix(CHANGE_CONTEXT_MARKER) {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: Some(change_context.to_string()),
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if update_line == EOF_MARKER {
|
||||
if chunks.last().is_some_and(|chunk| {
|
||||
chunk.old_lines.is_empty() && chunk.new_lines.is_empty()
|
||||
}) {
|
||||
return Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: self.line_number,
|
||||
});
|
||||
}
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
chunk.is_end_of_file = true;
|
||||
}
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if line.is_empty() {
|
||||
if chunks.is_empty() {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
}
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
chunk.old_lines.push(String::new());
|
||||
chunk.new_lines.push(String::new());
|
||||
}
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(line_to_add) = line.strip_prefix(' ') {
|
||||
if chunks.is_empty() {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
}
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
chunk.old_lines.push(line_to_add.to_string());
|
||||
chunk.new_lines.push(line_to_add.to_string());
|
||||
}
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(line_to_add) = line.strip_prefix('+') {
|
||||
if chunks.is_empty() {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
}
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
chunk.new_lines.push(line_to_add.to_string());
|
||||
}
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(line_to_remove) = line.strip_prefix('-') {
|
||||
if chunks.is_empty() {
|
||||
chunks.push(UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: Vec::new(),
|
||||
new_lines: Vec::new(),
|
||||
is_end_of_file: false,
|
||||
});
|
||||
}
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
chunk.old_lines.push(line_to_remove.to_string());
|
||||
}
|
||||
self.state.mode = StreamingParserMode::UpdateFile { hunk_line_number };
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if chunks.last().is_some_and(|chunk| {
|
||||
!chunk.old_lines.is_empty() || !chunk.new_lines.is_empty()
|
||||
}) {
|
||||
return Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"Expected update hunk to start with a @@ context marker, got: '{line}'"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(InvalidHunkError {
|
||||
message: format!(
|
||||
"Unexpected line found in update hunk: '{line}'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
|
||||
),
|
||||
line_number: self.line_number,
|
||||
})
|
||||
}
|
||||
StreamingParserMode::EndedPatch => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_streams_complete_lines_before_end_patch() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Add File: src/hello.txt\n+hello\n+wor"),
|
||||
Ok(vec![AddFile {
|
||||
path: PathBuf::from("src/hello.txt"),
|
||||
contents: "hello\n".to_string(),
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
parser.push_delta("ld\n"),
|
||||
Ok(vec![AddFile {
|
||||
path: PathBuf::from("src/hello.txt"),
|
||||
contents: "hello\nworld\n".to_string(),
|
||||
}])
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"*** Begin Patch\n*** Update File: src/old.rs\n*** Move to: src/new.rs\n@@\n-old\n+new\n",
|
||||
),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("src/old.rs"),
|
||||
move_path: Some(PathBuf::from("src/new.rs")),
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Delete File: gone.txt"),
|
||||
Ok(Vec::new())
|
||||
);
|
||||
assert_eq!(
|
||||
parser.push_delta("\n"),
|
||||
Ok(vec![DeleteFile {
|
||||
path: PathBuf::from("gone.txt"),
|
||||
}])
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"*** Begin Patch\n*** Add File: src/one.txt\n+one\n*** Delete File: src/two.txt\n",
|
||||
),
|
||||
Ok(vec![
|
||||
AddFile {
|
||||
path: PathBuf::from("src/one.txt"),
|
||||
contents: "one\n".to_string(),
|
||||
},
|
||||
DeleteFile {
|
||||
path: PathBuf::from("src/two.txt"),
|
||||
},
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_large_patch_split_by_character() {
|
||||
let patch = "\
|
||||
*** Begin Patch
|
||||
*** Add File: docs/release-notes.md
|
||||
+# Release notes
|
||||
+
|
||||
+## CLI
|
||||
+- Surface apply_patch progress while arguments stream.
|
||||
+- Keep final patch application gated on the completed tool call.
|
||||
+- Include file summaries in the progress event payload.
|
||||
*** Update File: src/config.rs
|
||||
@@ impl Config
|
||||
- pub apply_patch_progress: bool,
|
||||
+ pub stream_apply_patch_progress: bool,
|
||||
pub include_diagnostics: bool,
|
||||
@@ fn default_progress_interval()
|
||||
- Duration::from_millis(500)
|
||||
+ Duration::from_millis(250)
|
||||
*** Delete File: src/legacy_patch_progress.rs
|
||||
*** Update File: crates/cli/src/main.rs
|
||||
*** Move to: crates/cli/src/bin/codex.rs
|
||||
@@ fn run()
|
||||
- let args = Args::parse();
|
||||
- dispatch(args)
|
||||
+ let cli = Cli::parse();
|
||||
+ dispatch(cli)
|
||||
*** Add File: tests/fixtures/apply_patch_progress.json
|
||||
+{
|
||||
+ \"type\": \"apply_patch_progress\",
|
||||
+ \"hunks\": [
|
||||
+ { \"operation\": \"add\", \"path\": \"docs/release-notes.md\" },
|
||||
+ { \"operation\": \"update\", \"path\": \"src/config.rs\" }
|
||||
+ ]
|
||||
+}
|
||||
*** Update File: README.md
|
||||
@@ Development workflow
|
||||
Build the Rust workspace before opening a pull request.
|
||||
+When touching streamed tool calls, include parser coverage for partial input.
|
||||
+Prefer tests that exercise the exact event payload shape.
|
||||
*** Delete File: docs/old-apply-patch-progress.md
|
||||
*** End Patch";
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
let mut max_hunk_count = 0;
|
||||
let mut saw_hunk_counts = Vec::new();
|
||||
let mut hunks = Vec::new();
|
||||
for ch in patch.chars() {
|
||||
let updated_hunks = parser.push_delta(&ch.to_string()).unwrap();
|
||||
if !updated_hunks.is_empty() {
|
||||
let hunk_count = updated_hunks.len();
|
||||
assert!(
|
||||
hunk_count >= max_hunk_count,
|
||||
"hunk count should never decrease while streaming: {hunk_count} < {max_hunk_count}",
|
||||
);
|
||||
if hunk_count > max_hunk_count {
|
||||
saw_hunk_counts.push(hunk_count);
|
||||
max_hunk_count = hunk_count;
|
||||
}
|
||||
hunks = updated_hunks;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(saw_hunk_counts, vec![1, 2, 3, 4, 5, 6, 7]);
|
||||
assert_eq!(hunks.len(), 7);
|
||||
assert_eq!(
|
||||
hunks
|
||||
.iter()
|
||||
.map(|hunk| match hunk {
|
||||
AddFile { .. } => "add",
|
||||
DeleteFile { .. } => "delete",
|
||||
UpdateFile {
|
||||
move_path: Some(_), ..
|
||||
} => "move-update",
|
||||
UpdateFile {
|
||||
move_path: None, ..
|
||||
} => "update",
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
"add",
|
||||
"update",
|
||||
"delete",
|
||||
"move-update",
|
||||
"add",
|
||||
"update",
|
||||
"delete"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_keeps_indented_update_markers_as_context_lines() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"\
|
||||
*** Begin Patch
|
||||
*** Update File: a.txt
|
||||
@@
|
||||
-old a
|
||||
+new a
|
||||
*** Update File: b.txt
|
||||
@@
|
||||
-old b
|
||||
+new b
|
||||
*** End Patch
|
||||
",
|
||||
),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("a.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![
|
||||
UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old a".to_string(), "*** Update File: b.txt".to_string()],
|
||||
new_lines: vec!["new a".to_string(), "*** Update File: b.txt".to_string()],
|
||||
is_end_of_file: false,
|
||||
},
|
||||
UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old b".to_string()],
|
||||
new_lines: vec!["new b".to_string()],
|
||||
is_end_of_file: false,
|
||||
},
|
||||
],
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_preserves_bare_empty_update_lines() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"\
|
||||
*** Begin Patch
|
||||
*** Update File: file.txt
|
||||
@@
|
||||
context before
|
||||
|
||||
context after
|
||||
*** End Patch
|
||||
",
|
||||
),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
// The normal parser treats a bare empty line in an update hunk as an
|
||||
// empty context line. Preserve that leniency in the streaming parser.
|
||||
old_lines: vec![
|
||||
"context before".to_string(),
|
||||
String::new(),
|
||||
"context after".to_string(),
|
||||
],
|
||||
new_lines: vec![
|
||||
"context before".to_string(),
|
||||
String::new(),
|
||||
"context after".to_string(),
|
||||
],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_matches_line_ending_behavior() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\r\n*** Update File: file.txt\r\n@@\r\n-old\r\n+new\r\n*** End Patch\r\n"),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\r\n*** Update File: file.txt\r\n@@\r\n-old\r\r\n+new\r\n*** End Patch\r\n"),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old\r".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_finish_processes_final_line_without_newline() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\n+hello\n*** End Patch"),
|
||||
Ok(vec![AddFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
contents: "hello\n".to_string(),
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
Ok(vec![AddFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
contents: "hello\n".to_string(),
|
||||
}])
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"*** Begin Patch\n*** Update File: file.txt\n@@\n-old\n+new\n *** End Patch",
|
||||
),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
Ok(vec![UpdateFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
move_path: None,
|
||||
chunks: vec![UpdateFileChunk {
|
||||
change_context: None,
|
||||
old_lines: vec!["old".to_string()],
|
||||
new_lines: vec!["new".to_string()],
|
||||
is_end_of_file: false,
|
||||
}],
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_finish_requires_end_patch() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\n+hello\n"),
|
||||
Ok(vec![AddFile {
|
||||
path: PathBuf::from("file.txt"),
|
||||
contents: "hello\n".to_string(),
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
Err(InvalidPatchError(
|
||||
"The last line of the patch must be '*** End Patch'".to_string(),
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_patch_parser_returns_errors() {
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("bad\n"),
|
||||
Err(InvalidPatchError(
|
||||
"The first line of the patch must be '*** Begin Patch'".to_string(),
|
||||
))
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(parser.push_delta("*** Begin Patch\n"), Ok(Vec::new()));
|
||||
assert_eq!(
|
||||
parser.push_delta("bad\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
|
||||
.to_string(),
|
||||
line_number: 2,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Add File: file.txt\nbad\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
|
||||
.to_string(),
|
||||
line_number: 3,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Delete File: file.txt\nbad\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "'bad' is not a valid hunk header. Valid hunk headers: '*** Add File: {path}', '*** Delete File: {path}', '*** Update File: {path}'"
|
||||
.to_string(),
|
||||
line_number: 3,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n*** End Patch\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update file hunk for path 'file.txt' is empty".to_string(),
|
||||
line_number: 2,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"*** Begin Patch\n*** Update File: old.txt\n*** Move to: new.txt\n*** Delete File: other.txt\n",
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update file hunk for path 'old.txt' is empty".to_string(),
|
||||
line_number: 2,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n*** End Patch\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 4,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n*** End of File\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "Update hunk does not contain any lines".to_string(),
|
||||
line_number: 4,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n@@\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "Unexpected line found in update hunk: '@@'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
|
||||
.to_string(),
|
||||
line_number: 4,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta("*** Begin Patch\n*** Update File: file.txt\n@@\n-old\nbad\n"),
|
||||
Err(InvalidHunkError {
|
||||
message: "Expected update hunk to start with a @@ context marker, got: 'bad'"
|
||||
.to_string(),
|
||||
line_number: 5,
|
||||
})
|
||||
);
|
||||
|
||||
let mut parser = StreamingPatchParser::default();
|
||||
assert_eq!(
|
||||
parser.push_delta(
|
||||
"*** Begin Patch\n*** Update File: file.txt\n@@\n*** Update File: other.txt\n",
|
||||
),
|
||||
Err(InvalidHunkError {
|
||||
message: "Unexpected line found in update hunk: '*** Update File: other.txt'. Every line should start with ' ' (context line), '+' (added line), or '-' (removed line)"
|
||||
.to_string(),
|
||||
line_number: 4,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -97,6 +97,11 @@ impl AwsAuthContext {
|
||||
&self.service
|
||||
}
|
||||
|
||||
pub async fn preload_credentials(&self) -> Result<(), AwsAuthError> {
|
||||
let _ = self.credentials_provider.provide_credentials().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
|
||||
self.sign_at(request, SystemTime::now()).await
|
||||
}
|
||||
@@ -202,6 +207,14 @@ mod tests {
|
||||
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preload_credentials_resolves_provider() {
|
||||
test_context(/*session_token*/ None)
|
||||
.preload_credentials()
|
||||
.await
|
||||
.expect("static credentials should resolve");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_provider_failures_are_retryable() {
|
||||
assert!(
|
||||
|
||||
@@ -673,6 +673,15 @@ impl ModelClient {
|
||||
true
|
||||
}
|
||||
|
||||
/// Resolves provider credentials during session startup when the provider requests it.
|
||||
pub(crate) async fn prewarm_provider_auth(&self) -> Result<()> {
|
||||
if !self.state.provider.prewarms_auth_on_startup() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.state.provider.prewarm_auth().await
|
||||
}
|
||||
|
||||
/// Returns auth + provider configuration resolved from the current session auth state.
|
||||
///
|
||||
/// This centralizes setup used by both prewarm and normal request paths so they stay in
|
||||
|
||||
@@ -82,6 +82,7 @@ pub(crate) mod mentions {
|
||||
mod sandbox_tags;
|
||||
pub mod sandboxing;
|
||||
mod session_prefix;
|
||||
mod session_startup_auth_prewarm;
|
||||
mod session_startup_prewarm;
|
||||
mod shell_detect;
|
||||
pub mod skills;
|
||||
|
||||
@@ -980,6 +980,7 @@ impl Session {
|
||||
anyhow::bail!("required MCP servers failed to initialize: {details}");
|
||||
}
|
||||
}
|
||||
sess.schedule_startup_auth_prewarm().await;
|
||||
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
|
||||
.await;
|
||||
let session_start_source = match &initial_history {
|
||||
|
||||
@@ -1902,7 +1902,7 @@ async fn try_run_sampling_request(
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
if let Some((_, mut consumer)) = active_tool_argument_diff_consumer.take()
|
||||
&& let Some(event) = consumer.flush_on_complete()
|
||||
&& let Ok(Some(event)) = consumer.finish()
|
||||
{
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
|
||||
16
codex-rs/core/src/session_startup_auth_prewarm.rs
Normal file
16
codex-rs/core/src/session_startup_auth_prewarm.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use crate::session::session::Session;
|
||||
|
||||
impl Session {
|
||||
pub(crate) async fn schedule_startup_auth_prewarm(self: &Arc<Self>) {
|
||||
let model_client = self.services.model_client.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = model_client.prewarm_provider_auth().await {
|
||||
debug!("startup provider auth prewarm failed: {err:#}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -33,10 +33,9 @@ use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
|
||||
use crate::tools::sandboxing::ToolCtx;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchArgs;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
use codex_apply_patch::Hunk;
|
||||
use codex_apply_patch::parse_patch_streaming;
|
||||
use codex_apply_patch::StreamingPatchParser;
|
||||
use codex_exec_server::ExecutorFileSystem;
|
||||
use codex_features::Feature;
|
||||
use codex_protocol::models::AdditionalPermissionProfile;
|
||||
@@ -56,8 +55,7 @@ pub struct ApplyPatchHandler;
|
||||
|
||||
#[derive(Default)]
|
||||
struct ApplyPatchArgumentDiffConsumer {
|
||||
input: String,
|
||||
last_progress: Option<Vec<Hunk>>,
|
||||
parser: StreamingPatchParser,
|
||||
last_sent_at: Option<Instant>,
|
||||
pending: Option<PatchApplyUpdatedEvent>,
|
||||
}
|
||||
@@ -77,26 +75,19 @@ impl ToolArgumentDiffConsumer for ApplyPatchArgumentDiffConsumer {
|
||||
.map(EventMsg::PatchApplyUpdated)
|
||||
}
|
||||
|
||||
fn flush_on_complete(&mut self) -> Option<EventMsg> {
|
||||
self.flush_update_on_complete()
|
||||
.map(EventMsg::PatchApplyUpdated)
|
||||
fn finish(&mut self) -> Result<Option<EventMsg>, FunctionCallError> {
|
||||
self.finish_update_on_complete()
|
||||
.map(|event| event.map(EventMsg::PatchApplyUpdated))
|
||||
}
|
||||
}
|
||||
|
||||
impl ApplyPatchArgumentDiffConsumer {
|
||||
fn push_delta(&mut self, call_id: String, delta: &str) -> Option<PatchApplyUpdatedEvent> {
|
||||
self.input.push_str(delta);
|
||||
|
||||
let ApplyPatchArgs { hunks, .. } = parse_patch_streaming(&self.input).ok()?;
|
||||
let hunks = self.parser.push_delta(delta).ok()?;
|
||||
if hunks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if self.last_progress.as_ref() == Some(&hunks) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let changes = convert_apply_patch_hunks_to_protocol(&hunks);
|
||||
self.last_progress = Some(hunks);
|
||||
let event = PatchApplyUpdatedEvent { call_id, changes };
|
||||
let now = Instant::now();
|
||||
match self.last_sent_at {
|
||||
@@ -114,12 +105,18 @@ impl ApplyPatchArgumentDiffConsumer {
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_update_on_complete(&mut self) -> Option<PatchApplyUpdatedEvent> {
|
||||
fn finish_update_on_complete(
|
||||
&mut self,
|
||||
) -> Result<Option<PatchApplyUpdatedEvent>, FunctionCallError> {
|
||||
self.parser.finish().map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse apply_patch: {err}"))
|
||||
})?;
|
||||
|
||||
let event = self.pending.take();
|
||||
if event.is_some() {
|
||||
self.last_sent_at = Some(Instant::now());
|
||||
}
|
||||
event
|
||||
Ok(event)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ fn diff_consumer_streams_apply_patch_changes() {
|
||||
HashMap::from([(
|
||||
PathBuf::from("hello.txt"),
|
||||
FileChange::Add {
|
||||
content: "hello\n".to_string(),
|
||||
content: String::new(),
|
||||
},
|
||||
)]),
|
||||
)
|
||||
@@ -147,8 +147,16 @@ fn diff_consumer_streams_apply_patch_changes() {
|
||||
.push_delta("call-1".to_string(), "\n+world")
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
consumer
|
||||
.push_delta("call-1".to_string(), "\n*** End Patch")
|
||||
.is_none()
|
||||
);
|
||||
|
||||
let event = consumer.flush_update_on_complete().expect("progress event");
|
||||
let event = consumer
|
||||
.finish_update_on_complete()
|
||||
.expect("finish parser")
|
||||
.expect("progress event");
|
||||
assert_eq!(
|
||||
(event.call_id, event.changes),
|
||||
(
|
||||
@@ -175,7 +183,7 @@ fn diff_consumer_sends_next_update_after_buffer_interval() {
|
||||
HashMap::from([(
|
||||
PathBuf::from("hello.txt"),
|
||||
FileChange::Add {
|
||||
content: "hello\n".to_string(),
|
||||
content: String::new(),
|
||||
},
|
||||
)])
|
||||
);
|
||||
@@ -190,7 +198,7 @@ fn diff_consumer_sends_next_update_after_buffer_interval() {
|
||||
HashMap::from([(
|
||||
PathBuf::from("hello.txt"),
|
||||
FileChange::Add {
|
||||
content: "hello\nworld\n".to_string(),
|
||||
content: "hello\n".to_string(),
|
||||
},
|
||||
)])
|
||||
);
|
||||
|
||||
@@ -98,9 +98,9 @@ pub(crate) trait ToolArgumentDiffConsumer: Send {
|
||||
fn consume_diff(&mut self, turn: &TurnContext, call_id: String, diff: &str)
|
||||
-> Option<EventMsg>;
|
||||
|
||||
/// Flush any buffered event before the tool call completes.
|
||||
fn flush_on_complete(&mut self) -> Option<EventMsg> {
|
||||
None
|
||||
/// Finish consuming argument diffs before the tool call completes.
|
||||
fn finish(&mut self) -> Result<Option<EventMsg>, FunctionCallError> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1027,7 +1027,7 @@ async fn apply_patch_custom_tool_streaming_emits_updated_changes() -> Result<()>
|
||||
.changes
|
||||
.get(&std::path::PathBuf::from("streamed.txt")),
|
||||
Some(&codex_protocol::protocol::FileChange::Add {
|
||||
content: "hello\n".to_string(),
|
||||
content: String::new(),
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
|
||||
@@ -129,7 +129,7 @@ pub struct ShutdownHandle {
|
||||
impl ShutdownHandle {
|
||||
/// Signals the login loop to terminate.
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown_notify.notify_waiters();
|
||||
self.shutdown_notify.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ use super::mantle::region_from_config;
|
||||
const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK";
|
||||
const LEGACY_SESSION_ID_HEADER: &str = "session_id";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) enum BedrockAuthMethod {
|
||||
EnvBearerToken { token: String, region: String },
|
||||
AwsSdkAuth { context: AwsAuthContext },
|
||||
@@ -42,17 +43,25 @@ pub(super) async fn resolve_auth_method(
|
||||
Ok(BedrockAuthMethod::AwsSdkAuth { context })
|
||||
}
|
||||
|
||||
pub(super) async fn resolve_provider_auth(
|
||||
aws: &ModelProviderAwsAuthInfo,
|
||||
) -> Result<SharedAuthProvider> {
|
||||
match resolve_auth_method(aws).await? {
|
||||
BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider {
|
||||
pub(super) async fn prewarm_credentials(auth_method: &BedrockAuthMethod) -> Result<()> {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { .. } => Ok(()),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => context
|
||||
.preload_credentials()
|
||||
.await
|
||||
.map_err(aws_auth_error_to_codex_error),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn provider_auth_from_method(auth_method: BedrockAuthMethod) -> SharedAuthProvider {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { token, .. } => Arc::new(BearerAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: None,
|
||||
is_fedramp_account: false,
|
||||
})),
|
||||
}),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => {
|
||||
Ok(Arc::new(BedrockMantleSigV4AuthProvider::new(context)))
|
||||
Arc::new(BedrockMantleSigV4AuthProvider::new(context))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
|
||||
use super::auth::BedrockAuthMethod;
|
||||
use super::auth::resolve_auth_method;
|
||||
|
||||
const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle";
|
||||
const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [
|
||||
@@ -48,16 +47,15 @@ pub(super) fn base_url(region: &str) -> Result<String> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn runtime_base_url(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
|
||||
let region = resolve_region(aws).await?;
|
||||
base_url(®ion)
|
||||
pub(super) fn region_from_auth_method(auth_method: &BedrockAuthMethod) -> String {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { region, .. } => region.clone(),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => context.region().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
|
||||
match resolve_auth_method(aws).await? {
|
||||
BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => Ok(context.region().to_string()),
|
||||
}
|
||||
pub(super) fn runtime_base_url_from_auth_method(auth_method: &BedrockAuthMethod) -> Result<String> {
|
||||
base_url(®ion_from_auth_method(auth_method))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -16,20 +16,26 @@ use codex_models_manager::manager::StaticModelsManager;
|
||||
use codex_protocol::account::ProviderAccount;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::provider::ProviderAccountResult;
|
||||
use crate::provider::ProviderAccountState;
|
||||
use crate::provider::ProviderCapabilities;
|
||||
use auth::resolve_provider_auth;
|
||||
use auth::BedrockAuthMethod;
|
||||
use auth::prewarm_credentials;
|
||||
use auth::provider_auth_from_method;
|
||||
use auth::resolve_auth_method;
|
||||
pub(crate) use catalog::static_model_catalog;
|
||||
use mantle::runtime_base_url;
|
||||
use mantle::runtime_base_url_from_auth_method;
|
||||
|
||||
/// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AmazonBedrockModelProvider {
|
||||
pub(crate) info: ModelProviderInfo,
|
||||
pub(crate) aws: ModelProviderAwsAuthInfo,
|
||||
auth_method: Arc<OnceCell<BedrockAuthMethod>>,
|
||||
credentials_prewarmed: Arc<OnceCell<()>>,
|
||||
}
|
||||
|
||||
impl AmazonBedrockModelProvider {
|
||||
@@ -44,8 +50,25 @@ impl AmazonBedrockModelProvider {
|
||||
Self {
|
||||
info: provider_info,
|
||||
aws,
|
||||
auth_method: Arc::new(OnceCell::new()),
|
||||
credentials_prewarmed: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn auth_method(&self) -> Result<BedrockAuthMethod> {
|
||||
self.auth_method
|
||||
.get_or_try_init(|| resolve_auth_method(&self.aws))
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
|
||||
async fn prewarm_bedrock_credentials(&self) -> Result<()> {
|
||||
let auth_method = self.auth_method().await?;
|
||||
self.credentials_prewarmed
|
||||
.get_or_try_init(|| async move { prewarm_credentials(&auth_method).await })
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -70,6 +93,14 @@ impl ModelProvider for AmazonBedrockModelProvider {
|
||||
None
|
||||
}
|
||||
|
||||
fn prewarms_auth_on_startup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn prewarm_auth(&self) -> Result<()> {
|
||||
self.prewarm_bedrock_credentials().await
|
||||
}
|
||||
|
||||
fn account_state(&self) -> ProviderAccountResult {
|
||||
Ok(ProviderAccountState {
|
||||
account: Some(ProviderAccount::AmazonBedrock),
|
||||
@@ -79,16 +110,20 @@ impl ModelProvider for AmazonBedrockModelProvider {
|
||||
|
||||
async fn api_provider(&self) -> Result<Provider> {
|
||||
let mut api_provider_info = self.info.clone();
|
||||
api_provider_info.base_url = Some(runtime_base_url(&self.aws).await?);
|
||||
api_provider_info.base_url = Some(runtime_base_url_from_auth_method(
|
||||
&self.auth_method().await?,
|
||||
)?);
|
||||
api_provider_info.to_api_provider(/*auth_mode*/ None)
|
||||
}
|
||||
|
||||
async fn runtime_base_url(&self) -> Result<Option<String>> {
|
||||
Ok(Some(runtime_base_url(&self.aws).await?))
|
||||
Ok(Some(runtime_base_url_from_auth_method(
|
||||
&self.auth_method().await?,
|
||||
)?))
|
||||
}
|
||||
|
||||
async fn api_auth(&self) -> Result<SharedAuthProvider> {
|
||||
resolve_provider_auth(&self.aws).await
|
||||
Ok(provider_auth_from_method(self.auth_method().await?))
|
||||
}
|
||||
|
||||
fn models_manager(
|
||||
|
||||
@@ -96,6 +96,16 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
|
||||
/// Returns the current provider-scoped auth value, if one is configured.
|
||||
async fn auth(&self) -> Option<CodexAuth>;
|
||||
|
||||
/// Returns whether this provider should resolve request credentials during session startup.
|
||||
fn prewarms_auth_on_startup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Resolves provider credentials before the first model request when startup prewarm is enabled.
|
||||
async fn prewarm_auth(&self) -> codex_protocol::error::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the current app-visible account state for this provider.
|
||||
fn account_state(&self) -> ProviderAccountResult;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user