mirror of
https://github.com/openai/codex.git
synced 2026-02-05 00:13:42 +00:00
Compare commits
62 Commits
forking-2
...
patch-tool
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
570639cf98 | ||
|
|
1c50fbb8a7 | ||
|
|
3316d04ed4 | ||
|
|
67a8566f59 | ||
|
|
2d36621f48 | ||
|
|
0a70810fc0 | ||
|
|
b5cf9e09ff | ||
|
|
b2067c73d9 | ||
|
|
13e8771ee9 | ||
|
|
bba567cee9 | ||
|
|
6577197fa4 | ||
|
|
fd1e12f34e | ||
|
|
ba6af23cb6 | ||
|
|
f805d17930 | ||
|
|
9580603fed | ||
|
|
da38a8f56a | ||
|
|
552a438cc9 | ||
|
|
a36a273d4e | ||
|
|
6884c6ccf6 | ||
|
|
1e5a613c55 | ||
|
|
90965fbc84 | ||
|
|
4fee2ca3fd | ||
|
|
3318cf9369 | ||
|
|
5ba0bcf035 | ||
|
|
6d55ef62f9 | ||
|
|
cecf3a82a6 | ||
|
|
c172e8e997 | ||
|
|
9bbeb75361 | ||
|
|
6ccd32c601 | ||
|
|
3b5a5412bb | ||
|
|
44bb53df1e | ||
|
|
9a7266a33f | ||
|
|
2abad8fece | ||
|
|
0d4a25b981 | ||
|
|
8453915e02 | ||
|
|
44587c2443 | ||
|
|
8f7b22b652 | ||
|
|
027944c64e | ||
|
|
bec51f6c05 | ||
|
|
66967500bb | ||
|
|
167b4f0e25 | ||
|
|
167154178b | ||
|
|
674e3d3c90 | ||
|
|
114ce9ff4d | ||
|
|
e13b35ecb0 | ||
|
|
377af75730 | ||
|
|
86e0f31a7e | ||
|
|
8f837f1093 | ||
|
|
162e1235a8 | ||
|
|
c09ed74a16 | ||
|
|
65f3528cad | ||
|
|
44262d8fd8 | ||
|
|
95a9938d3a | ||
|
|
f69f07b028 | ||
|
|
8d766088e6 | ||
|
|
87654ec0b7 | ||
|
|
51d9e05de7 | ||
|
|
8068cc75f8 | ||
|
|
acb28bf914 | ||
|
|
97338de578 | ||
|
|
5200b7a95d | ||
|
|
64e6c4afbb |
2
.github/workflows/rust-ci.yml
vendored
2
.github/workflows/rust-ci.yml
vendored
@@ -62,6 +62,8 @@ jobs:
|
||||
components: rustfmt
|
||||
- name: cargo fmt
|
||||
run: cargo fmt -- --config imports_granularity=Item --check
|
||||
- name: Verify codegen for mcp-types
|
||||
run: ./mcp-types/check_lib_rs.py
|
||||
|
||||
cargo_shear:
|
||||
name: cargo shear
|
||||
|
||||
19
.github/workflows/rust-release.yml
vendored
19
.github/workflows/rust-release.yml
vendored
@@ -219,3 +219,22 @@ jobs:
|
||||
with:
|
||||
tag: ${{ github.ref_name }}
|
||||
config: .github/dotslash-config.json
|
||||
|
||||
update-branch:
|
||||
name: Update latest-alpha-cli branch
|
||||
permissions:
|
||||
contents: write
|
||||
needs: release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Update latest-alpha-cli branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh api \
|
||||
repos/${GITHUB_REPOSITORY}/git/refs/heads/latest-alpha-cli \
|
||||
-X PATCH \
|
||||
-f sha="${GITHUB_SHA}" \
|
||||
-F force=true
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, see <a href="https://chatgpt.com/codex">chatgpt.com/codex</a>.</p>
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
</br>
|
||||
</br>If you want Codex in your code editor (VS Code, Cursor, Windsurf), <a href="https://developers.openai.com/codex/ide">install in your IDE</a>
|
||||
</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, go to <a href="https://chatgpt.com/codex">chatgpt.com/codex</a></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="./.github/codex-cli-splash.png" alt="Codex CLI splash" width="80%" />
|
||||
|
||||
191
codex-rs/Cargo.lock
generated
191
codex-rs/Cargo.lock
generated
@@ -311,15 +311,6 @@ version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1"
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
@@ -570,7 +561,6 @@ dependencies = [
|
||||
"clap",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
@@ -655,7 +645,7 @@ dependencies = [
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
"toml_edit 0.23.4",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
@@ -778,6 +768,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"base64",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
@@ -879,6 +870,7 @@ dependencies = [
|
||||
"path-clean",
|
||||
"pathdiff",
|
||||
"pretty_assertions",
|
||||
"pulldown-cmark",
|
||||
"rand 0.9.2",
|
||||
"ratatui",
|
||||
"regex-lite",
|
||||
@@ -895,7 +887,6 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tui-markdown",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
"url",
|
||||
@@ -1763,12 +1754,6 @@ version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.31"
|
||||
@@ -1854,12 +1839,6 @@ version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
|
||||
[[package]]
|
||||
name = "globset"
|
||||
version = "0.4.16"
|
||||
@@ -2567,12 +2546,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
@@ -3014,28 +2987,6 @@ version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
version = "6.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"onig_sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onig_sys"
|
||||
version = "69.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.73"
|
||||
@@ -3361,15 +3312,6 @@ dependencies = [
|
||||
"yansi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-crate"
|
||||
version = "3.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35"
|
||||
dependencies = [
|
||||
"toml_edit 0.22.27",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.95"
|
||||
@@ -3381,9 +3323,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pulldown-cmark"
|
||||
version = "0.13.0"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0"
|
||||
checksum = "76979bea66e7875e7509c4ec5300112b316af87fa7a252ca91c448b32dfe3993"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"getopts",
|
||||
@@ -3394,9 +3336,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pulldown-cmark-escape"
|
||||
version = "0.11.0"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae"
|
||||
checksum = "bd348ff538bc9caeda7ee8cad2d1d48236a1f443c1fa3913c6a02fe0043b1dd3"
|
||||
|
||||
[[package]]
|
||||
name = "pxfm"
|
||||
@@ -3627,12 +3569,6 @@ version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "relative-path"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.23"
|
||||
@@ -3691,51 +3627,12 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d"
|
||||
dependencies = [
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"rstest_macros",
|
||||
"rustc_version",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest_macros"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"glob",
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"relative-path",
|
||||
"rustc_version",
|
||||
"syn 2.0.104",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@@ -3975,12 +3872,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.219"
|
||||
@@ -4464,28 +4355,6 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syntect"
|
||||
version = "5.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "874dcfa363995604333cf947ae9f751ca3af4522c60886774c4963943b4746b1"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"bitflags 1.3.2",
|
||||
"flate2",
|
||||
"fnv",
|
||||
"once_cell",
|
||||
"onig",
|
||||
"plist",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sys-locale"
|
||||
version = "0.3.2"
|
||||
@@ -4809,18 +4678,12 @@ dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime 0.7.0",
|
||||
"toml_datetime",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c"
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.7.0"
|
||||
@@ -4830,17 +4693,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"toml_datetime 0.6.11",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.23.4"
|
||||
@@ -4848,7 +4700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7211ff1b8f0d3adae1663b7da9ffe396eabe1ca25f0b0bee42b0da29a9ddce93"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"toml_datetime 0.7.0",
|
||||
"toml_datetime",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
"winnow",
|
||||
@@ -5058,22 +4910,6 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tui-markdown"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d10648c25931bfaaf5334ff4e7dc5f3d830e0c50d7b0119b1d5cfe771f540536"
|
||||
dependencies = [
|
||||
"ansi-to-tui",
|
||||
"itertools 0.14.0",
|
||||
"pretty_assertions",
|
||||
"pulldown-cmark",
|
||||
"ratatui",
|
||||
"rstest",
|
||||
"syntect",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
@@ -5855,15 +5691,6 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d"
|
||||
|
||||
[[package]]
|
||||
name = "yaml-rust"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85"
|
||||
dependencies = [
|
||||
"linked-hash-map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yansi"
|
||||
version = "1.0.1"
|
||||
|
||||
@@ -34,6 +34,7 @@ rust = {}
|
||||
|
||||
[workspace.lints.clippy]
|
||||
expect_used = "deny"
|
||||
redundant_clone = "deny"
|
||||
uninlined_format_args = "deny"
|
||||
unwrap_used = "deny"
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ npx @modelcontextprotocol/inspector codex mcp
|
||||
|
||||
You can enable notifications by configuring a script that is run whenever the agent finishes a turn. The [notify documentation](../docs/config.md#notify) includes a detailed example that explains how to get desktop notifications via [terminal-notifier](https://github.com/julienXX/terminal-notifier) on macOS.
|
||||
|
||||
### `codex exec` to run Codex programmatially/non-interactively
|
||||
### `codex exec` to run Codex programmatically/non-interactively
|
||||
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
|
||||
|
||||
@@ -733,6 +733,8 @@ fn compute_replacements(
|
||||
}
|
||||
}
|
||||
|
||||
replacements.sort_by(|(lhs_idx, _, _), (rhs_idx, _, _)| lhs_idx.cmp(rhs_idx));
|
||||
|
||||
Ok(replacements)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1218,33 @@ PATCH"#,
|
||||
assert_eq!(contents, "a\nB\nc\nd\nE\nf\ng\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pure_addition_chunk_followed_by_removal() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("panic.txt");
|
||||
fs::write(&path, "line1\nline2\nline3\n").unwrap();
|
||||
let patch = wrap_patch(&format!(
|
||||
r#"*** Update File: {}
|
||||
@@
|
||||
+after-context
|
||||
+second-line
|
||||
@@
|
||||
line1
|
||||
-line2
|
||||
-line3
|
||||
+line2-replacement"#,
|
||||
path.display()
|
||||
));
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
apply_patch(&patch, &mut stdout, &mut stderr).unwrap();
|
||||
let contents = fs::read_to_string(path).unwrap();
|
||||
assert_eq!(
|
||||
contents,
|
||||
"line1\nline2-replacement\nafter-context\nsecond-line\n"
|
||||
);
|
||||
}
|
||||
|
||||
/// Ensure that patches authored with ASCII characters can update lines that
|
||||
/// contain typographic Unicode punctuation (e.g. EN DASH, NON-BREAKING
|
||||
/// HYPHEN). Historically `git apply` succeeds in such scenarios but our
|
||||
|
||||
@@ -617,7 +617,7 @@ fn test_parse_patch_lenient() {
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
hunks: expected_patch,
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
@@ -637,7 +637,7 @@ fn test_parse_patch_lenient() {
|
||||
"<<EOF\n*** Begin Patch\n*** Update File: file2.py\nEOF\n".to_string();
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Strict),
|
||||
Err(expected_error.clone())
|
||||
Err(expected_error)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Lenient),
|
||||
|
||||
@@ -11,7 +11,6 @@ anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
@@ -20,7 +19,7 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
||||
|
||||
/// Initialize the ChatGPT token from auth.json file
|
||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT)?;
|
||||
let auth = CodexAuth::from_codex_home(codex_home)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::logout;
|
||||
use codex_core::config::Config;
|
||||
@@ -9,7 +8,6 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::run_login_server;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
@@ -60,19 +58,11 @@ pub async fn run_login_with_api_key(
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method) {
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
Ok(api_key) => {
|
||||
eprintln!("Logged in using an API key - {}", safe_format_key(&api_key));
|
||||
|
||||
if let Ok(env_api_key) = env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
&& env_api_key == api_key
|
||||
{
|
||||
eprintln!(
|
||||
" API loaded from OPENAI_API_KEY environment variable or .env file"
|
||||
);
|
||||
}
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -37,10 +37,8 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
// Use conversation_manager API to start a conversation
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Returns a string representing the elapsed time since `start_time` like
|
||||
/// "1m15s" or "1.50s".
|
||||
/// "1m 15s" or "1.50s".
|
||||
pub fn format_elapsed(start_time: Instant) -> String {
|
||||
format_duration(start_time.elapsed())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ pub fn format_elapsed(start_time: Instant) -> String {
|
||||
/// Formatting rules:
|
||||
/// * < 1 s -> "{milli}ms"
|
||||
/// * < 60 s -> "{sec:.2}s" (two decimal places)
|
||||
/// * >= 60 s -> "{min}m{sec:02}s"
|
||||
/// * >= 60 s -> "{min}m {sec:02}s"
|
||||
pub fn format_duration(duration: Duration) -> String {
|
||||
let millis = duration.as_millis() as i64;
|
||||
format_elapsed_millis(millis)
|
||||
@@ -26,7 +26,7 @@ fn format_elapsed_millis(millis: i64) -> String {
|
||||
} else {
|
||||
let minutes = millis / 60_000;
|
||||
let seconds = (millis % 60_000) / 1000;
|
||||
format!("{minutes}m{seconds:02}s")
|
||||
format!("{minutes}m {seconds:02}s")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,12 +61,18 @@ mod tests {
|
||||
fn test_format_duration_minutes() {
|
||||
// Durations ≥ 1 minute should be printed mmss.
|
||||
let dur = Duration::from_millis(75_000); // 1m15s
|
||||
assert_eq!(format_duration(dur), "1m15s");
|
||||
assert_eq!(format_duration(dur), "1m 15s");
|
||||
|
||||
let dur_exact = Duration::from_millis(60_000); // 1m0s
|
||||
assert_eq!(format_duration(dur_exact), "1m00s");
|
||||
assert_eq!(format_duration(dur_exact), "1m 00s");
|
||||
|
||||
let dur_long = Duration::from_millis(3_601_000);
|
||||
assert_eq!(format_duration(dur_long), "60m01s");
|
||||
assert_eq!(format_duration(dur_long), "60m 01s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration_one_hour_has_space() {
|
||||
let dur_hour = Duration::from_millis(3_600_000);
|
||||
assert_eq!(format_duration(dur_hour), "60m 00s");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ pub fn builtin_model_presets() -> &'static [ModelPreset] {
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::High,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-high-new",
|
||||
label: "gpt-5 high new",
|
||||
description: "— our latest release tuned to rely on the model's built-in reasoning defaults",
|
||||
model: "gpt-5-high-new",
|
||||
effort: ReasoningEffort::Medium,
|
||||
},
|
||||
];
|
||||
PRESETS
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
which = "6"
|
||||
wildmatch = "2.4.0"
|
||||
|
||||
|
||||
@@ -69,9 +70,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
|
||||
[target.aarch64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { version = "*", features = ["vendored"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
which = "6"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
core_test_support = { path = "tests/common" }
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::time::Duration;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
use crate::token_data::PlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
@@ -70,13 +71,9 @@ impl CodexAuth {
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json or
|
||||
/// OPENAI_API_KEY environment variable.
|
||||
pub fn from_codex_home(
|
||||
codex_home: &Path,
|
||||
preferred_auth_method: AuthMode,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home, true, preferred_auth_method)
|
||||
/// Loads the available auth information from the auth.json.
|
||||
pub fn from_codex_home(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
@@ -135,13 +132,12 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.account_id.clone())
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub fn get_plan_type(&self) -> Option<String> {
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type.as_ref().map(|p| p.as_string()))
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
@@ -150,7 +146,7 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens.clone())
|
||||
self.get_current_auth_json().and_then(|t| t.tokens)
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
@@ -193,10 +189,11 @@ impl CodexAuth {
|
||||
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
fn read_openai_api_key_from_env() -> Option<String> {
|
||||
pub fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
@@ -214,7 +211,7 @@ pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an `auth.json` that contains only the API key. Intended for CLI use.
|
||||
/// Writes an `auth.json` that contains only the API key.
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
@@ -224,28 +221,11 @@ pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<(
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
fn load_auth(
|
||||
codex_home: &Path,
|
||||
include_env_var: bool,
|
||||
preferred_auth_method: AuthMode,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
// First, check to see if there is a valid auth.json file. If not, we fall
|
||||
// back to AuthMode::ApiKey using the OPENAI_API_KEY environment variable
|
||||
// (if it is set).
|
||||
fn load_auth(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let client = crate::default_client::create_client();
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
// If auth.json does not exist, try to read the OPENAI_API_KEY from the
|
||||
// environment variable.
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound && include_env_var => {
|
||||
return match read_openai_api_key_from_env() {
|
||||
Some(api_key) => Ok(Some(CodexAuth::from_api_key_with_client(&api_key, client))),
|
||||
None => Ok(None),
|
||||
};
|
||||
}
|
||||
// Though if auth.json exists but is malformed, do not fall back to the
|
||||
// env var because the user may be expecting to use AuthMode::ChatGPT.
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
@@ -257,32 +237,11 @@ fn load_auth(
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// If the auth.json has an API key AND does not appear to be on a plan that
|
||||
// should prefer AuthMode::ChatGPT, use AuthMode::ApiKey.
|
||||
// Prefer AuthMode.ApiKey if it's set in the auth.json.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
// Should any of these be AuthMode::ChatGPT with the api_key set?
|
||||
// Does AuthMode::ChatGPT indicate that there is an auth.json that is
|
||||
// "refreshable" even if we are using the API key for auth?
|
||||
match &tokens {
|
||||
Some(tokens) => {
|
||||
if tokens.should_use_api_key(preferred_auth_method, tokens.is_openai_email()) {
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
} else {
|
||||
// Ignore the API key and fall through to ChatGPT auth.
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// We have an API key but no tokens in the auth.json file.
|
||||
// Perhaps the user ran `codex login --api-key <KEY>` or updated
|
||||
// auth.json by hand. Either way, let's assume they are trying
|
||||
// to use their API key.
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
}
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
|
||||
// For the AuthMode::ChatGPT variant, perhaps neither api_key nor
|
||||
// openai_api_key should exist?
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
@@ -332,10 +291,10 @@ async fn update_tokens(
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token.to_string();
|
||||
tokens.access_token = access_token;
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token.to_string();
|
||||
tokens.refresh_token = refresh_token;
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
@@ -412,7 +371,6 @@ use std::sync::RwLock;
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
preferred_auth_mode: AuthMode,
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
@@ -468,9 +426,7 @@ mod tests {
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
} = super::load_auth(codex_home.path()).unwrap().unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
@@ -499,88 +455,6 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
/// Even if the OPENAI_API_KEY is set in auth.json, if the plan is not in
|
||||
/// [`TokenData::is_plan_that_should_use_api_key`], it should use
|
||||
/// [`AuthMode::ChatGPT`].
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_api_key_still_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// If the OPENAI_API_KEY is set in auth.json and it is an enterprise
|
||||
/// account, then it should use [`AuthMode::ApiKey`].
|
||||
#[tokio::test]
|
||||
async fn enterprise_account_with_api_key_uses_apikey_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "enterprise".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(Some("sk-test-key".to_string()), api_key);
|
||||
assert_eq!(AuthMode::ApiKey, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().expect("should unwrap");
|
||||
assert!(guard.is_none(), "auth_dot_json should be None");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
@@ -591,9 +465,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = super::load_auth(dir.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let auth = super::load_auth(dir.path()).unwrap().unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
@@ -683,26 +555,17 @@ impl AuthManager {
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home, preferred_auth_mode)
|
||||
.ok()
|
||||
.flatten();
|
||||
pub fn new(codex_home: PathBuf) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home).ok().flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth,
|
||||
}),
|
||||
inner: RwLock::new(CachedAuth { auth }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let preferred_auth_mode = auth.mode;
|
||||
let cached = CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth: Some(auth),
|
||||
};
|
||||
let cached = CachedAuth { auth: Some(auth) };
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
inner: RwLock::new(cached),
|
||||
@@ -714,21 +577,10 @@ impl AuthManager {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Preferred auth method used when (re)loading.
|
||||
pub fn preferred_auth_method(&self) -> AuthMode {
|
||||
self.inner
|
||||
.read()
|
||||
.map(|c| c.preferred_auth_mode)
|
||||
.unwrap_or(AuthMode::ApiKey)
|
||||
}
|
||||
|
||||
/// Force a reload using the existing preferred auth method. Returns
|
||||
/// Force a reload of the auth information from auth.json. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let preferred = self.preferred_auth_method();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home, preferred)
|
||||
.ok()
|
||||
.flatten();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home).ok().flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
@@ -747,8 +599,8 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home, preferred_auth_mode))
|
||||
pub fn shared(codex_home: PathBuf) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
|
||||
@@ -41,6 +41,7 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -60,7 +61,7 @@ struct Error {
|
||||
message: Option<String>,
|
||||
|
||||
// Optional fields available on "usage_limit_reached" and "usage_not_included" errors
|
||||
plan_type: Option<String>,
|
||||
plan_type: Option<PlanType>,
|
||||
resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -239,10 +240,10 @@ impl ModelClient {
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
@@ -304,7 +305,7 @@ impl ModelClient {
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.and_then(|a| a.get_plan_type()));
|
||||
.or_else(|| auth.as_ref().and_then(|a| a.get_plan_type()));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
@@ -1037,4 +1038,37 @@ mod tests {
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_known_plan_type_and_serializes_back() {
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json = r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_in_seconds":3600}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_unknown_plan_type_and_serializes_back() {
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json =
|
||||
r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_in_seconds":60}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use std::sync::atomic::AtomicU64;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::conversation_history::EventMsgsHistory;
|
||||
use crate::event_mapping::map_response_item_to_event_messages;
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
@@ -17,13 +16,16 @@ use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::protocol::ConversationHistoryResponseEvent;
|
||||
use codex_protocol::protocol::CompactedItem;
|
||||
use codex_protocol::protocol::ConversationPathResponseEvent;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::prelude::*;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -45,7 +47,7 @@ use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::conversation_history::ResponseItemsHistory;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::environment_context::EnvironmentContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
@@ -110,6 +112,7 @@ use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
@@ -208,12 +211,7 @@ impl Codex {
|
||||
let conversation_id = session.conversation_id;
|
||||
|
||||
// This task will run until Op::Shutdown is received.
|
||||
tokio::spawn(submission_loop(
|
||||
session.clone(),
|
||||
turn_context,
|
||||
config,
|
||||
rx_sub,
|
||||
));
|
||||
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
|
||||
let codex = Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
@@ -264,8 +262,7 @@ struct State {
|
||||
current_task: Option<AgentTask>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
response_items: ResponseItemsHistory,
|
||||
event_msgs: EventMsgsHistory,
|
||||
history: ConversationHistory,
|
||||
token_info: Option<TokenUsageInfo>,
|
||||
}
|
||||
|
||||
@@ -279,6 +276,7 @@ pub(crate) struct Session {
|
||||
/// Manager for external MCP servers/tools.
|
||||
mcp_connection_manager: McpConnectionManager,
|
||||
session_manager: ExecSessionManager,
|
||||
unified_exec_manager: UnifiedExecSessionManager,
|
||||
|
||||
/// External notifier command (will be passed as args to exec()). When
|
||||
/// `None` this feature is disabled.
|
||||
@@ -405,7 +403,7 @@ impl Session {
|
||||
let rollout_fut = RolloutRecorder::new(&config, rollout_params);
|
||||
|
||||
let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
|
||||
let default_shell_fut = shell::default_user_shell(conversation_id.0, &config.codex_home);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
|
||||
// Join all independent futures.
|
||||
@@ -419,7 +417,7 @@ impl Session {
|
||||
let rollout_path = rollout_recorder.rollout_path.clone();
|
||||
// Create the mutable state for the Session.
|
||||
let state = State {
|
||||
response_items: ResponseItemsHistory::new(),
|
||||
history: ConversationHistory::new(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -464,12 +462,12 @@ impl Session {
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
}),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -478,12 +476,12 @@ impl Session {
|
||||
shell_environment_policy: config.shell_environment_policy.clone(),
|
||||
cwd,
|
||||
};
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event: tx_event.clone(),
|
||||
mcp_connection_manager,
|
||||
session_manager: ExecSessionManager::default(),
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notify,
|
||||
state: Mutex::new(state),
|
||||
rollout: Mutex::new(Some(rollout_recorder)),
|
||||
@@ -503,6 +501,7 @@ impl Session {
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
reasoning_effort: model_reasoning_effort,
|
||||
history_log_id,
|
||||
history_entry_count,
|
||||
initial_messages,
|
||||
@@ -543,7 +542,7 @@ impl Session {
|
||||
InitialHistory::New => {
|
||||
// Build and record initial items (user instructions + environment context)
|
||||
let items = self.build_initial_context(turn_context);
|
||||
self.record_response_items(&items).await;
|
||||
self.record_conversation_items(&items).await;
|
||||
}
|
||||
InitialHistory::Resumed(_) | InitialHistory::Forked(_) => {
|
||||
let rollout_items = conversation_history.get_rollout_items();
|
||||
@@ -552,13 +551,7 @@ impl Session {
|
||||
// Always add response items to conversation history
|
||||
let response_items = conversation_history.get_response_items();
|
||||
if !response_items.is_empty() {
|
||||
self.record_into_history_response_items(&response_items);
|
||||
}
|
||||
|
||||
// Always add event msgs to conversation history
|
||||
let event_msgs = conversation_history.get_event_msgs();
|
||||
if let Some(event_msgs) = event_msgs {
|
||||
self.record_into_history_event_msgs(&event_msgs);
|
||||
self.record_into_history(&response_items);
|
||||
}
|
||||
|
||||
// If persisting, persist all rollout items as-is (recorder filters)
|
||||
@@ -571,9 +564,9 @@ impl Session {
|
||||
|
||||
/// Persist the event to rollout and send it to clients.
|
||||
pub(crate) async fn send_event(&self, event: Event) {
|
||||
// Persist the event into event_msgs in memory
|
||||
self.record_conversation_event_msgs(std::slice::from_ref(&event.msg))
|
||||
.await;
|
||||
// Persist the event into rollout (recorder filters as needed)
|
||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
error!("failed to send tool call event: {e}");
|
||||
}
|
||||
@@ -663,31 +656,18 @@ impl Session {
|
||||
state.approved_commands.insert(cmd);
|
||||
}
|
||||
|
||||
async fn record_conversation_event_msgs(&self, items: &[EventMsg]) {
|
||||
self.record_into_history_event_msgs(items);
|
||||
self.persist_rollout_event_msgs(items).await;
|
||||
}
|
||||
|
||||
/// Records input items: always append to conversation history and
|
||||
/// persist these response items to rollout.
|
||||
async fn record_response_items(&self, items: &[ResponseItem]) {
|
||||
self.record_into_history_response_items(items);
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
self.record_into_history(items);
|
||||
self.persist_rollout_response_items(items).await;
|
||||
}
|
||||
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
fn record_into_history_response_items(&self, items: &[ResponseItem]) {
|
||||
fn record_into_history(&self, items: &[ResponseItem]) {
|
||||
self.state
|
||||
.lock_unchecked()
|
||||
.response_items
|
||||
.record_items(items.iter());
|
||||
}
|
||||
|
||||
/// Append EventMsgs to the in-memory conversation history only.
|
||||
fn record_into_history_event_msgs(&self, items: &[EventMsg]) {
|
||||
self.state
|
||||
.lock_unchecked()
|
||||
.event_msgs
|
||||
.history
|
||||
.record_items(items.iter());
|
||||
}
|
||||
|
||||
@@ -700,23 +680,11 @@ impl Session {
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
}
|
||||
|
||||
async fn persist_rollout_event_msgs(&self, items: &[EventMsg]) {
|
||||
let rollout_items: Vec<RolloutItem> =
|
||||
items.iter().cloned().map(RolloutItem::EventMsg).collect();
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
}
|
||||
|
||||
fn build_initial_context(&self, turn_context: &TurnContext) -> Vec<ResponseItem> {
|
||||
let mut items = Vec::<ResponseItem>::with_capacity(2);
|
||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||
items.push(UserInstructions::new(user_instructions.to_string()).into());
|
||||
}
|
||||
items.push(ResponseItem::from(EnvironmentContext::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
Some(self.user_shell.clone()),
|
||||
)));
|
||||
items
|
||||
}
|
||||
|
||||
@@ -737,14 +705,13 @@ impl Session {
|
||||
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
|
||||
let response_item: ResponseItem = response_input.clone().into();
|
||||
// Add to conversation history and persist response item to rollout
|
||||
self.record_response_items(std::slice::from_ref(&response_item))
|
||||
self.record_conversation_items(std::slice::from_ref(&response_item))
|
||||
.await;
|
||||
|
||||
// Derive user message events and persist only UserMessage to rollout
|
||||
let msgs =
|
||||
map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning);
|
||||
let user_msgs: Vec<RolloutItem> = msgs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter_map(|m| match m {
|
||||
EventMsg::UserMessage(ev) => Some(RolloutItem::EventMsg(EventMsg::UserMessage(ev))),
|
||||
@@ -754,7 +721,6 @@ impl Session {
|
||||
if !user_msgs.is_empty() {
|
||||
self.persist_rollout_items(&user_msgs).await;
|
||||
}
|
||||
self.state.lock_unchecked().event_msgs.record_items(&msgs);
|
||||
}
|
||||
|
||||
async fn on_exec_command_begin(
|
||||
@@ -937,7 +903,7 @@ impl Session {
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
/// history with additional items for this turn.
|
||||
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
[self.state.lock_unchecked().response_items.contents(), extra].concat()
|
||||
[self.state.lock_unchecked().history.contents(), extra].concat()
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
@@ -1094,7 +1060,7 @@ impl AgentTask {
|
||||
id: self.sub_id,
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
let sess = self.sess.clone();
|
||||
let sess = self.sess;
|
||||
tokio::spawn(async move {
|
||||
sess.send_event(event).await;
|
||||
});
|
||||
@@ -1130,10 +1096,10 @@ async fn submission_loop(
|
||||
let provider = prev.client.get_provider();
|
||||
|
||||
// Effective model + family
|
||||
let (effective_model, effective_family) = if let Some(m) = model {
|
||||
let (effective_model, effective_family) = if let Some(ref m) = model {
|
||||
let fam =
|
||||
find_family_for_model(&m).unwrap_or_else(|| config.model_family.clone());
|
||||
(m, fam)
|
||||
find_family_for_model(m).unwrap_or_else(|| config.model_family.clone());
|
||||
(m.clone(), fam)
|
||||
} else {
|
||||
(prev.client.get_model(), prev.client.get_model_family())
|
||||
};
|
||||
@@ -1170,12 +1136,12 @@ async fn submission_loop(
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &effective_family,
|
||||
approval_policy: new_approval_policy,
|
||||
sandbox_policy: new_sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
});
|
||||
|
||||
let new_turn_context = TurnContext {
|
||||
@@ -1191,25 +1157,18 @@ async fn submission_loop(
|
||||
|
||||
// Install the new persistent context for subsequent tasks/turns.
|
||||
turn_context = Arc::new(new_turn_context);
|
||||
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
|
||||
sess.record_response_items(&[ResponseItem::from(EnvironmentContext::new(
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
))])
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::UserInput { items } => {
|
||||
// attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(items) {
|
||||
// no current task, spawn a new one
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task);
|
||||
}
|
||||
submit_user_input(
|
||||
turn_context.cwd.clone(),
|
||||
turn_context.approval_policy,
|
||||
turn_context.sandbox_policy.clone(),
|
||||
&sess,
|
||||
&turn_context,
|
||||
sub.id.clone(),
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Op::UserTurn {
|
||||
items,
|
||||
@@ -1254,13 +1213,14 @@ async fn submission_loop(
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config
|
||||
.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config
|
||||
.use_experimental_unified_exec_tool,
|
||||
}),
|
||||
user_instructions: turn_context.user_instructions.clone(),
|
||||
base_instructions: turn_context.base_instructions.clone(),
|
||||
@@ -1269,11 +1229,16 @@ async fn submission_loop(
|
||||
shell_environment_policy: turn_context.shell_environment_policy.clone(),
|
||||
cwd,
|
||||
};
|
||||
// TODO: record the new environment context in the conversation history
|
||||
// no current task, spawn a new one with the per‑turn context
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
|
||||
sess.set_task(task);
|
||||
submit_user_input(
|
||||
fresh_turn_context.cwd.clone(),
|
||||
fresh_turn_context.approval_policy,
|
||||
fresh_turn_context.sandbox_policy.clone(),
|
||||
&sess,
|
||||
&Arc::new(fresh_turn_context),
|
||||
sub.id.clone(),
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::ExecApproval { id, decision } => match decision {
|
||||
@@ -1407,19 +1372,29 @@ async fn submission_loop(
|
||||
sess.send_event(event).await;
|
||||
break;
|
||||
}
|
||||
Op::GetHistory => {
|
||||
Op::GetPath => {
|
||||
let sub_id = sub.id.clone();
|
||||
let entries = {
|
||||
let state = sess.state.lock_unchecked();
|
||||
let rolled_response_items: Vec<RolloutItem> = (&state.response_items).into();
|
||||
let rolled_event_msgs: Vec<RolloutItem> = (&state.event_msgs).into();
|
||||
[rolled_response_items, rolled_event_msgs].concat()
|
||||
// Flush rollout writes before returning the path so readers observe a consistent file.
|
||||
let (path, rec_opt) = {
|
||||
let guard = sess.rollout.lock_unchecked();
|
||||
match guard.as_ref() {
|
||||
Some(rec) => (rec.get_rollout_path(), Some(rec.clone())),
|
||||
None => {
|
||||
error!("rollout recorder not found");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
if let Some(rec) = rec_opt
|
||||
&& let Err(e) = rec.flush().await
|
||||
{
|
||||
warn!("failed to flush rollout recorder before GetHistory: {e}");
|
||||
}
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::ConversationHistory(ConversationHistoryResponseEvent {
|
||||
msg: EventMsg::ConversationPath(ConversationPathResponseEvent {
|
||||
conversation_id: sess.conversation_id,
|
||||
history: InitialHistory::Forked(entries),
|
||||
path,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
@@ -1480,7 +1455,7 @@ async fn run_task(
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
sess.record_response_items(&pending_input).await;
|
||||
sess.record_conversation_items(&pending_input).await;
|
||||
|
||||
// Construct the input that we will send to the model. When using the
|
||||
// Chat completions API (or ZDR clients), the model needs the full
|
||||
@@ -1607,7 +1582,7 @@ async fn run_task(
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
sess.record_response_items(&items_to_record_in_conversation_history)
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1762,7 +1737,7 @@ async fn try_run_turn(
|
||||
}
|
||||
})
|
||||
.map(|call_id| ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
call_id,
|
||||
output: "aborted".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
@@ -1778,6 +1753,15 @@ async fn try_run_turn(
|
||||
})
|
||||
};
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
@@ -1960,10 +1944,14 @@ async fn run_compact_task(
|
||||
|
||||
sess.remove_task(&sub_id);
|
||||
|
||||
{
|
||||
let rollout_item = {
|
||||
let mut state = sess.state.lock_unchecked();
|
||||
state.response_items.keep_last_messages(1);
|
||||
}
|
||||
state.history.keep_last_messages(1);
|
||||
RolloutItem::Compacted(CompactedItem {
|
||||
message: state.history.last_agent_message(),
|
||||
})
|
||||
};
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
@@ -2097,6 +2085,72 @@ async fn handle_response_item(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn handle_unified_exec_tool_call(
|
||||
sess: &Session,
|
||||
call_id: String,
|
||||
session_id: Option<String>,
|
||||
arguments: Vec<String>,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> ResponseInputItem {
|
||||
let parsed_session_id = if let Some(session_id) = session_id {
|
||||
match session_id.parse::<i32>() {
|
||||
Ok(parsed) => Some(parsed),
|
||||
Err(output) => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("invalid session_id: {session_id} due to error {output}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = crate::unified_exec::UnifiedExecRequest {
|
||||
session_id: parsed_session_id,
|
||||
input_chunks: &arguments,
|
||||
timeout_ms,
|
||||
};
|
||||
|
||||
let result = sess.unified_exec_manager.handle_request(request).await;
|
||||
|
||||
let output_payload = match result {
|
||||
Ok(value) => {
|
||||
#[derive(Serialize)]
|
||||
struct SerializedUnifiedExecResult<'a> {
|
||||
session_id: Option<String>,
|
||||
output: &'a str,
|
||||
}
|
||||
|
||||
match serde_json::to_string(&SerializedUnifiedExecResult {
|
||||
session_id: value.session_id.map(|id| id.to_string()),
|
||||
output: &value.output,
|
||||
}) {
|
||||
Ok(serialized) => FunctionCallOutputPayload {
|
||||
content: serialized,
|
||||
success: Some(true),
|
||||
},
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: format!("failed to serialize unified exec output: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
}
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: format!("unified exec failed: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: output_payload,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_function_call(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
@@ -2124,6 +2178,38 @@ async fn handle_function_call(
|
||||
)
|
||||
.await
|
||||
}
|
||||
"unified_exec" => {
|
||||
#[derive(Deserialize)]
|
||||
struct UnifiedExecArgs {
|
||||
input: Vec<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
#[serde(default)]
|
||||
timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
let args = match serde_json::from_str::<UnifiedExecArgs>(&arguments) {
|
||||
Ok(args) => args,
|
||||
Err(err) => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("failed to parse function arguments: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
handle_unified_exec_tool_call(
|
||||
sess,
|
||||
call_id,
|
||||
args.session_id,
|
||||
args.input,
|
||||
args.timeout_ms,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"view_image" => {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct SeeImageArgs {
|
||||
@@ -2352,25 +2438,13 @@ pub struct ExecInvokeArgs<'a> {
|
||||
pub stdout_stream: Option<StdoutStream>,
|
||||
}
|
||||
|
||||
fn should_translate_shell_command(
|
||||
shell: &crate::shell::Shell,
|
||||
shell_policy: &ShellEnvironmentPolicy,
|
||||
) -> bool {
|
||||
matches!(shell, crate::shell::Shell::PowerShell(_))
|
||||
|| shell_policy.use_profile
|
||||
|| matches!(
|
||||
shell,
|
||||
crate::shell::Shell::Posix(shell) if shell.shell_snapshot.is_some()
|
||||
)
|
||||
}
|
||||
|
||||
fn maybe_translate_shell_command(
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
) -> ExecParams {
|
||||
let should_translate =
|
||||
should_translate_shell_command(&sess.user_shell, &turn_context.shell_environment_policy);
|
||||
let should_translate = matches!(sess.user_shell, crate::shell::Shell::PowerShell(_))
|
||||
|| turn_context.shell_environment_policy.use_profile;
|
||||
|
||||
if should_translate
|
||||
&& let Some(command) = sess
|
||||
@@ -2614,6 +2688,20 @@ async fn handle_sandbox_error(
|
||||
let sub_id = exec_command_context.sub_id.clone();
|
||||
let cwd = exec_command_context.cwd.clone();
|
||||
|
||||
// if the command timed out, we can simply return this failure to the model
|
||||
if matches!(error, SandboxErr::Timeout) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"command timed out after {} milliseconds",
|
||||
params.timeout_duration().as_millis()
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Early out if either the user never wants to be asked for approval, or
|
||||
// we're letting the model manage escalation requests. Otherwise, continue
|
||||
match turn_context.approval_policy {
|
||||
@@ -2631,20 +2719,6 @@ async fn handle_sandbox_error(
|
||||
AskForApproval::UnlessTrusted | AskForApproval::OnFailure => (),
|
||||
}
|
||||
|
||||
// similarly, if the command timed out, we can simply return this failure to the model
|
||||
if matches!(error, SandboxErr::Timeout) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"command timed out after {} milliseconds",
|
||||
params.timeout_duration().as_millis()
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Note that when `error` is `SandboxErr::Denied`, it could be a false
|
||||
// positive. That is, it may have exited with a non-zero exit code, not
|
||||
// because the sandbox denied it, but because that is its expected behavior,
|
||||
@@ -2739,6 +2813,29 @@ async fn handle_sandbox_error(
|
||||
}
|
||||
}
|
||||
|
||||
async fn submit_user_input(
|
||||
cwd: PathBuf,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
items: Vec<InputItem>,
|
||||
) {
|
||||
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
|
||||
Some(cwd),
|
||||
Some(approval_policy),
|
||||
Some(sandbox_policy),
|
||||
Some(sess.user_shell.clone()),
|
||||
))])
|
||||
.await;
|
||||
if let Err(items) = sess.inject_input(items) {
|
||||
// no current task, spawn a new one
|
||||
let task = AgentTask::spawn(Arc::clone(sess), Arc::clone(turn_context), sub_id, items);
|
||||
sess.set_task(task);
|
||||
}
|
||||
}
|
||||
|
||||
fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
aggregated_output, ..
|
||||
@@ -2903,6 +3000,15 @@ async fn drain_to_completed(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
@@ -2916,9 +3022,7 @@ async fn drain_to_completed(
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
// Record only to in-memory conversation history; avoid state snapshot.
|
||||
let mut state = sess.state.lock_unchecked();
|
||||
state
|
||||
.response_items
|
||||
.record_items(std::slice::from_ref(&item));
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
}
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
@@ -2989,15 +3093,10 @@ fn convert_call_tool_result_to_function_call_output_payload(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config_types::ShellEnvironmentPolicyInherit;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use shell::ShellSnapshot;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
fn text_block(s: &str) -> ContentBlock {
|
||||
@@ -3008,48 +3107,6 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
fn shell_policy_with_profile(use_profile: bool) -> ShellEnvironmentPolicy {
|
||||
ShellEnvironmentPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::All,
|
||||
ignore_default_excludes: false,
|
||||
exclude: Vec::new(),
|
||||
r#set: HashMap::new(),
|
||||
include_only: Vec::new(),
|
||||
use_profile,
|
||||
}
|
||||
}
|
||||
|
||||
fn zsh_shell(shell_snapshot: Option<Arc<ShellSnapshot>>) -> shell::Shell {
|
||||
shell::Shell::Posix(shell::PosixShell {
|
||||
shell_path: "/bin/zsh".to_string(),
|
||||
rc_path: "/Users/example/.zshrc".to_string(),
|
||||
shell_snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn translates_commands_when_shell_policy_requests_profile() {
|
||||
let policy = shell_policy_with_profile(true);
|
||||
let shell = zsh_shell(None);
|
||||
assert!(should_translate_shell_command(&shell, &policy));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn translates_commands_for_zsh_with_snapshot() {
|
||||
let policy = shell_policy_with_profile(false);
|
||||
let shell = zsh_shell(Some(Arc::new(ShellSnapshot::new(PathBuf::from(
|
||||
"/tmp/snapshot",
|
||||
)))));
|
||||
assert!(should_translate_shell_command(&shell, &policy));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bypasses_translation_for_zsh_without_snapshot_or_profile() {
|
||||
let policy = shell_policy_with_profile(false);
|
||||
let shell = zsh_shell(None);
|
||||
assert!(!should_translate_shell_command(&shell, &policy));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefers_structured_content_when_present() {
|
||||
let ctr = CallToolResult {
|
||||
@@ -3085,7 +3142,7 @@ mod tests {
|
||||
exit_code: 0,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(full.clone()),
|
||||
aggregated_output: StreamOutput::new(full),
|
||||
duration: StdDuration::from_secs(1),
|
||||
};
|
||||
|
||||
@@ -3119,7 +3176,7 @@ mod tests {
|
||||
fn model_truncation_respects_byte_budget() {
|
||||
// Construct a large output (about 100kB) so byte budget dominates
|
||||
let big_line = "x".repeat(100);
|
||||
let full = std::iter::repeat_n(big_line.clone(), 1000)
|
||||
let full = std::iter::repeat_n(big_line, 1000)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use anyhow::Context;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::Tools;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use dirs::home_dir;
|
||||
@@ -32,13 +32,14 @@ use toml::Value as TomlValue;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
pub const GPT5_HIGH_MODEL: &str = "gpt-5-high";
|
||||
|
||||
/// Maximum number of bytes of the documentation that will be embedded. Larger
|
||||
/// files are *silently truncated* to this size so we do not take up too much of
|
||||
/// the context window.
|
||||
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
|
||||
|
||||
const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
pub(crate) const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
/// Application configuration loaded from disk and merged with overrides.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -129,9 +130,6 @@ pub struct Config {
|
||||
/// output will be hyperlinked using the specified URI scheme.
|
||||
pub file_opener: UriBasedFileOpener,
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
pub tui: Tui,
|
||||
|
||||
/// Path to the `codex-linux-sandbox` executable. This must be set if
|
||||
/// [`crate::exec::SandboxType::LinuxSeccomp`] is used. Note that this
|
||||
/// cannot be set in the config file: it must be set in code via
|
||||
@@ -167,13 +165,17 @@ pub struct Config {
|
||||
|
||||
pub tools_web_search_request: bool,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: AuthMode,
|
||||
|
||||
pub use_experimental_streamable_shell_tool: bool,
|
||||
|
||||
/// If set to `true`, used only the experimental unified exec tool.
|
||||
pub use_experimental_unified_exec_tool: bool,
|
||||
|
||||
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
||||
pub include_view_image_tool: bool,
|
||||
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -257,17 +259,7 @@ pub fn load_config_as_toml(codex_home: &Path) -> std::io::Result<TomlValue> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
fn set_project_trusted_inner(doc: &mut DocumentMut, project_path: &Path) -> anyhow::Result<()> {
|
||||
// Ensure we render a human-friendly structure:
|
||||
//
|
||||
// [projects]
|
||||
@@ -283,14 +275,26 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
// Ensure top-level `projects` exists as a non-inline, explicit table. If it
|
||||
// exists but was previously represented as a non-table (e.g., inline),
|
||||
// replace it with an explicit table.
|
||||
let mut created_projects_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("projects")
|
||||
|| root.get("projects").and_then(|i| i.as_table()).is_none();
|
||||
if needs_table {
|
||||
root.insert("projects", toml_edit::table());
|
||||
created_projects_table = true;
|
||||
// If `projects` exists but isn't a standard table (e.g., it's an inline table),
|
||||
// convert it to an explicit table while preserving existing entries.
|
||||
let existing_projects = root.get("projects").cloned();
|
||||
if existing_projects.as_ref().is_none_or(|i| !i.is_table()) {
|
||||
let mut projects_tbl = toml_edit::Table::new();
|
||||
projects_tbl.set_implicit(true);
|
||||
|
||||
// If there was an existing inline table, migrate its entries to explicit tables.
|
||||
if let Some(inline_tbl) = existing_projects.as_ref().and_then(|i| i.as_inline_table()) {
|
||||
for (k, v) in inline_tbl.iter() {
|
||||
if let Some(inner_tbl) = v.as_inline_table() {
|
||||
let new_tbl = inner_tbl.clone().into_table();
|
||||
projects_tbl.insert(k, toml_edit::Item::Table(new_tbl));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
root.insert("projects", toml_edit::Item::Table(projects_tbl));
|
||||
}
|
||||
}
|
||||
let Some(projects_tbl) = doc["projects"].as_table_mut() else {
|
||||
@@ -299,12 +303,6 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
));
|
||||
};
|
||||
|
||||
// If we created the `projects` table ourselves, keep it implicit so we
|
||||
// don't render a standalone `[projects]` header.
|
||||
if created_projects_table {
|
||||
projects_tbl.set_implicit(true);
|
||||
}
|
||||
|
||||
// Ensure the per-project entry is its own explicit table. If it exists but
|
||||
// is not a table (e.g., an inline table), replace it with an explicit table.
|
||||
let needs_proj_table = !projects_tbl.contains_key(project_key.as_str())
|
||||
@@ -323,6 +321,21 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
};
|
||||
proj_tbl.set_implicit(false);
|
||||
proj_tbl["trust_level"] = toml_edit::value("trusted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
set_project_trusted_inner(&mut doc, project_path)?;
|
||||
|
||||
// ensure codex_home exists
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
@@ -337,6 +350,107 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
) -> anyhow::Result<&'a mut toml_edit::Table> {
|
||||
let mut created_profiles_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("profiles")
|
||||
|| root
|
||||
.get("profiles")
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_table {
|
||||
root.insert("profiles", toml_edit::table());
|
||||
created_profiles_table = true;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(profiles_table) = doc["profiles"].as_table_mut() else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"profiles table missing after initialization"
|
||||
));
|
||||
};
|
||||
|
||||
if created_profiles_table {
|
||||
profiles_table.set_implicit(true);
|
||||
}
|
||||
|
||||
let needs_profile_table = !profiles_table.contains_key(profile_name)
|
||||
|| profiles_table
|
||||
.get(profile_name)
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_profile_table {
|
||||
profiles_table.insert(profile_name, toml_edit::table());
|
||||
}
|
||||
|
||||
let Some(profile_table) = profiles_table
|
||||
.get_mut(profile_name)
|
||||
.and_then(|item| item.as_table_mut())
|
||||
else {
|
||||
return Err(anyhow::anyhow!(format!(
|
||||
"profile table missing for {profile_name}"
|
||||
)));
|
||||
};
|
||||
|
||||
profile_table.set_implicit(false);
|
||||
Ok(profile_table)
|
||||
}
|
||||
|
||||
// TODO(jif) refactor config persistence.
|
||||
pub async fn persist_model_selection(
|
||||
codex_home: &Path,
|
||||
active_profile: Option<&str>,
|
||||
model: &str,
|
||||
effort: Option<ReasoningEffort>,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let serialized = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => String::new(),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
let mut doc = if serialized.is_empty() {
|
||||
DocumentMut::new()
|
||||
} else {
|
||||
serialized.parse::<DocumentMut>()?
|
||||
};
|
||||
|
||||
if let Some(profile_name) = active_profile {
|
||||
let profile_table = ensure_profile_table(&mut doc, profile_name)?;
|
||||
profile_table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
profile_table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
} else {
|
||||
let table = doc.as_table_mut();
|
||||
table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) refactor the home creation
|
||||
tokio::fs::create_dir_all(codex_home)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to create Codex home directory at {}",
|
||||
codex_home.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
tokio::fs::write(&config_path, doc.to_string())
|
||||
.await
|
||||
.with_context(|| format!("failed to persist config.toml at {}", config_path.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply a single dotted-path override onto a TOML value.
|
||||
fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
use toml::value::Table;
|
||||
@@ -381,7 +495,7 @@ fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
}
|
||||
|
||||
/// Base config deserialized from ~/.codex/config.toml.
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ConfigToml {
|
||||
/// Optional override of model selection.
|
||||
pub model: Option<String>,
|
||||
@@ -472,12 +586,10 @@ pub struct ConfigToml {
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: Option<AuthMode>,
|
||||
|
||||
/// Nested tools section for feature toggles
|
||||
pub tools: Option<ToolsToml>,
|
||||
|
||||
@@ -515,7 +627,7 @@ pub struct ProjectConfig {
|
||||
pub trust_level: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ToolsToml {
|
||||
#[serde(default, alias = "web_search_request")]
|
||||
pub web_search: Option<bool>,
|
||||
@@ -653,7 +765,11 @@ impl Config {
|
||||
tools_web_search_request: override_tools_web_search_request,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
let active_profile_name = config_profile_key
|
||||
.as_ref()
|
||||
.or(cfg.profile.as_ref())
|
||||
.cloned();
|
||||
let config_profile = match active_profile_name.as_ref() {
|
||||
Some(key) => cfg
|
||||
.profiles
|
||||
.get(key)
|
||||
@@ -788,7 +904,6 @@ impl Config {
|
||||
codex_home,
|
||||
history,
|
||||
file_opener: cfg.file_opener.unwrap_or(UriBasedFileOpener::VsCode),
|
||||
tui: cfg.tui.unwrap_or_default(),
|
||||
codex_linux_sandbox_exe,
|
||||
|
||||
hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false),
|
||||
@@ -814,11 +929,14 @@ impl Config {
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
.unwrap_or(false),
|
||||
use_experimental_unified_exec_tool: cfg
|
||||
.experimental_use_unified_exec_tool
|
||||
.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
};
|
||||
Ok(config)
|
||||
@@ -929,6 +1047,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
@@ -1019,6 +1138,145 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_overwrites_existing_model() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.dev]
|
||||
model = "gpt-4.1"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"o4-mini",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("o4-mini"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-4.1"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::Low),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
let profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("profile should be created");
|
||||
|
||||
assert_eq!(profile.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(profile.model_reasoning_effort, Some(ReasoningEffort::Low));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_existing_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[profiles.dev]
|
||||
model = "gpt-4"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.prod]
|
||||
model = "gpt-5"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"o4-high",
|
||||
Some(ReasoningEffort::Medium),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
let dev_profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("dev profile should survive updates");
|
||||
assert_eq!(dev_profile.model.as_deref(), Some("o4-high"));
|
||||
assert_eq!(
|
||||
dev_profile.model_reasoning_effort,
|
||||
Some(ReasoningEffort::Medium)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("prod")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-5"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct PrecedenceTestFixture {
|
||||
cwd: TempDir,
|
||||
codex_home: TempDir,
|
||||
@@ -1177,7 +1435,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1190,9 +1447,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
},
|
||||
o3_profile_config
|
||||
@@ -1233,7 +1491,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1246,9 +1503,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1304,7 +1562,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1317,9 +1574,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1361,7 +1619,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1374,9 +1631,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1387,17 +1645,14 @@ model_verbosity = "high"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_writes_explicit_tables() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
let mut doc = DocumentMut::new();
|
||||
|
||||
// Call the function under test
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
// Read back the generated config.toml and assert exact contents
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1415,12 +1670,10 @@ trust_level = "trusted"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_converts_inline_to_explicit() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
|
||||
// Seed config.toml with an inline project entry under [projects]
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1432,13 +1685,12 @@ trust_level = "trusted"
|
||||
{path_str} = {{ trust_level = "untrusted" }}
|
||||
"#
|
||||
);
|
||||
std::fs::create_dir_all(codex_home.path())?;
|
||||
std::fs::write(&config_path, initial)?;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Run the function; it should convert to explicit tables and set trusted
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Assert exact output after conversion to explicit table
|
||||
let expected = format!(
|
||||
@@ -1453,5 +1705,37 @@ trust_level = "trusted"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// No test enforcing the presence of a standalone [projects] header.
|
||||
#[test]
|
||||
fn test_set_project_trusted_migrates_top_level_inline_projects_preserving_entries()
|
||||
-> anyhow::Result<()> {
|
||||
let initial = r#"toplevel = "baz"
|
||||
projects = { "/Users/mbolin/code/codex4" = { trust_level = "trusted", foo = "bar" } , "/Users/mbolin/code/codex3" = { trust_level = "trusted" } }
|
||||
model = "foo""#;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Approve a new directory
|
||||
let new_project = Path::new("/Users/mbolin/code/codex2");
|
||||
set_project_trusted_inner(&mut doc, new_project)?;
|
||||
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Since we created the [projects] table as part of migration, it is kept implicit.
|
||||
// Expect explicit per-project tables, preserving prior entries and appending the new one.
|
||||
let expected = r#"toplevel = "baz"
|
||||
model = "foo"
|
||||
|
||||
[projects."/Users/mbolin/code/codex4"]
|
||||
trust_level = "trusted"
|
||||
foo = "bar"
|
||||
|
||||
[projects."/Users/mbolin/code/codex3"]
|
||||
trust_level = "trusted"
|
||||
|
||||
[projects."/Users/mbolin/code/codex2"]
|
||||
trust_level = "trusted"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
582
codex-rs/core/src/config_edit.rs
Normal file
582
codex-rs/core/src/config_edit.rs
Normal file
@@ -0,0 +1,582 @@
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use tempfile::NamedTempFile;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
pub const CONFIG_KEY_MODEL: &str = "model";
|
||||
pub const CONFIG_KEY_EFFORT: &str = "model_reasoning_effort";
|
||||
|
||||
/// Persist overrides into `config.toml` using explicit key segments per
|
||||
/// override. This avoids ambiguity with keys that contain dots or spaces.
|
||||
pub async fn persist_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], &str)],
|
||||
) -> Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
let mut doc = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
tokio::fs::create_dir_all(codex_home).await?;
|
||||
DocumentMut::new()
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let effective_profile = if let Some(p) = profile {
|
||||
Some(p.to_owned())
|
||||
} else {
|
||||
doc.get("profile")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(|s| s.to_string())
|
||||
};
|
||||
|
||||
for (segments, val) in overrides.iter().copied() {
|
||||
let value = toml_edit::value(val);
|
||||
if let Some(ref name) = effective_profile {
|
||||
if segments.first().copied() == Some("profiles") {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
} else {
|
||||
let mut seg_buf: Vec<&str> = Vec::with_capacity(2 + segments.len());
|
||||
seg_buf.push("profiles");
|
||||
seg_buf.push(name.as_str());
|
||||
seg_buf.extend_from_slice(segments);
|
||||
apply_toml_edit_override_segments(&mut doc, &seg_buf, value);
|
||||
}
|
||||
} else {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
}
|
||||
}
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
tokio::fs::write(tmp_file.path(), doc.to_string()).await?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist overrides where values may be optional. Any entries with `None`
|
||||
/// values are skipped. If all values are `None`, this becomes a no-op and
|
||||
/// returns `Ok(())` without touching the file.
|
||||
pub async fn persist_non_null_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
) -> Result<()> {
|
||||
let filtered: Vec<(&[&str], &str)> = overrides
|
||||
.iter()
|
||||
.filter_map(|(k, v)| v.map(|vv| (*k, vv)))
|
||||
.collect();
|
||||
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
persist_overrides(codex_home, profile, &filtered).await
|
||||
}
|
||||
|
||||
/// Apply a single override onto a `toml_edit` document while preserving
|
||||
/// existing formatting/comments.
|
||||
/// The key is expressed as explicit segments to correctly handle keys that
|
||||
/// contain dots or spaces.
|
||||
fn apply_toml_edit_override_segments(
|
||||
doc: &mut DocumentMut,
|
||||
segments: &[&str],
|
||||
value: toml_edit::Item,
|
||||
) {
|
||||
use toml_edit::Item;
|
||||
|
||||
if segments.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut current = doc.as_table_mut();
|
||||
for seg in &segments[..segments.len() - 1] {
|
||||
if !current.contains_key(seg) {
|
||||
current[*seg] = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = current[*seg].as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let maybe_item = current.get_mut(seg);
|
||||
let Some(item) = maybe_item else { return };
|
||||
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = item.as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tbl) = item.as_table_mut() else {
|
||||
return;
|
||||
};
|
||||
current = tbl;
|
||||
}
|
||||
|
||||
let last = segments[segments.len() - 1];
|
||||
current[last] = value;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Verifies model and effort are written at top-level when no profile is set.
|
||||
#[tokio::test]
|
||||
async fn set_default_model_and_effort_top_level_when_no_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "high"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies values are written under the active profile when `profile` is set.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_when_profile_set() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile selection but without profiles table
|
||||
let seed = "profile = \"o3\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "o3"
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies profile names with dots/spaces are preserved via explicit segments.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_with_dot_and_space() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile name that contains a dot and a space
|
||||
let seed = "profile = \"my.team name\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "my.team name"
|
||||
|
||||
[profiles."my.team name"]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies explicit profile override writes under that profile even without active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_when_profile_override_supplied() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No profile key in config.toml
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), "")
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Persist with an explicit profile override
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("o3"),
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies nested tables are created as needed when applying overrides.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_creates_nested_tables() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&["a", "b", "c"], "v"),
|
||||
(&["x"], "y"),
|
||||
(&["profiles", "p1", CONFIG_KEY_MODEL], "gpt-5"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"x = "y"
|
||||
|
||||
[a.b]
|
||||
c = "v"
|
||||
|
||||
[profiles.p1]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies a scalar key becomes a table when nested keys are written.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_replaces_scalar_with_table() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
let seed = "foo = \"bar\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&["foo", "bar", "baz"], "ok")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[foo.bar]
|
||||
baz = "ok"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing under active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config with comments and spacing we expect to preserve
|
||||
let seed = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Apply defaults; since profile is set, it should write under [profiles.o3]
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing at top level.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_global_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config WITHOUT a profile, containing comments and spacing
|
||||
let seed = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Since there is no profile, the defaults should be written at top-level
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies errors on invalid TOML propagate and file is not clobbered.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_errors_on_parse_failure() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Write an intentionally invalid TOML file
|
||||
let invalid = "invalid = [unclosed";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), invalid)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Attempting to persist should return an error and must not clobber the file.
|
||||
let res = persist_overrides(codex_home, None, &[(&["x"], "y")]).await;
|
||||
assert!(res.is_err(), "expected parse error to propagate");
|
||||
|
||||
// File should be unchanged
|
||||
let contents = read_config(codex_home).await;
|
||||
assert_eq!(contents, invalid);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_existing_effort_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an effort value only
|
||||
let seed = "model_reasoning_effort = \"minimal\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the model
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o3")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model_reasoning_effort = "minimal"
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_existing_model_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with a model value only
|
||||
let seed = "model = \"gpt-5\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the effort
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_EFFORT], "high")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort in active profile.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_effort_in_active_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an active profile and an existing effort under that profile
|
||||
let seed = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o4-mini")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
model = "o4-mini"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model in a profile override.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_model_in_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No active profile key; we'll target an explicit override
|
||||
let seed = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[(&[CONFIG_KEY_EFFORT], "minimal")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies `persist_non_null_overrides` skips `None` entries and writes only present values at top-level.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_skips_none_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("gpt-5")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = "model = \"gpt-5\"\n";
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies no-op behavior when all provided overrides are `None` (no file created/modified).
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_noop_when_all_none() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&["a"], None), (&["profiles", "p", "x"], None)],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
// Should not create config.toml on a pure no-op
|
||||
assert!(!codex_home.join(CONFIG_TOML_FILE).exists());
|
||||
}
|
||||
|
||||
/// Verifies entries are written under the specified profile and `None` entries are skipped.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_respects_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("o3")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
// Test helper moved to bottom per review guidance.
|
||||
async fn read_config(codex_home: &Path) -> String {
|
||||
let p = codex_home.join(CONFIG_TOML_FILE);
|
||||
tokio::fs::read_to_string(p).await.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,14 @@
|
||||
use crate::rollout::policy::should_persist_event_msg;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
|
||||
/// Transcript of conversation history
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ResponseItemsHistory {
|
||||
pub(crate) struct ConversationHistory {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
impl ResponseItemsHistory {
|
||||
impl ConversationHistory {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self { items: Vec::new() }
|
||||
}
|
||||
@@ -62,50 +60,25 @@ impl ResponseItemsHistory {
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct EventMsgsHistory {
|
||||
items: Vec<EventMsg>,
|
||||
}
|
||||
|
||||
impl EventMsgsHistory {
|
||||
pub(crate) fn record_items<I>(&mut self, items: I)
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: std::ops::Deref<Target = EventMsg>,
|
||||
{
|
||||
for item in items {
|
||||
if self.should_record_item(&item) {
|
||||
self.items.push(item.clone());
|
||||
pub(crate) fn last_agent_message(&self) -> String {
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item
|
||||
&& role == "assistant"
|
||||
{
|
||||
return content
|
||||
.iter()
|
||||
.find_map(|ci| {
|
||||
if let ContentItem::OutputText { text } = ci {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_record_item(&self, item: &EventMsg) -> bool {
|
||||
should_persist_event_msg(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ResponseItemsHistory> for Vec<RolloutItem> {
|
||||
fn from(history: &ResponseItemsHistory) -> Self {
|
||||
history
|
||||
.items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&EventMsgsHistory> for Vec<RolloutItem> {
|
||||
fn from(history: &EventMsgsHistory) -> Self {
|
||||
history
|
||||
.items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::EventMsg)
|
||||
.collect()
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +124,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn filters_non_api_messages() {
|
||||
let mut h = ResponseItemsHistory::default();
|
||||
let mut h = ConversationHistory::default();
|
||||
// System message is not an API message; Other is ignored.
|
||||
let system = ResponseItem::Message {
|
||||
id: None,
|
||||
|
||||
@@ -150,13 +150,13 @@ impl ConversationManager {
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
conversation_history: InitialHistory,
|
||||
num_messages_to_drop: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let history =
|
||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_after_dropping_last_messages(history, num_messages_to_drop);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
@@ -173,18 +173,31 @@ impl ConversationManager {
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return history;
|
||||
return InitialHistory::Forked(history.get_rollout_items());
|
||||
}
|
||||
|
||||
// Compute event prefix by dropping the last `n` user events (counted from the end).
|
||||
let event_msgs_prefix: Vec<EventMsg> =
|
||||
build_event_prefix_excluding_last_n_user_turns(&history, n);
|
||||
// Work directly on rollout items, and cut the vector at the nth-from-last user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Keep only response items strictly before the cut (drop last `n` user messages).
|
||||
let response_prefix: Vec<ResponseItem> =
|
||||
build_response_prefix_excluding_last_n_user_turns(&history, n);
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, .. }) = item
|
||||
&& role == "user"
|
||||
{
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// If fewer than n user messages exist, treat as empty.
|
||||
if user_positions.len() < n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth-from-last user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[user_positions.len() - n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
|
||||
let rolled = build_truncated_rollout(&event_msgs_prefix, &response_prefix);
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
} else {
|
||||
@@ -192,96 +205,9 @@ fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> I
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the event messages prefix from `history` by dropping the last `n` user
|
||||
/// turns (counted from the end) and taking everything before that cut.
|
||||
fn build_event_prefix_excluding_last_n_user_turns(
|
||||
history: &InitialHistory,
|
||||
n: usize,
|
||||
) -> Vec<EventMsg> {
|
||||
match history.get_event_msgs() {
|
||||
Some(all_events) => {
|
||||
take_prefix_before_index(&all_events, find_cut_event_index(&all_events, n))
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the response items prefix from `history` by dropping the last `n` user
|
||||
/// turns (counted from the end) and taking everything before that cut.
|
||||
fn build_response_prefix_excluding_last_n_user_turns(
|
||||
history: &InitialHistory,
|
||||
n: usize,
|
||||
) -> Vec<ResponseItem> {
|
||||
let all_items: Vec<ResponseItem> = history.get_response_items();
|
||||
take_prefix_before_index(&all_items, find_cut_response_index(&all_items, n))
|
||||
}
|
||||
|
||||
/// Return a cloned prefix of `items` up to (but not including) `idx`.
|
||||
/// If `idx` is `None`, returns an empty vector.
|
||||
fn take_prefix_before_index<T: Clone>(items: &[T], idx: Option<usize>) -> Vec<T> {
|
||||
match idx {
|
||||
Some(i) => items[..i].to_vec(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the index (into response items) of the Nth user message from the end.
|
||||
fn find_cut_response_index(response_items: &[ResponseItem], n: usize) -> Option<usize> {
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
let mut remaining = n;
|
||||
for (idx, item) in response_items.iter().enumerate().rev() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
remaining -= 1;
|
||||
if remaining == 0 {
|
||||
return Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Find the index (into event messages) of the Nth user event from the end.
|
||||
fn find_cut_event_index(event_msgs: &[EventMsg], n: usize) -> Option<usize> {
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
let mut remaining = n;
|
||||
for (idx, ev) in event_msgs.iter().enumerate().rev() {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
remaining -= 1;
|
||||
if remaining == 0 {
|
||||
return Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Build a truncated rollout by concatenating the (already-sliced) event messages and response items.
|
||||
fn build_truncated_rollout(
|
||||
event_msgs: &[EventMsg],
|
||||
response_items: &[ResponseItem],
|
||||
) -> Vec<RolloutItem> {
|
||||
let mut rolled: Vec<RolloutItem> = Vec::with_capacity(event_msgs.len() + response_items.len());
|
||||
rolled.extend(event_msgs.iter().cloned().map(RolloutItem::EventMsg));
|
||||
rolled.extend(
|
||||
response_items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem),
|
||||
);
|
||||
rolled
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::event_mapping::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -295,15 +221,6 @@ mod tests {
|
||||
}],
|
||||
}
|
||||
}
|
||||
fn user_input(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
@@ -365,54 +282,4 @@ mod tests {
|
||||
let truncated2 = truncate_after_dropping_last_messages(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_prefix_counts_from_end_with_duplicate_user_prompts() {
|
||||
// Two identical user prompts with assistant replies between them.
|
||||
let responses = vec![
|
||||
user_input("same"),
|
||||
assistant_msg("a1"),
|
||||
user_input("same"),
|
||||
assistant_msg("a2"),
|
||||
];
|
||||
|
||||
// Derive event messages in order from responses (user → UserMessage, assistant → AgentMessage).
|
||||
let mut events: Vec<EventMsg> = Vec::new();
|
||||
for r in &responses {
|
||||
events.extend(map_response_item_to_event_messages(r, false));
|
||||
}
|
||||
|
||||
// Build initial history containing both events and responses.
|
||||
let mut initial: Vec<RolloutItem> = Vec::new();
|
||||
initial.extend(events.iter().cloned().map(RolloutItem::EventMsg));
|
||||
initial.extend(responses.iter().cloned().map(RolloutItem::ResponseItem));
|
||||
|
||||
// Drop the last user turn.
|
||||
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
|
||||
|
||||
// Expect the event prefix to include the first user + first assistant only,
|
||||
// and the response prefix to include the first user + first assistant only.
|
||||
let got_items = truncated.get_rollout_items();
|
||||
|
||||
// Compute expected events and responses after cut.
|
||||
let expected_event_prefix: Vec<RolloutItem> = events[..2]
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::EventMsg)
|
||||
.collect();
|
||||
let expected_response_prefix: Vec<RolloutItem> = responses[..2]
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
|
||||
let mut expected: Vec<RolloutItem> = Vec::new();
|
||||
expected.extend(expected_event_prefix);
|
||||
expected.extend(expected_response_prefix);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ pub(crate) struct EnvironmentContext {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
}
|
||||
|
||||
@@ -57,6 +58,16 @@ impl EnvironmentContext {
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
writable_roots: match sandbox_policy {
|
||||
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
|
||||
if writable_roots.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(writable_roots)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
}
|
||||
}
|
||||
@@ -72,6 +83,7 @@ impl EnvironmentContext {
|
||||
/// <cwd>...</cwd>
|
||||
/// <approval_policy>...</approval_policy>
|
||||
/// <sandbox_mode>...</sandbox_mode>
|
||||
/// <writable_roots>...</writable_roots>
|
||||
/// <network_access>...</network_access>
|
||||
/// <shell>...</shell>
|
||||
/// </environment_context>
|
||||
@@ -94,6 +106,16 @@ impl EnvironmentContext {
|
||||
" <network_access>{network_access}</network_access>"
|
||||
));
|
||||
}
|
||||
if let Some(writable_roots) = self.writable_roots {
|
||||
lines.push(" <writable_roots>".to_string());
|
||||
for writable_root in writable_roots {
|
||||
lines.push(format!(
|
||||
" <root>{}</root>",
|
||||
writable_root.to_string_lossy()
|
||||
));
|
||||
}
|
||||
lines.push(" </writable_roots>".to_string());
|
||||
}
|
||||
if let Some(shell) = self.shell
|
||||
&& let Some(shell_name) = shell.name()
|
||||
{
|
||||
@@ -115,3 +137,77 @@ impl From<EnvironmentContext> for ResponseItem {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
|
||||
network_access,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_workspace_write_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<cwd>/repo</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
<writable_roots>
|
||||
<root>/repo</root>
|
||||
<root>/tmp</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_read_only_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_full_access_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>on-failure</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
@@ -127,38 +129,58 @@ pub enum CodexErr {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UsageLimitReachedError {
|
||||
pub plan_type: Option<String>,
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
pub(crate) plan_type: Option<PlanType>,
|
||||
pub(crate) resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UsageLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Base message differs slightly for legacy ChatGPT Plus plan users.
|
||||
if let Some(plan_type) = &self.plan_type
|
||||
&& plan_type == "plus"
|
||||
{
|
||||
write!(
|
||||
f,
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing) or try again"
|
||||
)?;
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " later.")?;
|
||||
let message = match self.plan_type.as_ref() {
|
||||
Some(PlanType::Known(KnownPlan::Plus)) => format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing){}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
|
||||
format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin{}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
write!(f, "You've hit your usage limit.")?;
|
||||
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " Try again in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " Try again later.")?;
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
| Some(PlanType::Known(KnownPlan::Enterprise))
|
||||
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Unknown(_)) | None => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" Try again in {reset_duration}.")
|
||||
} else {
|
||||
" Try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix_after_or(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" or try again in {reset_duration}.")
|
||||
} else {
|
||||
" or try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,7 +259,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_plus_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -246,6 +268,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Free)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_when_none() {
|
||||
let err = UsageLimitReachedError {
|
||||
@@ -258,10 +292,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_team_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Team)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again in 1 hour."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_business_plan_without_reset() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_for_other_plans() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("pro".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -285,7 +343,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_hours_and_minutes() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: Some(3 * 3600 + 32 * 60),
|
||||
};
|
||||
assert_eq!(
|
||||
|
||||
@@ -159,7 +159,7 @@ mod tests {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_eq!(user.images, Some(vec![img1.clone(), img2.clone()]));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession {
|
||||
|
||||
/// JoinHandle for the child wait task.
|
||||
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||
|
||||
/// Tracks whether the underlying process has exited.
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
impl ExecCommandSession {
|
||||
@@ -34,6 +37,7 @@ impl ExecCommandSession {
|
||||
reader_handle: JoinHandle<()>,
|
||||
writer_handle: JoinHandle<()>,
|
||||
wait_handle: JoinHandle<()>,
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
writer_tx,
|
||||
@@ -42,6 +46,7 @@ impl ExecCommandSession {
|
||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||
exit_status,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +57,10 @@ impl ExecCommandSession {
|
||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||
self.output_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn has_exited(&self) -> bool {
|
||||
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ExecCommandSession {
|
||||
|
||||
@@ -6,6 +6,7 @@ mod session_manager;
|
||||
|
||||
pub use exec_command_params::ExecCommandParams;
|
||||
pub use exec_command_params::WriteStdinParams;
|
||||
pub(crate) use exec_command_session::ExecCommandSession;
|
||||
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use portable_pty::CommandBuilder;
|
||||
@@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||
use crate::exec_command::session_id::SessionId;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -327,11 +329,14 @@ async fn create_exec_command_session(
|
||||
|
||||
// Keep the child alive until it exits, then signal exit code.
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = exit_status.clone();
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.exit_code() as i32,
|
||||
Err(_) => -1,
|
||||
};
|
||||
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
@@ -343,116 +348,11 @@ async fn create_exec_command_session(
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
);
|
||||
Ok((session, exit_rx))
|
||||
}
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
// No truncation needed
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
// Cannot keep any content; still return a full marker (never truncated).
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
// Helper to truncate a string to a given byte length on a char boundary.
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
||||
// to UTF-8 char boundaries.
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1; // keep the newline so suffix starts on a fresh line
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1; // start after newline
|
||||
}
|
||||
// Fall back to a char boundary at or after start_tail.
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
// Refine marker length and budgets until stable. Marker is never truncated.
|
||||
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
// No room for any content within the cap; return a full, untruncated marker
|
||||
// that reflects the entire truncated content.
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
// Place marker on its own line for symmetry when we keep line boundaries.
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
// Fallback: use last guess to build output.
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -616,50 +516,4 @@ Output:
|
||||
abc"#;
|
||||
assert_eq!(expected, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
// A long string with no newlines that exceeds the cap.
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
let max_bytes = 16; // force truncation
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
// For very small caps, we return the full, untruncated marker,
|
||||
// even if it exceeds the cap.
|
||||
assert_eq!(out, "…16 tokens truncated…");
|
||||
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
||||
assert_eq!(original, Some(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -802,7 +802,7 @@ mod tests {
|
||||
async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap().to_path_buf();
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&repo_path),
|
||||
@@ -810,10 +810,7 @@ mod tests {
|
||||
);
|
||||
let nested = repo_path.join("sub/dir");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&nested),
|
||||
Some(expected.clone())
|
||||
);
|
||||
assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
68
codex-rs/core/src/internal_storage.rs
Normal file
68
codex-rs/core/src/internal_storage.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub(crate) const INTERNAL_STORAGE_FILE: &str = "internal_storage.json";
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct InternalStorage {
|
||||
#[serde(skip)]
|
||||
storage_path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub gpt_5_high_model_prompt_seen: bool,
|
||||
}
|
||||
|
||||
// TODO(jif) generalise all the file writers and build proper async channel inserters.
|
||||
impl InternalStorage {
|
||||
pub fn load(codex_home: &Path) -> Self {
|
||||
let storage_path = codex_home.join(INTERNAL_STORAGE_FILE);
|
||||
|
||||
match std::fs::read_to_string(&storage_path) {
|
||||
Ok(serialized) => match serde_json::from_str::<Self>(&serialized) {
|
||||
Ok(mut storage) => {
|
||||
storage.storage_path = storage_path;
|
||||
storage
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to parse internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to read internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn empty(storage_path: PathBuf) -> Self {
|
||||
Self {
|
||||
storage_path,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> anyhow::Result<()> {
|
||||
let serialized = serde_json::to_string_pretty(self)?;
|
||||
|
||||
if let Some(parent) = self.storage_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!(
|
||||
"failed to create internal storage directory at {}",
|
||||
parent.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&self.storage_path, serialized)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to persist internal storage at {}",
|
||||
self.storage_path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -27,6 +28,7 @@ mod exec_command;
|
||||
pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod internal_storage;
|
||||
mod is_safe_command;
|
||||
pub mod landlock;
|
||||
mod mcp_connection_manager;
|
||||
@@ -34,6 +36,8 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
@@ -71,6 +75,7 @@ pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use safety::get_platform_sandbox;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
|
||||
@@ -17,7 +17,7 @@ use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::McpClientInfo;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
@@ -159,10 +159,14 @@ impl McpConnectionManager {
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: McpClientInfo {
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -80,7 +80,10 @@ pub struct ModelProviderInfo {
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
max_output_tokens: 4_096,
|
||||
}),
|
||||
|
||||
"gpt-5" => Some(ModelInfo {
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo {
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
@@ -8,7 +8,6 @@ use std::collections::HashMap;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
use crate::tool_apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tool_apply_patch::create_apply_patch_json_tool;
|
||||
@@ -58,7 +57,7 @@ pub(crate) enum OpenAiTool {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
ShellWithRequest,
|
||||
LocalShell,
|
||||
StreamableShell,
|
||||
}
|
||||
@@ -70,17 +69,18 @@ pub(crate) struct ToolsConfig {
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) model_family: &'a ModelFamily,
|
||||
pub(crate) approval_policy: AskForApproval,
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) include_plan_tool: bool,
|
||||
pub(crate) include_apply_patch_tool: bool,
|
||||
pub(crate) include_web_search_request: bool,
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
@@ -88,12 +88,12 @@ impl ToolsConfig {
|
||||
let ToolsConfigParams {
|
||||
model_family,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_web_search_request,
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
} = params;
|
||||
let mut shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::StreamableShell
|
||||
@@ -103,9 +103,7 @@ impl ToolsConfig {
|
||||
ConfigShellToolType::DefaultShell
|
||||
};
|
||||
if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
|
||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
}
|
||||
shell_type = ConfigShellToolType::ShellWithRequest;
|
||||
}
|
||||
|
||||
let apply_patch_tool_type = match model_family.apply_patch_tool_type {
|
||||
@@ -126,6 +124,7 @@ impl ToolsConfig {
|
||||
apply_patch_tool_type,
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,7 +199,56 @@ fn create_shell_tool() -> OpenAiTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
fn create_unified_exec_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some(
|
||||
"When no session_id is provided, treat the array as the command and arguments \
|
||||
to launch. When session_id is set, concatenate the strings (in order) and write \
|
||||
them to the session's stdin."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"session_id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier for an existing interactive session. If omitted, a new command \
|
||||
is spawned."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum time in milliseconds to wait for output after writing the input."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "unified_exec".to_string(),
|
||||
description:
|
||||
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["input".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const SHELL_TOOL_DESCRIPTION: &str = r#"Runs a shell command and returns its output"#;
|
||||
|
||||
fn create_shell_tool_for_request() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
@@ -212,82 +260,29 @@ fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
description: Some("Working directory to execute the command in.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
description: Some("Timeout for the command in milliseconds.".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) {
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
description: Some("Request escalated permissions, only for when a command would otherwise be blocked by the sandbox.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
description: Some("Required if and only if with_escalated_permissions == true. One sentence explaining why escalation is needed (e.g., write outside CWD, network fetch, git commit).".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
writable_roots,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
{}{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
writable_roots.iter().map(|wr| format!(" - {}", wr.to_string_lossy())).collect::<Vec<String>>().join("\n"),
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
let description = SHELL_TOOL_DESCRIPTION.to_string();
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
@@ -300,7 +295,6 @@ The shell tool is used to execute shell commands.
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_view_image_tool() -> OpenAiTool {
|
||||
// Support only local filesystem path.
|
||||
let mut properties = BTreeMap::new();
|
||||
@@ -534,23 +528,27 @@ pub(crate) fn get_openai_tools(
|
||||
) -> Vec<OpenAiTool> {
|
||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
if config.experimental_unified_exec_tool {
|
||||
tools.push(create_unified_exec_tool());
|
||||
} else {
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest => {
|
||||
tools.push(create_shell_tool_for_request());
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,10 +575,8 @@ pub(crate) fn get_openai_tools(
|
||||
if config.include_view_image_tool {
|
||||
tools.push(create_view_image_tool());
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
// Ensure deterministic ordering to maximize prompt cache hits.
|
||||
// HashMap iteration order is non-deterministic, so sort by fully-qualified tool name.
|
||||
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
@@ -636,18 +632,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["local_shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -657,18 +653,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -678,12 +674,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
@@ -726,7 +722,7 @@ mod tests {
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -783,12 +779,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -841,11 +837,11 @@ mod tests {
|
||||
]);
|
||||
|
||||
let tools = get_openai_tools(&config, Some(tools_map));
|
||||
// Expect shell first, followed by MCP tools sorted by fully-qualified name.
|
||||
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -860,12 +856,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -893,7 +889,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/search"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/search"],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -922,12 +918,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -953,7 +949,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/paginate"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/paginate"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
@@ -979,12 +975,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1008,7 +1004,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/tags"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1036,12 +1035,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1065,7 +1064,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/value"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1086,13 +1088,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_workspace_write() {
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec!["workspace".into()],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
let tool = super::create_shell_tool_for_sandbox(&sandbox_policy);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1101,29 +1097,13 @@ mod tests {
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
- workspace
|
||||
- Commands that require network access
|
||||
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_readonly() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::ReadOnly);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1132,27 +1112,13 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_danger_full_access() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::DangerFullAccess);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1161,6 +1127,7 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
assert_eq!(description, "Runs a shell command and returns its output.");
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -868,7 +868,7 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
let parts = if contains_connectors(&normalized) {
|
||||
split_on_connectors(&normalized)
|
||||
} else {
|
||||
vec![normalized.clone()]
|
||||
vec![normalized]
|
||||
};
|
||||
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
@@ -1201,10 +1201,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: cmd.clone(),
|
||||
name,
|
||||
}
|
||||
ParsedCommand::Read { cmd, name }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
@@ -1215,10 +1212,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
ParsedCommand::ListFiles { path, cmd, .. } => {
|
||||
if had_connectors {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: cmd.clone(),
|
||||
path,
|
||||
}
|
||||
ParsedCommand::ListFiles { cmd, path }
|
||||
} else {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
@@ -1230,11 +1224,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
query, path, cmd, ..
|
||||
} => {
|
||||
if had_connectors {
|
||||
ParsedCommand::Search {
|
||||
cmd: cmd.clone(),
|
||||
query,
|
||||
path,
|
||||
}
|
||||
ParsedCommand::Search { cmd, query, path }
|
||||
} else {
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
|
||||
@@ -26,7 +26,7 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
|
||||
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
pub async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match read_project_docs(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
@@ -115,7 +115,7 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
// Build chain from cwd upwards and detect git root.
|
||||
let mut chain: Vec<PathBuf> = vec![dir.clone()];
|
||||
let mut git_root: Option<PathBuf> = None;
|
||||
let mut cursor = dir.clone();
|
||||
let mut cursor = dir;
|
||||
while let Some(parent) = cursor.parent() {
|
||||
let git_marker = cursor.join(".git");
|
||||
let git_exists = match std::fs::metadata(&git_marker) {
|
||||
|
||||
@@ -318,6 +318,12 @@ async fn read_head_and_flags(
|
||||
head.push(val);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::EventMsg(ev) => {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
saw_user_event = true;
|
||||
|
||||
@@ -8,8 +8,10 @@ pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
|
||||
// Always persist session meta
|
||||
RolloutItem::SessionMeta(_) => true,
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +67,6 @@ pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationHistory(_) => false,
|
||||
| EventMsg::ConversationPath(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,13 @@ pub enum RolloutRecorderParams {
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<RolloutItem>),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
/// Ensure all prior writes are processed; respond when flushed.
|
||||
Flush {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
Shutdown {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RolloutRecorderParams {
|
||||
@@ -185,6 +191,17 @@ impl RolloutRecorder {
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
/// Flush all queued writes and wait until they are committed by the writer task.
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::Flush { ack: tx })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
tracing::error!("Resuming rollout from {path:?}");
|
||||
@@ -211,16 +228,22 @@ impl RolloutRecorder {
|
||||
match serde_json::from_value::<RolloutLine>(v.clone()) {
|
||||
Ok(rollout_line) => match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
tracing::error!(
|
||||
"Parsed conversation ID from rollout file: {:?}",
|
||||
session_meta_line.meta.id
|
||||
);
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
// Use the FIRST SessionMeta encountered in the file as the canonical
|
||||
// conversation id and main session information. Keep all items intact.
|
||||
if conversation_id.is_none() {
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
}
|
||||
items.push(RolloutItem::SessionMeta(session_meta_line));
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
RolloutItem::Compacted(item) => {
|
||||
items.push(RolloutItem::Compacted(item));
|
||||
}
|
||||
RolloutItem::TurnContext(item) => {
|
||||
items.push(RolloutItem::TurnContext(item));
|
||||
}
|
||||
RolloutItem::EventMsg(_ev) => {
|
||||
items.push(RolloutItem::EventMsg(_ev));
|
||||
}
|
||||
@@ -251,6 +274,10 @@ impl RolloutRecorder {
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn get_rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
@@ -351,6 +378,14 @@ async fn rollout_writer(
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::Flush { ack } => {
|
||||
// Ensure underlying file is flushed and then ack.
|
||||
if let Err(e) = writer.file.flush().await {
|
||||
let _ = ack.send(());
|
||||
return Err(e);
|
||||
}
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ async fn test_pagination_cursor() {
|
||||
path: p1,
|
||||
head: head_1,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor3.clone()),
|
||||
next_cursor: Some(expected_cursor3),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02 (anchor), 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
@@ -344,7 +344,7 @@ async fn test_get_conversation_contents() {
|
||||
let expected_cursor: Cursor = serde_json::from_str(&format!("\"{ts}|{uuid}\"")).unwrap();
|
||||
let expected_page = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: expected_path.clone(),
|
||||
path: expected_path,
|
||||
head: expected_head,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor),
|
||||
@@ -437,7 +437,7 @@ async fn test_stable_ordering_same_second_pagination() {
|
||||
path: p1,
|
||||
head: head(u1),
|
||||
}],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
next_cursor: Some(expected_cursor2),
|
||||
num_scanned_files: 3, // scanned u3, u2 (anchor), u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
|
||||
@@ -293,7 +293,7 @@ mod tests {
|
||||
// With the parent dir explicitly added as a writable root, the
|
||||
// outside write should be permitted.
|
||||
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![parent.clone()],
|
||||
writable_roots: vec![parent],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -153,7 +153,7 @@ mod tests {
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
writable_roots: vec![root_with_git, root_without_git],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -1,36 +1,18 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use shlex;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// This structure cannot derive Clone or this will break the Drop implementation.
|
||||
pub struct ShellSnapshot {
|
||||
pub(crate) path: PathBuf,
|
||||
}
|
||||
|
||||
impl ShellSnapshot {
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ShellSnapshot {
|
||||
fn drop(&mut self) {
|
||||
delete_shell_snapshot(&self.path);
|
||||
}
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct ZshShell {
|
||||
shell_path: String,
|
||||
zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PosixShell {
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) rc_path: String,
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub(crate) shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
pub struct BashShell {
|
||||
shell_path: String,
|
||||
bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
@@ -41,7 +23,8 @@ pub struct PowerShellConfig {
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub enum Shell {
|
||||
Posix(PosixShell),
|
||||
Zsh(ZshShell),
|
||||
Bash(BashShell),
|
||||
PowerShell(PowerShellConfig),
|
||||
Unknown,
|
||||
}
|
||||
@@ -49,27 +32,11 @@ pub enum Shell {
|
||||
impl Shell {
|
||||
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Shell::Posix(shell) => {
|
||||
let joined = strip_bash_lc(&command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok())?;
|
||||
|
||||
let mut source_path = Path::new(&shell.rc_path);
|
||||
|
||||
let session_cmd = if let Some(shell_snapshot) = &shell.shell_snapshot
|
||||
&& shell_snapshot.path.exists()
|
||||
{
|
||||
source_path = shell_snapshot.path.as_path();
|
||||
"-c".to_string()
|
||||
} else {
|
||||
"-lc".to_string()
|
||||
};
|
||||
|
||||
let source_path_str = source_path.to_string_lossy().to_string();
|
||||
let quoted_source_path = shlex::try_quote(&source_path_str).ok()?;
|
||||
let rc_command =
|
||||
format!("[ -f {quoted_source_path} ] && . {quoted_source_path}; ({joined})");
|
||||
|
||||
Some(vec![shell.shell_path.clone(), session_cmd, rc_command])
|
||||
Shell::Zsh(zsh) => {
|
||||
format_shell_invocation_with_rc(&command, &zsh.shell_path, &zsh.zshrc_path)
|
||||
}
|
||||
Shell::Bash(bash) => {
|
||||
format_shell_invocation_with_rc(&command, &bash.shell_path, &bash.bashrc_path)
|
||||
}
|
||||
Shell::PowerShell(ps) => {
|
||||
// If model generated a bash command, prefer a detected bash fallback
|
||||
@@ -122,20 +89,33 @@ impl Shell {
|
||||
|
||||
pub fn name(&self) -> Option<String> {
|
||||
match self {
|
||||
Shell::Posix(shell) => Path::new(&shell.shell_path)
|
||||
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::PowerShell(ps) => Some(ps.exe.clone()),
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_snapshot(&self) -> Option<Arc<ShellSnapshot>> {
|
||||
match self {
|
||||
Shell::Posix(shell) => shell.shell_snapshot.clone(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
fn format_shell_invocation_with_rc(
|
||||
command: &Vec<String>,
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
) -> Option<Vec<String>> {
|
||||
let joined = strip_bash_lc(command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok())?;
|
||||
|
||||
let rc_command = if std::path::Path::new(rc_path).exists() {
|
||||
format!("source {rc_path} && ({joined})")
|
||||
} else {
|
||||
joined
|
||||
};
|
||||
|
||||
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
|
||||
}
|
||||
|
||||
fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
@@ -152,7 +132,7 @@ fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn detect_default_user_shell(session_id: Uuid, codex_home: &Path) -> Shell {
|
||||
fn detect_default_user_shell() -> Shell {
|
||||
use libc::getpwuid;
|
||||
use libc::getuid;
|
||||
use std::ffi::CStr;
|
||||
@@ -167,45 +147,31 @@ async fn detect_default_user_shell(session_id: Uuid, codex_home: &Path) -> Shell
|
||||
.into_owned();
|
||||
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
|
||||
|
||||
let rc_path = if shell_path.ends_with("/zsh") {
|
||||
format!("{home_path}/.zshrc")
|
||||
} else if shell_path.ends_with("/bash") {
|
||||
format!("{home_path}/.bashrc")
|
||||
} else {
|
||||
return Shell::Unknown;
|
||||
};
|
||||
|
||||
let snapshot_path = snapshots::ensure_posix_snapshot(
|
||||
&shell_path,
|
||||
&rc_path,
|
||||
Path::new(&home_path),
|
||||
codex_home,
|
||||
session_id,
|
||||
)
|
||||
.await;
|
||||
if snapshot_path.is_none() {
|
||||
trace!("failed to prepare posix snapshot; using live profile");
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path,
|
||||
zshrc_path: format!("{home_path}/.zshrc"),
|
||||
});
|
||||
}
|
||||
let shell_snapshot =
|
||||
snapshot_path.map(|snapshot| Arc::new(ShellSnapshot::new(snapshot)));
|
||||
|
||||
return Shell::Posix(PosixShell {
|
||||
shell_path,
|
||||
rc_path,
|
||||
shell_snapshot,
|
||||
});
|
||||
if shell_path.ends_with("/bash") {
|
||||
return Shell::Bash(BashShell {
|
||||
shell_path,
|
||||
bashrc_path: format!("{home_path}/.bashrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub async fn default_user_shell(session_id: Uuid, codex_home: &Path) -> Shell {
|
||||
detect_default_user_shell(session_id, codex_home).await
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
detect_default_user_shell()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub async fn default_user_shell(_session_id: Uuid, _codex_home: &Path) -> Shell {
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
|
||||
// Prefer PowerShell 7+ (`pwsh`) if available, otherwise fall back to Windows PowerShell.
|
||||
@@ -245,158 +211,42 @@ pub async fn default_user_shell(_session_id: Uuid, _codex_home: &Path) -> Shell
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "windows"), not(unix)))]
|
||||
pub async fn default_user_shell(_session_id: Uuid, _codex_home: &Path) -> Shell {
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
mod snapshots {
|
||||
use super::*;
|
||||
|
||||
fn zsh_profile_paths(home: &Path) -> Vec<PathBuf> {
|
||||
[".zshenv", ".zprofile", ".zshrc", ".zlogin"]
|
||||
.into_iter()
|
||||
.map(|name| home.join(name))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn posix_profile_source_script(home: &Path) -> String {
|
||||
zsh_profile_paths(home)
|
||||
.into_iter()
|
||||
.map(|profile| {
|
||||
let profile_string = profile.to_string_lossy().into_owned();
|
||||
let quoted = shlex::try_quote(&profile_string)
|
||||
.map(|cow| cow.into_owned())
|
||||
.unwrap_or(profile_string.clone());
|
||||
|
||||
format!("[ -f {quoted} ] && . {quoted}")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_posix_snapshot(
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
home: &Path,
|
||||
codex_home: &Path,
|
||||
session_id: Uuid,
|
||||
) -> Option<PathBuf> {
|
||||
let snapshot_path = codex_home.join(format!("shell_snapshots/snapshot_{session_id}.zsh"));
|
||||
|
||||
// Check if an update in the profile requires to re-generate the snapshot.
|
||||
let snapshot_is_stale = async {
|
||||
let snapshot_metadata = tokio::fs::metadata(&snapshot_path).await.ok()?;
|
||||
let snapshot_modified = snapshot_metadata.modified().ok()?;
|
||||
|
||||
for profile in zsh_profile_paths(home) {
|
||||
let Ok(profile_metadata) = tokio::fs::metadata(&profile).await else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Ok(profile_modified) = profile_metadata.modified() else {
|
||||
return Some(true);
|
||||
};
|
||||
|
||||
if profile_modified > snapshot_modified {
|
||||
return Some(true);
|
||||
}
|
||||
}
|
||||
|
||||
Some(false)
|
||||
}
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
|
||||
if !snapshot_is_stale {
|
||||
return Some(snapshot_path);
|
||||
}
|
||||
|
||||
match regenerate_posix_snapshot(shell_path, rc_path, home, &snapshot_path).await {
|
||||
Ok(()) => Some(snapshot_path),
|
||||
Err(err) => {
|
||||
tracing::warn!("failed to generate posix snapshot: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn regenerate_posix_snapshot(
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
home: &Path,
|
||||
snapshot_path: &Path,
|
||||
) -> std::io::Result<()> {
|
||||
// Use `emulate -L sh` instead of `set -o posix` so we work on zsh builds
|
||||
// that disable that option. Guard `alias -p` with `|| true` so the script
|
||||
// keeps a zero exit status even if aliases are disabled.
|
||||
let mut capture_script = String::new();
|
||||
let profile_sources = posix_profile_source_script(home);
|
||||
if !profile_sources.is_empty() {
|
||||
capture_script.push_str(&format!("{profile_sources}; "));
|
||||
}
|
||||
|
||||
let zshrc = home.join(rc_path);
|
||||
|
||||
capture_script.push_str(
|
||||
&format!(". {}; setopt posixbuiltins; export -p; {{ alias | sed 's/^/alias /'; }} 2>/dev/null || true", zshrc.display()),
|
||||
);
|
||||
let output = tokio::process::Command::new(shell_path)
|
||||
.arg("-lc")
|
||||
.arg(capture_script)
|
||||
.env("HOME", home)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"snapshot capture exited with status {}",
|
||||
output.status
|
||||
)));
|
||||
}
|
||||
|
||||
let mut contents = String::from("# Generated by Codex. Do not edit.\n");
|
||||
|
||||
contents.push_str(&String::from_utf8_lossy(&output.stdout));
|
||||
contents.push('\n');
|
||||
|
||||
if let Some(parent) = snapshot_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let tmp_path = snapshot_path.with_extension("tmp");
|
||||
tokio::fs::write(&tmp_path, contents).await?;
|
||||
|
||||
// Restrict the snapshot to user read/write so that environment variables or aliases
|
||||
// that may contain secrets are not exposed to other users on the system.
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let permissions = std::fs::Permissions::from_mode(0o600);
|
||||
tokio::fs::set_permissions(&tmp_path, permissions).await?;
|
||||
|
||||
tokio::fs::rename(&tmp_path, snapshot_path).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn delete_shell_snapshot(path: &Path) {
|
||||
if let Err(err) = std::fs::remove_file(path) {
|
||||
trace!("failed to delete shell snapshot {path:?}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(unix)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command;
|
||||
|
||||
use std::path::PathBuf;
|
||||
#[tokio::test]
|
||||
async fn test_current_shell_detects_zsh() {
|
||||
let shell = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg("echo $SHELL")
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
|
||||
if shell_path.ends_with("/zsh") {
|
||||
assert_eq!(
|
||||
default_user_shell().await,
|
||||
Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc",),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_zshrc_not_exists() {
|
||||
let shell = Shell::Posix(PosixShell {
|
||||
let shell = Shell::Zsh(ZshShell {
|
||||
shell_path: "/bin/zsh".to_string(),
|
||||
rc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
shell_snapshot: None,
|
||||
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(
|
||||
@@ -404,7 +254,24 @@ mod tests {
|
||||
Some(vec![
|
||||
"/bin/zsh".to_string(),
|
||||
"-lc".to_string(),
|
||||
"[ -f /does/not/exist/.zshrc ] && . /does/not/exist/.zshrc; (myecho)".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bashrc_not_exists() {
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".to_string(),
|
||||
bashrc_path: "/does/not/exist/.bashrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
@@ -416,11 +283,7 @@ mod tests {
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f BASHRC_PATH ] && . BASHRC_PATH; (myecho)",
|
||||
],
|
||||
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
@@ -428,7 +291,7 @@ mod tests {
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f BASHRC_PATH ] && . BASHRC_PATH; (echo 'single' \"double\")",
|
||||
"source BASHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
@@ -454,20 +317,16 @@ mod tests {
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Posix(PosixShell {
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
rc_path: bashrc_path.to_str().unwrap().to_string(),
|
||||
shell_snapshot: None,
|
||||
bashrc_path: bashrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
@@ -507,82 +366,6 @@ mod tests {
|
||||
#[cfg(target_os = "macos")]
|
||||
mod macos_tests {
|
||||
use super::*;
|
||||
use crate::shell::snapshots::ensure_posix_snapshot;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snapshot_generation_uses_session_id_and_cleanup() {
|
||||
let shell_path = "/bin/zsh";
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let codex_home = tempfile::tempdir().unwrap();
|
||||
std::fs::write(
|
||||
temp_home.path().join(".zshrc"),
|
||||
"export SNAPSHOT_TEST_VAR=1\nalias snapshot_test_alias='echo hi'\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let session_id = Uuid::new_v4();
|
||||
let snapshot_path = ensure_posix_snapshot(
|
||||
shell_path,
|
||||
".zshrc",
|
||||
temp_home.path(),
|
||||
codex_home.path(),
|
||||
session_id,
|
||||
)
|
||||
.await
|
||||
.expect("snapshot path");
|
||||
|
||||
let filename = snapshot_path
|
||||
.file_name()
|
||||
.unwrap()
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
assert!(filename.contains(&session_id.to_string()));
|
||||
assert!(snapshot_path.exists());
|
||||
|
||||
let snapshot_path_second = ensure_posix_snapshot(
|
||||
shell_path,
|
||||
".zshrc",
|
||||
temp_home.path(),
|
||||
codex_home.path(),
|
||||
session_id,
|
||||
)
|
||||
.await
|
||||
.expect("snapshot path");
|
||||
assert_eq!(snapshot_path, snapshot_path_second);
|
||||
|
||||
let contents = std::fs::read_to_string(&snapshot_path).unwrap();
|
||||
assert!(contents.contains("alias snapshot_test_alias='echo hi'"));
|
||||
assert!(contents.contains("SNAPSHOT_TEST_VAR=1"));
|
||||
|
||||
delete_shell_snapshot(&snapshot_path);
|
||||
assert!(!snapshot_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_default_shell_invocation_prefers_snapshot_when_available() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let snapshot_path = temp_dir.path().join("snapshot.zsh");
|
||||
std::fs::write(&snapshot_path, "export SNAPSHOT_READY=1").unwrap();
|
||||
|
||||
let shell = Shell::Posix(PosixShell {
|
||||
shell_path: "/bin/zsh".to_string(),
|
||||
rc_path: {
|
||||
let path = temp_dir.path().join(".zshrc");
|
||||
std::fs::write(&path, "# test zshrc").unwrap();
|
||||
path.to_string_lossy().to_string()
|
||||
},
|
||||
shell_snapshot: Some(Arc::new(ShellSnapshot::new(snapshot_path.clone()))),
|
||||
});
|
||||
|
||||
let invocation = shell.format_default_shell_invocation(vec!["echo".to_string()]);
|
||||
let expected_command = vec!["/bin/zsh".to_string(), "-c".to_string(), {
|
||||
let snapshot_path = snapshot_path.to_string_lossy();
|
||||
format!("[ -f {snapshot_path} ] && . {snapshot_path}; (echo)")
|
||||
}];
|
||||
|
||||
assert_eq!(invocation, Some(expected_command));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
@@ -591,20 +374,12 @@ mod macos_tests {
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f ZSHRC_PATH ] && . ZSHRC_PATH; (myecho)",
|
||||
],
|
||||
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f ZSHRC_PATH ] && . ZSHRC_PATH; (myecho)",
|
||||
],
|
||||
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
@@ -612,7 +387,7 @@ mod macos_tests {
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f ZSHRC_PATH ] && . ZSHRC_PATH; (bash -c \"echo 'single' \\\"double\\\"\")",
|
||||
"source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
@@ -621,7 +396,7 @@ mod macos_tests {
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"[ -f ZSHRC_PATH ] && . ZSHRC_PATH; (echo 'single' \"double\")",
|
||||
"source ZSHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
@@ -648,20 +423,16 @@ mod macos_tests {
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Posix(PosixShell {
|
||||
let shell = Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
rc_path: zshrc_path.to_str().unwrap().to_string(),
|
||||
shell_snapshot: None,
|
||||
zshrc_path: zshrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
/// Flat info parsed from the JWT in auth.json.
|
||||
@@ -22,36 +20,6 @@ pub struct TokenData {
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl TokenData {
|
||||
/// Returns true if this is a plan that should use the traditional
|
||||
/// "metered" billing via an API key.
|
||||
pub(crate) fn should_use_api_key(
|
||||
&self,
|
||||
preferred_auth_method: AuthMode,
|
||||
is_openai_email: bool,
|
||||
) -> bool {
|
||||
if preferred_auth_method == AuthMode::ApiKey {
|
||||
return true;
|
||||
}
|
||||
// If the email is an OpenAI email, use AuthMode::ChatGPT unless preferred_auth_method is AuthMode::ApiKey.
|
||||
if is_openai_email {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.id_token
|
||||
.chatgpt_plan_type
|
||||
.as_ref()
|
||||
.is_none_or(|plan| plan.is_plan_that_should_use_api_key())
|
||||
}
|
||||
|
||||
pub fn is_openai_email(&self) -> bool {
|
||||
self.id_token
|
||||
.email
|
||||
.as_deref()
|
||||
.is_some_and(|email| email.trim().to_ascii_lowercase().ends_with("@openai.com"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flat subset of useful claims in id_token from auth.json.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
@@ -79,28 +47,6 @@ pub(crate) enum PlanType {
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl PlanType {
|
||||
fn is_plan_that_should_use_api_key(&self) -> bool {
|
||||
match self {
|
||||
Self::Known(known) => {
|
||||
use KnownPlan::*;
|
||||
!matches!(known, Free | Plus | Pro | Team)
|
||||
}
|
||||
Self::Unknown(_) => {
|
||||
// Unknown plans should use the API key.
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_string(&self) -> String {
|
||||
match self {
|
||||
Self::Known(known) => format!("{known:?}").to_lowercase(),
|
||||
Self::Unknown(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum KnownPlan {
|
||||
|
||||
180
codex-rs/core/src/truncate.rs
Normal file
180
codex-rs/core/src/truncate.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||
//! and suffix on UTF-8 boundaries.
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1;
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1;
|
||||
}
|
||||
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
let mut guess_tokens = est_tokens;
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::truncate_middle;
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||
let max_bytes = 32;
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
assert!(out.starts_with("abc"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("XYZ*"));
|
||||
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
let max_bytes = 64;
|
||||
let (out, tokens) = truncate_middle(&s, max_bytes);
|
||||
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||
assert_eq!(tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_handles_utf8_content() {
|
||||
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||
let max_bytes = 32;
|
||||
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(!out.contains('\u{fffd}'));
|
||||
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -678,7 +678,7 @@ index {left_oid}..{right_oid}
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
src,
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
|
||||
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum UnifiedExecError {
|
||||
#[error("Failed to create unified exec session: {pty_error}")]
|
||||
CreateSession {
|
||||
#[source]
|
||||
pty_error: anyhow::Error,
|
||||
},
|
||||
#[error("Unknown session id {session_id}")]
|
||||
UnknownSessionId { session_id: i32 },
|
||||
#[error("failed to write to stdin")]
|
||||
WriteToStdin,
|
||||
#[error("missing command line for unified exec request")]
|
||||
MissingCommandLine,
|
||||
}
|
||||
|
||||
impl UnifiedExecError {
|
||||
pub(crate) fn create_session(error: anyhow::Error) -> Self {
|
||||
Self::CreateSession { pty_error: error }
|
||||
}
|
||||
}
|
||||
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
@@ -0,0 +1,633 @@
|
||||
use portable_pty::CommandBuilder;
|
||||
use portable_pty::PtySize;
|
||||
use portable_pty::native_pty_system;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI32;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::exec_command::ExecCommandSession;
|
||||
use crate::truncate::truncate_middle;
|
||||
|
||||
mod errors;
|
||||
|
||||
pub(crate) use errors::UnifiedExecError;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
const MAX_TIMEOUT_MS: u64 = 60_000;
|
||||
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecRequest<'a> {
|
||||
pub session_id: Option<i32>,
|
||||
pub input_chunks: &'a [String],
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct UnifiedExecResult {
|
||||
pub session_id: Option<i32>,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct UnifiedExecSessionManager {
|
||||
next_session_id: AtomicI32,
|
||||
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ManagedUnifiedExecSession {
|
||||
session: ExecCommandSession,
|
||||
output_buffer: OutputBuffer,
|
||||
/// Notifies waiters whenever new output has been appended to
|
||||
/// `output_buffer`, allowing clients to poll for fresh data.
|
||||
output_notify: Arc<Notify>,
|
||||
output_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct OutputBufferState {
|
||||
chunks: VecDeque<Vec<u8>>,
|
||||
total_bytes: usize,
|
||||
}
|
||||
|
||||
impl OutputBufferState {
|
||||
fn push_chunk(&mut self, chunk: Vec<u8>) {
|
||||
self.total_bytes = self.total_bytes.saturating_add(chunk.len());
|
||||
self.chunks.push_back(chunk);
|
||||
|
||||
let mut excess = self
|
||||
.total_bytes
|
||||
.saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
|
||||
while excess > 0 {
|
||||
match self.chunks.front_mut() {
|
||||
Some(front) if excess >= front.len() => {
|
||||
excess -= front.len();
|
||||
self.total_bytes = self.total_bytes.saturating_sub(front.len());
|
||||
self.chunks.pop_front();
|
||||
}
|
||||
Some(front) => {
|
||||
front.drain(..excess);
|
||||
self.total_bytes = self.total_bytes.saturating_sub(excess);
|
||||
break;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn drain(&mut self) -> Vec<Vec<u8>> {
|
||||
let drained: Vec<Vec<u8>> = self.chunks.drain(..).collect();
|
||||
self.total_bytes = 0;
|
||||
drained
|
||||
}
|
||||
}
|
||||
|
||||
type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||
|
||||
impl ManagedUnifiedExecSession {
|
||||
fn new(session: ExecCommandSession) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let mut receiver = session.output_receiver();
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
output_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.session.writer_sender()
|
||||
}
|
||||
|
||||
fn output_handles(&self) -> OutputHandles {
|
||||
(
|
||||
Arc::clone(&self.output_buffer),
|
||||
Arc::clone(&self.output_notify),
|
||||
)
|
||||
}
|
||||
|
||||
fn has_exited(&self) -> bool {
|
||||
self.session.has_exited()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ManagedUnifiedExecSession {
|
||||
fn drop(&mut self) {
|
||||
self.output_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedExecSessionManager {
|
||||
pub async fn handle_request(
|
||||
&self,
|
||||
request: UnifiedExecRequest<'_>,
|
||||
) -> Result<UnifiedExecResult, UnifiedExecError> {
|
||||
let (timeout_ms, timeout_warning) = match request.timeout_ms {
|
||||
Some(requested) if requested > MAX_TIMEOUT_MS => (
|
||||
MAX_TIMEOUT_MS,
|
||||
Some(format!(
|
||||
"Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n"
|
||||
)),
|
||||
),
|
||||
Some(requested) => (requested, None),
|
||||
None => (DEFAULT_TIMEOUT_MS, None),
|
||||
};
|
||||
|
||||
let mut new_session: Option<ManagedUnifiedExecSession> = None;
|
||||
let session_id;
|
||||
let writer_tx;
|
||||
let output_buffer;
|
||||
let output_notify;
|
||||
|
||||
if let Some(existing_id) = request.session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
match sessions.get(&existing_id) {
|
||||
Some(session) => {
|
||||
if session.has_exited() {
|
||||
sessions.remove(&existing_id);
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
let (buffer, notify) = session.output_handles();
|
||||
session_id = existing_id;
|
||||
writer_tx = session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
}
|
||||
None => {
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
drop(sessions);
|
||||
} else {
|
||||
let command = request.input_chunks.to_vec();
|
||||
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||
let session = create_unified_exec_session(&command).await?;
|
||||
let managed_session = ManagedUnifiedExecSession::new(session);
|
||||
let (buffer, notify) = managed_session.output_handles();
|
||||
writer_tx = managed_session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
session_id = new_id;
|
||||
new_session = Some(managed_session);
|
||||
};
|
||||
|
||||
if request.session_id.is_some() {
|
||||
let joined_input = request.input_chunks.join(" ");
|
||||
if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err()
|
||||
{
|
||||
return Err(UnifiedExecError::WriteToStdin);
|
||||
}
|
||||
}
|
||||
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let start = Instant::now();
|
||||
let deadline = start + Duration::from_millis(timeout_ms);
|
||||
|
||||
loop {
|
||||
let drained_chunks;
|
||||
let mut wait_for_output = None;
|
||||
{
|
||||
let mut guard = output_buffer.lock().await;
|
||||
drained_chunks = guard.drain();
|
||||
if drained_chunks.is_empty() {
|
||||
wait_for_output = Some(output_notify.notified());
|
||||
}
|
||||
}
|
||||
|
||||
if drained_chunks.is_empty() {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining == Duration::ZERO {
|
||||
break;
|
||||
}
|
||||
|
||||
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||
tokio::pin!(notified);
|
||||
tokio::select! {
|
||||
_ = &mut notified => {}
|
||||
_ = tokio::time::sleep(remaining) => break,
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for chunk in drained_chunks {
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (output, _maybe_tokens) = truncate_middle(
|
||||
&String::from_utf8_lossy(&collected),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
let output = if let Some(warning) = timeout_warning {
|
||||
format!("{warning}{output}")
|
||||
} else {
|
||||
output
|
||||
};
|
||||
|
||||
let should_store_session = if let Some(session) = new_session.as_ref() {
|
||||
!session.has_exited()
|
||||
} else if request.session_id.is_some() {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
if existing.has_exited() {
|
||||
sessions.remove(&session_id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_store_session {
|
||||
if let Some(session) = new_session {
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
}
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: Some(session_id),
|
||||
output,
|
||||
})
|
||||
} else {
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: None,
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_unified_exec_session(
|
||||
command: &[String],
|
||||
) -> Result<ExecCommandSession, UnifiedExecError> {
|
||||
if command.is_empty() {
|
||||
return Err(UnifiedExecError::MissingCommandLine);
|
||||
}
|
||||
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
let pair = pty_system
|
||||
.openpty(PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
|
||||
// Safe thanks to the check at the top of the function.
|
||||
let mut command_builder = CommandBuilder::new(command[0].clone());
|
||||
for arg in &command[1..] {
|
||||
command_builder.arg(arg);
|
||||
}
|
||||
|
||||
let mut child = pair
|
||||
.slave
|
||||
.spawn_command(command_builder)
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let killer = child.clone_killer();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||
|
||||
let mut reader = pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||
}
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
continue;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer = pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let writer = Arc::new(StdMutex::new(writer));
|
||||
let writer_handle = tokio::spawn({
|
||||
let writer = writer.clone();
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let writer = writer.clone();
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Ok(mut guard) = writer.lock() {
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = Arc::clone(&exit_status);
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let _ = child.wait();
|
||||
wait_exit_status.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
Ok(ExecCommandSession::new(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer,
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn push_chunk_trims_only_excess_bytes() {
|
||||
let mut buffer = OutputBufferState::default();
|
||||
buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]);
|
||||
buffer.push_chunk(vec![b'b']);
|
||||
buffer.push_chunk(vec![b'c']);
|
||||
|
||||
assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
assert_eq!(buffer.chunks.len(), 3);
|
||||
assert_eq!(
|
||||
buffer.chunks.front().unwrap().len(),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2
|
||||
);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session_id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_2.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let shell_a = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_a = shell_a.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &[
|
||||
"echo".to_string(),
|
||||
"$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(10),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(7)).await;
|
||||
|
||||
let empty = Vec::new();
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &empty,
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(120_000),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.output.starts_with(
|
||||
"Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n"
|
||||
));
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.session_id.is_none());
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
assert!(manager.sessions.lock().await.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["exit\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let err = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[],
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await
|
||||
.expect_err("expected unknown session error");
|
||||
|
||||
match err {
|
||||
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
|
||||
assert_eq!(err_id, session_id);
|
||||
}
|
||||
other => panic!("expected UnknownSessionId, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!manager.sessions.lock().await.contains_key(&session_id));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,12 @@ use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::project_doc::get_user_instructions;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -222,6 +223,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
};
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
let cwd = TempDir::new().unwrap();
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.experimental_resume = Some(session_path.clone());
|
||||
// Also configure user instructions to ensure they are NOT delivered on resume.
|
||||
@@ -260,6 +263,29 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
// Build expected environment context for this turn.
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg_turn = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
@@ -271,12 +297,14 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "output_text", "text": "resumed assistant message" }]
|
||||
},
|
||||
expected_env_msg_turn,
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "hello" }]
|
||||
}
|
||||
]);
|
||||
|
||||
assert_eq!(request_body["input"], expected_input);
|
||||
}
|
||||
|
||||
@@ -489,79 +517,6 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_chatgpt_token_when_config_prefers_chatgpt() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
// Expect ChatGPT base path and correct headers
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header_regex("Authorization", r"Bearer Access-123"))
|
||||
.and(header_regex("chatgpt-account-id", r"acc-123"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
// Write auth.json that contains both API key and ChatGPT tokens for a plan that should prefer ChatGPT.
|
||||
let _jwt = write_auth_json(
|
||||
&codex_home,
|
||||
Some("sk-test-key"),
|
||||
"pro",
|
||||
"Access-123",
|
||||
Some("acc-123"),
|
||||
);
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ChatGPT;
|
||||
|
||||
let auth_manager =
|
||||
match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
@@ -606,14 +561,12 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ApiKey;
|
||||
|
||||
let auth_manager =
|
||||
match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let auth_manager = match CodexAuth::from_codex_home(codex_home.path()) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
@@ -914,7 +867,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.new_conversation(config.clone())
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
@@ -949,34 +902,49 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
|
||||
|
||||
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
|
||||
let r3_tail_expected = json!([
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U1"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U2"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U3"}]
|
||||
}
|
||||
]);
|
||||
// Build expected environment context dynamically to avoid OS-dependent flakiness.
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
std::env::current_dir().unwrap().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
});
|
||||
// Wrap user instructions in the XML container to match the raw/ingest view
|
||||
let expected_ui_text = format!(
|
||||
"<user_instructions>\n\n{}\n\n</user_instructions>",
|
||||
user_instructions.clone().unwrap()
|
||||
);
|
||||
let expected_ui_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
|
||||
let expected_full = json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U1"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U2"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg,
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U3"}]}]);
|
||||
|
||||
let r3_input_array = requests[2]
|
||||
.body_json::<serde_json::Value>()
|
||||
@@ -985,12 +953,6 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.expect("r3 missing input array");
|
||||
// skipping earlier context and developer messages
|
||||
let tail_len = r3_tail_expected.as_array().unwrap().len();
|
||||
let actual_tail = &r3_input_array[r3_input_array.len() - tail_len..];
|
||||
assert_eq!(
|
||||
serde_json::Value::Array(actual_tail.to_vec()),
|
||||
r3_tail_expected,
|
||||
"request 3 tail mismatch",
|
||||
);
|
||||
|
||||
assert_eq!(json!(r3_input_array), expected_full);
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -142,11 +145,12 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager.new_conversation(config).await.unwrap();
|
||||
let rollout_path = session_configured.rollout_path;
|
||||
|
||||
// 1) Normal user input – should hit server once.
|
||||
codex
|
||||
@@ -248,4 +252,47 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
|
||||
// Shut down Codex to flush rollout entries before inspecting the file.
|
||||
codex.submit(Op::Shutdown).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
// Verify rollout contains APITurn entries for each API call and a Compacted entry.
|
||||
let text = std::fs::read_to_string(&rollout_path).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"failed to read rollout file {}: {e}",
|
||||
rollout_path.display()
|
||||
)
|
||||
});
|
||||
let mut api_turn_count = 0usize;
|
||||
let mut saw_compacted_summary = false;
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(entry): Result<RolloutLine, _> = serde_json::from_str(trimmed) else {
|
||||
continue;
|
||||
};
|
||||
match entry.item {
|
||||
RolloutItem::TurnContext(_) => {
|
||||
api_turn_count += 1;
|
||||
}
|
||||
RolloutItem::Compacted(ci) => {
|
||||
if ci.message == SUMMARY_TEXT {
|
||||
saw_compacted_summary = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
api_turn_count == 3,
|
||||
"expected three APITurn entries in rollout"
|
||||
);
|
||||
assert!(
|
||||
saw_compacted_summary,
|
||||
"expected a Compacted entry containing the summarizer output"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ConversationHistoryResponseEvent;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
@@ -71,92 +75,121 @@ async fn fork_conversation_twice_drops_to_first_message() {
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Request history from the base conversation.
|
||||
codex.submit(Op::GetHistory).await.unwrap();
|
||||
// Request history from the base conversation to obtain rollout path.
|
||||
codex.submit(Op::GetPath).await.unwrap();
|
||||
let base_history =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
|
||||
|
||||
// Capture entries from the base history and compute expected prefixes after each fork.
|
||||
let history_after_three = match &base_history {
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
|
||||
history.clone()
|
||||
}
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationPath(_))).await;
|
||||
let base_path = match &base_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event"),
|
||||
};
|
||||
let entries_after_three = history_after_three.get_rollout_items();
|
||||
// History layout for this test:
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
// [3] "second" user message,
|
||||
// [4] "third" user message.
|
||||
|
||||
// Fork 1: drops the last user message and everything after.
|
||||
let expected_after_first = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
entries_after_three[3].clone(),
|
||||
];
|
||||
// GetHistory flushes before returning the path; no wait needed.
|
||||
|
||||
// Fork 2: drops the last user message and everything after.
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
let expected_after_second = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
];
|
||||
// Helper: read rollout items (excluding SessionMeta) from a JSONL path.
|
||||
let read_items = |p: &std::path::Path| -> Vec<RolloutItem> {
|
||||
let text = std::fs::read_to_string(p).expect("read rollout file");
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value = serde_json::from_str(line).expect("jsonl line");
|
||||
let rl: RolloutLine = serde_json::from_value(v).expect("rollout line");
|
||||
match rl.item {
|
||||
RolloutItem::SessionMeta(_) => {}
|
||||
other => items.push(other),
|
||||
}
|
||||
}
|
||||
items
|
||||
};
|
||||
|
||||
// Fork once with n=1 → drops the last user message and everything after.
|
||||
// Compute expected prefixes after each fork by truncating base rollout at nth-from-last user input.
|
||||
let base_items = read_items(&base_path);
|
||||
let find_user_input_positions = |items: &[RolloutItem]| -> Vec<usize> {
|
||||
let mut pos = Vec::new();
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = it
|
||||
&& role == "user"
|
||||
{
|
||||
// Consider any user message as an input boundary; recorder stores both EventMsg and ResponseItem.
|
||||
// We specifically look for input items, which are represented as ContentItem::InputText.
|
||||
if content
|
||||
.iter()
|
||||
.any(|c| matches!(c, ContentItem::InputText { .. }))
|
||||
{
|
||||
pos.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
pos
|
||||
};
|
||||
let user_inputs = find_user_input_positions(&base_items);
|
||||
|
||||
// After dropping last user input (n=1), cut strictly before that input if present, else empty.
|
||||
let cut1 = user_inputs
|
||||
.get(user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_first: Vec<RolloutItem> = base_items[..cut1].to_vec();
|
||||
|
||||
// After dropping again (n=1 on fork1), compute expected relative to fork1's rollout.
|
||||
|
||||
// Fork once with n=1 → drops the last user input and everything after.
|
||||
let NewConversation {
|
||||
conversation: codex_fork1,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(history_after_three.clone(), 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), base_path.clone())
|
||||
.await
|
||||
.expect("fork 1");
|
||||
|
||||
codex_fork1.submit(Op::GetHistory).await.unwrap();
|
||||
codex_fork1.submit(Op::GetPath).await.unwrap();
|
||||
let fork1_history = wait_for_event(&codex_fork1, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
let history_after_first_fork = match &fork1_history {
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
|
||||
let got = history.get_rollout_items();
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got).unwrap(),
|
||||
serde_json::to_value(&expected_after_first).unwrap()
|
||||
);
|
||||
history.clone()
|
||||
}
|
||||
let fork1_path = match &fork1_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after first fork"),
|
||||
};
|
||||
|
||||
// GetHistory on fork1 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork1_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_first).unwrap()
|
||||
);
|
||||
|
||||
// Fork again with n=1 → drops the (new) last user message, leaving only the first.
|
||||
let NewConversation {
|
||||
conversation: codex_fork2,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(history_after_first_fork.clone(), 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), fork1_path.clone())
|
||||
.await
|
||||
.expect("fork 2");
|
||||
|
||||
codex_fork2.submit(Op::GetHistory).await.unwrap();
|
||||
codex_fork2.submit(Op::GetPath).await.unwrap();
|
||||
let fork2_history = wait_for_event(&codex_fork2, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
match &fork2_history {
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { history, .. }) => {
|
||||
let got = history.get_rollout_items();
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got).unwrap(),
|
||||
serde_json::to_value(&expected_after_second).unwrap()
|
||||
);
|
||||
}
|
||||
let fork2_path = match &fork2_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after second fork"),
|
||||
}
|
||||
};
|
||||
// GetHistory on fork2 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
let fork1_user_inputs = find_user_input_positions(&fork1_items);
|
||||
let cut_last_on_fork1 = fork1_user_inputs
|
||||
.get(fork1_user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_second: Vec<RolloutItem> = fork1_items[..cut_last_on_fork1].to_vec();
|
||||
let fork2_items = read_items(&fork2_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork2_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_second).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod prompt_caching;
|
||||
mod seatbelt;
|
||||
mod stream_error_allows_next_turn;
|
||||
|
||||
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
const CONFIG_TOML: &str = "config.toml";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_persist_when_config_exists() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
let initial_contents = "model = \"gpt-4o\"\n";
|
||||
tokio::fs::write(&config_path, initial_contents)
|
||||
.await
|
||||
.expect("seed config.toml");
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model = "gpt-4o".to_string();
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
let contents = tokio::fs::read_to_string(&config_path)
|
||||
.await
|
||||
.expect("read config.toml after override");
|
||||
assert_eq!(contents, initial_contents);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_create_config_file() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"test setup should start without config"
|
||||
);
|
||||
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"override should not create config.toml"
|
||||
);
|
||||
}
|
||||
@@ -17,7 +17,6 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -270,9 +269,14 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let shell = default_user_shell(Uuid::new_v4(), codex_home.path()).await;
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let expected_env_text = format!(
|
||||
// Per-turn environment context includes the shell tag.
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
@@ -280,18 +284,15 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
}
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_ui_text =
|
||||
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
|
||||
|
||||
let expected_env_msg = serde_json::json!({
|
||||
let expected_env_msg_turn = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
let expected_ui_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
@@ -305,11 +306,29 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
"content": [ { "type": "input_text", "text": "hello 1" } ]
|
||||
});
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body1_input = body1["input"].as_array().unwrap();
|
||||
assert_eq!(
|
||||
body1["input"],
|
||||
serde_json::json!([expected_ui_msg, expected_env_msg, expected_user_message_1])
|
||||
serde_json::json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg_turn,
|
||||
expected_user_message_1
|
||||
])
|
||||
);
|
||||
|
||||
let env_texts: Vec<&str> = body1_input
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
msg.get("content")
|
||||
.and_then(|content| content.as_array())
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|item| item.get("text"))
|
||||
.and_then(|text| text.as_str())
|
||||
})
|
||||
.filter(|text| text.starts_with("<environment_context>"))
|
||||
.collect();
|
||||
assert_eq!(env_texts, vec![expected_env_text_turn.as_str()]);
|
||||
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
@@ -319,7 +338,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_turn, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
@@ -424,14 +443,29 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
// After overriding the turn context, the environment context should be emitted again
|
||||
// reflecting the new approval policy and sandbox settings. Omit cwd because it did
|
||||
// not change.
|
||||
let expected_env_text_2 = r#"<environment_context>
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
@@ -541,12 +575,165 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
new_cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_2, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
assert_eq!(body2["input"], expected_body2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tools_stable_across_all_approval_policy_transitions() {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed("resp");
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
// Build all transitions FROM each to each other (exclude self transitions)
|
||||
let policies = vec![
|
||||
AskForApproval::UnlessTrusted,
|
||||
AskForApproval::OnFailure,
|
||||
AskForApproval::OnRequest,
|
||||
AskForApproval::Never,
|
||||
];
|
||||
let mut transitions: Vec<(AskForApproval, AskForApproval)> = Vec::new();
|
||||
for &from in &policies {
|
||||
for &to in &policies {
|
||||
if from != to {
|
||||
transitions.push((from, to));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expect 2 POSTs per transition
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect((transitions.len() * 2) as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
// Keep tools stable and minimal
|
||||
config.include_plan_tool = false;
|
||||
config.include_apply_patch_tool = false;
|
||||
config.tools_web_search_request = false;
|
||||
config.use_experimental_unified_exec_tool = true; // policy-independent tool
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
for (i, (from, to)) in transitions.iter().enumerate() {
|
||||
// Ensure a known starting policy for this pair
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*from),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-a"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Override to the target policy and send next turn
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*to),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-b"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Verify tool arrays are identical across each pair of requests
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
transitions.len() * 2,
|
||||
"expected 2 requests per transition"
|
||||
);
|
||||
|
||||
for i in 0..transitions.len() {
|
||||
let body_a = requests[2 * i].body_json::<serde_json::Value>().unwrap();
|
||||
let body_b = requests[2 * i + 1]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
body_a["tools"], body_b["tools"],
|
||||
"tools changed between requests for transition #{i}: {:?}",
|
||||
transitions[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
parsed_cmd: _,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
},
|
||||
@@ -382,7 +382,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
@@ -520,6 +520,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
let SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
initial_messages: _,
|
||||
@@ -559,7 +560,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
},
|
||||
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
|
||||
EventMsg::ConversationHistory(_) => {}
|
||||
EventMsg::ConversationPath(_) => {}
|
||||
EventMsg::UserMessage(_) => {}
|
||||
}
|
||||
CodexStatus::Running
|
||||
|
||||
@@ -187,10 +187,8 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn run_e2e_exec_test(cwd: &Path, response_streams: Vec<String>)
|
||||
.context("should find binary for codex-exec")
|
||||
.expect("should find binary for codex-exec")
|
||||
.current_dir(cwd.clone())
|
||||
.env("CODEX_HOME", cwd.clone())
|
||||
.env("CODEX_HOME", cwd)
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("OPENAI_BASE_URL", format!("{uri}/v1"))
|
||||
.arg("--skip-git-repo-check")
|
||||
|
||||
@@ -88,7 +88,7 @@ impl ExecvChecker {
|
||||
let mut program = valid_exec.program.to_string();
|
||||
for system_path in valid_exec.system_path {
|
||||
if is_executable_file(&system_path) {
|
||||
program = system_path.to_string();
|
||||
program = system_path;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -196,7 +196,7 @@ system_path=[{fake_cp:?}]
|
||||
let checker = setup(&fake_cp);
|
||||
let exec_call = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![source.clone(), dest.clone()],
|
||||
args: vec![source, dest.clone()],
|
||||
};
|
||||
let valid_exec = match checker.r#match(&exec_call)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -207,7 +207,7 @@ system_path=[{fake_cp:?}]
|
||||
assert_eq!(
|
||||
checker.check(valid_exec.clone(), &cwd, &[], &[]),
|
||||
Err(ReadablePathNotInReadableFolders {
|
||||
file: source_path.clone(),
|
||||
file: source_path,
|
||||
folders: vec![]
|
||||
}),
|
||||
);
|
||||
@@ -229,7 +229,7 @@ system_path=[{fake_cp:?}]
|
||||
// Both readable and writeable folders specified.
|
||||
assert_eq!(
|
||||
checker.check(
|
||||
valid_exec.clone(),
|
||||
valid_exec,
|
||||
&cwd,
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
@@ -241,7 +241,7 @@ system_path=[{fake_cp:?}]
|
||||
// folders.
|
||||
let exec_call_folders_as_args = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![root.clone(), root.clone()],
|
||||
args: vec![root.clone(), root],
|
||||
};
|
||||
let valid_exec_call_folders_as_args = match checker.r#match(&exec_call_folders_as_args)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -254,7 +254,7 @@ system_path=[{fake_cp:?}]
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
),
|
||||
Ok(cp.clone()),
|
||||
Ok(cp),
|
||||
);
|
||||
|
||||
// Specify a parent of a readable folder as input.
|
||||
|
||||
@@ -104,7 +104,7 @@ impl PolicyBuilder {
|
||||
info!("adding program spec: {program_spec:?}");
|
||||
let name = program_spec.program.clone();
|
||||
let mut programs = self.programs.borrow_mut();
|
||||
programs.insert(name.clone(), program_spec);
|
||||
programs.insert(name, program_spec);
|
||||
}
|
||||
|
||||
fn add_forbidden_substrings(&self, substrings: &[String]) {
|
||||
|
||||
@@ -31,6 +31,13 @@ install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
# Run `cargo nextest` since it's faster than `cargo test`, though including
|
||||
# --no-fail-fast is important to ensure all tests are run.
|
||||
#
|
||||
# Run `cargo install cargo-nextest` if you don't have it installed.
|
||||
test:
|
||||
cargo nextest run --no-fail-fast
|
||||
|
||||
# Run the MCP server
|
||||
mcp-server-run *args:
|
||||
cargo run -p codex-mcp-server -- "$@"
|
||||
|
||||
@@ -42,7 +42,7 @@ impl ServerOptions {
|
||||
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
client_id: client_id.to_string(),
|
||||
client_id,
|
||||
issuer: DEFAULT_ISSUER.to_string(),
|
||||
port: DEFAULT_PORT,
|
||||
open_browser: true,
|
||||
@@ -126,7 +126,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
let server = server.clone();
|
||||
let server = server;
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
|
||||
@@ -17,10 +17,10 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_mcp_client::McpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::MCP_SCHEMA_VERSION;
|
||||
use mcp_types::McpClientInfo;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
@@ -60,10 +60,13 @@ async fn main() -> Result<()> {
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: McpClientInfo {
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
// This field is used by Codex when it is an MCP server: it should
|
||||
// not be used when Codex is an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -40,6 +40,7 @@ uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
base64 = "0.22"
|
||||
mcp_test_support = { path = "tests/common" }
|
||||
os_info = "3.12.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
|
||||
@@ -11,10 +11,16 @@ use codex_core::NewConversation;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::SessionMeta;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::try_read_auth_json;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::config_edit::CONFIG_KEY_EFFORT;
|
||||
use codex_core::config_edit::CONFIG_KEY_MODEL;
|
||||
use codex_core::config_edit::persist_non_null_overrides;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
@@ -37,7 +43,6 @@ use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationResponse;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
@@ -55,6 +60,8 @@ use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationResponse;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
@@ -67,6 +74,9 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
||||
use codex_protocol::mcp_protocol::ServerNotification;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -169,6 +179,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GitDiffToRemote { request_id, params } => {
|
||||
self.git_diff_to_origin(request_id, params.cwd).await;
|
||||
}
|
||||
ClientRequest::LoginApiKey { request_id, params } => {
|
||||
self.login_api_key(request_id, params).await;
|
||||
}
|
||||
ClientRequest::LoginChatGpt { request_id } => {
|
||||
self.login_chatgpt(request_id).await;
|
||||
}
|
||||
@@ -184,15 +197,54 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GetUserSavedConfig { request_id } => {
|
||||
self.get_user_saved_config(request_id).await;
|
||||
}
|
||||
ClientRequest::SetDefaultModel { request_id, params } => {
|
||||
self.set_default_model(request_id, params).await;
|
||||
}
|
||||
ClientRequest::GetUserAgent { request_id } => {
|
||||
self.get_user_agent(request_id).await;
|
||||
}
|
||||
ClientRequest::UserInfo { request_id } => {
|
||||
self.get_user_info(request_id).await;
|
||||
}
|
||||
ClientRequest::ExecOneOffCommand { request_id, params } => {
|
||||
self.exec_one_off_command(request_id, params).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_api_key(&mut self, request_id: RequestId, params: LoginApiKeyParams) {
|
||||
{
|
||||
let mut guard = self.active_login.lock().await;
|
||||
if let Some(active) = guard.take() {
|
||||
active.drop();
|
||||
}
|
||||
}
|
||||
|
||||
match login_with_api_key(&self.config.codex_home, ¶ms.api_key) {
|
||||
Ok(()) => {
|
||||
self.auth_manager.reload();
|
||||
self.outgoing
|
||||
.send_response(request_id, LoginApiKeyResponse {})
|
||||
.await;
|
||||
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: self.auth_manager.auth().map(|auth| auth.mode),
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to save api key: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_chatgpt(&mut self, request_id: RequestId) {
|
||||
let config = self.config.as_ref();
|
||||
|
||||
@@ -346,7 +398,7 @@ impl CodexMessageProcessor {
|
||||
.await;
|
||||
|
||||
// Send auth status change notification reflecting the current auth mode
|
||||
// after logout (which may fall back to API key via env var).
|
||||
// after logout.
|
||||
let current_auth_method = self.auth_manager.auth().map(|auth| auth.mode);
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: current_auth_method,
|
||||
@@ -361,7 +413,6 @@ impl CodexMessageProcessor {
|
||||
request_id: RequestId,
|
||||
params: codex_protocol::mcp_protocol::GetAuthStatusParams,
|
||||
) {
|
||||
let preferred_auth_method: AuthMode = self.auth_manager.preferred_auth_method();
|
||||
let include_token = params.include_token.unwrap_or(false);
|
||||
let do_refresh = params.refresh_token.unwrap_or(false);
|
||||
|
||||
@@ -369,6 +420,11 @@ impl CodexMessageProcessor {
|
||||
tracing::warn!("failed to refresh token while getting auth status: {err}");
|
||||
}
|
||||
|
||||
// Determine whether auth is required based on the active model provider.
|
||||
// If a custom provider is configured with `requires_openai_auth == false`,
|
||||
// then no auth step is required; otherwise, default to requiring auth.
|
||||
let requires_openai_auth = Some(self.config.model_provider.requires_openai_auth);
|
||||
|
||||
let response = match self.auth_manager.auth() {
|
||||
Some(auth) => {
|
||||
let (reported_auth_method, token_opt) = match auth.get_token().await {
|
||||
@@ -384,14 +440,14 @@ impl CodexMessageProcessor {
|
||||
};
|
||||
codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: reported_auth_method,
|
||||
preferred_auth_method,
|
||||
auth_token: token_opt,
|
||||
requires_openai_auth,
|
||||
}
|
||||
}
|
||||
None => codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: None,
|
||||
preferred_auth_method,
|
||||
auth_token: None,
|
||||
requires_openai_auth,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -439,6 +495,52 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_user_info(&self, request_id: RequestId) {
|
||||
// Read alleged user email from auth.json (best-effort; not verified).
|
||||
let auth_path = get_auth_file(&self.config.codex_home);
|
||||
let alleged_user_email = match try_read_auth_json(&auth_path) {
|
||||
Ok(auth) => auth.tokens.and_then(|t| t.id_token.email),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let response = UserInfoResponse { alleged_user_email };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) {
|
||||
let SetDefaultModelParams {
|
||||
model,
|
||||
reasoning_effort,
|
||||
} = params;
|
||||
let effort_str = reasoning_effort.map(|effort| effort.to_string());
|
||||
|
||||
let overrides: [(&[&str], Option<&str>); 2] = [
|
||||
(&[CONFIG_KEY_MODEL], model.as_deref()),
|
||||
(&[CONFIG_KEY_EFFORT], effort_str.as_deref()),
|
||||
];
|
||||
|
||||
match persist_non_null_overrides(
|
||||
&self.config.codex_home,
|
||||
self.config.active_profile.as_deref(),
|
||||
&overrides,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
let response = SetDefaultModelResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to persist overrides: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
||||
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
||||
|
||||
@@ -533,6 +635,7 @@ impl CodexMessageProcessor {
|
||||
let response = NewConversationResponse {
|
||||
conversation_id,
|
||||
model: session_configured.model,
|
||||
reasoning_effort: session_configured.reasoning_effort,
|
||||
rollout_path: session_configured.rollout_path,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
@@ -222,7 +222,7 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
Some(msg) => msg,
|
||||
None => "".to_string(),
|
||||
};
|
||||
let result = CallToolResult {
|
||||
@@ -277,7 +277,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::ConversationPath(_)
|
||||
| EventMsg::UserMessage(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
|
||||
@@ -56,8 +56,7 @@ impl MessageProcessor {
|
||||
config: Arc<Config>,
|
||||
) -> Self {
|
||||
let outgoing = Arc::new(outgoing);
|
||||
let auth_manager =
|
||||
AuthManager::shared(config.codex_home.clone(), config.preferred_auth_method);
|
||||
let auth_manager = AuthManager::shared(config.codex_home.clone());
|
||||
let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone()));
|
||||
let codex_message_processor = CodexMessageProcessor::new(
|
||||
auth_manager,
|
||||
@@ -234,11 +233,11 @@ impl MessageProcessor {
|
||||
},
|
||||
instructions: None,
|
||||
protocol_version: params.protocol_version.clone(),
|
||||
server_info: mcp_types::McpServerInfo {
|
||||
server_info: mcp_types::Implementation {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
user_agent: get_codex_user_agent(),
|
||||
user_agent: Some(get_codex_user_agent()),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -532,7 +531,6 @@ impl MessageProcessor {
|
||||
|
||||
// Spawn the long-running reply handler.
|
||||
tokio::spawn({
|
||||
let codex = codex.clone();
|
||||
let outgoing = outgoing.clone();
|
||||
let prompt = prompt.clone();
|
||||
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
@@ -258,6 +258,7 @@ pub(crate) struct OutgoingError {
|
||||
mod tests {
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -279,6 +280,7 @@ mod tests {
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
@@ -299,7 +301,7 @@ mod tests {
|
||||
let Ok(expected_params) = serde_json::to_value(&event) else {
|
||||
panic!("Event must serialize");
|
||||
};
|
||||
assert_eq!(params, Some(expected_params.clone()));
|
||||
assert_eq!(params, Some(expected_params));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -312,6 +314,7 @@ mod tests {
|
||||
let session_configured_event = SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
@@ -342,6 +345,7 @@ mod tests {
|
||||
"msg": {
|
||||
"session_id": session_configured_event.session_id,
|
||||
"model": session_configured_event.model,
|
||||
"reasoning_effort": session_configured_event.reasoning_effort,
|
||||
"history_log_id": session_configured_event.history_log_id,
|
||||
"history_entry_count": session_configured_event.history_entry_count,
|
||||
"type": "session_configured",
|
||||
|
||||
@@ -18,21 +18,23 @@ use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::McpClientInfo;
|
||||
use mcp_types::ModelContextProtocolNotification;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
@@ -54,6 +56,18 @@ pub struct McpProcess {
|
||||
|
||||
impl McpProcess {
|
||||
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
||||
Self::new_with_env(codex_home, &[]).await
|
||||
}
|
||||
|
||||
/// Creates a new MCP process, allowing tests to override or remove
|
||||
/// specific environment variables for the child process only.
|
||||
///
|
||||
/// Pass a tuple of (key, Some(value)) to set/override, or (key, None) to
|
||||
/// remove a variable from the child's environment.
|
||||
pub async fn new_with_env(
|
||||
codex_home: &Path,
|
||||
env_overrides: &[(&str, Option<&str>)],
|
||||
) -> anyhow::Result<Self> {
|
||||
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
||||
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
||||
.context("should find binary for codex-mcp-server")?;
|
||||
@@ -68,6 +82,17 @@ impl McpProcess {
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
cmd.env("RUST_LOG", "debug");
|
||||
|
||||
for (k, v) in env_overrides {
|
||||
match v {
|
||||
Some(val) => {
|
||||
cmd.env(k, val);
|
||||
}
|
||||
None => {
|
||||
cmd.env_remove(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
@@ -111,10 +136,11 @@ impl McpProcess {
|
||||
roots: None,
|
||||
sampling: None,
|
||||
},
|
||||
client_info: McpClientInfo {
|
||||
client_info: Implementation {
|
||||
name: "elicitation test".into(),
|
||||
title: Some("Elicitation Test".into()),
|
||||
version: "0.0.0".into(),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
||||
};
|
||||
@@ -271,6 +297,20 @@ impl McpProcess {
|
||||
self.send_request("getUserAgent", None).await
|
||||
}
|
||||
|
||||
/// Send a `userInfo` JSON-RPC request.
|
||||
pub async fn send_user_info_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("userInfo", None).await
|
||||
}
|
||||
|
||||
/// Send a `setDefaultModel` JSON-RPC request.
|
||||
pub async fn send_set_default_model_request(
|
||||
&mut self,
|
||||
params: SetDefaultModelParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("setDefaultModel", params).await
|
||||
}
|
||||
|
||||
/// Send a `listConversations` JSON-RPC request.
|
||||
pub async fn send_list_conversations_request(
|
||||
&mut self,
|
||||
@@ -289,6 +329,15 @@ impl McpProcess {
|
||||
self.send_request("resumeConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginApiKey` JSON-RPC request.
|
||||
pub async fn send_login_api_key_request(
|
||||
&mut self,
|
||||
params: LoginApiKeyParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("loginApiKey", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginChatGpt` JSON-RPC request.
|
||||
pub async fn send_login_chat_gpt_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("loginChatGpt", None).await
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -36,12 +37,31 @@ stream_max_retries = 0
|
||||
)
|
||||
}
|
||||
|
||||
async fn login_with_api_key_via_request(mcp: &mut McpProcess, api_key: &str) {
|
||||
let request_id = mcp
|
||||
.send_login_api_key_request(LoginApiKeyParams {
|
||||
api_key: api_key.to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("send loginApiKey: {e}"));
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("loginApiKey timeout: {e}"))
|
||||
.unwrap_or_else(|e| panic!("loginApiKey response: {e}"));
|
||||
let _: LoginApiKeyResponse =
|
||||
to_response(resp).unwrap_or_else(|e| panic!("deserialize login response: {e}"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_no_auth() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
let mut mcp = McpProcess::new_with_env(codex_home.path(), &[("OPENAI_API_KEY", None)])
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
@@ -72,8 +92,7 @@ async fn get_auth_status_no_auth() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -83,6 +102,8 @@ async fn get_auth_status_with_api_key() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_auth_status_request(GetAuthStatusParams {
|
||||
include_token: Some(true),
|
||||
@@ -101,14 +122,12 @@ async fn get_auth_status_with_api_key() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert_eq!(status.auth_token, Some("sk-test-key".to_string()));
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -118,6 +137,8 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
// Build params via struct so None field is omitted in wire JSON.
|
||||
let params = GetAuthStatusParams {
|
||||
include_token: None,
|
||||
@@ -138,5 +159,4 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert!(status.auth_token.is_none(), "token must be omitted");
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
@@ -90,6 +90,7 @@ async fn test_codex_jsonrpc_conversation_flow() {
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
rollout_path: _,
|
||||
} = new_conv_resp;
|
||||
assert_eq!(model, "mock-model");
|
||||
|
||||
@@ -59,6 +59,7 @@ async fn test_conversation_create_and_send_message_ok() {
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
rollout_path: _,
|
||||
} = to_response::<NewConversationResponse>(new_conv_resp)
|
||||
.expect("deserialize newConversation response");
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
@@ -46,7 +46,7 @@ async fn logout_chatgpt_removes_auth() {
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
assert!(codex_home.path().join("auth.json").exists());
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
let mut mcp = McpProcess::new_with_env(codex_home.path(), &[("OPENAI_API_KEY", None)])
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
@@ -95,7 +95,7 @@ async fn logout_chatgpt_removes_auth() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn login_and_cancel_chatgpt() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
|
||||
@@ -9,4 +9,6 @@ mod interrupt;
|
||||
mod list_resume;
|
||||
mod login;
|
||||
mod send_message;
|
||||
mod set_default_model;
|
||||
mod user_agent;
|
||||
mod user_info;
|
||||
|
||||
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn set_default_model_persists_overrides() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
let params = SetDefaultModelParams {
|
||||
model: Some("o4-mini".to_string()),
|
||||
reasoning_effort: Some(ReasoningEffort::High),
|
||||
};
|
||||
|
||||
let request_id = mcp
|
||||
.send_set_default_model_request(params)
|
||||
.await
|
||||
.expect("send setDefaultModel");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("setDefaultModel timeout")
|
||||
.expect("setDefaultModel response");
|
||||
|
||||
let _: SetDefaultModelResponse =
|
||||
to_response(resp).expect("deserialize setDefaultModel response");
|
||||
|
||||
let config_path = codex_home.path().join("config.toml");
|
||||
let config_contents = tokio::fs::read_to_string(&config_path)
|
||||
.await
|
||||
.expect("read config.toml");
|
||||
let config_toml: ConfigToml = toml::from_str(&config_contents).expect("parse config.toml");
|
||||
|
||||
assert_eq!(
|
||||
ConfigToml {
|
||||
model: Some("o4-mini".to_string()),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
..Default::default()
|
||||
},
|
||||
config_toml,
|
||||
);
|
||||
}
|
||||
78
codex-rs/mcp-server/tests/suite/user_info.rs
Normal file
78
codex-rs/mcp-server/tests/suite/user_info.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use codex_core::auth::AuthDotJson;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::auth::write_auth_json;
|
||||
use codex_core::token_data::IdTokenInfo;
|
||||
use codex_core::token_data::TokenData;
|
||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn user_info_returns_email_from_auth_json() {
|
||||
let codex_home = TempDir::new().expect("create tempdir");
|
||||
|
||||
let auth_path = get_auth_file(codex_home.path());
|
||||
let mut id_token = IdTokenInfo::default();
|
||||
id_token.email = Some("user@example.com".to_string());
|
||||
id_token.raw_jwt = encode_id_token_with_email("user@example.com").expect("encode id token");
|
||||
|
||||
let auth = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token,
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: "refresh".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&auth_path, &auth).expect("write auth.json");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("initialize timeout")
|
||||
.expect("initialize request");
|
||||
|
||||
let request_id = mcp.send_user_info_request().await.expect("send userInfo");
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("userInfo timeout")
|
||||
.expect("userInfo response");
|
||||
|
||||
let received: UserInfoResponse = to_response(response).expect("deserialize userInfo response");
|
||||
let expected = UserInfoResponse {
|
||||
alleged_user_email: Some("user@example.com".to_string()),
|
||||
};
|
||||
|
||||
assert_eq!(received, expected);
|
||||
}
|
||||
|
||||
fn encode_id_token_with_email(email: &str) -> anyhow::Result<String> {
|
||||
let header_b64 = URL_SAFE_NO_PAD.encode(
|
||||
serde_json::to_vec(&json!({ "alg": "none", "typ": "JWT" }))
|
||||
.context("serialize jwt header")?,
|
||||
);
|
||||
let payload =
|
||||
serde_json::to_vec(&json!({ "email": email })).context("serialize jwt payload")?;
|
||||
let payload_b64 = URL_SAFE_NO_PAD.encode(payload);
|
||||
Ok(format!("{header_b64}.{payload_b64}.signature"))
|
||||
}
|
||||
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
crate_dir = Path(__file__).resolve().parent
|
||||
generator = crate_dir / "generate_mcp_types.py"
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(generator), "--check"],
|
||||
cwd=crate_dir,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -5,15 +5,19 @@ import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
from shutil import copy2
|
||||
|
||||
# Helper first so it is defined when other functions call it.
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
SCHEMA_VERSION = "2025-06-18"
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
@@ -43,16 +47,31 @@ def main() -> int:
|
||||
default_schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
default_lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
parser.add_argument(
|
||||
"schema_file",
|
||||
nargs="?",
|
||||
default=default_schema_file,
|
||||
help="schema.json file to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Regenerate lib.rs in a sandbox and ensure the checked-in file matches",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
schema_file = args.schema_file
|
||||
schema_file = Path(args.schema_file)
|
||||
crate_dir = Path(__file__).resolve().parent
|
||||
|
||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
if args.check:
|
||||
return run_check(schema_file, crate_dir, default_lib_rs)
|
||||
|
||||
generate_lib_rs(schema_file, default_lib_rs, fmt=True)
|
||||
return 0
|
||||
|
||||
|
||||
def generate_lib_rs(schema_file: Path, lib_rs: Path, fmt: bool) -> None:
|
||||
lib_rs.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
global DEFINITIONS # Allow helper functions to access the schema.
|
||||
|
||||
@@ -117,9 +136,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
for req_name in CLIENT_REQUEST_TYPE_NAMES:
|
||||
defn = definitions[req_name]
|
||||
method_const = (
|
||||
defn.get("properties", {}).get("method", {}).get("const", req_name)
|
||||
)
|
||||
method_const = defn.get("properties", {}).get("method", {}).get("const", req_name)
|
||||
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
|
||||
try_from_impl_lines.append(f' "{method_const}" => {{\n')
|
||||
try_from_impl_lines.append(
|
||||
@@ -128,9 +145,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
try_from_impl_lines.append(
|
||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||
)
|
||||
try_from_impl_lines.append(
|
||||
f" Ok(ClientRequest::{req_name}(params))\n"
|
||||
)
|
||||
try_from_impl_lines.append(f" Ok(ClientRequest::{req_name}(params))\n")
|
||||
try_from_impl_lines.append(" },\n")
|
||||
|
||||
try_from_impl_lines.append(
|
||||
@@ -144,9 +159,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
# Generate TryFrom for ServerNotification
|
||||
notif_impl_lines: list[str] = []
|
||||
notif_impl_lines.append(
|
||||
"impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
|
||||
)
|
||||
notif_impl_lines.append("impl TryFrom<JSONRPCNotification> for ServerNotification {\n")
|
||||
notif_impl_lines.append(" type Error = serde_json::Error;\n")
|
||||
notif_impl_lines.append(
|
||||
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
|
||||
@@ -155,9 +168,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
|
||||
n_def = definitions[notif_name]
|
||||
method_const = (
|
||||
n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
||||
)
|
||||
method_const = n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
||||
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
|
||||
notif_impl_lines.append(f' "{method_const}" => {{\n')
|
||||
# params may be optional
|
||||
@@ -167,9 +178,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
notif_impl_lines.append(
|
||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||
)
|
||||
notif_impl_lines.append(
|
||||
f" Ok(ServerNotification::{notif_name}(params))\n"
|
||||
)
|
||||
notif_impl_lines.append(f" Ok(ServerNotification::{notif_name}(params))\n")
|
||||
notif_impl_lines.append(" },\n")
|
||||
|
||||
notif_impl_lines.append(
|
||||
@@ -185,13 +194,70 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
for chunk in out:
|
||||
f.write(chunk)
|
||||
|
||||
subprocess.check_call(
|
||||
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
||||
cwd=lib_rs.parent.parent,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if fmt:
|
||||
subprocess.check_call(
|
||||
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
||||
cwd=lib_rs.parent.parent,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
def run_check(schema_file: Path, crate_dir: Path, checked_in_lib: Path) -> int:
|
||||
config_path = crate_dir.parent / "rustfmt.toml"
|
||||
eprint(f"Running --check with schema {schema_file}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
eprint(f"Created temporary workspace at {tmp_path}")
|
||||
manifest_path = tmp_path / "Cargo.toml"
|
||||
eprint(f"Copying Cargo.toml into {manifest_path}")
|
||||
copy2(crate_dir / "Cargo.toml", manifest_path)
|
||||
manifest_text = manifest_path.read_text(encoding="utf-8")
|
||||
manifest_text = manifest_text.replace(
|
||||
"version = { workspace = true }",
|
||||
'version = "0.0.0"',
|
||||
)
|
||||
manifest_text = manifest_text.replace("\n[lints]\nworkspace = true\n", "\n")
|
||||
manifest_path.write_text(manifest_text, encoding="utf-8")
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir(parents=True, exist_ok=True)
|
||||
eprint(f"Generating lib.rs into {src_dir}")
|
||||
generated_lib = src_dir / "lib.rs"
|
||||
|
||||
generate_lib_rs(schema_file, generated_lib, fmt=False)
|
||||
|
||||
eprint("Formatting generated lib.rs with rustfmt")
|
||||
subprocess.check_call(
|
||||
[
|
||||
"rustfmt",
|
||||
"--config-path",
|
||||
str(config_path),
|
||||
str(generated_lib),
|
||||
],
|
||||
cwd=tmp_path,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
eprint("Comparing generated lib.rs with checked-in version")
|
||||
checked_in_contents = checked_in_lib.read_text(encoding="utf-8")
|
||||
generated_contents = generated_lib.read_text(encoding="utf-8")
|
||||
|
||||
if checked_in_contents == generated_contents:
|
||||
eprint("lib.rs matches checked-in version")
|
||||
return 0
|
||||
|
||||
diff = unified_diff(
|
||||
checked_in_contents.splitlines(keepends=True),
|
||||
generated_contents.splitlines(keepends=True),
|
||||
fromfile=str(checked_in_lib),
|
||||
tofile=str(generated_lib),
|
||||
)
|
||||
diff_text = "".join(diff)
|
||||
eprint("Generated lib.rs does not match the checked-in version. Diff:")
|
||||
if diff_text:
|
||||
eprint(diff_text, end="")
|
||||
eprint("Re-run generate_mcp_types.py without --check to update src/lib.rs.")
|
||||
return 1
|
||||
|
||||
|
||||
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
|
||||
@@ -265,8 +331,11 @@ class StructField:
|
||||
name: str
|
||||
type_name: str
|
||||
serde: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
def append(self, out: list[str], supports_const: bool) -> None:
|
||||
if self.comment:
|
||||
out.append(f" // {self.comment}\n")
|
||||
if self.serde:
|
||||
out.append(f" {self.serde}\n")
|
||||
if self.viz == "const":
|
||||
@@ -312,6 +381,18 @@ def define_struct(
|
||||
else:
|
||||
fields.append(StructField("pub", rs_prop.name, prop_type, rs_prop.serde))
|
||||
|
||||
# Special-case: add Codex-specific user_agent to Implementation
|
||||
if name == "Implementation":
|
||||
fields.append(
|
||||
StructField(
|
||||
"pub",
|
||||
"user_agent",
|
||||
"Option<String>",
|
||||
'#[serde(default, skip_serializing_if = "Option::is_none")]',
|
||||
"This is an extra field that the Codex MCP server sends as part of InitializeResult.",
|
||||
)
|
||||
)
|
||||
|
||||
if implements_request_trait(name):
|
||||
add_trait_impl(name, "ModelContextProtocolRequest", fields, out)
|
||||
elif implements_notification_trait(name):
|
||||
@@ -406,15 +487,11 @@ def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> Non
|
||||
case "integer":
|
||||
out.append(" Integer(i64),\n")
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown type in untagged enum: {simple_type} in {name}"
|
||||
)
|
||||
raise ValueError(f"Unknown type in untagged enum: {simple_type} in {name}")
|
||||
out.append("}\n\n")
|
||||
|
||||
|
||||
def define_any_of(
|
||||
name: str, list_of_refs: list[Any], description: str | None = None
|
||||
) -> list[str]:
|
||||
def define_any_of(name: str, list_of_refs: list[Any], description: str | None = None) -> list[str]:
|
||||
"""Generate a Rust enum for a JSON-Schema `anyOf` union.
|
||||
|
||||
For most types we simply map each `$ref` inside the `anyOf` list to a
|
||||
@@ -479,9 +556,7 @@ def define_any_of(
|
||||
if name == "ClientRequest":
|
||||
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
|
||||
else:
|
||||
payload_type = (
|
||||
f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
||||
)
|
||||
payload_type = f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
||||
|
||||
# Determine the wire value for `method` so we can annotate the
|
||||
# variant appropriately. If for some reason the schema does not
|
||||
@@ -489,9 +564,7 @@ def define_any_of(
|
||||
# least compile (although deserialization will likely fail).
|
||||
request_def = DEFINITIONS.get(ref_name, {})
|
||||
method_const = (
|
||||
request_def.get("properties", {})
|
||||
.get("method", {})
|
||||
.get("const", ref_name)
|
||||
request_def.get("properties", {}).get("method", {}).get("const", ref_name)
|
||||
)
|
||||
|
||||
out.append(f' #[serde(rename = "{method_const}")]\n')
|
||||
@@ -541,7 +614,7 @@ def map_type(
|
||||
if type_prop == "string":
|
||||
if const_prop := typedef.get("const", None):
|
||||
assert isinstance(const_prop, str)
|
||||
return f'&\'static str = "{const_prop }"'
|
||||
return f'&\'static str = "{const_prop}"'
|
||||
else:
|
||||
return "String"
|
||||
elif type_prop == "integer":
|
||||
@@ -617,7 +690,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
serde_annotations.append('skip_serializing_if = "Option::is_none"')
|
||||
|
||||
if serde_annotations:
|
||||
serde_str = f'#[serde({", ".join(serde_annotations)})]'
|
||||
serde_str = f"#[serde({', '.join(serde_annotations)})]"
|
||||
else:
|
||||
serde_str = None
|
||||
return RustProp(prop_name, serde_str)
|
||||
@@ -625,9 +698,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
|
||||
def to_snake_case(name: str) -> str:
|
||||
"""Convert a camelCase or PascalCase name to snake_case."""
|
||||
snake_case = name[0].lower() + "".join(
|
||||
"_" + c.lower() if c.isupper() else c for c in name[1:]
|
||||
)
|
||||
snake_case = name[0].lower() + "".join("_" + c.lower() if c.isupper() else c for c in name[1:])
|
||||
if snake_case != name:
|
||||
return snake_case
|
||||
else:
|
||||
@@ -663,5 +734,9 @@ def emit_doc_comment(text: str | None, out: list[str]) -> None:
|
||||
out.append(f"/// {line.rstrip()}\n")
|
||||
|
||||
|
||||
def eprint(*args: Any, **kwargs: Any) -> None:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -482,21 +482,14 @@ pub struct ImageContent {
|
||||
|
||||
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)]
|
||||
pub struct McpClientInfo {
|
||||
pub struct Implementation {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)]
|
||||
pub struct McpServerInfo {
|
||||
pub name: String,
|
||||
// This is an extra field that the Codex MCP server sends as part of InitializeResult.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
pub user_agent: String,
|
||||
pub user_agent: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)]
|
||||
@@ -512,7 +505,7 @@ impl ModelContextProtocolRequest for InitializeRequest {
|
||||
pub struct InitializeRequestParams {
|
||||
pub capabilities: ClientCapabilities,
|
||||
#[serde(rename = "clientInfo")]
|
||||
pub client_info: McpClientInfo,
|
||||
pub client_info: Implementation,
|
||||
#[serde(rename = "protocolVersion")]
|
||||
pub protocol_version: String,
|
||||
}
|
||||
@@ -526,7 +519,7 @@ pub struct InitializeResult {
|
||||
#[serde(rename = "protocolVersion")]
|
||||
pub protocol_version: String,
|
||||
#[serde(rename = "serverInfo")]
|
||||
pub server_info: McpServerInfo,
|
||||
pub server_info: Implementation,
|
||||
}
|
||||
|
||||
impl From<InitializeResult> for serde_json::Value {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::ClientRequest;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::McpClientInfo;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json::json;
|
||||
|
||||
@@ -58,10 +58,11 @@ fn deserialize_initialize_request() {
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: McpClientInfo {
|
||||
client_info: Implementation {
|
||||
name: "acme-client".into(),
|
||||
title: Some("Acme".to_string()),
|
||||
version: "1.2.3".into(),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: "2025-06-18".into(),
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
||||
codex_protocol::mcp_protocol::SendUserTurnResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::InterruptConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GitDiffToRemoteResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginApiKeyParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginApiKeyResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::CancelLoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LogoutChatGptResponse::export_all_to(out_dir)?;
|
||||
@@ -38,7 +40,9 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
||||
codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SetDefaultModelResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::UserInfoResponse::export_all_to(out_dir)?;
|
||||
|
||||
// All notification types reachable from this enum will be generated by
|
||||
// induction, so they do not need to be listed individually.
|
||||
|
||||
@@ -126,6 +126,11 @@ pub enum ClientRequest {
|
||||
request_id: RequestId,
|
||||
params: GitDiffToRemoteParams,
|
||||
},
|
||||
LoginApiKey {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: LoginApiKeyParams,
|
||||
},
|
||||
LoginChatGpt {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
@@ -148,10 +153,19 @@ pub enum ClientRequest {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
},
|
||||
SetDefaultModel {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: SetDefaultModelParams,
|
||||
},
|
||||
GetUserAgent {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
},
|
||||
UserInfo {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
},
|
||||
/// Execute a command (argv vector) under the server's sandbox.
|
||||
ExecOneOffCommand {
|
||||
#[serde(rename = "id")]
|
||||
@@ -208,6 +222,8 @@ pub struct NewConversationParams {
|
||||
pub struct NewConversationResponse {
|
||||
pub conversation_id: ConversationId,
|
||||
pub model: String,
|
||||
/// Note this could be ignored by the model.
|
||||
pub reasoning_effort: ReasoningEffort,
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
@@ -284,6 +300,16 @@ pub struct ArchiveConversationResponse {}
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RemoveConversationSubscriptionResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoginApiKeyParams {
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoginApiKeyResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoginChatGptResponse {
|
||||
@@ -363,9 +389,14 @@ pub struct ExecArbitraryCommandResponse {
|
||||
pub struct GetAuthStatusResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auth_method: Option<AuthMode>,
|
||||
pub preferred_auth_method: AuthMode,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auth_token: Option<String>,
|
||||
|
||||
// Indicates that auth method must be valid to use the server.
|
||||
// This can be false if using a custom provider that is configured
|
||||
// with requires_openai_auth == false.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub requires_openai_auth: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
@@ -374,12 +405,35 @@ pub struct GetUserAgentResponse {
|
||||
pub user_agent: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UserInfoResponse {
|
||||
/// Note: `alleged_user_email` is not currently verified. We read it from
|
||||
/// the local auth.json, which the user could theoretically modify. In the
|
||||
/// future, we may add logic to verify the email against the server before
|
||||
/// returning it.
|
||||
pub alleged_user_email: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetUserSavedConfigResponse {
|
||||
pub config: UserSavedConfig,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SetDefaultModelParams {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SetDefaultModelResponse {}
|
||||
|
||||
/// UserSavedConfig contains a subset of the config. It is meant to expose mcp
|
||||
/// client-configurable settings that can be specified in the NewConversation
|
||||
/// and SendUserTurn requests.
|
||||
|
||||
@@ -115,7 +115,6 @@ pub enum ResponseItem {
|
||||
status: Option<String>,
|
||||
action: WebSearchAction,
|
||||
},
|
||||
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
@@ -220,7 +219,7 @@ impl From<Vec<InputItem>> for ResponseInputItem {
|
||||
let mime = mime_guess::from_path(&path)
|
||||
.first()
|
||||
.map(|m| m.essence_str().to_owned())
|
||||
.unwrap_or_else(|| "application/octet-stream".to_string());
|
||||
.unwrap_or_else(|| "image".to_string());
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
|
||||
Some(ContentItem::InputImage {
|
||||
image_url: format!("data:{mime};base64,{encoded}"),
|
||||
|
||||
@@ -149,7 +149,7 @@ pub enum Op {
|
||||
|
||||
/// Request the full in-memory conversation transcript for the current session.
|
||||
/// Reply is delivered via `EventMsg::ConversationHistory`.
|
||||
GetHistory,
|
||||
GetPath,
|
||||
|
||||
/// Request the list of MCP tools available across all configured servers.
|
||||
/// Reply is delivered via `EventMsg::McpListToolsResponse`.
|
||||
@@ -499,7 +499,7 @@ pub enum EventMsg {
|
||||
/// Notification that the agent is shutting down.
|
||||
ShutdownComplete,
|
||||
|
||||
ConversationHistory(ConversationHistoryResponseEvent),
|
||||
ConversationPath(ConversationPathResponseEvent),
|
||||
}
|
||||
|
||||
// Individual event payload types matching each `EventMsg` variant.
|
||||
@@ -801,9 +801,9 @@ pub struct WebSearchEndEvent {
|
||||
/// Response payload for `Op::GetHistory` containing the current session's
|
||||
/// in-memory transcript.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ConversationHistoryResponseEvent {
|
||||
pub struct ConversationPathResponseEvent {
|
||||
pub conversation_id: ConversationId,
|
||||
pub history: InitialHistory,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
@@ -897,9 +897,26 @@ pub struct SessionMetaLine {
|
||||
pub enum RolloutItem {
|
||||
SessionMeta(SessionMetaLine),
|
||||
ResponseItem(ResponseItem),
|
||||
Compacted(CompactedItem),
|
||||
TurnContext(TurnContextItem),
|
||||
EventMsg(EventMsg),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, TS)]
|
||||
pub struct CompactedItem {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, TS)]
|
||||
pub struct TurnContextItem {
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub model: String,
|
||||
pub effort: ReasoningEffortConfig,
|
||||
pub summary: ReasoningSummaryConfig,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct RolloutLine {
|
||||
pub timestamp: String,
|
||||
@@ -1064,6 +1081,9 @@ pub struct SessionConfiguredEvent {
|
||||
/// Tell the client what model is being queried.
|
||||
pub model: String,
|
||||
|
||||
/// The effort the model is putting into reasoning about the user's request.
|
||||
pub reasoning_effort: ReasoningEffortConfig,
|
||||
|
||||
/// Identifier of the history log file (inode on Unix, 0 otherwise).
|
||||
pub history_log_id: u64,
|
||||
|
||||
@@ -1152,6 +1172,7 @@ mod tests {
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "codex-mini-latest".to_string(),
|
||||
reasoning_effort: ReasoningEffortConfig::default(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
initial_messages: None,
|
||||
@@ -1165,6 +1186,7 @@ mod tests {
|
||||
"type": "session_configured",
|
||||
"session_id": "67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
"model": "codex-mini-latest",
|
||||
"reasoning_effort": "medium",
|
||||
"history_log_id": 0,
|
||||
"history_entry_count": 0,
|
||||
"rollout_path": format!("{}", rollout_file.path().display()),
|
||||
|
||||
@@ -79,7 +79,7 @@ tokio-stream = "0.1.17"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
tui-markdown = "0.3.3"
|
||||
pulldown-cmark = "0.10"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
url = "2"
|
||||
|
||||
@@ -11,7 +11,10 @@ use codex_ansi_escape::ansi_escape_line;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::persist_model_selection;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use color_eyre::eyre::Result;
|
||||
use color_eyre::eyre::WrapErr;
|
||||
use crossterm::event::KeyCode;
|
||||
@@ -37,6 +40,9 @@ pub(crate) struct App {
|
||||
|
||||
/// Config is stored here so we can recreate ChatWidgets as needed.
|
||||
pub(crate) config: Config,
|
||||
pub(crate) active_profile: Option<String>,
|
||||
model_saved_to_profile: bool,
|
||||
model_saved_to_global: bool,
|
||||
|
||||
pub(crate) file_search: FileSearchManager,
|
||||
|
||||
@@ -61,6 +67,7 @@ impl App {
|
||||
tui: &mut tui::Tui,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
config: Config,
|
||||
active_profile: Option<String>,
|
||||
initial_prompt: Option<String>,
|
||||
initial_images: Vec<PathBuf>,
|
||||
resume_selection: ResumeSelection,
|
||||
@@ -119,6 +126,9 @@ impl App {
|
||||
app_event_tx,
|
||||
chat_widget,
|
||||
config,
|
||||
active_profile,
|
||||
model_saved_to_profile: false,
|
||||
model_saved_to_global: false,
|
||||
file_search,
|
||||
enhanced_keys_supported,
|
||||
transcript_lines: Vec::new(),
|
||||
@@ -288,10 +298,17 @@ impl App {
|
||||
self.chat_widget.apply_file_search_result(query, matches);
|
||||
}
|
||||
AppEvent::UpdateReasoningEffort(effort) => {
|
||||
self.chat_widget.set_reasoning_effort(effort);
|
||||
self.on_update_reasoning_effort(effort);
|
||||
}
|
||||
AppEvent::UpdateModel(model) => {
|
||||
self.chat_widget.set_model(model);
|
||||
self.chat_widget.set_model(model.clone());
|
||||
self.config.model = model.clone();
|
||||
if let Some(family) = find_family_for_model(&model) {
|
||||
self.config.model_family = family;
|
||||
}
|
||||
self.model_saved_to_profile = false;
|
||||
self.model_saved_to_global = false;
|
||||
self.show_model_save_hint();
|
||||
}
|
||||
AppEvent::UpdateAskForApprovalPolicy(policy) => {
|
||||
self.chat_widget.set_approval_policy(policy);
|
||||
@@ -304,7 +321,107 @@ impl App {
|
||||
}
|
||||
|
||||
pub(crate) fn token_usage(&self) -> codex_core::protocol::TokenUsage {
|
||||
self.chat_widget.token_usage().clone()
|
||||
self.chat_widget.token_usage()
|
||||
}
|
||||
|
||||
fn show_model_save_hint(&mut self) {
|
||||
let model = self.config.model.clone();
|
||||
if self.active_profile.is_some() {
|
||||
self.chat_widget.add_info_message(format!(
|
||||
"Model switched to {model}. Press Ctrl+S to save it for this profile, then press Ctrl+S again to set it as your global default."
|
||||
));
|
||||
} else {
|
||||
self.chat_widget.add_info_message(format!(
|
||||
"Model switched to {model}. Press Ctrl+S to save it as your global default."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn on_update_reasoning_effort(&mut self, effort: ReasoningEffortConfig) {
|
||||
let changed = self.config.model_reasoning_effort != effort;
|
||||
self.chat_widget.set_reasoning_effort(effort);
|
||||
self.config.model_reasoning_effort = effort;
|
||||
if changed {
|
||||
let show_hint = self.model_saved_to_profile || self.model_saved_to_global;
|
||||
self.model_saved_to_profile = false;
|
||||
self.model_saved_to_global = false;
|
||||
if show_hint {
|
||||
self.show_model_save_hint();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_model_shortcut(&mut self) {
|
||||
enum SaveScope<'a> {
|
||||
Profile(&'a str),
|
||||
Global,
|
||||
AlreadySaved,
|
||||
}
|
||||
|
||||
let scope = if let Some(profile) = self
|
||||
.active_profile
|
||||
.as_deref()
|
||||
.filter(|_| !self.model_saved_to_profile)
|
||||
{
|
||||
SaveScope::Profile(profile)
|
||||
} else if !self.model_saved_to_global {
|
||||
SaveScope::Global
|
||||
} else {
|
||||
SaveScope::AlreadySaved
|
||||
};
|
||||
|
||||
let model = self.config.model.clone();
|
||||
let effort = self.config.model_reasoning_effort;
|
||||
let codex_home = self.config.codex_home.clone();
|
||||
|
||||
match scope {
|
||||
SaveScope::Profile(profile) => {
|
||||
match persist_model_selection(&codex_home, Some(profile), &model, Some(effort))
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
self.model_saved_to_profile = true;
|
||||
self.chat_widget.add_info_message(format!(
|
||||
"Saved model {model} ({effort}) for profile `{profile}`. Press Ctrl+S again to make this your global default."
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
"failed to persist model selection via shortcut"
|
||||
);
|
||||
self.chat_widget.add_error_message(format!(
|
||||
"Failed to save model preference for profile `{profile}`: {err}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
SaveScope::Global => {
|
||||
match persist_model_selection(&codex_home, None, &model, Some(effort)).await {
|
||||
Ok(()) => {
|
||||
self.model_saved_to_global = true;
|
||||
self.chat_widget.add_info_message(format!(
|
||||
"Saved model {model} ({effort}) as your global default."
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
"failed to persist global model selection via shortcut"
|
||||
);
|
||||
self.chat_widget.add_error_message(format!(
|
||||
"Failed to save global model preference: {err}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
SaveScope::AlreadySaved => {
|
||||
self.chat_widget.add_info_message(
|
||||
"Model preference already saved globally; no further action needed."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_key_event(&mut self, tui: &mut tui::Tui, key_event: KeyEvent) {
|
||||
@@ -320,6 +437,14 @@ impl App {
|
||||
self.overlay = Some(Overlay::new_transcript(self.transcript_lines.clone()));
|
||||
tui.frame_requester().schedule_frame();
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Char('s'),
|
||||
modifiers: crossterm::event::KeyModifiers::CONTROL,
|
||||
kind: KeyEventKind::Press,
|
||||
..
|
||||
} => {
|
||||
self.persist_model_shortcut().await;
|
||||
}
|
||||
// Esc primes/advances backtracking only in normal (not working) mode
|
||||
// with an empty composer. In any other state, forward Esc so the
|
||||
// active UI (e.g. status indicator, modals, popups) handles it.
|
||||
@@ -366,3 +491,67 @@ impl App {
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::app_backtrack::BacktrackState;
|
||||
use crate::chatwidget::tests::make_chatwidget_manual_with_sender;
|
||||
use crate::file_search::FileSearchManager;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use ratatui::text::Line;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
fn make_test_app() -> App {
|
||||
let (chat_widget, app_event_tx, _rx, _op_rx) = make_chatwidget_manual_with_sender();
|
||||
let config = chat_widget.config_ref().clone();
|
||||
|
||||
let server = Arc::new(ConversationManager::with_auth(CodexAuth::from_api_key(
|
||||
"Test API Key",
|
||||
)));
|
||||
let file_search = FileSearchManager::new(config.cwd.clone(), app_event_tx.clone());
|
||||
|
||||
App {
|
||||
server,
|
||||
app_event_tx,
|
||||
chat_widget,
|
||||
config,
|
||||
active_profile: None,
|
||||
model_saved_to_profile: false,
|
||||
model_saved_to_global: false,
|
||||
file_search,
|
||||
transcript_lines: Vec::<Line<'static>>::new(),
|
||||
overlay: None,
|
||||
deferred_history_lines: Vec::new(),
|
||||
has_emitted_history_lines: false,
|
||||
enhanced_keys_supported: false,
|
||||
commit_anim_running: Arc::new(AtomicBool::new(false)),
|
||||
backtrack: BacktrackState::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_reasoning_effort_updates_config_and_resets_flags() {
|
||||
let mut app = make_test_app();
|
||||
app.model_saved_to_profile = true;
|
||||
app.model_saved_to_global = true;
|
||||
app.config.model_reasoning_effort = ReasoningEffortConfig::Medium;
|
||||
app.chat_widget
|
||||
.set_reasoning_effort(ReasoningEffortConfig::Medium);
|
||||
|
||||
app.on_update_reasoning_effort(ReasoningEffortConfig::High);
|
||||
|
||||
assert_eq!(
|
||||
app.config.model_reasoning_effort,
|
||||
ReasoningEffortConfig::High
|
||||
);
|
||||
assert_eq!(
|
||||
app.chat_widget.config_ref().model_reasoning_effort,
|
||||
ReasoningEffortConfig::High
|
||||
);
|
||||
assert!(!app.model_saved_to_profile);
|
||||
assert!(!app.model_saved_to_global);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::backtrack_helpers;
|
||||
use crate::pager_overlay::Overlay;
|
||||
use crate::tui;
|
||||
use crate::tui::TuiEvent;
|
||||
use codex_core::InitialHistory;
|
||||
use codex_core::protocol::ConversationHistoryResponseEvent;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use color_eyre::eyre::Result;
|
||||
use crossterm::event::KeyCode;
|
||||
@@ -99,7 +100,7 @@ impl App {
|
||||
) {
|
||||
self.backtrack.pending = Some((base_id, drop_last_messages, prefill));
|
||||
self.app_event_tx.send(crate::app_event::AppEvent::CodexOp(
|
||||
codex_core::protocol::Op::GetHistory,
|
||||
codex_core::protocol::Op::GetPath,
|
||||
));
|
||||
}
|
||||
|
||||
@@ -266,7 +267,7 @@ impl App {
|
||||
pub(crate) async fn on_conversation_history_for_backtrack(
|
||||
&mut self,
|
||||
tui: &mut tui::Tui,
|
||||
ev: ConversationHistoryResponseEvent,
|
||||
ev: ConversationPathResponseEvent,
|
||||
) -> Result<()> {
|
||||
if let Some((base_id, _, _)) = self.backtrack.pending.as_ref()
|
||||
&& ev.conversation_id == *base_id
|
||||
@@ -282,14 +283,14 @@ impl App {
|
||||
async fn fork_and_switch_to_new_conversation(
|
||||
&mut self,
|
||||
tui: &mut tui::Tui,
|
||||
ev: ConversationHistoryResponseEvent,
|
||||
ev: ConversationPathResponseEvent,
|
||||
drop_count: usize,
|
||||
prefill: String,
|
||||
) {
|
||||
let cfg = self.chat_widget.config_ref().clone();
|
||||
// Perform the fork via a thin wrapper for clarity/testability.
|
||||
let result = self
|
||||
.perform_fork(ev.history.clone(), drop_count, cfg.clone())
|
||||
.perform_fork(ev.path.clone(), drop_count, cfg.clone())
|
||||
.await;
|
||||
match result {
|
||||
Ok(new_conv) => {
|
||||
@@ -302,13 +303,11 @@ impl App {
|
||||
/// Thin wrapper around ConversationManager::fork_conversation.
|
||||
async fn perform_fork(
|
||||
&self,
|
||||
entries: InitialHistory,
|
||||
path: PathBuf,
|
||||
drop_count: usize,
|
||||
cfg: codex_core::config::Config,
|
||||
) -> codex_core::error::Result<codex_core::NewConversation> {
|
||||
self.server
|
||||
.fork_conversation(entries, drop_count, cfg)
|
||||
.await
|
||||
self.server.fork_conversation(drop_count, cfg, path).await
|
||||
}
|
||||
|
||||
/// Install a forked conversation into the ChatWidget and update UI to reflect selection.
|
||||
@@ -336,7 +335,7 @@ impl App {
|
||||
self.trim_transcript_for_backtrack(drop_count);
|
||||
self.render_transcript_once(tui);
|
||||
if !prefill.is_empty() {
|
||||
self.chat_widget.insert_str(prefill);
|
||||
self.chat_widget.set_composer_text(prefill.to_string());
|
||||
}
|
||||
tui.frame_requester().schedule_frame();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use codex_core::protocol::ConversationHistoryResponseEvent;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_file_search::FileMatch;
|
||||
|
||||
@@ -58,5 +58,5 @@ pub(crate) enum AppEvent {
|
||||
UpdateSandboxPolicy(SandboxPolicy),
|
||||
|
||||
/// Forwarded conversation history snapshot from the current conversation.
|
||||
ConversationHistory(ConversationHistoryResponseEvent),
|
||||
ConversationHistory(ConversationPathResponseEvent),
|
||||
}
|
||||
|
||||
15
codex-rs/tui/src/bin/md-events.rs
Normal file
15
codex-rs/tui/src/bin/md-events.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use std::io::Read;
|
||||
use std::io::{self};
|
||||
|
||||
fn main() {
|
||||
let mut input = String::new();
|
||||
if let Err(err) = io::stdin().read_to_string(&mut input) {
|
||||
eprintln!("failed to read stdin: {err}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let parser = pulldown_cmark::Parser::new(&input);
|
||||
for event in parser {
|
||||
println!("{event:?}");
|
||||
}
|
||||
}
|
||||
@@ -243,6 +243,10 @@ impl ChatComposer {
|
||||
|
||||
/// Replace the entire composer content with `text` and reset cursor.
|
||||
pub(crate) fn set_text_content(&mut self, text: String) {
|
||||
// Clear any existing content, placeholders, and attachments first.
|
||||
self.textarea.set_text("");
|
||||
self.pending_pastes.clear();
|
||||
self.attached_images.clear();
|
||||
self.textarea.set_text(&text);
|
||||
self.textarea.set_cursor(0);
|
||||
self.sync_command_popup();
|
||||
@@ -483,7 +487,7 @@ impl ChatComposer {
|
||||
} => {
|
||||
// Hide popup without modifying text, remember token to avoid immediate reopen.
|
||||
if let Some(tok) = Self::current_at_token(&self.textarea) {
|
||||
self.dismissed_file_popup_token = Some(tok.to_string());
|
||||
self.dismissed_file_popup_token = Some(tok);
|
||||
}
|
||||
self.active_popup = ActivePopup::None;
|
||||
(InputResult::None, true)
|
||||
@@ -542,7 +546,7 @@ impl ChatComposer {
|
||||
Some(ext) if ext == "jpg" || ext == "jpeg" => "JPEG",
|
||||
_ => "IMG",
|
||||
};
|
||||
self.attach_image(path_buf.clone(), w, h, format_label);
|
||||
self.attach_image(path_buf, w, h, format_label);
|
||||
// Add a trailing space to keep typing fluid.
|
||||
self.textarea.insert_str(" ");
|
||||
} else {
|
||||
@@ -2119,7 +2123,7 @@ mod tests {
|
||||
|
||||
// Re-add and test backspace in middle: should break the placeholder string
|
||||
// and drop the image mapping (same as text placeholder behavior).
|
||||
composer.attach_image(path.clone(), 20, 10, "PNG");
|
||||
composer.attach_image(path, 20, 10, "PNG");
|
||||
let placeholder2 = composer.attached_images[0].placeholder.clone();
|
||||
// Move cursor to roughly middle of placeholder
|
||||
if let Some(start_pos) = composer.textarea.text().find(&placeholder2) {
|
||||
@@ -2178,7 +2182,7 @@ mod tests {
|
||||
let path1 = PathBuf::from("/tmp/image_dup1.png");
|
||||
let path2 = PathBuf::from("/tmp/image_dup2.png");
|
||||
|
||||
composer.attach_image(path1.clone(), 10, 5, "PNG");
|
||||
composer.attach_image(path1, 10, 5, "PNG");
|
||||
// separate placeholders with a space for clarity
|
||||
composer.handle_paste(" ".into());
|
||||
composer.attach_image(path2.clone(), 10, 5, "PNG");
|
||||
@@ -2227,7 +2231,7 @@ mod tests {
|
||||
assert!(composer.textarea.text().starts_with("[image 3x2 PNG] "));
|
||||
|
||||
let imgs = composer.take_recent_submission_images();
|
||||
assert_eq!(imgs, vec![tmp_path.clone()]);
|
||||
assert_eq!(imgs, vec![tmp_path]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -564,7 +564,7 @@ mod tests {
|
||||
let (tx_raw, rx) = unbounded_channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx.clone(),
|
||||
app_event_tx: tx,
|
||||
frame_requester: FrameRequester::test_dummy(),
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
|
||||
@@ -649,9 +649,7 @@ impl TextArea {
|
||||
}
|
||||
|
||||
fn add_element(&mut self, range: Range<usize>) {
|
||||
let elem = TextElement {
|
||||
range: range.clone(),
|
||||
};
|
||||
let elem = TextElement { range };
|
||||
self.elements.push(elem);
|
||||
self.elements.sort_by_key(|e| e.range.start);
|
||||
}
|
||||
|
||||
@@ -574,14 +574,14 @@ impl ChatWidget {
|
||||
self.active_exec_cell = Some(history_cell::new_active_exec_command(
|
||||
ev.call_id.clone(),
|
||||
ev.command.clone(),
|
||||
ev.parsed_cmd.clone(),
|
||||
ev.parsed_cmd,
|
||||
));
|
||||
}
|
||||
} else {
|
||||
self.active_exec_cell = Some(history_cell::new_active_exec_command(
|
||||
ev.call_id.clone(),
|
||||
ev.command.clone(),
|
||||
ev.parsed_cmd.clone(),
|
||||
ev.parsed_cmd,
|
||||
));
|
||||
}
|
||||
|
||||
@@ -804,7 +804,7 @@ impl ChatWidget {
|
||||
"attach_image path={path:?} width={width} height={height} format={format_label}",
|
||||
);
|
||||
self.bottom_pane
|
||||
.attach_image(path.clone(), width, height, format_label);
|
||||
.attach_image(path, width, height, format_label);
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
@@ -986,7 +986,7 @@ impl ChatWidget {
|
||||
|
||||
// Only show the text portion in conversation history.
|
||||
if !text.is_empty() {
|
||||
self.add_to_history(history_cell::new_user_prompt(text.clone()));
|
||||
self.add_to_history(history_cell::new_user_prompt(text));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1055,10 +1055,10 @@ impl ChatWidget {
|
||||
EventMsg::PlanUpdate(update) => self.on_plan_update(update),
|
||||
EventMsg::ExecApprovalRequest(ev) => {
|
||||
// For replayed events, synthesize an empty id (these should not occur).
|
||||
self.on_exec_approval_request(id.clone().unwrap_or_default(), ev)
|
||||
self.on_exec_approval_request(id.unwrap_or_default(), ev)
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ev) => {
|
||||
self.on_apply_patch_approval_request(id.clone().unwrap_or_default(), ev)
|
||||
self.on_apply_patch_approval_request(id.unwrap_or_default(), ev)
|
||||
}
|
||||
EventMsg::ExecCommandBegin(ev) => self.on_exec_command_begin(ev),
|
||||
EventMsg::ExecCommandOutputDelta(delta) => self.on_exec_command_output_delta(delta),
|
||||
@@ -1083,7 +1083,7 @@ impl ChatWidget {
|
||||
self.on_user_message_event(ev);
|
||||
}
|
||||
}
|
||||
EventMsg::ConversationHistory(ev) => {
|
||||
EventMsg::ConversationPath(ev) => {
|
||||
self.app_event_tx
|
||||
.send(crate::app_event::AppEvent::ConversationHistory(ev));
|
||||
}
|
||||
@@ -1207,7 +1207,7 @@ impl ChatWidget {
|
||||
self.bottom_pane.show_selection_view(
|
||||
"Select model and reasoning level".to_string(),
|
||||
Some("Switch between OpenAI models for this and future Codex CLI session".to_string()),
|
||||
Some("Press Enter to confirm or Esc to go back".to_string()),
|
||||
Some("Press Enter to confirm, Esc to go back, Ctrl+S to save".to_string()),
|
||||
items,
|
||||
);
|
||||
}
|
||||
@@ -1273,6 +1273,16 @@ impl ChatWidget {
|
||||
self.config.model = model;
|
||||
}
|
||||
|
||||
pub(crate) fn add_info_message(&mut self, message: String) {
|
||||
self.add_to_history(history_cell::new_info_event(message));
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
pub(crate) fn add_error_message(&mut self, message: String) {
|
||||
self.add_to_history(history_cell::new_error_event(message));
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
pub(crate) fn add_mcp_output(&mut self) {
|
||||
if self.config.mcp_servers.is_empty() {
|
||||
self.add_to_history(history_cell::empty_mcp_output());
|
||||
@@ -1316,6 +1326,11 @@ impl ChatWidget {
|
||||
self.bottom_pane.insert_str(text);
|
||||
}
|
||||
|
||||
/// Replace the composer content with the provided text and reset cursor.
|
||||
pub(crate) fn set_composer_text(&mut self, text: String) {
|
||||
self.bottom_pane.set_composer_text(text);
|
||||
}
|
||||
|
||||
pub(crate) fn show_esc_backtrack_hint(&mut self) {
|
||||
self.bottom_pane.show_esc_backtrack_hint();
|
||||
}
|
||||
@@ -1436,4 +1451,4 @@ fn extract_first_bold(s: &str) -> Option<String> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub(crate) mod tests;
|
||||
|
||||
@@ -20,7 +20,7 @@ pub(crate) fn spawn_agent(
|
||||
) -> UnboundedSender<Op> {
|
||||
let (codex_op_tx, mut codex_op_rx) = unbounded_channel::<Op>();
|
||||
|
||||
let app_event_tx_clone = app_event_tx.clone();
|
||||
let app_event_tx_clone = app_event_tx;
|
||||
tokio::spawn(async move {
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
@@ -71,7 +71,7 @@ pub(crate) fn spawn_agent_from_existing(
|
||||
) -> UnboundedSender<Op> {
|
||||
let (codex_op_tx, mut codex_op_rx) = unbounded_channel::<Op>();
|
||||
|
||||
let app_event_tx_clone = app_event_tx.clone();
|
||||
let app_event_tx_clone = app_event_tx;
|
||||
tokio::spawn(async move {
|
||||
// Forward the captured `SessionConfigured` event so it can be rendered in the UI.
|
||||
let ev = codex_core::protocol::Event {
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
---
|
||||
source: tui/src/chatwidget/tests.rs
|
||||
expression: visual
|
||||
---
|
||||
> -- Indented code block (4 spaces)
|
||||
SELECT *
|
||||
FROM "users"
|
||||
WHERE "email" LIKE '%@example.com';
|
||||
|
||||
```sh
|
||||
printf 'fenced within fenced\n'
|
||||
```
|
||||
|
||||
{
|
||||
// comment allowed in jsonc
|
||||
"path": "C:\\Program Files\\App",
|
||||
"regex": "^foo.*(bar)?$"
|
||||
}
|
||||
@@ -138,6 +138,7 @@ fn resumed_initial_messages_render_history() {
|
||||
let configured = codex_core::protocol::SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "test-model".to_string(),
|
||||
reasoning_effort: ReasoningEffortConfig::default(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
initial_messages: Some(vec![
|
||||
@@ -246,6 +247,17 @@ fn make_chatwidget_manual() -> (
|
||||
(widget, rx, op_rx)
|
||||
}
|
||||
|
||||
pub(crate) fn make_chatwidget_manual_with_sender() -> (
|
||||
ChatWidget,
|
||||
AppEventSender,
|
||||
tokio::sync::mpsc::UnboundedReceiver<AppEvent>,
|
||||
tokio::sync::mpsc::UnboundedReceiver<Op>,
|
||||
) {
|
||||
let (widget, rx, op_rx) = make_chatwidget_manual();
|
||||
let app_event_tx = widget.app_event_tx.clone();
|
||||
(widget, app_event_tx, rx, op_rx)
|
||||
}
|
||||
|
||||
fn drain_insert_history(
|
||||
rx: &mut tokio::sync::mpsc::UnboundedReceiver<AppEvent>,
|
||||
) -> Vec<Vec<ratatui::text::Line<'static>>> {
|
||||
@@ -352,7 +364,7 @@ fn exec_approval_decision_truncates_multiline_and_long_commands() {
|
||||
let long = format!("echo {}", "a".repeat(200));
|
||||
let ev_long = ExecApprovalRequestEvent {
|
||||
call_id: "call-long".into(),
|
||||
command: vec!["bash".into(), "-lc".into(), long.clone()],
|
||||
command: vec!["bash".into(), "-lc".into(), long],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||
reason: None,
|
||||
};
|
||||
@@ -1756,3 +1768,123 @@ fn chatwidget_exec_and_status_layout_vt100_snapshot() {
|
||||
let visual = vt_lines.join("\n");
|
||||
assert_snapshot!(visual);
|
||||
}
|
||||
|
||||
// E2E vt100 snapshot for complex markdown with indented and nested fenced code blocks
|
||||
#[test]
|
||||
fn chatwidget_markdown_code_blocks_vt100_snapshot() {
|
||||
let (mut chat, mut rx, _op_rx) = make_chatwidget_manual();
|
||||
|
||||
// Simulate a final agent message via streaming deltas instead of a single message
|
||||
|
||||
chat.handle_codex_event(Event {
|
||||
id: "t1".into(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: None,
|
||||
}),
|
||||
});
|
||||
// Build a vt100 visual from the history insertions only (no UI overlay)
|
||||
let width: u16 = 80;
|
||||
let height: u16 = 50;
|
||||
let backend = ratatui::backend::TestBackend::new(width, height);
|
||||
let mut term = crate::custom_terminal::Terminal::with_options(backend).expect("terminal");
|
||||
// Place viewport at the last line so that history lines insert above it
|
||||
term.set_viewport_area(Rect::new(0, height - 1, width, 1));
|
||||
|
||||
let mut ansi: Vec<u8> = Vec::new();
|
||||
|
||||
// Simulate streaming via AgentMessageDelta in 2-character chunks (no final AgentMessage).
|
||||
let source: &str = r#"
|
||||
|
||||
-- Indented code block (4 spaces)
|
||||
SELECT *
|
||||
FROM "users"
|
||||
WHERE "email" LIKE '%@example.com';
|
||||
|
||||
````markdown
|
||||
```sh
|
||||
printf 'fenced within fenced\n'
|
||||
```
|
||||
````
|
||||
|
||||
```jsonc
|
||||
{
|
||||
// comment allowed in jsonc
|
||||
"path": "C:\\Program Files\\App",
|
||||
"regex": "^foo.*(bar)?$"
|
||||
}
|
||||
```
|
||||
"#;
|
||||
|
||||
let mut it = source.chars();
|
||||
loop {
|
||||
let mut delta = String::new();
|
||||
match it.next() {
|
||||
Some(c) => delta.push(c),
|
||||
None => break,
|
||||
}
|
||||
if let Some(c2) = it.next() {
|
||||
delta.push(c2);
|
||||
}
|
||||
|
||||
chat.handle_codex_event(Event {
|
||||
id: "t1".into(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
});
|
||||
// Drive commit ticks and drain emitted history lines into the vt100 buffer.
|
||||
loop {
|
||||
chat.on_commit_tick();
|
||||
let mut inserted_any = false;
|
||||
while let Ok(app_ev) = rx.try_recv() {
|
||||
if let AppEvent::InsertHistoryCell(cell) = app_ev {
|
||||
let lines = cell.display_lines(width);
|
||||
crate::insert_history::insert_history_lines_to_writer(
|
||||
&mut term, &mut ansi, lines,
|
||||
);
|
||||
inserted_any = true;
|
||||
}
|
||||
}
|
||||
if !inserted_any {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize the stream without sending a final AgentMessage, to flush any tail.
|
||||
chat.handle_codex_event(Event {
|
||||
id: "t1".into(),
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
});
|
||||
for lines in drain_insert_history(&mut rx) {
|
||||
crate::insert_history::insert_history_lines_to_writer(&mut term, &mut ansi, lines);
|
||||
}
|
||||
|
||||
let mut parser = vt100::Parser::new(height, width, 0);
|
||||
parser.process(&ansi);
|
||||
|
||||
let mut vt_lines: Vec<String> = (0..height)
|
||||
.map(|row| {
|
||||
let mut s = String::with_capacity(width as usize);
|
||||
for col in 0..width {
|
||||
if let Some(cell) = parser.screen().cell(row, col) {
|
||||
if let Some(ch) = cell.contents().chars().next() {
|
||||
s.push(ch);
|
||||
} else {
|
||||
s.push(' ');
|
||||
}
|
||||
} else {
|
||||
s.push(' ');
|
||||
}
|
||||
}
|
||||
s.trim_end().to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Compact trailing blank rows for a stable snapshot
|
||||
while matches!(vt_lines.last(), Some(l) if l.trim().is_empty()) {
|
||||
vt_lines.pop();
|
||||
}
|
||||
let visual = vt_lines.join("\n");
|
||||
assert_snapshot!(visual);
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -737,10 +737,10 @@ mod tests {
|
||||
|
||||
let mut changes: HashMap<PathBuf, FileChange> = HashMap::new();
|
||||
changes.insert(
|
||||
abs_old.clone(),
|
||||
abs_old,
|
||||
FileChange::Update {
|
||||
unified_diff: patch,
|
||||
move_path: Some(abs_new.clone()),
|
||||
move_path: Some(abs_new),
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -601,6 +601,7 @@ pub(crate) fn new_session_info(
|
||||
) -> PlainHistoryCell {
|
||||
let SessionConfiguredEvent {
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
session_id: _,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
@@ -697,7 +698,7 @@ fn spinner(start_time: Option<Instant>) -> Span<'static> {
|
||||
|
||||
pub(crate) fn new_active_mcp_tool_call(invocation: McpInvocation) -> PlainHistoryCell {
|
||||
let title_line = Line::from(vec!["tool".magenta(), " running...".dim()]);
|
||||
let lines: Vec<Line> = vec![title_line, format_mcp_invocation(invocation.clone())];
|
||||
let lines: Vec<Line> = vec![title_line, format_mcp_invocation(invocation)];
|
||||
|
||||
PlainHistoryCell { lines }
|
||||
}
|
||||
@@ -1052,6 +1053,12 @@ pub(crate) fn new_mcp_tools_output(
|
||||
PlainHistoryCell { lines }
|
||||
}
|
||||
|
||||
pub(crate) fn new_info_event(message: String) -> PlainHistoryCell {
|
||||
let lines: Vec<Line<'static>> =
|
||||
vec![vec![padded_emoji("💾").green(), " ".into(), message.into()].into()];
|
||||
PlainHistoryCell { lines }
|
||||
}
|
||||
|
||||
pub(crate) fn new_error_event(message: String) -> PlainHistoryCell {
|
||||
// Use a hair space (U+200A) to create a subtle, near-invisible separation
|
||||
// before the text. VS16 is intentionally omitted to keep spacing tighter
|
||||
@@ -1324,7 +1331,7 @@ fn format_mcp_invocation<'a>(invocation: McpInvocation) -> Line<'a> {
|
||||
let invocation_spans = vec![
|
||||
invocation.server.clone().cyan(),
|
||||
".".into(),
|
||||
invocation.tool.clone().cyan(),
|
||||
invocation.tool.cyan(),
|
||||
"(".into(),
|
||||
args_str.dim(),
|
||||
")".into(),
|
||||
|
||||
@@ -97,7 +97,17 @@ pub fn insert_history_lines_to_writer<B, W>(
|
||||
|
||||
for line in wrapped {
|
||||
queue!(writer, Print("\r\n")).ok();
|
||||
write_spans(writer, &line).ok();
|
||||
// Merge line-level style into each span so that ANSI colors reflect
|
||||
// line styles (e.g., blockquotes with green fg).
|
||||
let merged_spans: Vec<Span> = line
|
||||
.spans
|
||||
.iter()
|
||||
.map(|s| Span {
|
||||
style: s.style.patch(line.style),
|
||||
content: s.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
write_spans(writer, merged_spans.iter()).ok();
|
||||
}
|
||||
|
||||
queue!(writer, ResetScrollRegion).ok();
|
||||
@@ -264,6 +274,10 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::markdown_render::render_markdown_text;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::style::Color;
|
||||
use vt100::Parser;
|
||||
|
||||
#[test]
|
||||
fn writes_bold_then_regular_spans() {
|
||||
@@ -292,4 +306,240 @@ mod tests {
|
||||
String::from_utf8(expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vt100_blockquote_line_emits_green_fg() {
|
||||
// Set up a small off-screen terminal
|
||||
let width: u16 = 40;
|
||||
let height: u16 = 10;
|
||||
let backend = ratatui::backend::TestBackend::new(width, height);
|
||||
let mut term = crate::custom_terminal::Terminal::with_options(backend).expect("terminal");
|
||||
// Place viewport on the last line so history inserts scroll upward
|
||||
let viewport = Rect::new(0, height - 1, width, 1);
|
||||
term.set_viewport_area(viewport);
|
||||
|
||||
// Build a blockquote-like line: apply line-level green style and prefix "> "
|
||||
let mut line: Line<'static> = Line::from(vec!["> ".into(), "Hello world".into()]);
|
||||
line = line.style(Color::Green);
|
||||
let mut ansi: Vec<u8> = Vec::new();
|
||||
insert_history_lines_to_writer(&mut term, &mut ansi, vec![line]);
|
||||
|
||||
// Parse ANSI using vt100 and assert at least one non-default fg color appears
|
||||
let mut parser = Parser::new(height, width, 0);
|
||||
parser.process(&ansi);
|
||||
|
||||
let mut saw_colored = false;
|
||||
'outer: for row in 0..height {
|
||||
for col in 0..width {
|
||||
if let Some(cell) = parser.screen().cell(row, col)
|
||||
&& cell.has_contents()
|
||||
&& cell.fgcolor() != vt100::Color::Default
|
||||
{
|
||||
saw_colored = true;
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
saw_colored,
|
||||
"expected at least one colored cell in vt100 output"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vt100_blockquote_wrap_preserves_color_on_all_wrapped_lines() {
|
||||
// Force wrapping by using a narrow viewport width and a long blockquote line.
|
||||
let width: u16 = 20;
|
||||
let height: u16 = 8;
|
||||
let backend = ratatui::backend::TestBackend::new(width, height);
|
||||
let mut term = crate::custom_terminal::Terminal::with_options(backend).expect("terminal");
|
||||
// Viewport is the last line so history goes directly above it.
|
||||
let viewport = Rect::new(0, height - 1, width, 1);
|
||||
term.set_viewport_area(viewport);
|
||||
|
||||
// Create a long blockquote with a distinct prefix and enough text to wrap.
|
||||
let mut line: Line<'static> = Line::from(vec![
|
||||
"> ".into(),
|
||||
"This is a long quoted line that should wrap".into(),
|
||||
]);
|
||||
line = line.style(Color::Green);
|
||||
|
||||
let mut ansi: Vec<u8> = Vec::new();
|
||||
insert_history_lines_to_writer(&mut term, &mut ansi, vec![line]);
|
||||
|
||||
// Parse and inspect the final screen buffer.
|
||||
let mut parser = Parser::new(height, width, 0);
|
||||
parser.process(&ansi);
|
||||
let screen = parser.screen();
|
||||
|
||||
// Collect rows that are non-empty; these should correspond to our wrapped lines.
|
||||
let mut non_empty_rows: Vec<u16> = Vec::new();
|
||||
for row in 0..height {
|
||||
let mut any = false;
|
||||
for col in 0..width {
|
||||
if let Some(cell) = screen.cell(row, col)
|
||||
&& cell.has_contents()
|
||||
&& cell.contents() != "\0"
|
||||
&& cell.contents() != " "
|
||||
{
|
||||
any = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if any {
|
||||
non_empty_rows.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
// Expect at least two rows due to wrapping.
|
||||
assert!(
|
||||
non_empty_rows.len() >= 2,
|
||||
"expected wrapped output to span >=2 rows, got {non_empty_rows:?}",
|
||||
);
|
||||
|
||||
// For each non-empty row, ensure all non-space cells are using a non-default fg color.
|
||||
for row in non_empty_rows {
|
||||
for col in 0..width {
|
||||
if let Some(cell) = screen.cell(row, col) {
|
||||
let contents = cell.contents();
|
||||
if !contents.is_empty() && contents != " " {
|
||||
assert!(
|
||||
cell.fgcolor() != vt100::Color::Default,
|
||||
"expected non-default fg on row {row} col {col}, got {:?}",
|
||||
cell.fgcolor()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vt100_colored_prefix_then_plain_text_resets_color() {
|
||||
let width: u16 = 40;
|
||||
let height: u16 = 6;
|
||||
let backend = ratatui::backend::TestBackend::new(width, height);
|
||||
let mut term = crate::custom_terminal::Terminal::with_options(backend).expect("terminal");
|
||||
let viewport = Rect::new(0, height - 1, width, 1);
|
||||
term.set_viewport_area(viewport);
|
||||
|
||||
// First span colored, rest plain.
|
||||
let line: Line<'static> = Line::from(vec![
|
||||
Span::styled("1. ", ratatui::style::Style::default().fg(Color::LightBlue)),
|
||||
Span::raw("Hello world"),
|
||||
]);
|
||||
|
||||
let mut ansi: Vec<u8> = Vec::new();
|
||||
insert_history_lines_to_writer(&mut term, &mut ansi, vec![line]);
|
||||
|
||||
let mut parser = Parser::new(height, width, 0);
|
||||
parser.process(&ansi);
|
||||
let screen = parser.screen();
|
||||
|
||||
// Find the first non-empty row; verify first three cells are colored, following cells default.
|
||||
'rows: for row in 0..height {
|
||||
let mut has_text = false;
|
||||
for col in 0..width {
|
||||
if let Some(cell) = screen.cell(row, col)
|
||||
&& cell.has_contents()
|
||||
&& cell.contents() != " "
|
||||
{
|
||||
has_text = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !has_text {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Expect "1. Hello world" starting at col 0.
|
||||
for col in 0..3 {
|
||||
let cell = screen.cell(row, col).unwrap();
|
||||
assert!(
|
||||
cell.fgcolor() != vt100::Color::Default,
|
||||
"expected colored prefix at col {col}, got {:?}",
|
||||
cell.fgcolor()
|
||||
);
|
||||
}
|
||||
for col in 3..(3 + "Hello world".len() as u16) {
|
||||
let cell = screen.cell(row, col).unwrap();
|
||||
assert_eq!(
|
||||
cell.fgcolor(),
|
||||
vt100::Color::Default,
|
||||
"expected default color for plain text at col {col}, got {:?}",
|
||||
cell.fgcolor()
|
||||
);
|
||||
}
|
||||
break 'rows;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vt100_deep_nested_mixed_list_third_level_marker_is_colored() {
|
||||
// Markdown with five levels (ordered → unordered → ordered → unordered → unordered).
|
||||
let md = "1. First\n - Second level\n 1. Third level (ordered)\n - Fourth level (bullet)\n - Fifth level to test indent consistency\n";
|
||||
let text = render_markdown_text(md);
|
||||
let lines: Vec<Line<'static>> = text.lines.clone();
|
||||
|
||||
let width: u16 = 60;
|
||||
let height: u16 = 12;
|
||||
let backend = ratatui::backend::TestBackend::new(width, height);
|
||||
let mut term = crate::custom_terminal::Terminal::with_options(backend).expect("terminal");
|
||||
let viewport = ratatui::layout::Rect::new(0, height - 1, width, 1);
|
||||
term.set_viewport_area(viewport);
|
||||
|
||||
let mut ansi: Vec<u8> = Vec::new();
|
||||
insert_history_lines_to_writer(&mut term, &mut ansi, lines);
|
||||
|
||||
let mut parser = Parser::new(height, width, 0);
|
||||
parser.process(&ansi);
|
||||
let screen = parser.screen();
|
||||
|
||||
// Reconstruct screen rows as strings to locate the 3rd level line.
|
||||
let mut rows: Vec<String> = Vec::with_capacity(height as usize);
|
||||
for row in 0..height {
|
||||
let mut s = String::with_capacity(width as usize);
|
||||
for col in 0..width {
|
||||
if let Some(cell) = screen.cell(row, col) {
|
||||
if let Some(ch) = cell.contents().chars().next() {
|
||||
s.push(ch);
|
||||
} else {
|
||||
s.push(' ');
|
||||
}
|
||||
} else {
|
||||
s.push(' ');
|
||||
}
|
||||
}
|
||||
rows.push(s.trim_end().to_string());
|
||||
}
|
||||
|
||||
let needle = "1. Third level (ordered)";
|
||||
let row_idx = rows
|
||||
.iter()
|
||||
.position(|r| r.contains(needle))
|
||||
.unwrap_or_else(|| {
|
||||
panic!("expected to find row containing {needle:?}, have rows: {rows:?}")
|
||||
});
|
||||
let col_start = rows[row_idx].find(needle).unwrap() as u16; // column where '1' starts
|
||||
|
||||
// Verify that the numeric marker ("1.") at the third level is colored
|
||||
// (non-default fg) and the content after the following space resets to default.
|
||||
for c in [col_start, col_start + 1] {
|
||||
let cell = screen.cell(row_idx as u16, c).unwrap();
|
||||
assert!(
|
||||
cell.fgcolor() != vt100::Color::Default,
|
||||
"expected colored 3rd-level marker at row {row_idx} col {c}, got {:?}",
|
||||
cell.fgcolor()
|
||||
);
|
||||
}
|
||||
let content_col = col_start + 3; // skip '1', '.', and the space
|
||||
if let Some(cell) = screen.cell(row_idx as u16, content_col) {
|
||||
assert_eq!(
|
||||
cell.fgcolor(),
|
||||
vt100::Color::Default,
|
||||
"expected default color for 3rd-level content at row {row_idx} col {content_col}, got {:?}",
|
||||
cell.fgcolor()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user