mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
115 Commits
oss
...
tokencount
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
082b775a88 | ||
|
|
e2318f38b2 | ||
|
|
87a4de05db | ||
|
|
91036ea7df | ||
|
|
543e22b215 | ||
|
|
2354594eeb | ||
|
|
b7c95b57fd | ||
|
|
5f18406e8d | ||
|
|
e9b78c9296 | ||
|
|
2a1499a626 | ||
|
|
52a5ea5f4e | ||
|
|
aa94d9c4b3 | ||
|
|
c2c7778c38 | ||
|
|
5f37b45158 | ||
|
|
d5e2032490 | ||
|
|
268be2ac52 | ||
|
|
a045ef05f2 | ||
|
|
660327aa9c | ||
|
|
2d976aa667 | ||
|
|
adf9dae907 | ||
|
|
4a7b842b53 | ||
|
|
e7a20f5109 | ||
|
|
f7671481c1 | ||
|
|
50e001888b | ||
|
|
b48d47de21 | ||
|
|
4d5335f797 | ||
|
|
0c732e4a53 | ||
|
|
777fbba58c | ||
|
|
40da893c46 | ||
|
|
5eaaf307e1 | ||
|
|
18330c2362 | ||
|
|
4c46490e53 | ||
|
|
5c1416d99b | ||
|
|
0525b48baa | ||
|
|
1f4f9cde8e | ||
|
|
cad37009e1 | ||
|
|
e2b3053b2b | ||
|
|
e47bd33689 | ||
|
|
6b878bea01 | ||
|
|
ca46510fd3 | ||
|
|
6efb52e545 | ||
|
|
d84a799ec0 | ||
|
|
c8fab51372 | ||
|
|
58d77ca4e7 | ||
|
|
0269096229 | ||
|
|
70a6d4b1b4 | ||
|
|
b1d5f7c0bd | ||
|
|
066c6cce02 | ||
|
|
bd65f81e54 | ||
|
|
ba9620aea7 | ||
|
|
45c3b20041 | ||
|
|
6cfc012e9d | ||
|
|
17a80d43c8 | ||
|
|
c11696f6b1 | ||
|
|
5775174ec2 | ||
|
|
ba631e7928 | ||
|
|
db3834733a | ||
|
|
d6182becbe | ||
|
|
323a5cb7e7 | ||
|
|
3f40fbc0a8 | ||
|
|
742feaf40f | ||
|
|
907d3dd348 | ||
|
|
7df9e9c664 | ||
|
|
b795fbe244 | ||
|
|
82ed7bd285 | ||
|
|
1c04e1314d | ||
|
|
bef7ed0ccc | ||
|
|
be23fe1353 | ||
|
|
2073fa7139 | ||
|
|
e60a44cbab | ||
|
|
075e385969 | ||
|
|
aa083b795d | ||
|
|
91708bb031 | ||
|
|
82dfec5b10 | ||
|
|
1e82bf9d98 | ||
|
|
0a83db5512 | ||
|
|
bd4fa85507 | ||
|
|
234c0a0469 | ||
|
|
0f4ae1b5b0 | ||
|
|
2b96f9f569 | ||
|
|
f2036572b6 | ||
|
|
bea64569c1 | ||
|
|
e83c5f429c | ||
|
|
ed0d23d560 | ||
|
|
4ae45a6c8d | ||
|
|
6b83c1c3f3 | ||
|
|
db5276f8e6 | ||
|
|
77fb9f3465 | ||
|
|
0e827b6598 | ||
|
|
daaadfb260 | ||
|
|
c636f821ae | ||
|
|
af338cc505 | ||
|
|
97000c6e6d | ||
|
|
fb5dfe3396 | ||
|
|
a56eb48195 | ||
|
|
d77b33ded7 | ||
|
|
9ad2e726fc | ||
|
|
6aa306c584 | ||
|
|
44dce748b6 | ||
|
|
d489690efe | ||
|
|
3f76220055 | ||
|
|
90725fe3d5 | ||
|
|
53413c728e | ||
|
|
b127a3643f | ||
|
|
a93a907c7e | ||
|
|
03e2796ca4 | ||
|
|
051f185ce3 | ||
|
|
6f75114695 | ||
|
|
3baccba0ac | ||
|
|
578ff09e17 | ||
|
|
0d5ffb000e | ||
|
|
431a10fc50 | ||
|
|
8b993b557d | ||
|
|
60fdfc5f14 | ||
|
|
13e5b567f5 |
23
.github/workflows/ci.yml
vendored
23
.github/workflows/ci.yml
vendored
@@ -14,33 +14,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.8.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "store_path=$(pnpm store path --silent)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v4
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.store_path }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
node-version: 22
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Run all tasks using workspace filters
|
||||
|
||||
|
||||
5
.github/workflows/rust-release.yml
vendored
5
.github/workflows/rust-release.yml
vendored
@@ -111,6 +111,11 @@ jobs:
|
||||
cp target/${{ matrix.target }}/release/codex "$dest/codex-${{ matrix.target }}"
|
||||
fi
|
||||
|
||||
- if: ${{ matrix.runner == 'windows-11-arm' }}
|
||||
name: Install zstd
|
||||
shell: powershell
|
||||
run: choco install -y zstandard
|
||||
|
||||
- name: Compress artifacts
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
6
.vscode/extensions.json
vendored
6
.vscode/extensions.json
vendored
@@ -1,5 +1,11 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"rust-lang.rust-analyzer",
|
||||
"tamasfe.even-better-toml",
|
||||
"vadimcn.vscode-lldb",
|
||||
|
||||
// Useful if touching files in .github/workflows, though most
|
||||
// contributors will not be doing that?
|
||||
// "github.vscode-github-actions",
|
||||
]
|
||||
}
|
||||
|
||||
29
AGENTS.md
29
AGENTS.md
@@ -8,10 +8,10 @@ In the codex-rs folder where the rust code lives:
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
|
||||
Before finalizing a change to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
1. Run the test for the specific project that was changed. For example, if changes were made in `codex-rs/tui`, run `cargo test -p codex-tui`.
|
||||
2. Once those pass, if any changes were made in common, core, or protocol, run the complete test suite with `cargo test --all-features`.
|
||||
When running interactively, ask the user before running these commands to finalize.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
|
||||
## TUI style conventions
|
||||
|
||||
@@ -26,7 +26,26 @@ See `codex-rs/tui/styles.md`.
|
||||
- Example: patch summary file lines
|
||||
- Desired: vec![" └ ".into(), "M".red(), " ".dim(), "tui/src/app.rs".dim()]
|
||||
|
||||
## Snapshot tests
|
||||
### TUI Styling (ratatui)
|
||||
- Prefer Stylize helpers: use "text".dim(), .bold(), .cyan(), .italic(), .underlined() instead of manual Style where possible.
|
||||
- Prefer simple conversions: use "text".into() for spans and vec![…].into() for lines; when inference is ambiguous (e.g., Paragraph::new/Cell::from), use Line::from(spans) or Span::from(text).
|
||||
- Computed styles: if the Style is computed at runtime, using `Span::styled` is OK (`Span::from(text).set_style(style)` is also acceptable).
|
||||
- Avoid hardcoded white: do not use `.white()`; prefer the default foreground (no color).
|
||||
- Chaining: combine helpers by chaining for readability (e.g., url.cyan().underlined()).
|
||||
- Single items: prefer "text".into(); use Line::from(text) or Span::from(text) only when the target type isn’t obvious from context, or when using .into() would require extra type annotations.
|
||||
- Building lines: use vec![…].into() to construct a Line when the target type is obvious and no extra type annotations are needed; otherwise use Line::from(vec![…]).
|
||||
- Avoid churn: don’t refactor between equivalent forms (Span::styled ↔ set_style, Line::from ↔ .into()) without a clear readability or functional gain; follow file‑local conventions and do not introduce type annotations solely to satisfy .into().
|
||||
- Compactness: prefer the form that stays on one line after rustfmt; if only one of Line::from(vec![…]) or vec![…].into() avoids wrapping, choose that. If both wrap, pick the one with fewer wrapped lines.
|
||||
|
||||
### Text wrapping
|
||||
- Always use textwrap::wrap to wrap plain strings.
|
||||
- If you have a ratatui Line and you want to wrap it, use the helpers in tui/src/wrapping.rs, e.g. word_wrap_lines / word_wrap_line.
|
||||
- If you need to indent wrapped lines, use the initial_indent / subsequent_indent options from RtOptions if you can, rather than writing custom logic.
|
||||
- If you have a list of lines and you need to prefix them all with some prefix (optionally different on the first vs subsequent lines), use the `prefix_lines` helper from line_utils.
|
||||
|
||||
## Tests
|
||||
|
||||
### Snapshot tests
|
||||
|
||||
This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to validate rendered output. When UI or text output changes intentionally, update the snapshots as follows:
|
||||
|
||||
@@ -41,3 +60,7 @@ This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to va
|
||||
|
||||
If you don’t have the tool:
|
||||
- `cargo install cargo-insta`
|
||||
|
||||
### Test assertions
|
||||
|
||||
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
|
||||
|
||||
@@ -75,7 +75,7 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [CLI usage](./docs/getting-started.md#cli-usage)
|
||||
- [Running with a prompt as input](./docs/getting-started.md#running-with-a-prompt-as-input)
|
||||
- [Example prompts](./docs/getting-started.md#example-prompts)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory--project-docs)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory-with-agentsmd)
|
||||
- [Configuration](./docs/config.md)
|
||||
- [**Sandbox & approvals**](./docs/sandbox.md)
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
|
||||
@@ -43,7 +43,8 @@ switch (platform) {
|
||||
targetTriple = "x86_64-pc-windows-msvc.exe";
|
||||
break;
|
||||
case "arm64":
|
||||
// We do not build this today, fall through...
|
||||
targetTriple = "aarch64-pc-windows-msvc.exe";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ CODEX_CLI_ROOT=""
|
||||
# Until we start publishing stable GitHub releases, we have to grab the binaries
|
||||
# from the GitHub Action that created them. Update the URL below to point to the
|
||||
# appropriate workflow run:
|
||||
WORKFLOW_URL="https://github.com/openai/codex/actions/runs/16840150768" # rust-v0.20.0-alpha.2
|
||||
WORKFLOW_URL="https://github.com/openai/codex/actions/runs/17417194663" # rust-v0.28.0
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
@@ -87,5 +87,8 @@ zstd -d "$ARTIFACTS_DIR/aarch64-apple-darwin/codex-aarch64-apple-darwin.zst" \
|
||||
# x64 Windows
|
||||
zstd -d "$ARTIFACTS_DIR/x86_64-pc-windows-msvc/codex-x86_64-pc-windows-msvc.exe.zst" \
|
||||
-o "$BIN_DIR/codex-x86_64-pc-windows-msvc.exe"
|
||||
# ARM64 Windows
|
||||
zstd -d "$ARTIFACTS_DIR/aarch64-pc-windows-msvc/codex-aarch64-pc-windows-msvc.exe.zst" \
|
||||
-o "$BIN_DIR/codex-aarch64-pc-windows-msvc.exe"
|
||||
|
||||
echo "Installed native dependencies into $BIN_DIR"
|
||||
|
||||
216
codex-rs/Cargo.lock
generated
216
codex-rs/Cargo.lock
generated
@@ -298,17 +298,6 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
@@ -558,9 +547,9 @@ checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.45"
|
||||
version = "4.5.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc0e74a703892159f5ae7d3aac52c8e6c392f5ae5f359c70b5881d60aaac318"
|
||||
checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@@ -568,9 +557,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.44"
|
||||
version = "4.5.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
|
||||
checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -590,9 +579,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.45"
|
||||
version = "4.5.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14cb31bb0a7d536caef2639baa7fad459e15c3144efefa6dbd1c84562c4739f6"
|
||||
checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
@@ -636,10 +625,11 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"once_cell",
|
||||
"pretty_assertions",
|
||||
"similar",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
]
|
||||
@@ -665,7 +655,7 @@ dependencies = [
|
||||
"clap",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-login",
|
||||
"codex-protocol",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -718,7 +708,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"chrono",
|
||||
"codex-apply-patch",
|
||||
"codex-login",
|
||||
"codex-mcp-client",
|
||||
"codex-protocol",
|
||||
"core_test_support",
|
||||
@@ -748,7 +737,7 @@ dependencies = [
|
||||
"similar",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
@@ -847,6 +836,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
@@ -855,7 +845,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
"tiny_http",
|
||||
"tokio",
|
||||
"url",
|
||||
@@ -927,14 +917,18 @@ name = "codex-protocol"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"icu_decimal",
|
||||
"icu_locale_core",
|
||||
"mcp-types",
|
||||
"mime_guess",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"strum 0.27.2",
|
||||
"strum_macros 0.27.2",
|
||||
"sys-locale",
|
||||
"tracing",
|
||||
"ts-rs",
|
||||
"uuid",
|
||||
@@ -1278,12 +1272,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.10.0"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fb84100978c1c7b37f09ed3ce3e5f843af02c2a2c431bae5b19230dad2c1b490"
|
||||
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"deadpool-runtime",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
"tokio",
|
||||
]
|
||||
@@ -1716,6 +1710,26 @@ version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "fax"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f05de7d48f37cd6730705cbca900770cab77a89f413d23e100ad7fad7795a0ab"
|
||||
dependencies = [
|
||||
"fax_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fax_derive"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fd-lock"
|
||||
version = "4.0.4"
|
||||
@@ -1747,6 +1761,17 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixed_decimal"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35943d22b2f19c0cb198ecf915910a8158e94541c89dcc63300d7799d46c2c5e"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"smallvec",
|
||||
"writeable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
@@ -2130,13 +2155,14 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.6.0"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
|
||||
checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"futures-core",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
@@ -2144,6 +2170,7 @@ dependencies = [
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
"want",
|
||||
@@ -2244,6 +2271,45 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_decimal"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fec61c43fdc4e368a9f450272833123a8ef0d7083a44597660ce94d791b8a2e2"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"fixed_decimal",
|
||||
"icu_decimal_data",
|
||||
"icu_locale",
|
||||
"icu_locale_core",
|
||||
"icu_provider",
|
||||
"tinystr",
|
||||
"writeable",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_decimal_data"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b70963bc35f9bdf1bc66a5c1f458f4991c1dc71760e00fa06016b2c76b2738d5"
|
||||
|
||||
[[package]]
|
||||
name = "icu_locale"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ae5921528335e91da1b6c695dbf1ec37df5ac13faa3f91e5640be93aa2fbefd"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"icu_collections",
|
||||
"icu_locale_core",
|
||||
"icu_locale_data",
|
||||
"icu_provider",
|
||||
"potential_utf",
|
||||
"tinystr",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_locale_core"
|
||||
version = "2.0.0"
|
||||
@@ -2257,6 +2323,12 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_locale_data"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fdef0c124749d06a743c69e938350816554eb63ac979166590e2b4ee4252765"
|
||||
|
||||
[[package]]
|
||||
name = "icu_normalizer"
|
||||
version = "2.0.0"
|
||||
@@ -2368,9 +2440,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.25.6"
|
||||
version = "0.25.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db35664ce6b9810857a38a906215e75a9c879f0696556a39f59c62829710251a"
|
||||
checksum = "529feb3e6769d234375c4cf1ee2ce713682b8e76538cb13f9fc23e1400a591e7"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder-lite",
|
||||
@@ -2378,6 +2450,7 @@ dependencies = [
|
||||
"exr",
|
||||
"gif",
|
||||
"image-webp",
|
||||
"moxcms",
|
||||
"num-traits",
|
||||
"png",
|
||||
"qoi",
|
||||
@@ -2441,9 +2514,9 @@ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "insta"
|
||||
version = "1.43.1"
|
||||
version = "1.43.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371"
|
||||
checksum = "46fdb647ebde000f43b5b53f773c30cf9b0cb4300453208713fa38b2c70935a0"
|
||||
dependencies = [
|
||||
"console",
|
||||
"once_cell",
|
||||
@@ -2631,12 +2704,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jpeg-decoder"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00810f1d8b74be64b13dbf3db89ac67740615d6c891f0e7b6179326533011a07"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
@@ -2686,7 +2753,7 @@ checksum = "b3d2ef408b88e913bfc6594f5e693d57676f6463ded7d8bf994175364320c706"
|
||||
dependencies = [
|
||||
"enumflags2",
|
||||
"libc",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2935,6 +3002,16 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "moxcms"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddd32fa8935aeadb8a8a6b6b351e40225570a37c43de67690383d87ef170cd08"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"pxfm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.10.1"
|
||||
@@ -3443,11 +3520,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.17.16"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526"
|
||||
checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"bitflags 2.9.1",
|
||||
"crc32fast",
|
||||
"fdeflate",
|
||||
"flate2",
|
||||
@@ -3496,6 +3573,7 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
@@ -3616,6 +3694,15 @@ version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae"
|
||||
|
||||
[[package]]
|
||||
name = "pxfm"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f55f4fedc84ed39cb7a489322318976425e42a147e2be79d8f878e2884f94e84"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qoi"
|
||||
version = "0.4.1"
|
||||
@@ -3858,7 +3945,7 @@ checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b"
|
||||
dependencies = [
|
||||
"getrandom 0.2.16",
|
||||
"libredox",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4823,6 +4910,15 @@ dependencies = [
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sys-locale"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eab9a99a024a169fe8a903cf9d4a3b3601109bcc13bd9e3c6fff259138626c4"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.6.1"
|
||||
@@ -4943,11 +5039,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.12"
|
||||
version = "2.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
|
||||
checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.12",
|
||||
"thiserror-impl 2.0.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4963,9 +5059,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.12"
|
||||
version = "2.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -4983,13 +5079,16 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.9.1"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e"
|
||||
checksum = "af9605de7fee8d9551863fd692cce7637f548dbd9db9180fcc07ccc6d26c336f"
|
||||
dependencies = [
|
||||
"fax",
|
||||
"flate2",
|
||||
"jpeg-decoder",
|
||||
"half",
|
||||
"quick-error",
|
||||
"weezl",
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5361,9 +5460,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.8"
|
||||
version = "0.25.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d7b8994f367f16e6fa14b5aebbcb350de5d7cbea82dc5b00ae997dd71680dd2"
|
||||
checksum = "ccd2a058a86cfece0bf96f7cce1021efef9c8ed0e892ab74639173e5ed7a34fa"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
@@ -5402,7 +5501,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ef1b7a6d914a34127ed8e1fa927eb7088903787bcded4fa3eef8f85ee1568be"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 2.0.16",
|
||||
"ts-rs-macros",
|
||||
"uuid",
|
||||
]
|
||||
@@ -5542,9 +5641,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.17.0"
|
||||
version = "1.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
|
||||
checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"js-sys",
|
||||
@@ -6271,12 +6370,11 @@ checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904"
|
||||
|
||||
[[package]]
|
||||
name = "wiremock"
|
||||
version = "0.6.4"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2b8b99d4cdbf36b239a9532e31fe4fb8acc38d1897c1761e161550a7dc78e6a"
|
||||
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"deadpool",
|
||||
"futures",
|
||||
|
||||
@@ -9,7 +9,7 @@ use ratatui::text::Text;
|
||||
pub fn ansi_escape_line(s: &str) -> Line<'static> {
|
||||
let text = ansi_escape(s);
|
||||
match text.lines.as_slice() {
|
||||
[] => Line::from(""),
|
||||
[] => "".into(),
|
||||
[only] => only.clone(),
|
||||
[first, rest @ ..] => {
|
||||
tracing::warn!("ansi_escape_line: expected a single line, got {first:?} and {rest:?}");
|
||||
|
||||
@@ -17,9 +17,10 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
similar = "2.7.0"
|
||||
thiserror = "2.0.12"
|
||||
tree-sitter = "0.25.8"
|
||||
thiserror = "2.0.16"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
once_cell = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::str::Utf8Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use once_cell::sync::Lazy;
|
||||
pub use parser::Hunk;
|
||||
pub use parser::ParseError;
|
||||
use parser::ParseError::*;
|
||||
@@ -18,6 +19,9 @@ use similar::TextDiff;
|
||||
use thiserror::Error;
|
||||
use tree_sitter::LanguageError;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Query;
|
||||
use tree_sitter::QueryCursor;
|
||||
use tree_sitter::StreamingIterator;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
pub use standalone_executable::main;
|
||||
@@ -84,6 +88,7 @@ pub enum MaybeApplyPatch {
|
||||
pub struct ApplyPatchArgs {
|
||||
pub patch: String,
|
||||
pub hunks: Vec<Hunk>,
|
||||
pub workdir: Option<String>,
|
||||
}
|
||||
|
||||
pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
@@ -92,18 +97,18 @@ pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
[bash, flag, script]
|
||||
if bash == "bash"
|
||||
&& flag == "-lc"
|
||||
&& APPLY_PATCH_COMMANDS
|
||||
.iter()
|
||||
.any(|cmd| script.trim_start().starts_with(cmd)) =>
|
||||
{
|
||||
match extract_heredoc_body_from_apply_patch_command(script) {
|
||||
Ok(body) => match parse_patch(&body) {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
[bash, flag, script] if bash == "bash" && flag == "-lc" => {
|
||||
match extract_apply_patch_from_bash(script) {
|
||||
Ok((body, workdir)) => match parse_patch(&body) {
|
||||
Ok(mut source) => {
|
||||
source.workdir = workdir;
|
||||
MaybeApplyPatch::Body(source)
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch) => {
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::ShellParseError(e),
|
||||
}
|
||||
}
|
||||
@@ -203,10 +208,25 @@ impl ApplyPatchAction {
|
||||
/// patch.
|
||||
pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified {
|
||||
match maybe_parse_apply_patch(argv) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { patch, hunks }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs {
|
||||
patch,
|
||||
hunks,
|
||||
workdir,
|
||||
}) => {
|
||||
let effective_cwd = workdir
|
||||
.as_ref()
|
||||
.map(|dir| {
|
||||
let path = Path::new(dir);
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
cwd.join(path)
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| cwd.to_path_buf());
|
||||
let mut changes = HashMap::new();
|
||||
for hunk in hunks {
|
||||
let path = hunk.resolve_path(cwd);
|
||||
let path = hunk.resolve_path(&effective_cwd);
|
||||
match hunk {
|
||||
Hunk::AddFile { contents, .. } => {
|
||||
changes.insert(path, ApplyPatchFileChange::Add { content: contents });
|
||||
@@ -251,7 +271,7 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp
|
||||
MaybeApplyPatchVerified::Body(ApplyPatchAction {
|
||||
changes,
|
||||
patch,
|
||||
cwd: cwd.to_path_buf(),
|
||||
cwd: effective_cwd,
|
||||
})
|
||||
}
|
||||
MaybeApplyPatch::ShellParseError(e) => MaybeApplyPatchVerified::ShellParseError(e),
|
||||
@@ -260,33 +280,96 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to extract a heredoc_body object from a string bash command like:
|
||||
/// Optimistically
|
||||
/// Extract the heredoc body (and optional `cd` workdir) from a `bash -lc` script
|
||||
/// that invokes the apply_patch tool using a heredoc.
|
||||
///
|
||||
/// ```bash
|
||||
/// bash -lc 'apply_patch <<EOF\n***Begin Patch\n...EOF'
|
||||
/// ```
|
||||
/// Supported top‑level forms (must be the only top‑level statement):
|
||||
/// - `apply_patch <<'EOF'\n...\nEOF`
|
||||
/// - `cd <path> && apply_patch <<'EOF'\n...\nEOF`
|
||||
///
|
||||
/// # Arguments
|
||||
/// Notes about matching:
|
||||
/// - Parsed with Tree‑sitter Bash and a strict query that uses anchors so the
|
||||
/// heredoc‑redirected statement is the only top‑level statement.
|
||||
/// - The connector between `cd` and `apply_patch` must be `&&` (not `|` or `||`).
|
||||
/// - Exactly one positional `word` argument is allowed for `cd` (no flags, no quoted
|
||||
/// strings, no second argument).
|
||||
/// - The apply command is validated in‑query via `#any-of?` to allow `apply_patch`
|
||||
/// or `applypatch`.
|
||||
/// - Preceding or trailing commands (e.g., `echo ...;` or `... && echo done`) do not match.
|
||||
///
|
||||
/// * `src` - A string slice that holds the full command
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// This function returns a `Result` which is:
|
||||
///
|
||||
/// * `Ok(String)` - The heredoc body if the extraction is successful.
|
||||
/// * `Err(anyhow::Error)` - An error if the extraction fails.
|
||||
///
|
||||
fn extract_heredoc_body_from_apply_patch_command(
|
||||
/// Returns `(heredoc_body, Some(path))` when the `cd` variant matches, or
|
||||
/// `(heredoc_body, None)` for the direct form. Errors are returned if the script
|
||||
/// cannot be parsed or does not match the allowed patterns.
|
||||
fn extract_apply_patch_from_bash(
|
||||
src: &str,
|
||||
) -> std::result::Result<String, ExtractHeredocError> {
|
||||
if !APPLY_PATCH_COMMANDS
|
||||
.iter()
|
||||
.any(|cmd| src.trim_start().starts_with(cmd))
|
||||
{
|
||||
return Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch);
|
||||
}
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
// This function uses a Tree-sitter query to recognize one of two
|
||||
// whole-script forms, each expressed as a single top-level statement:
|
||||
//
|
||||
// 1. apply_patch <<'EOF'\n...\nEOF
|
||||
// 2. cd <path> && apply_patch <<'EOF'\n...\nEOF
|
||||
//
|
||||
// Key ideas when reading the query:
|
||||
// - dots (`.`) between named nodes enforces adjacency among named children and
|
||||
// anchor to the start/end of the expression.
|
||||
// - we match a single redirected_statement directly under program with leading
|
||||
// and trailing anchors (`.`). This ensures it is the only top-level statement
|
||||
// (so prefixes like `echo ...;` or suffixes like `... && echo done` do not match).
|
||||
//
|
||||
// Overall, we want to be conservative and only match the intended forms, as other
|
||||
// forms are likely to be model errors, or incorrectly interpreted by later code.
|
||||
//
|
||||
// If you're editing this query, it's helpful to start by creating a debugging binary
|
||||
// which will let you see the AST of an arbitrary bash script passed in, and optionally
|
||||
// also run an arbitrary query against the AST. This is useful for understanding
|
||||
// how tree-sitter parses the script and whether the query syntax is correct. Be sure
|
||||
// to test both positive and negative cases.
|
||||
static APPLY_PATCH_QUERY: Lazy<Query> = Lazy::new(|| {
|
||||
let language = BASH.into();
|
||||
#[expect(clippy::expect_used)]
|
||||
Query::new(
|
||||
&language,
|
||||
r#"
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (command
|
||||
name: (command_name (word) @apply_name) .)
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (list
|
||||
. (command
|
||||
name: (command_name (word) @cd_name) .
|
||||
argument: [
|
||||
(word) @cd_path
|
||||
(string (string_content) @cd_path)
|
||||
(raw_string) @cd_raw_string
|
||||
] .)
|
||||
"&&"
|
||||
. (command
|
||||
name: (command_name (word) @apply_name))
|
||||
.)
|
||||
(#eq? @cd_name "cd")
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
"#,
|
||||
)
|
||||
.expect("valid bash query")
|
||||
});
|
||||
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
@@ -298,26 +381,55 @@ fn extract_heredoc_body_from_apply_patch_command(
|
||||
.ok_or(ExtractHeredocError::FailedToParsePatchIntoAst)?;
|
||||
|
||||
let bytes = src.as_bytes();
|
||||
let mut c = tree.root_node().walk();
|
||||
let root = tree.root_node();
|
||||
|
||||
loop {
|
||||
let node = c.node();
|
||||
if node.kind() == "heredoc_body" {
|
||||
let text = node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
return Ok(text.trim_end_matches('\n').to_owned());
|
||||
}
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut matches = cursor.matches(&APPLY_PATCH_QUERY, root, bytes);
|
||||
while let Some(m) = matches.next() {
|
||||
let mut heredoc_text: Option<String> = None;
|
||||
let mut cd_path: Option<String> = None;
|
||||
|
||||
if c.goto_first_child() {
|
||||
continue;
|
||||
}
|
||||
while !c.goto_next_sibling() {
|
||||
if !c.goto_parent() {
|
||||
return Err(ExtractHeredocError::FailedToFindHeredocBody);
|
||||
for capture in m.captures.iter() {
|
||||
let name = APPLY_PATCH_QUERY.capture_names()[capture.index as usize];
|
||||
match name {
|
||||
"heredoc" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.trim_end_matches('\n')
|
||||
.to_string();
|
||||
heredoc_text = Some(text);
|
||||
}
|
||||
"cd_path" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.to_string();
|
||||
cd_path = Some(text);
|
||||
}
|
||||
"cd_raw_string" => {
|
||||
let raw = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
let trimmed = raw
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''))
|
||||
.unwrap_or(raw);
|
||||
cd_path = Some(trimmed.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(heredoc) = heredoc_text {
|
||||
return Ok((heredoc, cd_path));
|
||||
}
|
||||
}
|
||||
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -718,6 +830,49 @@ mod tests {
|
||||
strs.iter().map(|s| s.to_string()).collect()
|
||||
}
|
||||
|
||||
// Test helpers to reduce repetition when building bash -lc heredoc scripts
|
||||
fn args_bash(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["bash", "-lc", script])
|
||||
}
|
||||
|
||||
fn heredoc_script(prefix: &str) -> String {
|
||||
format!(
|
||||
"{prefix}apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH"
|
||||
)
|
||||
}
|
||||
|
||||
fn heredoc_script_ps(prefix: &str, suffix: &str) -> String {
|
||||
format!(
|
||||
"{prefix}apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH{suffix}"
|
||||
)
|
||||
}
|
||||
|
||||
fn expected_single_add() -> Vec<Hunk> {
|
||||
vec![Hunk::AddFile {
|
||||
path: PathBuf::from("foo"),
|
||||
contents: "hi\n".to_string(),
|
||||
}]
|
||||
}
|
||||
|
||||
fn assert_match(script: &str, expected_workdir: Option<&str>) {
|
||||
let args = args_bash(script);
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, workdir, .. }) => {
|
||||
assert_eq!(workdir.as_deref(), expected_workdir);
|
||||
assert_eq!(hunks, expected_single_add());
|
||||
}
|
||||
result => panic!("expected MaybeApplyPatch::Body got {result:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert!(matches!(
|
||||
maybe_parse_apply_patch(&args),
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_literal() {
|
||||
let args = strs_to_strings(&[
|
||||
@@ -730,7 +885,7 @@ mod tests {
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, .. }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -755,7 +910,7 @@ mod tests {
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, .. }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -770,29 +925,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_heredoc() {
|
||||
let args = strs_to_strings(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
r#"apply_patch <<'PATCH'
|
||||
*** Begin Patch
|
||||
*** Add File: foo
|
||||
+hi
|
||||
*** End Patch
|
||||
PATCH"#,
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
path: PathBuf::from("foo"),
|
||||
contents: "hi\n".to_string()
|
||||
}]
|
||||
);
|
||||
}
|
||||
result => panic!("expected MaybeApplyPatch::Body got {result:?}"),
|
||||
}
|
||||
assert_match(&heredoc_script(""), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -809,7 +942,8 @@ PATCH"#,
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, workdir, .. }) => {
|
||||
assert_eq!(workdir, None);
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -822,6 +956,69 @@ PATCH"#,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heredoc_with_leading_cd() {
|
||||
assert_match(&heredoc_script("cd foo && "), Some("foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_with_semicolon_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo; "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_or_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd bar || "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_pipe_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd bar | "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_single_quoted_path_with_spaces() {
|
||||
assert_match(&heredoc_script("cd 'foo bar' && "), Some("foo bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_double_quoted_path_with_spaces() {
|
||||
assert_match(&heredoc_script("cd \"foo bar\" && "), Some("foo bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_and_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("echo foo && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_patch_with_arg_is_ignored() {
|
||||
let script = "apply_patch foo <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH";
|
||||
assert_not_match(script);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_cd_then_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo && cd bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_two_args_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_then_apply_patch_then_extra_is_ignored() {
|
||||
let script = heredoc_script_ps("cd bar && ", " && echo done");
|
||||
assert_not_match(&script);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_then_cd_and_apply_patch_is_ignored() {
|
||||
// Ensure preceding commands before the `cd && apply_patch <<...` sequence do not match.
|
||||
assert_not_match(&heredoc_script("echo foo; cd bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_file_hunk_creates_file_with_contents() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
@@ -175,7 +175,11 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
|
||||
remaining_lines = &remaining_lines[hunk_lines..]
|
||||
}
|
||||
let patch = lines.join("\n");
|
||||
Ok(ApplyPatchArgs { hunks, patch })
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks,
|
||||
patch,
|
||||
workdir: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Checks the start and end lines of the patch text for `apply_patch`,
|
||||
@@ -586,7 +590,8 @@ fn test_parse_patch_lenient() {
|
||||
parse_patch_text(&patch_text_in_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -599,7 +604,8 @@ fn test_parse_patch_lenient() {
|
||||
parse_patch_text(&patch_text_in_single_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -612,7 +618,8 @@ fn test_parse_patch_lenient() {
|
||||
parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ const MISSPELLED_APPLY_PATCH_ARG0: &str = "applypatch";
|
||||
/// `codex-linux-sandbox` we *directly* execute
|
||||
/// [`codex_linux_sandbox::run_main`] (which never returns). Otherwise we:
|
||||
///
|
||||
/// 1. Use [`dotenvy::from_path`] and [`dotenvy::dotenv`] to modify the
|
||||
/// environment before creating any threads.
|
||||
/// 1. Load `.env` values from `~/.codex/.env` before creating any threads.
|
||||
/// 2. Construct a Tokio multi-thread runtime.
|
||||
/// 3. Derive the path to the current executable (so children can re-invoke the
|
||||
/// sandbox) when running on Linux.
|
||||
@@ -106,7 +105,7 @@ where
|
||||
|
||||
const ILLEGAL_ENV_VAR_PREFIX: &str = "CODEX_";
|
||||
|
||||
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
|
||||
/// Load env vars from ~/.codex/.env.
|
||||
///
|
||||
/// Security: Do not allow `.env` files to create or modify any variables
|
||||
/// with names starting with `CODEX_`.
|
||||
@@ -116,10 +115,6 @@ fn load_dotenv() {
|
||||
{
|
||||
set_filtered(iter);
|
||||
}
|
||||
|
||||
if let Ok(iter) = dotenvy::dotenv_iter() {
|
||||
set_filtered(iter);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to set vars from a dotenvy iterator while filtering out `CODEX_` keys.
|
||||
|
||||
@@ -11,7 +11,7 @@ anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
@@ -31,7 +31,7 @@ pub async fn run_apply_command(
|
||||
ConfigOverrides::default(),
|
||||
)?;
|
||||
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
|
||||
let task_response = get_task(&config, apply_cli.task_id).await?;
|
||||
apply_diff_from_task(task_response, cwd).await
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use codex_core::config::Config;
|
||||
use codex_core::user_agent::get_codex_user_agent;
|
||||
use codex_core::default_client::create_client;
|
||||
|
||||
use crate::chatgpt_token::get_chatgpt_token_data;
|
||||
use crate::chatgpt_token::init_chatgpt_token_from_auth;
|
||||
@@ -13,10 +13,10 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
path: String,
|
||||
) -> anyhow::Result<T> {
|
||||
let chatgpt_base_url = &config.chatgpt_base_url;
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
|
||||
// Make direct HTTP request to ChatGPT backend API with the token
|
||||
let client = reqwest::Client::new();
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
let url = format!("{chatgpt_base_url}{path}");
|
||||
|
||||
let token =
|
||||
@@ -31,7 +31,6 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
.bearer_auth(&token.access_token)
|
||||
.header("chatgpt-account-id", account_id?)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", get_codex_user_agent(None))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request")?;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use codex_login::TokenData;
|
||||
use codex_core::token_data::TokenData;
|
||||
|
||||
static CHATGPT_TOKEN: LazyLock<RwLock<Option<TokenData>>> = LazyLock::new(|| RwLock::new(None));
|
||||
|
||||
@@ -19,8 +19,11 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
||||
}
|
||||
|
||||
/// Initialize the ChatGPT token from auth.json file
|
||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT)?;
|
||||
pub async fn init_chatgpt_token_from_auth(
|
||||
codex_home: &Path,
|
||||
originator: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT, originator)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::logout;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CLIENT_ID;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_login::logout;
|
||||
use codex_login::run_login_server;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
let opts = ServerOptions::new(codex_home, CLIENT_ID.to_string());
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf, originator: String) -> std::io::Result<()> {
|
||||
let opts = ServerOptions::new(codex_home, CLIENT_ID.to_string(), originator);
|
||||
let server = run_login_server(opts)?;
|
||||
|
||||
eprintln!(
|
||||
@@ -27,7 +27,12 @@ pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match login_with_chatgpt(config.codex_home).await {
|
||||
match login_with_chatgpt(
|
||||
config.codex_home,
|
||||
config.responses_originator_header.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
eprintln!("Successfully logged in");
|
||||
std::process::exit(0);
|
||||
@@ -60,7 +65,11 @@ pub async fn run_login_with_api_key(
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method) {
|
||||
match CodexAuth::from_codex_home(
|
||||
&config.codex_home,
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
Ok(api_key) => {
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::io::IsTerminal;
|
||||
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
@@ -9,7 +10,6 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_login::AuthManager;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tracing::error;
|
||||
@@ -40,6 +40,7 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
|
||||
@@ -18,7 +18,6 @@ base64 = "0.22"
|
||||
bytes = "1.10.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-apply-patch = { path = "../apply-patch" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-mcp-client = { path = "../mcp-client" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
dirs = "6"
|
||||
@@ -41,8 +40,8 @@ shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
strum_macros = "0.27.2"
|
||||
tempfile = "3"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
thiserror = "2.0.16"
|
||||
time = { version = "0.3", features = ["formatting", "parsing", "local-offset", "macros"] }
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -54,7 +53,7 @@ tokio-util = "0.7.16"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.23.4"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
whoami = "1.6.1"
|
||||
|
||||
@@ -14,6 +14,18 @@ Within this context, Codex refers to the open-source agentic coding interface (n
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
@@ -228,7 +240,6 @@ You are producing plain text that will later be styled by the CLI. Follow these
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Bold the keyword, then colon + concise description.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
|
||||
793
codex-rs/core/src/auth.rs
Normal file
793
codex-rs/core/src/auth.rs
Normal file
@@ -0,0 +1,793 @@
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodexAuth {
|
||||
pub mode: AuthMode,
|
||||
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) auth_dot_json: Arc<Mutex<Option<AuthDotJson>>>,
|
||||
pub(crate) auth_file: PathBuf,
|
||||
pub(crate) client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl PartialEq for CodexAuth {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.mode == other.mode
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexAuth {
|
||||
pub async fn refresh_token(&self) -> Result<String, std::io::Error> {
|
||||
let token_data = self
|
||||
.get_current_token_data()
|
||||
.ok_or(std::io::Error::other("Token data is not available."))?;
|
||||
let token = token_data.refresh_token;
|
||||
|
||||
let refresh_response = try_refresh_token(token, &self.client)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Ok(mut auth_lock) = self.auth_dot_json.lock() {
|
||||
*auth_lock = Some(updated.clone());
|
||||
}
|
||||
|
||||
let access = match updated.tokens {
|
||||
Some(t) => t.access_token,
|
||||
None => {
|
||||
return Err(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json or
|
||||
/// OPENAI_API_KEY environment variable.
|
||||
pub fn from_codex_home(
|
||||
codex_home: &Path,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home, true, preferred_auth_method, originator)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
let auth_dot_json: Option<AuthDotJson> = self.get_current_auth_json();
|
||||
match auth_dot_json {
|
||||
Some(AuthDotJson {
|
||||
tokens: Some(mut tokens),
|
||||
last_refresh: Some(last_refresh),
|
||||
..
|
||||
}) => {
|
||||
if last_refresh < Utc::now() - chrono::Duration::days(28) {
|
||||
let refresh_response = tokio::time::timeout(
|
||||
Duration::from_secs(60),
|
||||
try_refresh_token(tokens.refresh_token.clone(), &self.client),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
std::io::Error::other("timed out while refreshing OpenAI API key")
|
||||
})?
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated_auth_dot_json = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokens = updated_auth_dot_json
|
||||
.tokens
|
||||
.clone()
|
||||
.ok_or(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
))?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let mut auth_lock = self.auth_dot_json.lock().unwrap();
|
||||
*auth_lock = Some(updated_auth_dot_json);
|
||||
}
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
_ => Err(std::io::Error::other("Token data is not available.")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_token(&self) -> Result<String, std::io::Error> {
|
||||
match self.mode {
|
||||
AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()),
|
||||
AuthMode::ChatGPT => {
|
||||
let id_token = self.get_token_data().await?.access_token;
|
||||
Ok(id_token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.account_id.clone())
|
||||
}
|
||||
|
||||
pub fn get_plan_type(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type.as_ref().map(|p| p.as_string()))
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
self.auth_dot_json.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens.clone())
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
pub fn create_dummy_chatgpt_auth_for_testing() -> Self {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: Default::default(),
|
||||
access_token: "Access Token".to_string(),
|
||||
refresh_token: "test".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
}),
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
let auth_dot_json = Arc::new(Mutex::new(Some(auth_dot_json)));
|
||||
Self {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json,
|
||||
client: crate::default_client::create_client("codex_cli_rs"),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_api_key_with_client(api_key: &str, client: reqwest::Client) -> Self {
|
||||
Self {
|
||||
api_key: Some(api_key.to_owned()),
|
||||
mode: AuthMode::ApiKey,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json: Arc::new(Mutex::new(None)),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_api_key(api_key: &str) -> Self {
|
||||
Self::from_api_key_with_client(
|
||||
api_key,
|
||||
crate::default_client::create_client(crate::default_client::DEFAULT_ORIGINATOR),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join("auth.json")
|
||||
}
|
||||
|
||||
/// Delete the auth.json file inside `codex_home` if it exists. Returns `Ok(true)`
|
||||
/// if a file was removed, `Ok(false)` if no auth file was present.
|
||||
pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
match std::fs::remove_file(&auth_file) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an `auth.json` that contains only the API key. Intended for CLI use.
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
fn load_auth(
|
||||
codex_home: &Path,
|
||||
include_env_var: bool,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
// First, check to see if there is a valid auth.json file. If not, we fall
|
||||
// back to AuthMode::ApiKey using the OPENAI_API_KEY environment variable
|
||||
// (if it is set).
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let client = crate::default_client::create_client(originator);
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
// If auth.json does not exist, try to read the OPENAI_API_KEY from the
|
||||
// environment variable.
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound && include_env_var => {
|
||||
return match read_openai_api_key_from_env() {
|
||||
Some(api_key) => Ok(Some(CodexAuth::from_api_key_with_client(&api_key, client))),
|
||||
None => Ok(None),
|
||||
};
|
||||
}
|
||||
// Though if auth.json exists but is malformed, do not fall back to the
|
||||
// env var because the user may be expecting to use AuthMode::ChatGPT.
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let AuthDotJson {
|
||||
openai_api_key: auth_json_api_key,
|
||||
tokens,
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// If the auth.json has an API key AND does not appear to be on a plan that
|
||||
// should prefer AuthMode::ChatGPT, use AuthMode::ApiKey.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
// Should any of these be AuthMode::ChatGPT with the api_key set?
|
||||
// Does AuthMode::ChatGPT indicate that there is an auth.json that is
|
||||
// "refreshable" even if we are using the API key for auth?
|
||||
match &tokens {
|
||||
Some(tokens) => {
|
||||
if tokens.should_use_api_key(preferred_auth_method, tokens.is_openai_email()) {
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
} else {
|
||||
// Ignore the API key and fall through to ChatGPT auth.
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// We have an API key but no tokens in the auth.json file.
|
||||
// Perhaps the user ran `codex login --api-key <KEY>` or updated
|
||||
// auth.json by hand. Either way, let's assume they are trying
|
||||
// to use their API key.
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For the AuthMode::ChatGPT variant, perhaps neither api_key nor
|
||||
// openai_api_key should exist?
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file,
|
||||
auth_dot_json: Arc::new(Mutex::new(Some(AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens,
|
||||
last_refresh,
|
||||
}))),
|
||||
client,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
||||
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
||||
pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||
let mut file = File::open(auth_file)?;
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents)?;
|
||||
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
||||
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
pub fn write_auth_json(auth_file: &Path, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
|
||||
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
|
||||
let mut options = OpenOptions::new();
|
||||
options.truncate(true).write(true).create(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
options.mode(0o600);
|
||||
}
|
||||
let mut file = options.open(auth_file)?;
|
||||
file.write_all(json_data.as_bytes())?;
|
||||
file.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_tokens(
|
||||
auth_file: &Path,
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
) -> std::io::Result<AuthDotJson> {
|
||||
let mut auth_dot_json = try_read_auth_json(auth_file)?;
|
||||
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token.to_string();
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token.to_string();
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
async fn try_refresh_token(
|
||||
refresh_token: String,
|
||||
client: &reqwest::Client,
|
||||
) -> std::io::Result<RefreshResponse> {
|
||||
let refresh_request = RefreshRequest {
|
||||
client_id: CLIENT_ID,
|
||||
grant_type: "refresh_token",
|
||||
refresh_token,
|
||||
scope: "openid profile email",
|
||||
};
|
||||
|
||||
// Use shared client factory to include standard headers
|
||||
let response = client
|
||||
.post("https://auth.openai.com/oauth/token")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&refresh_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let refresh_response = response
|
||||
.json::<RefreshResponse>()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(refresh_response)
|
||||
} else {
|
||||
Err(std::io::Error::other(format!(
|
||||
"Failed to refresh token: {}",
|
||||
response.status()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest {
|
||||
client_id: &'static str,
|
||||
grant_type: &'static str,
|
||||
refresh_token: String,
|
||||
scope: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct RefreshResponse {
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Expected structure for $CODEX_HOME/auth.json.
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
pub struct AuthDotJson {
|
||||
#[serde(rename = "OPENAI_API_KEY")]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<TokenData>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// Shared constant for token refresh (client id used for oauth token refresh flow)
|
||||
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
|
||||
use std::sync::RwLock;
|
||||
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
preferred_auth_mode: AuthMode,
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::token_data::IdTokenInfo;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use base64::Engine;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
|
||||
const LAST_REFRESH: &str = "2025-08-06T20:41:36.232376Z";
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_auth_dot_json() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let _ = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let file = get_auth_file(codex_home.path());
|
||||
let auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
write_auth_json(&file, &auth_dot_json).unwrap();
|
||||
|
||||
let same_auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
assert_eq!(auth_dot_json, same_auth_dot_json);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_no_api_key_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// Even if the OPENAI_API_KEY is set in auth.json, if the plan is not in
|
||||
/// [`TokenData::is_plan_that_should_use_api_key`], it should use
|
||||
/// [`AuthMode::ChatGPT`].
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_api_key_still_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// If the OPENAI_API_KEY is set in auth.json and it is an enterprise
|
||||
/// account, then it should use [`AuthMode::ApiKey`].
|
||||
#[tokio::test]
|
||||
async fn enterprise_account_with_api_key_uses_apikey_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "enterprise".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(Some("sk-test-key".to_string()), api_key);
|
||||
assert_eq!(AuthMode::ApiKey, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().expect("should unwrap");
|
||||
assert!(guard.is_none(), "auth_dot_json should be None");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let auth_file = dir.path().join("auth.json");
|
||||
std::fs::write(
|
||||
auth_file,
|
||||
r#"{"OPENAI_API_KEY":"sk-test-key","tokens":null,"last_refresh":null}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = super::load_auth(dir.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
assert!(auth.get_token_data().await.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logout_removes_auth_file() -> Result<(), std::io::Error> {
|
||||
let dir = tempdir()?;
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&get_auth_file(dir.path()), &auth_dot_json)?;
|
||||
assert!(dir.path().join("auth.json").exists());
|
||||
let removed = logout(dir.path())?;
|
||||
assert!(removed);
|
||||
assert!(!dir.path().join("auth.json").exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct AuthFileParams {
|
||||
openai_api_key: Option<String>,
|
||||
chatgpt_plan_type: String,
|
||||
}
|
||||
|
||||
fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result<String> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
// Create a minimal valid JWT for the id_token field.
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_account_id": "bc3618e3-489d-4d49-9362-1561dc53ba53",
|
||||
"chatgpt_plan_type": params.chatgpt_plan_type,
|
||||
"chatgpt_user_id": "user-12345",
|
||||
"user_id": "user-12345",
|
||||
}
|
||||
});
|
||||
let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b);
|
||||
let header_b64 = b64(&serde_json::to_vec(&header)?);
|
||||
let payload_b64 = b64(&serde_json::to_vec(&payload)?);
|
||||
let signature_b64 = b64(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let auth_json_data = json!({
|
||||
"OPENAI_API_KEY": params.openai_api_key,
|
||||
"tokens": {
|
||||
"id_token": fake_jwt,
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token"
|
||||
},
|
||||
"last_refresh": LAST_REFRESH,
|
||||
});
|
||||
let auth_json = serde_json::to_string_pretty(&auth_json_data)?;
|
||||
std::fs::write(auth_file, auth_json)?;
|
||||
Ok(fake_jwt)
|
||||
}
|
||||
}
|
||||
|
||||
/// Central manager providing a single source of truth for auth.json derived
|
||||
/// authentication data. It loads once (or on preference change) and then
|
||||
/// hands out cloned `CodexAuth` values so the rest of the program has a
|
||||
/// consistent snapshot.
|
||||
///
|
||||
/// External modifications to `auth.json` will NOT be observed until
|
||||
/// `reload()` is called explicitly. This matches the design goal of avoiding
|
||||
/// different parts of the program seeing inconsistent auth data mid‑run.
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
originator: String,
|
||||
inner: RwLock<CachedAuth>,
|
||||
}
|
||||
|
||||
impl AuthManager {
|
||||
/// Create a new manager loading the initial auth using the provided
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode, originator: String) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home, preferred_auth_mode, &originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
originator,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let preferred_auth_mode = auth.mode;
|
||||
let cached = CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth: Some(auth),
|
||||
};
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
originator: "codex_cli_rs".to_string(),
|
||||
inner: RwLock::new(cached),
|
||||
})
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
pub fn auth(&self) -> Option<CodexAuth> {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Preferred auth method used when (re)loading.
|
||||
pub fn preferred_auth_method(&self) -> AuthMode {
|
||||
self.inner
|
||||
.read()
|
||||
.map(|c| c.preferred_auth_mode)
|
||||
.unwrap_or(AuthMode::ApiKey)
|
||||
}
|
||||
|
||||
/// Force a reload using the existing preferred auth method. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let preferred = self.preferred_auth_method();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home, preferred, &self.originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn auths_equal(a: &Option<CodexAuth>, b: &Option<CodexAuth>) -> bool {
|
||||
match (a, b) {
|
||||
(None, None) => true,
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(
|
||||
codex_home: PathBuf,
|
||||
preferred_auth_mode: AuthMode,
|
||||
originator: String,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home, preferred_auth_mode, originator))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
/// the auth state from disk so other components observe refreshed token.
|
||||
pub async fn refresh_token(&self) -> std::io::Result<Option<String>> {
|
||||
let auth = match self.auth() {
|
||||
Some(a) => a,
|
||||
None => return Ok(None),
|
||||
};
|
||||
match auth.refresh_token().await {
|
||||
Ok(token) => {
|
||||
// Reload to pick up persisted changes.
|
||||
self.reload();
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Log out by deleting the on‑disk auth.json (if present). Returns Ok(true)
|
||||
/// if a file was removed, Ok(false) if no auth file existed. On success,
|
||||
/// reloads the in‑memory auth cache so callers immediately observe the
|
||||
/// unauthenticated state.
|
||||
pub fn logout(&self) -> std::io::Result<bool> {
|
||||
let removed = super::auth::logout(&self.codex_home)?;
|
||||
// Always reload to clear any cached auth (even if file absent).
|
||||
self.reload();
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,6 @@ use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::config::Config;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
@@ -35,7 +34,6 @@ pub(crate) async fn stream_chat_completions(
|
||||
model_family: &ModelFamily,
|
||||
client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
config: &Config,
|
||||
) -> Result<ResponseStream> {
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
@@ -45,7 +43,107 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for c in items {
|
||||
match c {
|
||||
ReasoningItemContent::ReasoningText { text: t }
|
||||
| ReasoningItemContent::Text { text: t } => text.push_str(t),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
@@ -58,7 +156,24 @@ pub(crate) async fn stream_chat_completions(
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
messages.push(json!({"role": role, "content": text}));
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let mut msg = json!({"role": role, "content": text});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
@@ -66,7 +181,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
@@ -77,7 +192,13 @@ pub(crate) async fn stream_chat_completions(
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
@@ -86,7 +207,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
messages.push(json!({
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
@@ -95,7 +216,13 @@ pub(crate) async fn stream_chat_completions(
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
}));
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
@@ -131,26 +258,10 @@ pub(crate) async fn stream_chat_completions(
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::Reasoning {
|
||||
id: _,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content: _,
|
||||
} => {
|
||||
if !config.skip_reasoning_in_chat_completions {
|
||||
// There is no clear way of sending reasoning items over chat completions.
|
||||
// We are sending it as an assistant message.
|
||||
tracing::info!("reasoning item: {:?}", item);
|
||||
let reasoning =
|
||||
format!("Reasoning Summary: {summary:?}, Reasoning Content: {content:?}");
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": reasoning,
|
||||
}));
|
||||
}
|
||||
}
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {
|
||||
tracing::info!("omitting item from chat completions: {:?}", item);
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -349,7 +460,10 @@ async fn process_chat_sse<S>(
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string());
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
@@ -368,7 +482,7 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at end-of-turn.
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
reasoning_text.push_str(&reasoning);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
|
||||
@@ -376,6 +490,31 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
reasoning_text.push_str(s);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
reasoning_text.push_str(s);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
@@ -531,27 +670,47 @@ where
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_delta = matches!(&item, codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant");
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_delta {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message { content, .. } =
|
||||
&item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText { text } => {
|
||||
Some(text)
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
@@ -583,6 +742,11 @@ where
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use bytes::Bytes;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::AuthMode;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::prelude::*;
|
||||
use regex_lite::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -17,7 +20,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -28,6 +30,7 @@ use crate::client_common::ResponsesApiRequest;
|
||||
use crate::client_common::create_reasoning_param_for_request;
|
||||
use crate::client_common::create_text_param_for_request;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::create_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::UsageLimitReachedError;
|
||||
@@ -38,7 +41,6 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::user_agent::get_codex_user_agent;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -53,6 +55,8 @@ struct ErrorResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
|
||||
// Optional fields available on "usage_limit_reached" and "usage_not_included" errors
|
||||
@@ -66,7 +70,7 @@ pub struct ModelClient {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
@@ -78,14 +82,16 @@ impl ModelClient {
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> Self {
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
provider,
|
||||
session_id,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -110,7 +116,6 @@ impl ModelClient {
|
||||
&self.config.model_family,
|
||||
&self.client,
|
||||
&self.provider,
|
||||
&self.config,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -152,14 +157,6 @@ impl ModelClient {
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
|
||||
let auth_mode = auth_manager
|
||||
.as_ref()
|
||||
.and_then(|m| m.auth())
|
||||
.as_ref()
|
||||
.map(|a| a.mode);
|
||||
|
||||
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
|
||||
let tools_json = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
let reasoning = create_reasoning_param_for_request(
|
||||
@@ -168,9 +165,7 @@ impl ModelClient {
|
||||
self.summary,
|
||||
);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
let include: Vec<String> = if !store && reasoning.is_some() {
|
||||
let include: Vec<String> = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
@@ -199,10 +194,10 @@ impl ModelClient {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
store,
|
||||
store: false,
|
||||
stream: true,
|
||||
include,
|
||||
prompt_cache_key: Some(self.session_id.to_string()),
|
||||
prompt_cache_key: Some(self.conversation_id.to_string()),
|
||||
text,
|
||||
};
|
||||
|
||||
@@ -228,7 +223,9 @@ impl ModelClient {
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
// Send session_id for compatibility.
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
@@ -239,10 +236,6 @@ impl ModelClient {
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let originator = &self.config.responses_originator_header;
|
||||
req_builder = req_builder.header("originator", originator);
|
||||
req_builder = req_builder.header("User-Agent", get_codex_user_agent(Some(originator)));
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
@@ -407,9 +400,15 @@ impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val.input_tokens_details.map(|d| d.cached_tokens),
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val.output_tokens_details.map(|d| d.reasoning_tokens),
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
@@ -565,8 +564,9 @@ async fn process_sse<S>(
|
||||
if let Some(error) = error {
|
||||
match serde_json::from_value::<Error>(error.clone()) {
|
||||
Ok(error) => {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, None));
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("failed to parse ErrorResponse: {e}");
|
||||
@@ -652,6 +652,40 @@ async fn stream_from_fixture(
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static Regex {
|
||||
static RE: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").unwrap())
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
if err.code != Some("rate_limit_exceeded".to_string()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// parse the Please try again in 1.898s format using regex
|
||||
let re = rate_limit_regex();
|
||||
if let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str();
|
||||
|
||||
if unit == "s" {
|
||||
return Some(Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -872,7 +906,7 @@ mod tests {
|
||||
msg,
|
||||
"Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
|
||||
);
|
||||
assert_eq!(*delay, None);
|
||||
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
@@ -976,4 +1010,31 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_in_seconds: None
|
||||
};
|
||||
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_millis(28)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_no_delay() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_in_seconds: None
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::config_types::Verbosity as VerbosityConfig;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
@@ -6,7 +5,7 @@ use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
@@ -20,22 +19,15 @@ use tokio::sync::mpsc;
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
|
||||
/// wraps user instructions message in a tag for the model to parse more easily.
|
||||
const USER_INSTRUCTIONS_START: &str = "<user_instructions>\n\n";
|
||||
const USER_INSTRUCTIONS_END: &str = "\n\n</user_instructions>";
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub tools: Vec<OpenAiTool>,
|
||||
pub(crate) tools: Vec<OpenAiTool>,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
@@ -68,17 +60,6 @@ impl Prompt {
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
}
|
||||
|
||||
/// Creates a formatted user instructions message from a string
|
||||
pub(crate) fn format_user_instructions_message(ui: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -144,7 +125,6 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
@@ -174,7 +154,7 @@ pub(crate) fn create_text_param_for_request(
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) struct ResponseStream {
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
@@ -215,7 +195,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
@@ -245,7 +225,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use crate::config_types::History;
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::config_types::SandboxWorkspaceWrite;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::config_types::ShellEnvironmentPolicyToml;
|
||||
use crate::config_types::Tui;
|
||||
use crate::config_types::UriBasedFileOpener;
|
||||
use crate::config_types::Verbosity;
|
||||
use crate::git_info::resolve_root_git_project_for_trust;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::model_family::find_family_for_model;
|
||||
@@ -15,10 +15,13 @@ use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use codex_login::AuthMode;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::Tools;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@@ -75,11 +78,6 @@ pub struct Config {
|
||||
/// Defaults to `false`.
|
||||
pub show_raw_agent_reasoning: bool,
|
||||
|
||||
/// Disable server-side response storage (sends the full conversation
|
||||
/// context with every request). Currently necessary for OpenAI customers
|
||||
/// who have opted into Zero Data Retention (ZDR).
|
||||
pub disable_response_storage: bool,
|
||||
|
||||
/// User-provided instructions from AGENTS.md.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
@@ -185,10 +183,6 @@ pub struct Config {
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
pub disable_paste_burst: bool,
|
||||
|
||||
/// When `true`, reasoning items in Chat Completions input will be skipped.
|
||||
/// Defaults to `false`.
|
||||
pub skip_reasoning_in_chat_completions: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -418,11 +412,6 @@ pub struct ConfigToml {
|
||||
/// Sandbox configuration to apply if `sandbox` is `WorkspaceWrite`.
|
||||
pub sandbox_workspace_write: Option<SandboxWorkspaceWrite>,
|
||||
|
||||
/// Disable server-side response storage (sends the full conversation
|
||||
/// context with every request). Currently necessary for OpenAI customers
|
||||
/// who have opted into Zero Data Retention (ZDR).
|
||||
pub disable_response_storage: Option<bool>,
|
||||
|
||||
/// Optional external command to spawn for end-user notifications.
|
||||
#[serde(default)]
|
||||
pub notify: Option<Vec<String>>,
|
||||
@@ -475,6 +464,9 @@ pub struct ConfigToml {
|
||||
/// Override to force-enable reasoning summaries for the configured model.
|
||||
pub model_supports_reasoning_summaries: Option<bool>,
|
||||
|
||||
/// Override to force reasoning summary format for the configured model.
|
||||
pub model_reasoning_summary_format: Option<ReasoningSummaryFormat>,
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
@@ -501,10 +493,29 @@ pub struct ConfigToml {
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
pub disable_paste_burst: Option<bool>,
|
||||
}
|
||||
|
||||
/// When set to `true`, reasoning items will be skipped from Chat Completions input.
|
||||
/// Defaults to `false`.
|
||||
pub skip_reasoning_in_chat_completions: Option<bool>,
|
||||
impl From<ConfigToml> for UserSavedConfig {
|
||||
fn from(config_toml: ConfigToml) -> Self {
|
||||
let profiles = config_toml
|
||||
.profiles
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
approval_policy: config_toml.approval_policy,
|
||||
sandbox_mode: config_toml.sandbox_mode,
|
||||
sandbox_settings: config_toml.sandbox_workspace_write.map(From::from),
|
||||
model: config_toml.model,
|
||||
model_reasoning_effort: config_toml.model_reasoning_effort,
|
||||
model_reasoning_summary: config_toml.model_reasoning_summary,
|
||||
model_verbosity: config_toml.model_verbosity,
|
||||
tools: config_toml.tools.map(From::from),
|
||||
profile: config_toml.profile,
|
||||
profiles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
@@ -522,6 +533,15 @@ pub struct ToolsToml {
|
||||
pub view_image: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ToolsToml> for Tools {
|
||||
fn from(tools_toml: ToolsToml) -> Self {
|
||||
Self {
|
||||
web_search: tools_toml.web_search,
|
||||
view_image: tools_toml.view_image,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
/// Derive the effective sandbox policy from the configuration.
|
||||
fn derive_sandbox_policy(&self, sandbox_mode_override: Option<SandboxMode>) -> SandboxPolicy {
|
||||
@@ -610,7 +630,6 @@ pub struct ConfigOverrides {
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub include_view_image_tool: Option<bool>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub show_raw_agent_reasoning: Option<bool>,
|
||||
pub tools_web_search_request: Option<bool>,
|
||||
}
|
||||
@@ -638,7 +657,6 @@ impl Config {
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_view_image_tool,
|
||||
disable_response_storage,
|
||||
show_raw_agent_reasoning,
|
||||
tools_web_search_request: override_tools_web_search_request,
|
||||
} = overrides;
|
||||
@@ -714,19 +732,24 @@ impl Config {
|
||||
.or(config_profile.model)
|
||||
.or(cfg.model)
|
||||
.unwrap_or_else(default_model);
|
||||
let model_family = find_family_for_model(&model).unwrap_or_else(|| {
|
||||
let supports_reasoning_summaries =
|
||||
cfg.model_supports_reasoning_summaries.unwrap_or(false);
|
||||
ModelFamily {
|
||||
slug: model.clone(),
|
||||
family: model.clone(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
}
|
||||
|
||||
let mut model_family = find_family_for_model(&model).unwrap_or_else(|| ModelFamily {
|
||||
slug: model.clone(),
|
||||
family: model.clone(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
});
|
||||
|
||||
if let Some(supports_reasoning_summaries) = cfg.model_supports_reasoning_summaries {
|
||||
model_family.supports_reasoning_summaries = supports_reasoning_summaries;
|
||||
}
|
||||
if let Some(model_reasoning_summary_format) = cfg.model_reasoning_summary_format {
|
||||
model_family.reasoning_summary_format = model_reasoning_summary_format;
|
||||
}
|
||||
|
||||
let openai_model_info = get_model_info(&model_family);
|
||||
let model_context_window = cfg
|
||||
.model_context_window
|
||||
@@ -768,11 +791,6 @@ impl Config {
|
||||
.unwrap_or_else(AskForApproval::default),
|
||||
sandbox_policy,
|
||||
shell_environment_policy,
|
||||
disable_response_storage: config_profile
|
||||
.disable_response_storage
|
||||
.or(cfg.disable_response_storage)
|
||||
.or(disable_response_storage)
|
||||
.unwrap_or(false),
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -815,9 +833,6 @@ impl Config {
|
||||
.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
skip_reasoning_in_chat_completions: cfg
|
||||
.skip_reasoning_in_chat_completions
|
||||
.unwrap_or(false),
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -1040,7 +1055,6 @@ exclude_slash_tmp = true
|
||||
let toml = r#"
|
||||
model = "o3"
|
||||
approval_policy = "untrusted"
|
||||
disable_response_storage = false
|
||||
|
||||
# Can be used to determine which profile to use if not specified by
|
||||
# `ConfigOverrides`.
|
||||
@@ -1070,7 +1084,14 @@ model_provider = "openai-chat-completions"
|
||||
model = "o3"
|
||||
model_provider = "openai"
|
||||
approval_policy = "on-failure"
|
||||
disable_response_storage = true
|
||||
|
||||
[profiles.gpt5]
|
||||
model = "gpt-5"
|
||||
model_provider = "openai"
|
||||
approval_policy = "on-failure"
|
||||
model_reasoning_effort = "high"
|
||||
model_reasoning_summary = "detailed"
|
||||
model_verbosity = "high"
|
||||
"#;
|
||||
|
||||
let cfg: ConfigToml = toml::from_str(toml).expect("TOML deserialization should succeed");
|
||||
@@ -1160,7 +1181,6 @@ disable_response_storage = true
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1188,7 +1208,6 @@ disable_response_storage = true
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
disable_paste_burst: false,
|
||||
skip_reasoning_in_chat_completions: false,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -1219,7 +1238,6 @@ disable_response_storage = true
|
||||
approval_policy: AskForApproval::UnlessTrusted,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1247,7 +1265,6 @@ disable_response_storage = true
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
disable_paste_burst: false,
|
||||
skip_reasoning_in_chat_completions: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -1293,7 +1310,6 @@ disable_response_storage = true
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: true,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1321,7 +1337,6 @@ disable_response_storage = true
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
disable_paste_burst: false,
|
||||
skip_reasoning_in_chat_completions: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
@@ -1329,6 +1344,64 @@ disable_response_storage = true
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
|
||||
let fixture = create_test_fixture()?;
|
||||
|
||||
let gpt5_profile_overrides = ConfigOverrides {
|
||||
config_profile: Some("gpt5".to_string()),
|
||||
cwd: Some(fixture.cwd()),
|
||||
..Default::default()
|
||||
};
|
||||
let gpt5_profile_config = Config::load_from_base_config_with_overrides(
|
||||
fixture.cfg.clone(),
|
||||
gpt5_profile_overrides,
|
||||
fixture.codex_home(),
|
||||
)?;
|
||||
let expected_gpt5_profile_config = Config {
|
||||
model: "gpt-5".to_string(),
|
||||
model_family: find_family_for_model("gpt-5").expect("known model slug"),
|
||||
model_context_window: Some(272_000),
|
||||
model_max_output_tokens: Some(128_000),
|
||||
model_provider_id: "openai".to_string(),
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::High,
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
model_verbosity: Some(Verbosity::High),
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt5_profile_config, gpt5_profile_config);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_writes_explicit_tables() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config_types::Verbosity;
|
||||
use crate::protocol::AskForApproval;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
|
||||
/// Collection of common configuration options that a user can define as a unit
|
||||
/// in `config.toml`.
|
||||
@@ -15,10 +15,23 @@ pub struct ConfigProfile {
|
||||
/// [`ModelProviderInfo`] to use.
|
||||
pub model_provider: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl From<ConfigProfile> for codex_protocol::mcp_protocol::Profile {
|
||||
fn from(config_profile: ConfigProfile) -> Self {
|
||||
Self {
|
||||
model: config_profile.model,
|
||||
model_provider: config_profile.model_provider,
|
||||
approval_policy: config_profile.approval_policy,
|
||||
model_reasoning_effort: config_profile.model_reasoning_effort,
|
||||
model_reasoning_summary: config_profile.model_reasoning_summary,
|
||||
model_verbosity: config_profile.model_verbosity,
|
||||
chatgpt_base_url: config_profile.chatgpt_base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ use std::path::PathBuf;
|
||||
use wildmatch::WildMatchPattern;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct McpServerConfig {
|
||||
@@ -20,6 +18,10 @@ pub struct McpServerConfig {
|
||||
|
||||
#[serde(default)]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
|
||||
/// Startup timeout in milliseconds for initializing MCP server & initially listing tools.
|
||||
#[serde(default)]
|
||||
pub startup_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||
@@ -90,6 +92,17 @@ pub struct SandboxWorkspaceWrite {
|
||||
pub exclude_slash_tmp: bool,
|
||||
}
|
||||
|
||||
impl From<SandboxWorkspaceWrite> for codex_protocol::mcp_protocol::SandboxSettings {
|
||||
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
|
||||
Self {
|
||||
writable_roots: sandbox_workspace_write.writable_roots,
|
||||
network_access: Some(sandbox_workspace_write.network_access),
|
||||
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
|
||||
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ShellEnvironmentPolicyInherit {
|
||||
@@ -186,42 +199,10 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningEffort {
|
||||
Low,
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReasoningSummaryFormat {
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
/// Option to disable reasoning.
|
||||
None,
|
||||
}
|
||||
|
||||
/// A summary of the reasoning performed by the model. This can be useful for
|
||||
/// debugging and understanding the model's reasoning process.
|
||||
/// See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#reasoning-summaries
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningSummary {
|
||||
#[default]
|
||||
Auto,
|
||||
Concise,
|
||||
Detailed,
|
||||
/// Option to disable reasoning summaries.
|
||||
None,
|
||||
}
|
||||
|
||||
/// Controls output length/detail on GPT-5 models via the Responses API.
|
||||
/// Serialized with lowercase values to match the OpenAI API.
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum Verbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
Experimental,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::CodexSpawnOk;
|
||||
use crate::codex::INITIAL_SUBMIT_ID;
|
||||
@@ -16,12 +10,69 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutItem;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::recorder::RolloutItemSliceExt;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResumedHistory {
|
||||
pub conversation_id: ConversationId,
|
||||
pub history: Vec<RolloutItem>,
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InitialHistory {
|
||||
New,
|
||||
Resumed(ResumedHistory),
|
||||
Forked(Vec<ResponseItem>),
|
||||
}
|
||||
|
||||
impl PartialEq for InitialHistory {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(InitialHistory::New, InitialHistory::New) => true,
|
||||
(InitialHistory::Forked(a), InitialHistory::Forked(b)) => a == b,
|
||||
(InitialHistory::Resumed(_), InitialHistory::Resumed(_)) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InitialHistory {
|
||||
/// Return all response items contained in this initial history.
|
||||
pub fn get_response_items(&self) -> Vec<ResponseItem> {
|
||||
match self {
|
||||
InitialHistory::New => Vec::new(),
|
||||
InitialHistory::Forked(_) => Vec::new(),
|
||||
InitialHistory::Resumed(items) => {
|
||||
<[_] as RolloutItemSliceExt>::get_response_items(items.history.as_slice())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all events contained in this initial history.
|
||||
pub fn get_events(&self) -> Vec<crate::protocol::EventMsg> {
|
||||
match self {
|
||||
InitialHistory::New => Vec::new(),
|
||||
InitialHistory::Forked(_) => Vec::new(),
|
||||
InitialHistory::Resumed(items) => {
|
||||
<[_] as RolloutItemSliceExt>::get_events(items.history.as_slice())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
pub conversation_id: Uuid,
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
@@ -29,7 +80,7 @@ pub struct NewConversation {
|
||||
/// [`ConversationManager`] is responsible for creating conversations and
|
||||
/// maintaining them in memory.
|
||||
pub struct ConversationManager {
|
||||
conversations: Arc<RwLock<HashMap<Uuid, Arc<CodexConversation>>>>,
|
||||
conversations: Arc<RwLock<HashMap<ConversationId, Arc<CodexConversation>>>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
@@ -44,7 +95,7 @@ impl ConversationManager {
|
||||
/// Construct with a dummy AuthManager containing the provided CodexAuth.
|
||||
/// Used for integration tests: should not be used by ordinary business logic.
|
||||
pub fn with_auth(auth: CodexAuth) -> Self {
|
||||
Self::new(codex_login::AuthManager::from_auth_for_testing(auth))
|
||||
Self::new(crate::AuthManager::from_auth_for_testing(auth))
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
@@ -57,20 +108,27 @@ impl ConversationManager {
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = {
|
||||
let initial_history = None;
|
||||
Codex::spawn(config, auth_manager, initial_history).await?
|
||||
};
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
// TO BE REFACTORED: use the config experimental_resume field until we have a mainstream way.
|
||||
if let Some(resume_path) = config.experimental_resume.as_ref() {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(resume_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
} else {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, InitialHistory::New).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn finalize_spawn(
|
||||
&self,
|
||||
codex: Codex,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
@@ -101,7 +159,7 @@ impl ConversationManager {
|
||||
|
||||
pub async fn get_conversation(
|
||||
&self,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<Arc<CodexConversation>> {
|
||||
let conversations = self.conversations.read().await;
|
||||
conversations
|
||||
@@ -110,7 +168,21 @@ impl ConversationManager {
|
||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
||||
pub async fn resume_conversation_from_rollout(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: ConversationId) {
|
||||
self.conversations.write().await.remove(&conversation_id);
|
||||
}
|
||||
|
||||
@@ -120,30 +192,52 @@ impl ConversationManager {
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
conversation_history: Vec<ResponseItem>,
|
||||
base_rollout_path: PathBuf,
|
||||
_base_conversation_id: ConversationId,
|
||||
num_messages_to_drop: usize,
|
||||
config: Config,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let truncated_history =
|
||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||
// Read prior responses from the rollout file (tolerate both tagged and legacy formats).
|
||||
let text = tokio::fs::read_to_string(&base_rollout_path)
|
||||
.await
|
||||
.map_err(|e| CodexErr::Io(std::io::Error::other(format!("read rollout: {e}"))))?;
|
||||
let mut responses: Vec<ResponseItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
// Only consider response items (legacy lines have no record_type)
|
||||
match v.get("record_type").and_then(|s| s.as_str()) {
|
||||
Some("response") | None => {
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v) {
|
||||
responses.push(item);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let kept = truncate_after_dropping_last_messages(responses, num_messages_to_drop);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, Some(truncated_history)).await?;
|
||||
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, kept).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> Vec<ResponseItem> {
|
||||
if n == 0 || items.is_empty() {
|
||||
return items;
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return InitialHistory::Forked(items);
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
@@ -161,11 +255,11 @@ fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) ->
|
||||
}
|
||||
}
|
||||
}
|
||||
if count < n {
|
||||
// If fewer than n messages exist, drop everything.
|
||||
Vec::new()
|
||||
if cut_index == 0 {
|
||||
// No prefix remains after dropping; start a new conversation.
|
||||
InitialHistory::New
|
||||
} else {
|
||||
items.into_iter().take(cut_index).collect()
|
||||
InitialHistory::Forked(items.into_iter().take(cut_index).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,10 +317,10 @@ mod tests {
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
assert_eq!(
|
||||
truncated,
|
||||
vec![items[0].clone(), items[1].clone(), items[2].clone()]
|
||||
InitialHistory::Forked(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert!(truncated2.is_empty());
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
}
|
||||
|
||||
106
codex-rs/core/src/default_client.rs
Normal file
106
codex-rs/core/src/default_client.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
|
||||
pub fn get_codex_user_agent(originator: Option<&str>) -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator.unwrap_or(DEFAULT_ORIGINATOR),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a reqwest client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client(originator: &str) -> reqwest::Client {
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
let originator_value = HeaderValue::from_str(originator)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_ORIGINATOR));
|
||||
headers.insert("originator", originator_value);
|
||||
let ua = get_codex_user_agent(Some(originator));
|
||||
|
||||
match reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
{
|
||||
Ok(client) => client,
|
||||
Err(_) => reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_client_sets_default_headers() {
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let originator = "test_originator";
|
||||
let client = create_client(originator);
|
||||
|
||||
// Spin up a local mock server and capture a request.
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let resp = client
|
||||
.get(server.uri())
|
||||
.send()
|
||||
.await
|
||||
.expect("failed to send request");
|
||||
assert!(resp.status().is_success());
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch received requests");
|
||||
assert!(!requests.is_empty());
|
||||
let headers = &requests[0].headers;
|
||||
|
||||
// originator header is set to the provided value
|
||||
let originator_header = headers
|
||||
.get("originator")
|
||||
.expect("originator header missing");
|
||||
assert_eq!(originator_header.to_str().unwrap(), originator);
|
||||
|
||||
// User-Agent matches the computed Codex UA for that originator
|
||||
let expected_ua = get_codex_user_agent(Some(originator));
|
||||
let ua_header = headers
|
||||
.get("user-agent")
|
||||
.expect("user-agent header missing");
|
||||
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(re.is_match(&user_agent));
|
||||
}
|
||||
}
|
||||
@@ -8,12 +8,10 @@ use crate::shell::Shell;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_CLOSE_TAG;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>";
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_END: &str = "</environment_context>";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
@@ -79,7 +77,7 @@ impl EnvironmentContext {
|
||||
/// </environment_context>
|
||||
/// ```
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_START.to_string()];
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()];
|
||||
if let Some(cwd) = self.cwd {
|
||||
lines.push(format!(" <cwd>{}</cwd>", cwd.to_string_lossy()));
|
||||
}
|
||||
@@ -101,7 +99,7 @@ impl EnvironmentContext {
|
||||
{
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
}
|
||||
lines.push(ENVIRONMENT_CONTEXT_END.to_string());
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
@@ -49,7 +49,7 @@ pub enum CodexErr {
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(Uuid),
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
#[error("session configured event was not the first event in the stream")]
|
||||
SessionConfiguredNotFirstEvent,
|
||||
|
||||
98
codex-rs/core/src/event_mapping.rs
Normal file
98
codex-rs/core/src/event_mapping.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AgentReasoningEvent;
|
||||
use crate::protocol::AgentReasoningRawContentEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::UserMessageEvent;
|
||||
use crate::protocol::WebSearchEndEvent;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
|
||||
/// Convert a `ResponseItem` into zero or more `EventMsg` values that the UI can render.
|
||||
///
|
||||
/// When `show_raw_agent_reasoning` is false, raw reasoning content events are omitted.
|
||||
pub(crate) fn map_response_item_to_event_messages(
|
||||
item: &ResponseItem,
|
||||
show_raw_agent_reasoning: bool,
|
||||
) -> Vec<EventMsg> {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Do not surface system messages as user events.
|
||||
if role == "system" {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let events: Vec<EventMsg> = content
|
||||
.iter()
|
||||
.filter_map(|content_item| match content_item {
|
||||
ContentItem::OutputText { text } => {
|
||||
Some(EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: text.clone(),
|
||||
}))
|
||||
}
|
||||
ContentItem::InputText { text } => {
|
||||
let trimmed = text.trim_start();
|
||||
let kind = if trimmed.starts_with("<environment_context>") {
|
||||
Some(InputMessageKind::EnvironmentContext)
|
||||
} else if trimmed.starts_with("<user_instructions>") {
|
||||
Some(InputMessageKind::UserInstructions)
|
||||
} else {
|
||||
Some(InputMessageKind::Plain)
|
||||
};
|
||||
Some(EventMsg::UserMessage(UserMessageEvent {
|
||||
message: text.clone(),
|
||||
kind,
|
||||
}))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::Reasoning {
|
||||
summary, content, ..
|
||||
} => {
|
||||
let mut events = Vec::new();
|
||||
for ReasoningItemReasoningSummary::SummaryText { text } in summary {
|
||||
events.push(EventMsg::AgentReasoning(AgentReasoningEvent {
|
||||
text: text.clone(),
|
||||
}));
|
||||
}
|
||||
if let Some(items) = content.as_ref().filter(|_| show_raw_agent_reasoning) {
|
||||
for c in items {
|
||||
let text = match c {
|
||||
ReasoningItemContent::ReasoningText { text }
|
||||
| ReasoningItemContent::Text { text } => text,
|
||||
};
|
||||
events.push(EventMsg::AgentReasoningRawContent(
|
||||
AgentReasoningRawContentEvent { text: text.clone() },
|
||||
));
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::WebSearchCall { id, action, .. } => match action {
|
||||
WebSearchAction::Search { query } => {
|
||||
let call_id = id.clone().unwrap_or_else(|| "".to_string());
|
||||
vec![EventMsg::WebSearchEnd(WebSearchEndEvent {
|
||||
call_id,
|
||||
query: query.clone(),
|
||||
})]
|
||||
}
|
||||
WebSearchAction::Other => Vec::new(),
|
||||
},
|
||||
|
||||
// Variants that require side effects are handled by higher layers and do not emit events here.
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Other => Vec::new(),
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,6 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use serde_bytes::ByteBuf;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
@@ -369,7 +368,7 @@ async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk: ByteBuf::from(chunk),
|
||||
chunk,
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
|
||||
@@ -6,12 +6,14 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod apply_patch;
|
||||
mod bash;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
pub mod config;
|
||||
pub mod config_profile;
|
||||
@@ -32,14 +34,20 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
pub use auth::AuthManager;
|
||||
pub use auth::CodexAuth;
|
||||
pub mod default_client;
|
||||
pub mod model_family;
|
||||
mod openai_model_info;
|
||||
mod openai_tools;
|
||||
@@ -53,7 +61,11 @@ pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tool_apply_patch;
|
||||
pub mod turn_diff_tracker;
|
||||
pub mod user_agent;
|
||||
pub use rollout::RolloutRecorder;
|
||||
pub use rollout::SessionMeta;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
@@ -64,3 +76,14 @@ pub use codex_protocol::protocol;
|
||||
// Re-export protocol config enums to ensure call sites can use the same types
|
||||
// as those in the protocol crate when constructing protocol messages.
|
||||
pub use codex_protocol::config_types as protocol_config_types;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::ResponseEvent;
|
||||
pub use client_common::ResponseStream;
|
||||
pub use codex_protocol::models::ContentItem;
|
||||
pub use codex_protocol::models::LocalShellAction;
|
||||
pub use codex_protocol::models::LocalShellExecAction;
|
||||
pub use codex_protocol::models::LocalShellStatus;
|
||||
pub use codex_protocol::models::ReasoningItemContent;
|
||||
pub use codex_protocol::models::ResponseItem;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -36,8 +37,8 @@ use crate::config_types::McpServerConfig;
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default timeout for initializing MCP server & initially listing tools.
|
||||
const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Map that holds a startup error for every MCP server that could **not** be
|
||||
/// spawned successfully.
|
||||
@@ -81,6 +82,11 @@ struct ToolInfo {
|
||||
tool: Tool,
|
||||
}
|
||||
|
||||
struct ManagedClient {
|
||||
client: Arc<McpClient>,
|
||||
startup_timeout: Duration,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
@@ -88,7 +94,7 @@ pub(crate) struct McpConnectionManager {
|
||||
///
|
||||
/// The server name originates from the keys of the `mcp_servers` map in
|
||||
/// the user configuration.
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
clients: HashMap<String, ManagedClient>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
@@ -126,8 +132,15 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg
|
||||
.startup_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { command, args, env } = cfg;
|
||||
let McpServerConfig {
|
||||
command, args, env, ..
|
||||
} = cfg;
|
||||
let client_res = McpClient::new_stdio_client(
|
||||
command.into(),
|
||||
args.into_iter().map(OsString::from).collect(),
|
||||
@@ -154,12 +167,15 @@ impl McpConnectionManager {
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
let initialize_notification_params = None;
|
||||
let timeout = Some(Duration::from_secs(10));
|
||||
match client
|
||||
.initialize(params, initialize_notification_params, timeout)
|
||||
.initialize(
|
||||
params,
|
||||
initialize_notification_params,
|
||||
Some(startup_timeout),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_response) => (server_name, Ok(client)),
|
||||
Ok(_response) => (server_name, Ok((client, startup_timeout))),
|
||||
Err(e) => (server_name, Err(e)),
|
||||
}
|
||||
}
|
||||
@@ -168,15 +184,26 @@ impl McpConnectionManager {
|
||||
});
|
||||
}
|
||||
|
||||
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
|
||||
HashMap::with_capacity(join_set.len());
|
||||
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
let (server_name, client_res) = res?; // JoinError propagation
|
||||
let (server_name, client_res) = match res {
|
||||
Ok((server_name, client_res)) => (server_name, client_res),
|
||||
Err(e) => {
|
||||
warn!("Task panic when starting MCP server: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match client_res {
|
||||
Ok(client) => {
|
||||
clients.insert(server_name, std::sync::Arc::new(client));
|
||||
Ok((client, startup_timeout)) => {
|
||||
clients.insert(
|
||||
server_name,
|
||||
ManagedClient {
|
||||
client: Arc::new(client),
|
||||
startup_timeout,
|
||||
},
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(server_name, e);
|
||||
@@ -184,7 +211,13 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
let all_tools = match list_all_tools(&clients).await {
|
||||
Ok(tools) => tools,
|
||||
Err(e) => {
|
||||
warn!("Failed to list tools from some MCP servers: {e:#}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
|
||||
@@ -212,6 +245,7 @@ impl McpConnectionManager {
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.client
|
||||
.clone();
|
||||
|
||||
client
|
||||
@@ -229,21 +263,18 @@ impl McpConnectionManager {
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
// keeps the overall latency roughly at the slowest server instead of
|
||||
// the cumulative latency.
|
||||
for (server_name, client) in clients {
|
||||
for (server_name, managed_client) in clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = client.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let startup_timeout = managed_client.startup_timeout;
|
||||
join_set.spawn(async move {
|
||||
let res = client_clone
|
||||
.list_tools(None, Some(LIST_TOOLS_TIMEOUT))
|
||||
.await;
|
||||
let res = client_clone.list_tools(None, Some(startup_timeout)).await;
|
||||
(server_name_cloned, res)
|
||||
});
|
||||
}
|
||||
@@ -251,8 +282,19 @@ async fn list_all_tools(
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
let (server_name, list_result) = if let Ok(result) = join_res {
|
||||
result
|
||||
} else {
|
||||
warn!("Task panic when listing tools for MCP server: {join_res:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let list_result = if let Ok(result) = list_result {
|
||||
result
|
||||
} else {
|
||||
warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
//! JSON-Lines tooling. Each record has the following schema:
|
||||
//!
|
||||
//! ````text
|
||||
//! {"session_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! {"conversation_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! ````
|
||||
//!
|
||||
//! To minimise the chance of interleaved writes when multiple processes are
|
||||
@@ -22,14 +22,15 @@ use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config_types::HistoryPersistence;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
#[cfg(unix)]
|
||||
@@ -54,10 +55,14 @@ fn history_filepath(config: &Config) -> PathBuf {
|
||||
path
|
||||
}
|
||||
|
||||
/// Append a `text` entry associated with `session_id` to the history file. Uses
|
||||
/// Append a `text` entry associated with `conversation_id` to the history file. Uses
|
||||
/// advisory file locking to ensure that concurrent writes do not interleave,
|
||||
/// which entails a small amount of blocking I/O internally.
|
||||
pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config) -> Result<()> {
|
||||
pub(crate) async fn append_entry(
|
||||
text: &str,
|
||||
conversation_id: &ConversationId,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
match config.history.persistence {
|
||||
HistoryPersistence::SaveAll => {
|
||||
// Save everything: proceed.
|
||||
@@ -84,7 +89,7 @@ pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config)
|
||||
|
||||
// Construct the JSON line first so we can write it in a single syscall.
|
||||
let entry = HistoryEntry {
|
||||
session_id: session_id.to_string(),
|
||||
session_id: conversation_id.to_string(),
|
||||
ts,
|
||||
text: text.to_string(),
|
||||
};
|
||||
@@ -105,47 +110,34 @@ pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config)
|
||||
// Ensure permissions.
|
||||
ensure_owner_only_permissions(&history_file).await?;
|
||||
|
||||
// Lock file.
|
||||
acquire_exclusive_lock_with_retry(&history_file).await?;
|
||||
|
||||
// We use sync I/O with spawn_blocking() because we are using a
|
||||
// [`std::fs::File`] instead of a [`tokio::fs::File`] to leverage an
|
||||
// advisory file locking API that is not available in the async API.
|
||||
// Perform a blocking write under an advisory write lock using std::fs.
|
||||
tokio::task::spawn_blocking(move || -> Result<()> {
|
||||
history_file.write_all(line.as_bytes())?;
|
||||
history_file.flush()?;
|
||||
Ok(())
|
||||
// Retry a few times to avoid indefinite blocking when contended.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match history_file.try_lock() {
|
||||
Ok(()) => {
|
||||
// While holding the exclusive lock, write the full line.
|
||||
history_file.write_all(line.as_bytes())?;
|
||||
history_file.flush()?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire exclusive lock on history file after multiple attempts",
|
||||
))
|
||||
})
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Attempt to acquire an exclusive advisory lock on `file`, retrying up to 10
|
||||
/// times if the lock is currently held by another process. This prevents a
|
||||
/// potential indefinite wait while still giving other writers some time to
|
||||
/// finish their operation.
|
||||
async fn acquire_exclusive_lock_with_retry(file: &File) -> Result<()> {
|
||||
use tokio::time::sleep;
|
||||
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match file.try_lock() {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => match e {
|
||||
std::fs::TryLockError::WouldBlock => {
|
||||
sleep(RETRY_SLEEP).await;
|
||||
}
|
||||
other => return Err(other.into()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire exclusive lock on history file after multiple attempts",
|
||||
))
|
||||
}
|
||||
|
||||
/// Asynchronously fetch the history file's *identifier* (inode on Unix) and
|
||||
/// the current number of entries by counting newline characters.
|
||||
pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
|
||||
@@ -221,29 +213,42 @@ pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<Hist
|
||||
return None;
|
||||
}
|
||||
|
||||
// Open & lock file for reading.
|
||||
if let Err(e) = acquire_shared_lock_with_retry(&file) {
|
||||
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
|
||||
return None;
|
||||
}
|
||||
// Open & lock file for reading using a shared lock.
|
||||
// Retry a few times to avoid indefinite blocking.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
let lock_result = file.try_lock_shared();
|
||||
|
||||
let reader = BufReader::new(&file);
|
||||
for (idx, line_res) in reader.lines().enumerate() {
|
||||
let line = match line_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read line from history file");
|
||||
match lock_result {
|
||||
Ok(()) => {
|
||||
let reader = BufReader::new(&file);
|
||||
for (idx, line_res) in reader.lines().enumerate() {
|
||||
let line = match line_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read line from history file");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if idx == offset {
|
||||
match serde_json::from_str::<HistoryEntry>(&line) {
|
||||
Ok(entry) => return Some(entry),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to parse history entry");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not found at requested offset.
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if idx == offset {
|
||||
match serde_json::from_str::<HistoryEntry>(&line) {
|
||||
Ok(entry) => return Some(entry),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to parse history entry");
|
||||
return None;
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,26 +263,6 @@ pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<Hist
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn acquire_shared_lock_with_retry(file: &File) -> Result<()> {
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match file.try_lock_shared() {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => match e {
|
||||
std::fs::TryLockError::WouldBlock => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
other => return Err(other.into()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire shared lock on history file after multiple attempts",
|
||||
))
|
||||
}
|
||||
|
||||
/// On Unix systems ensure the file permissions are `0o600` (rw-------). If the
|
||||
/// permissions cannot be changed the error is propagated to the caller.
|
||||
#[cfg(unix)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
|
||||
/// A model family is a group of models that share certain characteristics.
|
||||
@@ -20,6 +21,9 @@ pub struct ModelFamily {
|
||||
// `summary` is optional).
|
||||
pub supports_reasoning_summaries: bool,
|
||||
|
||||
// Define if we need a special handling of reasoning summary
|
||||
pub reasoning_summary_format: ReasoningSummaryFormat,
|
||||
|
||||
// This should be set to true when the model expects a tool named
|
||||
// "local_shell" to be provided. Its contract must be understood natively by
|
||||
// the model such that its description can be omitted.
|
||||
@@ -41,6 +45,7 @@ macro_rules! model_family {
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
};
|
||||
@@ -61,6 +66,7 @@ macro_rules! simple_model_family {
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
})
|
||||
@@ -90,6 +96,7 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4.1") {
|
||||
model_family!(
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use crate::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -79,12 +79,12 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
}),
|
||||
|
||||
"gpt-5" => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
_ if slug.starts_with("codex-") => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ pub(crate) enum OpenAiTool {
|
||||
LocalShell {},
|
||||
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||
#[serde(rename = "web_search_preview")]
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
@@ -240,15 +240,17 @@ fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
writable_roots,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands will require escalated privileges:
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files outside the current directory, and protected folders like .git or .env{}
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
{}{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
@@ -257,8 +259,9 @@ The shell tool is used to execute shell commands.
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
writable_roots.iter().map(|wr| format!(" - {}", wr.to_string_lossy())).collect::<Vec<String>>().join("\n"),
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
@@ -270,9 +273,8 @@ The shell tool is used to execute shell commands.
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
@@ -1081,4 +1083,84 @@ mod tests {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_workspace_write() {
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec!["workspace".into()],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
let tool = super::create_shell_tool_for_sandbox(&sandbox_policy);
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
else {
|
||||
panic!("expected function tool");
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
- workspace
|
||||
- Commands that require network access
|
||||
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_readonly() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::ReadOnly);
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
else {
|
||||
panic!("expected function tool");
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_danger_full_access() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::DangerFullAccess);
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
else {
|
||||
panic!("expected function tool");
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
assert_eq!(description, "Runs a shell command and returns its output.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,6 +235,28 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_cd_then_bar_is_same_as_bar() {
|
||||
// Ensure a leading `cd` inside bash -lc is dropped when followed by another command.
|
||||
assert_parsed(
|
||||
&shlex_split_safe("bash -lc 'cd foo && bar'"),
|
||||
vec![ParsedCommand::Unknown {
|
||||
cmd: "bar".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_cd_then_cat_is_read() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("bash -lc 'cd foo && cat foo.txt'"),
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_ls_with_pipe() {
|
||||
let inner = "ls -la | sed -n '1,120p'";
|
||||
@@ -1149,6 +1171,10 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
.collect();
|
||||
if commands.len() > 1 {
|
||||
commands.retain(|pc| !matches!(pc, ParsedCommand::Unknown { cmd } if cmd == "true"));
|
||||
// Apply the same simplifications used for non-bash parsing, e.g., drop leading `cd`.
|
||||
while let Some(next) = simplify_once(&commands) {
|
||||
commands = next;
|
||||
}
|
||||
}
|
||||
if commands.len() == 1 {
|
||||
// If we reduced to a single command, attribute the full original script
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(
|
||||
config: &Config,
|
||||
uuid: Uuid,
|
||||
instructions: Option<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
|
||||
Ok(Self { tx })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
path: &Path,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => items.push(item),
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
None,
|
||||
cwd,
|
||||
));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
session_id: Uuid,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let date_str = timestamp
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{session_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)?;
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file
|
||||
writer.write_line(&session_meta_with_git).await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => {
|
||||
writer.write_line(&item).await?;
|
||||
}
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
writer
|
||||
.write_line(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
let _ = self.file.write_all(json.as_bytes()).await;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
335
codex-rs/core/src/rollout/list.rs
Normal file
335
codex-rs/core/src/rollout/list.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::io::{self};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use time::OffsetDateTime;
|
||||
use time::PrimitiveDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::recorder::SessionMetaWithGit;
|
||||
|
||||
/// Returned page of conversation summaries.
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct ConversationsPage {
|
||||
/// Conversation summaries ordered newest first.
|
||||
pub items: Vec<ConversationItem>,
|
||||
/// Opaque pagination token to resume after the last item, or `None` if end.
|
||||
pub next_cursor: Option<Cursor>,
|
||||
/// Total number of files touched while scanning this request.
|
||||
pub num_scanned_files: usize,
|
||||
/// True if a hard scan cap was hit; consider resuming with `next_cursor`.
|
||||
pub reached_scan_cap: bool,
|
||||
}
|
||||
|
||||
/// Summary information for a conversation rollout file.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ConversationItem {
|
||||
/// Absolute path to the rollout file.
|
||||
pub path: PathBuf,
|
||||
/// First up to 5 JSONL records parsed as JSON (includes meta line).
|
||||
pub head: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Hard cap to bound worst‑case work per request.
|
||||
const MAX_SCAN_FILES: usize = 10_000;
|
||||
const HEAD_RECORD_LIMIT: usize = 10;
|
||||
|
||||
/// Pagination cursor identifying a file by timestamp and UUID.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Cursor {
|
||||
ts: OffsetDateTime,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl Cursor {
|
||||
fn new(ts: OffsetDateTime, id: Uuid) -> Self {
|
||||
Self { ts, id }
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for Cursor {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let ts_str = self
|
||||
.ts
|
||||
.format(&format_description!(
|
||||
"[year]-[month]-[day]T[hour]-[minute]-[second]"
|
||||
))
|
||||
.map_err(|e| serde::ser::Error::custom(format!("format error: {e}")))?;
|
||||
serializer.serialize_str(&format!("{ts_str}|{}", self.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for Cursor {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
parse_cursor(&s).ok_or_else(|| serde::de::Error::custom("invalid cursor"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve recorded conversation file paths with token pagination. The returned `next_cursor`
|
||||
/// can be supplied on the next call to resume after the last returned item, resilient to
|
||||
/// concurrent new sessions being appended. Ordering is stable by timestamp desc, then UUID desc.
|
||||
pub(crate) async fn get_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut root = codex_home.to_path_buf();
|
||||
root.push(SESSIONS_SUBDIR);
|
||||
|
||||
if !root.exists() {
|
||||
return Ok(ConversationsPage {
|
||||
items: Vec::new(),
|
||||
next_cursor: None,
|
||||
num_scanned_files: 0,
|
||||
reached_scan_cap: false,
|
||||
});
|
||||
}
|
||||
|
||||
let anchor = cursor.cloned();
|
||||
|
||||
let result = traverse_directories_for_paths(root.clone(), page_size, anchor).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Load the full contents of a single conversation session file at `path`.
|
||||
/// Returns the entire file contents as a String.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn get_conversation(path: &Path) -> io::Result<String> {
|
||||
tokio::fs::read_to_string(path).await
|
||||
}
|
||||
|
||||
/// Load conversation file paths from disk using directory traversal.
|
||||
///
|
||||
/// Directory layout: `~/.codex/sessions/YYYY/MM/DD/rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl`
|
||||
/// Returned newest (latest) first.
|
||||
async fn traverse_directories_for_paths(
|
||||
root: PathBuf,
|
||||
page_size: usize,
|
||||
anchor: Option<Cursor>,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut items: Vec<ConversationItem> = Vec::with_capacity(page_size);
|
||||
let mut scanned_files = 0usize;
|
||||
let mut anchor_passed = anchor.is_none();
|
||||
let (anchor_ts, anchor_id) = match anchor {
|
||||
Some(c) => (c.ts, c.id),
|
||||
None => (OffsetDateTime::UNIX_EPOCH, Uuid::nil()),
|
||||
};
|
||||
|
||||
let year_dirs = collect_dirs_desc(&root, |s| s.parse::<u16>().ok()).await?;
|
||||
|
||||
'outer: for (_year, year_path) in year_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break;
|
||||
}
|
||||
let month_dirs = collect_dirs_desc(year_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_month, month_path) in month_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let day_dirs = collect_dirs_desc(month_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_day, day_path) in day_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let mut day_files = collect_files(day_path, |name_str, path| {
|
||||
if !name_str.starts_with("rollout-") || !name_str.ends_with(".jsonl") {
|
||||
return None;
|
||||
}
|
||||
|
||||
parse_timestamp_uuid_from_filename(name_str)
|
||||
.map(|(ts, id)| (ts, id, name_str.to_string(), path.to_path_buf()))
|
||||
})
|
||||
.await?;
|
||||
// Stable ordering within the same second: (timestamp desc, uuid desc)
|
||||
day_files.sort_by_key(|(ts, sid, _name_str, _path)| (Reverse(*ts), Reverse(*sid)));
|
||||
for (ts, sid, _name_str, path) in day_files.into_iter() {
|
||||
scanned_files += 1;
|
||||
if scanned_files >= MAX_SCAN_FILES && items.len() >= page_size {
|
||||
break 'outer;
|
||||
}
|
||||
if !anchor_passed {
|
||||
if ts < anchor_ts || (ts == anchor_ts && sid < anchor_id) {
|
||||
anchor_passed = true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if items.len() == page_size {
|
||||
break 'outer;
|
||||
}
|
||||
let head = read_first_jsonl_records(&path, HEAD_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if should_include_session(&head) {
|
||||
items.push(ConversationItem { path, head });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let next = build_next_cursor(&items);
|
||||
Ok(ConversationsPage {
|
||||
items,
|
||||
next_cursor: next,
|
||||
num_scanned_files: scanned_files,
|
||||
reached_scan_cap: scanned_files >= MAX_SCAN_FILES,
|
||||
})
|
||||
}
|
||||
|
||||
/// Pagination cursor token format: "<file_ts>|<uuid>" where `file_ts` matches the
|
||||
/// filename timestamp portion (YYYY-MM-DDThh-mm-ss) used in rollout filenames.
|
||||
/// The cursor orders files by timestamp desc, then UUID desc.
|
||||
fn parse_cursor(token: &str) -> Option<Cursor> {
|
||||
let (file_ts, uuid_str) = token.split_once('|')?;
|
||||
|
||||
let Ok(uuid) = Uuid::parse_str(uuid_str) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
|
||||
|
||||
Some(Cursor::new(ts, uuid))
|
||||
}
|
||||
|
||||
fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
|
||||
let last = items.last()?;
|
||||
let file_name = last.path.file_name()?.to_string_lossy();
|
||||
let (ts, id) = parse_timestamp_uuid_from_filename(&file_name)?;
|
||||
Some(Cursor::new(ts, id))
|
||||
}
|
||||
|
||||
/// Collects immediate subdirectories of `parent`, parses their (string) names with `parse`,
|
||||
/// and returns them sorted descending by the parsed key.
|
||||
async fn collect_dirs_desc<T, F>(parent: &Path, parse: F) -> io::Result<Vec<(T, PathBuf)>>
|
||||
where
|
||||
T: Ord + Copy,
|
||||
F: Fn(&str) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut vec: Vec<(T, PathBuf)> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_dir())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s)
|
||||
{
|
||||
vec.push((v, entry.path()));
|
||||
}
|
||||
}
|
||||
vec.sort_by_key(|(v, _)| Reverse(*v));
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
/// Collects files in a directory and parses them with `parse`.
|
||||
async fn collect_files<T, F>(parent: &Path, parse: F) -> io::Result<Vec<T>>
|
||||
where
|
||||
F: Fn(&str, &Path) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut collected: Vec<T> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_file())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s, &entry.path())
|
||||
{
|
||||
collected.push(v);
|
||||
}
|
||||
}
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> {
|
||||
// Expected: rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl
|
||||
let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?;
|
||||
|
||||
// Scan from the right for a '-' such that the suffix parses as a UUID.
|
||||
let (sep_idx, uuid) = core
|
||||
.match_indices('-')
|
||||
.rev()
|
||||
.find_map(|(i, _)| Uuid::parse_str(&core[i + 1..]).ok().map(|u| (i, u)))?;
|
||||
|
||||
let ts_str = &core[..sep_idx];
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(ts_str, format).ok()?.assume_utc();
|
||||
Some((ts, uuid))
|
||||
}
|
||||
|
||||
async fn read_first_jsonl_records(
|
||||
path: &Path,
|
||||
max_records: usize,
|
||||
) -> io::Result<Vec<serde_json::Value>> {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let file = tokio::fs::File::open(path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut head: Vec<serde_json::Value> = Vec::new();
|
||||
while head.len() < max_records {
|
||||
let line_opt = lines.next_line().await?;
|
||||
let Some(line) = line_opt else { break };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
head.push(v);
|
||||
}
|
||||
}
|
||||
Ok(head)
|
||||
}
|
||||
|
||||
/// Return true if this conversation should be included in the listing.
|
||||
///
|
||||
/// Current rule: include only when the first JSON object is a session meta record
|
||||
/// (i.e., has `{"record_type": "session_meta", ...}`), which is how rollout
|
||||
/// files are written. Empty or malformed heads are excluded.
|
||||
fn should_include_session(head: &[serde_json::Value]) -> bool {
|
||||
let Some(first) = head.first() else {
|
||||
return false;
|
||||
};
|
||||
passes_session_meta_filter(first)
|
||||
}
|
||||
|
||||
/// Validate that the first record is a fully‑formed session meta line.
|
||||
///
|
||||
/// Requirements:
|
||||
/// - `record_type == "session_meta"`
|
||||
/// - Remaining fields (after removing `record_type`) deserialize into
|
||||
/// `SessionMetaWithGit`.
|
||||
fn passes_session_meta_filter(first: &serde_json::Value) -> bool {
|
||||
let Some(obj) = first.as_object() else {
|
||||
return false;
|
||||
};
|
||||
let record_type = obj.get("record_type").and_then(|v| v.as_str());
|
||||
if record_type != Some("session_meta") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the marker field and validate the remainder matches SessionMetaWithGit
|
||||
let mut cleaned = obj.clone();
|
||||
cleaned.remove("record_type");
|
||||
let val = serde_json::Value::Object(cleaned);
|
||||
serde_json::from_value::<SessionMetaWithGit>(val).is_ok()
|
||||
}
|
||||
15
codex-rs/core/src/rollout/mod.rs
Normal file
15
codex-rs/core/src/rollout/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Rollout module: persistence and discovery of session rollout files.
|
||||
|
||||
pub(crate) const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use recorder::RolloutItem;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::SessionMeta;
|
||||
pub use recorder::SessionStateSnapshot;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
57
codex-rs/core/src/rollout/policy.rs
Normal file
57
codex-rs/core/src/rollout/policy.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => true,
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_persisted_event(event: &Event) -> bool {
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::Error(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::SessionConfigured(_) => false,
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_) => true,
|
||||
}
|
||||
}
|
||||
504
codex-rs/core/src/rollout/recorder.rs
Normal file
504
codex-rs/core/src/rollout/recorder.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::list::ConversationsPage;
|
||||
use super::list::Cursor;
|
||||
use super::list::get_conversations;
|
||||
use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::conversation_manager::ResumedHistory;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::git_info::collect_git_info;
|
||||
use crate::rollout::policy::is_persisted_event;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default, Debug)]
|
||||
pub struct SessionMeta {
|
||||
pub id: ConversationId,
|
||||
pub timestamp: String,
|
||||
pub cwd: String,
|
||||
pub originator: String,
|
||||
pub cli_version: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
// SessionMetaWithGit is used in writes and reads; ensure it implements Debug.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: ConversationId,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "record_type", rename_all = "snake_case")]
|
||||
enum TaggedLine {
|
||||
Response {
|
||||
#[serde(flatten)]
|
||||
item: ResponseItem,
|
||||
},
|
||||
Event {
|
||||
#[serde(flatten)]
|
||||
event: Event,
|
||||
},
|
||||
SessionMeta {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMetaWithGit,
|
||||
},
|
||||
PrevSessionMeta {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMetaWithGit,
|
||||
},
|
||||
State {
|
||||
#[serde(flatten)]
|
||||
state: SessionStateSnapshot,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct TimestampedLine {
|
||||
timestamp: String,
|
||||
#[serde(flatten)]
|
||||
record: TaggedLine,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RolloutItem {
|
||||
ResponseItem(ResponseItem),
|
||||
Event(Event),
|
||||
SessionMeta(SessionMetaWithGit),
|
||||
}
|
||||
|
||||
impl From<ResponseItem> for RolloutItem {
|
||||
fn from(item: ResponseItem) -> Self {
|
||||
RolloutItem::ResponseItem(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Event> for RolloutItem {
|
||||
fn from(event: Event) -> Self {
|
||||
RolloutItem::Event(event)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience helpers to extract typed items from a list of rollout items.
|
||||
pub trait RolloutItemSliceExt {
|
||||
fn get_response_items(&self) -> Vec<ResponseItem>;
|
||||
fn get_events(&self) -> Vec<EventMsg>;
|
||||
}
|
||||
|
||||
impl RolloutItemSliceExt for [RolloutItem] {
|
||||
fn get_response_items(&self) -> Vec<ResponseItem> {
|
||||
self.iter()
|
||||
.filter_map(|it| match it {
|
||||
RolloutItem::ResponseItem(ri) => Some(ri.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_events(&self) -> Vec<EventMsg> {
|
||||
self.iter()
|
||||
.filter_map(|it| match it {
|
||||
RolloutItem::Event(ev) => Some(ev.msg.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddResponseItems(Vec<ResponseItem>),
|
||||
AddEvents(Vec<Event>),
|
||||
AddSessionMeta(SessionMetaWithGit),
|
||||
Flush { ack: oneshot::Sender<()> },
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// List conversations (rollout files) under the provided Codex home directory.
|
||||
pub async fn list_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
) -> std::io::Result<ConversationsPage> {
|
||||
get_conversations(codex_home, page_size, cursor).await
|
||||
}
|
||||
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
path,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z");
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let cwd = config.cwd.to_path_buf();
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
cwd: config.cwd.to_string_lossy().to_string(),
|
||||
originator: config.responses_originator_header.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
|
||||
Ok(Self { tx, path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, item: RolloutItem) -> std::io::Result<()> {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => self.record_response_item(&item).await,
|
||||
RolloutItem::Event(event) => self.record_event(&event).await,
|
||||
RolloutItem::SessionMeta(meta) => self.record_session_meta(&meta).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure all writes up to this point have been processed by the writer task.
|
||||
///
|
||||
/// This is a sequencing barrier for readers that plan to open and read the
|
||||
/// rollout file immediately after calling this method. The background writer
|
||||
/// processes the channel serially; when it dequeues `Flush`, all prior
|
||||
/// `AddResponseItems`/`AddEvents`/`AddSessionMeta` have already been written
|
||||
/// via `write_line`, which calls `file.flush()` (OS‐buffer flush).
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::Flush { ack: tx_done })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
async fn record_response_item(&self, item: &ResponseItem) -> std::io::Result<()> {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if !is_persisted_response_item(item) {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddResponseItems(vec![item.clone()]))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
async fn record_event(&self, event: &Event) -> std::io::Result<()> {
|
||||
if !is_persisted_event(event) {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddEvents(vec![event.clone()]))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout event: {e}")))
|
||||
}
|
||||
|
||||
async fn record_session_meta(&self, meta: &SessionMetaWithGit) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::AddSessionMeta(meta.clone()))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout session meta: {e}")))
|
||||
}
|
||||
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
tracing::error!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let first_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let conversation_id = if let Ok(TimestampedLine {
|
||||
record: TaggedLine::SessionMeta { meta },
|
||||
..
|
||||
}) = serde_json::from_str::<TimestampedLine>(first_line)
|
||||
{
|
||||
Some(meta.meta.id)
|
||||
} else if let Ok(meta) = serde_json::from_str::<SessionMetaWithGit>(first_line) {
|
||||
Some(meta.meta.id)
|
||||
} else if let Ok(meta) = serde_json::from_str::<SessionMeta>(first_line) {
|
||||
Some(meta.id)
|
||||
} else {
|
||||
return Err(IoError::other(
|
||||
"failed to parse first line of rollout file as SessionMeta",
|
||||
));
|
||||
};
|
||||
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_str::<TimestampedLine>(line) {
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::State { .. },
|
||||
..
|
||||
}) => {}
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::Event { event },
|
||||
..
|
||||
}) => items.push(RolloutItem::Event(event)),
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::SessionMeta { meta },
|
||||
..
|
||||
})
|
||||
| Ok(TimestampedLine {
|
||||
record: TaggedLine::PrevSessionMeta { meta },
|
||||
..
|
||||
}) => items.push(RolloutItem::SessionMeta(meta)),
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::Response { item },
|
||||
..
|
||||
}) => {
|
||||
if is_persisted_response_item(&item) {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
}
|
||||
Err(_) => warn!("failed to parse rollout line: {line}"),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(
|
||||
"Resumed rollout with {} items, conversation ID: {:?}",
|
||||
items.len(),
|
||||
conversation_id
|
||||
);
|
||||
let conversation_id = conversation_id
|
||||
.ok_or_else(|| IoError::other("failed to parse conversation ID from rollout file"))?;
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(InitialHistory::New);
|
||||
}
|
||||
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id,
|
||||
history: items,
|
||||
rollout_path: path.to_path_buf(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
conversation_id: ConversationId,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
|
||||
/// Full filesystem path to the rollout file.
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
fn create_log_file(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_utc();
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let date_str = timestamp
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)?;
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
conversation_id,
|
||||
timestamp,
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
// Write the SessionMeta as the first item in the file
|
||||
writer
|
||||
.write_tagged(TaggedLine::SessionMeta {
|
||||
meta: session_meta_with_git,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddResponseItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_tagged(TaggedLine::Response { item }).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::AddEvents(events) => {
|
||||
for event in events {
|
||||
writer.write_tagged(TaggedLine::Event { event }).await?;
|
||||
}
|
||||
}
|
||||
// Sequencing barrier: by the time we handle `Flush`, all previously
|
||||
// queued writes have been applied and flushed to OS buffers.
|
||||
RolloutCmd::Flush { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::AddSessionMeta(meta) => {
|
||||
writer
|
||||
.write_tagged(TaggedLine::PrevSessionMeta { meta })
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
self.file.write_all(json.as_bytes()).await?;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn write_tagged(&mut self, record: TaggedLine) -> std::io::Result<()> {
|
||||
let timestamp = time::OffsetDateTime::now_utc()
|
||||
.format(&time::format_description::well_known::Rfc3339)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
let line = TimestampedLine { timestamp, record };
|
||||
self.write_line(&line).await
|
||||
}
|
||||
}
|
||||
399
codex-rs/core/src/rollout/tests.rs
Normal file
399
codex-rs/core/src/rollout/tests.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use tempfile::TempDir;
|
||||
use time::OffsetDateTime;
|
||||
use time::PrimitiveDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rollout::list::ConversationItem;
|
||||
use crate::rollout::list::ConversationsPage;
|
||||
use crate::rollout::list::Cursor;
|
||||
use crate::rollout::list::get_conversation;
|
||||
use crate::rollout::list::get_conversations;
|
||||
|
||||
fn write_session_file(
|
||||
root: &Path,
|
||||
ts_str: &str,
|
||||
uuid: Uuid,
|
||||
num_records: usize,
|
||||
) -> std::io::Result<(OffsetDateTime, Uuid)> {
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let dt = PrimitiveDateTime::parse(ts_str, format)
|
||||
.unwrap()
|
||||
.assume_utc();
|
||||
let dir = root
|
||||
.join("sessions")
|
||||
.join(format!("{:04}", dt.year()))
|
||||
.join(format!("{:02}", u8::from(dt.month())))
|
||||
.join(format!("{:02}", dt.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
let filename = format!("rollout-{ts_str}-{uuid}.jsonl");
|
||||
let file_path = dir.join(filename);
|
||||
let mut file = File::create(file_path)?;
|
||||
|
||||
let meta = serde_json::json!({
|
||||
"record_type": "session_meta",
|
||||
"timestamp": ts_str,
|
||||
"id": uuid.to_string(),
|
||||
"cwd": "/",
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
});
|
||||
writeln!(file, "{meta}")?;
|
||||
|
||||
for i in 0..num_records {
|
||||
let rec = serde_json::json!({
|
||||
"record_type": "response",
|
||||
"index": i
|
||||
});
|
||||
writeln!(file, "{rec}")?;
|
||||
}
|
||||
Ok((dt, uuid))
|
||||
}
|
||||
|
||||
fn expected_session_meta(ts: &str, uuid: Uuid) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"record_type": "session_meta",
|
||||
"timestamp": ts,
|
||||
"id": uuid.to_string(),
|
||||
"cwd": "/",
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_conversations_latest_first() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
// Fixed UUIDs for deterministic expectations
|
||||
let u1 = Uuid::from_u128(1);
|
||||
let u2 = Uuid::from_u128(2);
|
||||
let u3 = Uuid::from_u128(3);
|
||||
|
||||
// Create three sessions across three days
|
||||
write_session_file(home, "2025-01-01T12-00-00", u1, 3).unwrap();
|
||||
write_session_file(home, "2025-01-02T12-00-00", u2, 3).unwrap();
|
||||
write_session_file(home, "2025-01-03T12-00-00", u3, 3).unwrap();
|
||||
|
||||
let page = get_conversations(home, 10, None).await.unwrap();
|
||||
|
||||
// Build expected objects
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("03")
|
||||
.join(format!("rollout-2025-01-03T12-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-01-02T12-00-00-{u2}.jsonl"));
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-01-01T12-00-00-{u1}.jsonl"));
|
||||
|
||||
let head_3 = vec![
|
||||
expected_session_meta("2025-01-03T12-00-00", u3),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
expected_session_meta("2025-01-02T12-00-00", u2),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_1 = vec![
|
||||
expected_session_meta("2025-01-01T12-00-00", u1),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
|
||||
let expected_cursor: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-01-01T12-00-00|{u1}\"")).unwrap();
|
||||
|
||||
let expected = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p1,
|
||||
head: head_3,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head_2,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head_1,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor),
|
||||
num_scanned_files: 3,
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
|
||||
assert_eq!(page, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pagination_cursor() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
// Fixed UUIDs for deterministic expectations
|
||||
let u1 = Uuid::from_u128(11);
|
||||
let u2 = Uuid::from_u128(22);
|
||||
let u3 = Uuid::from_u128(33);
|
||||
let u4 = Uuid::from_u128(44);
|
||||
let u5 = Uuid::from_u128(55);
|
||||
|
||||
// Oldest to newest
|
||||
write_session_file(home, "2025-03-01T09-00-00", u1, 1).unwrap();
|
||||
write_session_file(home, "2025-03-02T09-00-00", u2, 1).unwrap();
|
||||
write_session_file(home, "2025-03-03T09-00-00", u3, 1).unwrap();
|
||||
write_session_file(home, "2025-03-04T09-00-00", u4, 1).unwrap();
|
||||
write_session_file(home, "2025-03-05T09-00-00", u5, 1).unwrap();
|
||||
|
||||
let page1 = get_conversations(home, 2, None).await.unwrap();
|
||||
let p5 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("05")
|
||||
.join(format!("rollout-2025-03-05T09-00-00-{u5}.jsonl"));
|
||||
let p4 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("04")
|
||||
.join(format!("rollout-2025-03-04T09-00-00-{u4}.jsonl"));
|
||||
let head_5 = vec![
|
||||
expected_session_meta("2025-03-05T09-00-00", u5),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_4 = vec![
|
||||
expected_session_meta("2025-03-04T09-00-00", u4),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let expected_cursor1: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-04T09-00-00|{u4}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p5,
|
||||
head: head_5,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p4,
|
||||
head: head_4,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor1.clone()),
|
||||
num_scanned_files: 3, // scanned 05, 04, and peeked at 03 before breaking
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page1, expected_page1);
|
||||
|
||||
let page2 = get_conversations(home, 2, page1.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("03")
|
||||
.join(format!("rollout-2025-03-03T09-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-03-02T09-00-00-{u2}.jsonl"));
|
||||
let head_3 = vec![
|
||||
expected_session_meta("2025-03-03T09-00-00", u3),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
expected_session_meta("2025-03-02T09-00-00", u2),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let expected_cursor2: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-02T09-00-00|{u2}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head_3,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head_2,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02, and peeked at 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page2, expected_page2);
|
||||
|
||||
let page3 = get_conversations(home, 2, page2.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-03-01T09-00-00-{u1}.jsonl"));
|
||||
let head_1 = vec![
|
||||
expected_session_meta("2025-03-01T09-00-00", u1),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let expected_cursor3: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-01T09-00-00|{u1}\"")).unwrap();
|
||||
let expected_page3 = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: p1,
|
||||
head: head_1,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor3.clone()),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02 (anchor), 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page3, expected_page3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_conversation_contents() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
let uuid = Uuid::new_v4();
|
||||
let ts = "2025-04-01T10-30-00";
|
||||
write_session_file(home, ts, uuid, 2).unwrap();
|
||||
|
||||
let page = get_conversations(home, 1, None).await.unwrap();
|
||||
let path = &page.items[0].path;
|
||||
|
||||
let content = get_conversation(path).await.unwrap();
|
||||
|
||||
// Page equality (single item)
|
||||
let expected_path = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("04")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-04-01T10-30-00-{uuid}.jsonl"));
|
||||
let expected_head = vec![
|
||||
expected_session_meta(ts, uuid),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
];
|
||||
let expected_cursor: Cursor = serde_json::from_str(&format!("\"{ts}|{uuid}\"")).unwrap();
|
||||
let expected_page = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: expected_path.clone(),
|
||||
head: expected_head,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor),
|
||||
num_scanned_files: 1,
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page, expected_page);
|
||||
|
||||
// Entire file contents equality
|
||||
let meta = expected_session_meta(ts, uuid);
|
||||
let rec0 = serde_json::json!({"record_type": "response", "index": 0});
|
||||
let rec1 = serde_json::json!({"record_type": "response", "index": 1});
|
||||
let expected_content = format!("{meta}\n{rec0}\n{rec1}\n");
|
||||
assert_eq!(content, expected_content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stable_ordering_same_second_pagination() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
let ts = "2025-07-01T00-00-00";
|
||||
let u1 = Uuid::from_u128(1);
|
||||
let u2 = Uuid::from_u128(2);
|
||||
let u3 = Uuid::from_u128(3);
|
||||
|
||||
write_session_file(home, ts, u1, 0).unwrap();
|
||||
write_session_file(home, ts, u2, 0).unwrap();
|
||||
write_session_file(home, ts, u3, 0).unwrap();
|
||||
|
||||
let page1 = get_conversations(home, 2, None).await.unwrap();
|
||||
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u2}.jsonl"));
|
||||
let head = |u: Uuid| -> Vec<serde_json::Value> { vec![expected_session_meta(ts, u)] };
|
||||
let expected_cursor1: Cursor = serde_json::from_str(&format!("\"{ts}|{u2}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head(u3),
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head(u2),
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor1.clone()),
|
||||
num_scanned_files: 3, // scanned u3, u2, peeked u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page1, expected_page1);
|
||||
|
||||
let page2 = get_conversations(home, 2, page1.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u1}.jsonl"));
|
||||
let expected_cursor2: Cursor = serde_json::from_str(&format!("\"{ts}|{u1}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: p1,
|
||||
head: head(u1),
|
||||
}],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
num_scanned_files: 3, // scanned u3, u2 (anchor), u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page2, expected_page2);
|
||||
}
|
||||
@@ -53,6 +53,13 @@ pub fn assess_patch_safety(
|
||||
// paths outside the project.
|
||||
match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
None if sandbox_policy == &SandboxPolicy::DangerFullAccess => {
|
||||
// If the user has explicitly requested DangerFullAccess, then
|
||||
// we can auto-approve even without a sandbox.
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
}
|
||||
}
|
||||
None => SafetyCheck::AskUser,
|
||||
}
|
||||
} else if policy == AskForApproval::Never {
|
||||
|
||||
@@ -9,6 +9,12 @@ pub struct ZshShell {
|
||||
zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct BashShell {
|
||||
shell_path: String,
|
||||
bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PowerShellConfig {
|
||||
exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
@@ -18,6 +24,7 @@ pub struct PowerShellConfig {
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub enum Shell {
|
||||
Zsh(ZshShell),
|
||||
Bash(BashShell),
|
||||
PowerShell(PowerShellConfig),
|
||||
Unknown,
|
||||
}
|
||||
@@ -26,22 +33,10 @@ impl Shell {
|
||||
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Shell::Zsh(zsh) => {
|
||||
if !std::path::Path::new(&zsh.zshrc_path).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = vec![zsh.shell_path.clone()];
|
||||
result.push("-lc".to_string());
|
||||
|
||||
let joined = strip_bash_lc(&command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
|
||||
|
||||
if let Some(joined) = joined {
|
||||
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
format_shell_invocation_with_rc(&command, &zsh.shell_path, &zsh.zshrc_path)
|
||||
}
|
||||
Shell::Bash(bash) => {
|
||||
format_shell_invocation_with_rc(&command, &bash.shell_path, &bash.bashrc_path)
|
||||
}
|
||||
Shell::PowerShell(ps) => {
|
||||
// If model generated a bash command, prefer a detected bash fallback
|
||||
@@ -97,12 +92,32 @@ impl Shell {
|
||||
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::PowerShell(ps) => Some(ps.exe.clone()),
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_shell_invocation_with_rc(
|
||||
command: &Vec<String>,
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
) -> Option<Vec<String>> {
|
||||
let joined = strip_bash_lc(command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok())?;
|
||||
|
||||
let rc_command = if std::path::Path::new(rc_path).exists() {
|
||||
format!("source {rc_path} && ({joined})")
|
||||
} else {
|
||||
joined
|
||||
};
|
||||
|
||||
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
|
||||
}
|
||||
|
||||
fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
match command.as_slice() {
|
||||
// exactly three items
|
||||
@@ -116,44 +131,43 @@ fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
use whoami;
|
||||
#[cfg(unix)]
|
||||
fn detect_default_user_shell() -> Shell {
|
||||
use libc::getpwuid;
|
||||
use libc::getuid;
|
||||
use std::ffi::CStr;
|
||||
|
||||
let user = whoami::username();
|
||||
let home = format!("/Users/{user}");
|
||||
let output = Command::new("dscl")
|
||||
.args([".", "-read", &home, "UserShell"])
|
||||
.output()
|
||||
.await
|
||||
.ok();
|
||||
match output {
|
||||
Some(o) => {
|
||||
if !o.status.success() {
|
||||
return Shell::Unknown;
|
||||
}
|
||||
let stdout = String::from_utf8_lossy(&o.stdout);
|
||||
for line in stdout.lines() {
|
||||
if let Some(shell_path) = line.strip_prefix("UserShell: ")
|
||||
&& shell_path.ends_with("/zsh")
|
||||
{
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc"),
|
||||
});
|
||||
}
|
||||
unsafe {
|
||||
let uid = getuid();
|
||||
let pw = getpwuid(uid);
|
||||
|
||||
if !pw.is_null() {
|
||||
let shell_path = CStr::from_ptr((*pw).pw_shell)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
|
||||
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path,
|
||||
zshrc_path: format!("{home_path}/.zshrc"),
|
||||
});
|
||||
}
|
||||
|
||||
Shell::Unknown
|
||||
if shell_path.ends_with("/bash") {
|
||||
return Shell::Bash(BashShell {
|
||||
shell_path,
|
||||
bashrc_path: format!("{home_path}/.bashrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => Shell::Unknown,
|
||||
}
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "macos"), not(target_os = "windows")))]
|
||||
#[cfg(unix)]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
detect_default_user_shell()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
@@ -196,8 +210,13 @@ pub async fn default_user_shell() -> Shell {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "windows"), not(unix)))]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
#[cfg(unix)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command;
|
||||
@@ -230,9 +249,127 @@ mod tests {
|
||||
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(actual_cmd, None);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/zsh".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bashrc_not_exists() {
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".to_string(),
|
||||
bashrc_path: "/does/not/exist/.bashrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bash_escaping_and_execution() {
|
||||
let shell_path = "/bin/bash";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source BASHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let bashrc_path = temp_home.path().join(".bashrc");
|
||||
std::fs::write(
|
||||
&bashrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
bashrc_path: bashrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: actual_cmd.unwrap(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout.text, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
mod macos_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
let shell_path = "/bin/zsh";
|
||||
|
||||
@@ -3,7 +3,7 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::AuthMode;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
@@ -58,7 +58,7 @@ pub struct IdTokenInfo {
|
||||
pub email: Option<String>,
|
||||
/// The ChatGPT subscription plan type
|
||||
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
|
||||
/// (Note: ae has not verified that those are the exact values.)
|
||||
/// (Note: values may vary by backend.)
|
||||
pub(crate) chatgpt_plan_type: Option<PlanType>,
|
||||
pub raw_jwt: String,
|
||||
}
|
||||
@@ -137,7 +137,7 @@ pub enum IdTokenInfoError {
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub(crate) fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
// JWT format: header.payload.signature
|
||||
let mut parts = id_token.split('.');
|
||||
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
|
||||
@@ -204,9 +204,33 @@ mod tests {
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
assert_eq!(
|
||||
info.chatgpt_plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
);
|
||||
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({ "sub": "123" });
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.get_chatgpt_plan_type().is_none());
|
||||
}
|
||||
}
|
||||
19
codex-rs/core/src/tool_apply_patch.lark
Normal file
19
codex-rs/core/src/tool_apply_patch.lark
Normal file
@@ -0,0 +1,19 @@
|
||||
start: begin_patch hunk+ end_patch
|
||||
begin_patch: "*** Begin Patch" LF
|
||||
end_patch: "*** End Patch" LF?
|
||||
|
||||
hunk: add_hunk | delete_hunk | update_hunk
|
||||
add_hunk: "*** Add File: " filename LF add_line+
|
||||
delete_hunk: "*** Delete File: " filename LF
|
||||
update_hunk: "*** Update File: " filename LF change_move? change?
|
||||
|
||||
filename: /(.+)/
|
||||
add_line: "+" /(.*)/ LF -> line
|
||||
|
||||
change_move: "*** Move to: " filename LF
|
||||
change: (change_context | change_line)+ eof_line?
|
||||
change_context: ("@@" | "@@ " /(.+)/) LF
|
||||
change_line: ("+" | "-" | " ") /(.*)/ LF
|
||||
eof_line: "*** End of File" LF
|
||||
|
||||
%import common.LF
|
||||
@@ -8,6 +8,8 @@ use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
|
||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
pub(crate) input: String,
|
||||
@@ -29,27 +31,7 @@ pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
definition: r#"start: begin_patch hunk+ end_patch
|
||||
begin_patch: "*** Begin Patch" LF
|
||||
end_patch: "*** End Patch" LF?
|
||||
|
||||
hunk: add_hunk | delete_hunk | update_hunk
|
||||
add_hunk: "*** Add File: " filename LF add_line+
|
||||
delete_hunk: "*** Delete File: " filename LF
|
||||
update_hunk: "*** Update File: " filename LF change_move? change?
|
||||
|
||||
filename: /(.+)/
|
||||
add_line: "+" /(.+)/ LF -> line
|
||||
|
||||
change_move: "*** Move to: " filename LF
|
||||
change: (change_context | change_line)+ eof_line?
|
||||
change_context: ("@@" | "@@ " /(.+)/) LF
|
||||
change_line: ("+" | "-" | " ") /(.+)/ LF
|
||||
eof_line: "*** End of File" LF
|
||||
|
||||
%import common.LF
|
||||
"#
|
||||
.to_string(),
|
||||
definition: APPLY_PATCH_LARK_GRAMMAR.to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
|
||||
pub fn get_codex_user_agent(originator: Option<&str>) -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator.unwrap_or(DEFAULT_ORIGINATOR),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(re.is_match(&user_agent));
|
||||
}
|
||||
}
|
||||
42
codex-rs/core/src/user_instructions.rs
Normal file
42
codex-rs/core/src/user_instructions.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::USER_INSTRUCTIONS_CLOSE_TAG;
|
||||
use codex_protocol::protocol::USER_INSTRUCTIONS_OPEN_TAG;
|
||||
|
||||
/// Wraps user instructions in a tag so the model can classify them easily.
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename = "user_instructions", rename_all = "snake_case")]
|
||||
pub(crate) struct UserInstructions {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl UserInstructions {
|
||||
pub fn new<T: Into<String>>(text: T) -> Self {
|
||||
Self { text: text.into() }
|
||||
}
|
||||
|
||||
/// Serializes the user instructions to an XML-like tagged block that starts
|
||||
/// with <user_instructions> so clients can classify it.
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
format!(
|
||||
"{USER_INSTRUCTIONS_OPEN_TAG}\n\n{}\n\n{USER_INSTRUCTIONS_CLOSE_TAG}",
|
||||
self.text
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserInstructions> for ResponseItem {
|
||||
fn from(ui: UserInstructions) -> Self {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: ui.serialize_to_xml(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
345
codex-rs/core/tests/chat_completions_payload.rs
Normal file
345
codex-rs/core/tests/chat_completions_payload.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::LocalShellAction;
|
||||
use codex_core::LocalShellExecAction;
|
||||
use codex_core::LocalShellStatus;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ReasoningItemContent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use futures::StreamExt;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn network_disabled() -> bool {
|
||||
std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok()
|
||||
}
|
||||
|
||||
async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(
|
||||
"data: {\"choices\":[{\"delta\":{}}]}\n\ndata: [DONE]\n\n",
|
||||
"text/event-stream",
|
||||
);
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = match TempDir::new() {
|
||||
Ok(dir) => dir,
|
||||
Err(e) => panic!("failed to create TempDir: {e}"),
|
||||
};
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
config.show_raw_agent_reasoning = true;
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let config = Arc::new(config);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
ConversationId::new(),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
|
||||
let mut stream = match client.stream(&prompt).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("stream chat failed: {e}"),
|
||||
};
|
||||
while let Some(event) = stream.next().await {
|
||||
if let Err(e) = event {
|
||||
panic!("stream event error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let requests = match server.received_requests().await {
|
||||
Some(reqs) => reqs,
|
||||
None => panic!("request not made"),
|
||||
};
|
||||
match requests[0].body_json() {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("invalid json body: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn user_message(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_message(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_item(text: &str) -> ResponseItem {
|
||||
ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: text.to_string(),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn function_call() -> ResponseItem {
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn local_shell_call() -> ResponseItem {
|
||||
ResponseItem::LocalShellCall {
|
||||
id: Some("id1".to_string()),
|
||||
call_id: None,
|
||||
status: LocalShellStatus::InProgress,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: Some(1_000),
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn messages_from(body: &Value) -> Vec<Value> {
|
||||
match body["messages"].as_array() {
|
||||
Some(arr) => arr.clone(),
|
||||
None => panic!("messages array missing"),
|
||||
}
|
||||
}
|
||||
|
||||
fn first_assistant(messages: &[Value]) -> &Value {
|
||||
match messages.iter().find(|msg| msg["role"] == "assistant") {
|
||||
Some(v) => v,
|
||||
None => panic!("assistant message not present"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn omits_reasoning_when_none_present() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![user_message("u1"), assistant_message("a1")]).await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant = first_assistant(&messages);
|
||||
|
||||
assert_eq!(assistant["content"], Value::String("a1".into()));
|
||||
assert!(assistant.get("reasoning").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn attaches_reasoning_to_previous_assistant() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
user_message("u1"),
|
||||
assistant_message("a1"),
|
||||
reasoning_item("rA"),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant = first_assistant(&messages);
|
||||
|
||||
assert_eq!(assistant["content"], Value::String("a1".into()));
|
||||
assert_eq!(assistant["reasoning"], Value::String("rA".into()));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn attaches_reasoning_to_function_call_anchor() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
user_message("u1"),
|
||||
reasoning_item("rFunc"),
|
||||
function_call(),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant = first_assistant(&messages);
|
||||
|
||||
assert_eq!(assistant["reasoning"], Value::String("rFunc".into()));
|
||||
let tool_calls = match assistant["tool_calls"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => panic!("tool call list missing"),
|
||||
};
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0]["type"], Value::String("function".into()));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn attaches_reasoning_to_local_shell_call() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
user_message("u1"),
|
||||
reasoning_item("rShell"),
|
||||
local_shell_call(),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant = first_assistant(&messages);
|
||||
|
||||
assert_eq!(assistant["reasoning"], Value::String("rShell".into()));
|
||||
assert_eq!(
|
||||
assistant["tool_calls"][0]["type"],
|
||||
Value::String("local_shell_call".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn drops_reasoning_when_last_role_is_user() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
assistant_message("aPrev"),
|
||||
reasoning_item("rHist"),
|
||||
user_message("uNew"),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
assert!(messages.iter().all(|msg| msg.get("reasoning").is_none()));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn ignores_reasoning_before_last_user() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
user_message("u1"),
|
||||
assistant_message("a1"),
|
||||
user_message("u2"),
|
||||
reasoning_item("rAfterU1"),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
assert!(messages.iter().all(|msg| msg.get("reasoning").is_none()));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn skips_empty_reasoning_segments() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![
|
||||
user_message("u1"),
|
||||
assistant_message("a1"),
|
||||
reasoning_item(""),
|
||||
reasoning_item(" "),
|
||||
])
|
||||
.await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant = first_assistant(&messages);
|
||||
assert!(assistant.get("reasoning").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn suppresses_duplicate_assistant_messages() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let body = run_request(vec![assistant_message("dup"), assistant_message("dup")]).await;
|
||||
let messages = messages_from(&body);
|
||||
let assistant_messages: Vec<_> = messages
|
||||
.iter()
|
||||
.filter(|msg| msg["role"] == "assistant")
|
||||
.collect();
|
||||
assert_eq!(assistant_messages.len(), 1);
|
||||
assert_eq!(
|
||||
assistant_messages[0]["content"],
|
||||
Value::String("dup".into())
|
||||
);
|
||||
}
|
||||
320
codex-rs/core/tests/chat_completions_sse.rs
Normal file
320
codex-rs/core/tests/chat_completions_sse.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use futures::StreamExt;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn network_disabled() -> bool {
|
||||
std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok()
|
||||
}
|
||||
|
||||
async fn run_stream(sse_body: &str) -> Vec<ResponseEvent> {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_body.to_string(), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = match TempDir::new() {
|
||||
Ok(dir) => dir,
|
||||
Err(e) => panic!("failed to create TempDir: {e}"),
|
||||
};
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
config.show_raw_agent_reasoning = true;
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let config = Arc::new(config);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
ConversationId::new(),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let mut stream = match client.stream(&prompt).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("stream chat failed: {e}"),
|
||||
};
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(ev) => events.push(ev),
|
||||
Err(e) => panic!("stream event error: {e}"),
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
fn assert_message(item: &ResponseItem, expected: &str) {
|
||||
if let ResponseItem::Message { content, .. } = item {
|
||||
let text = content.iter().find_map(|part| match part {
|
||||
ContentItem::OutputText { text } | ContentItem::InputText { text } => Some(text),
|
||||
_ => None,
|
||||
});
|
||||
let Some(text) = text else {
|
||||
panic!("message missing text: {item:?}");
|
||||
};
|
||||
assert_eq!(text, expected);
|
||||
} else {
|
||||
panic!("expected message item, got: {item:?}");
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_reasoning(item: &ResponseItem, expected: &str) {
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(parts),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut combined = String::new();
|
||||
for part in parts {
|
||||
match part {
|
||||
codex_core::ReasoningItemContent::ReasoningText { text }
|
||||
| codex_core::ReasoningItemContent::Text { text } => combined.push_str(text),
|
||||
}
|
||||
}
|
||||
assert_eq!(combined, expected);
|
||||
} else {
|
||||
panic!("expected reasoning item, got: {item:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_text_without_reasoning() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{}}]}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
);
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 3, "unexpected events: {events:?}");
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "hi"),
|
||||
other => panic!("expected text delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_message(item, "hi"),
|
||||
other => panic!("expected terminal message, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_reasoning_from_string_delta() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"reasoning\":\"think1\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n",
|
||||
);
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 5, "unexpected events: {events:?}");
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "think1"),
|
||||
other => panic!("expected reasoning delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "ok"),
|
||||
other => panic!("expected text delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[2] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "think1"),
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[3] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_message(item, "ok"),
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[4], ResponseEvent::Completed { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_reasoning_from_object_delta() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"reasoning\":{\"text\":\"partA\"}}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{\"reasoning\":{\"content\":\"partB\"}}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"answer\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n",
|
||||
);
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 6, "unexpected events: {events:?}");
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partA"),
|
||||
other => panic!("expected reasoning delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partB"),
|
||||
other => panic!("expected reasoning delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[2] {
|
||||
ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "answer"),
|
||||
other => panic!("expected text delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[3] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "partApartB"),
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[4] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_message(item, "answer"),
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[5], ResponseEvent::Completed { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_reasoning_from_final_message() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = "data: {\"choices\":[{\"message\":{\"reasoning\":\"final-cot\"},\"finish_reason\":\"stop\"}]}\n\n";
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 3, "unexpected events: {events:?}");
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "final-cot"),
|
||||
other => panic!("expected reasoning delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "final-cot"),
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_reasoning_before_tool_call() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"reasoning\":\"pre-tool\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"run\",\"arguments\":\"{}\"}}]},\"finish_reason\":\"tool_calls\"}]}\n\n",
|
||||
);
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 4, "unexpected events: {events:?}");
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "pre-tool"),
|
||||
other => panic!("expected reasoning delta, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "pre-tool"),
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
match &events[2] {
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
}) => {
|
||||
assert_eq!(name, "run");
|
||||
assert_eq!(arguments, "{}");
|
||||
assert_eq!(call_id, "call_1");
|
||||
}
|
||||
other => panic!("expected function call, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[3], ResponseEvent::Completed { .. }));
|
||||
}
|
||||
@@ -388,8 +388,7 @@ async fn integration_creates_and_checks_session_file() {
|
||||
"No message found in session file containing the marker"
|
||||
);
|
||||
|
||||
// Second run: resume and append.
|
||||
let orig_len = content.lines().count();
|
||||
// Second run: resume should update the existing file.
|
||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||
@@ -419,31 +418,50 @@ async fn integration_creates_and_checks_session_file() {
|
||||
let output2 = cmd2.output().unwrap();
|
||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||
|
||||
// The rollout writer runs on a background async task; give it a moment to flush.
|
||||
let mut new_len = orig_len;
|
||||
let deadline = Instant::now() + Duration::from_secs(5);
|
||||
let mut content2 = String::new();
|
||||
while Instant::now() < deadline {
|
||||
if let Ok(c) = std::fs::read_to_string(&path) {
|
||||
let count = c.lines().count();
|
||||
if count > orig_len {
|
||||
content2 = c;
|
||||
new_len = count;
|
||||
// Find the new session file containing the resumed marker.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut resumed_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && resumed_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let p = entry.path();
|
||||
let Ok(c) = std::fs::read_to_string(p) else {
|
||||
continue;
|
||||
};
|
||||
if c.contains(&marker2) {
|
||||
resumed_path = Some(p.to_path_buf());
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
if resumed_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
if content2.is_empty() {
|
||||
// last attempt
|
||||
content2 = std::fs::read_to_string(&path).unwrap();
|
||||
new_len = content2.lines().count();
|
||||
}
|
||||
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
||||
assert!(content2.contains(&marker), "rollout lost original marker");
|
||||
|
||||
let resumed_path = resumed_path.expect("No resumed session file found containing the marker2");
|
||||
// Resume should write to the existing log file.
|
||||
assert_eq!(
|
||||
resumed_path, path,
|
||||
"resume should create a new session file"
|
||||
);
|
||||
|
||||
let resumed_content = std::fs::read_to_string(&resumed_path).unwrap();
|
||||
assert!(
|
||||
content2.contains(&marker2),
|
||||
"rollout missing resumed marker"
|
||||
resumed_content.contains(&marker),
|
||||
"resumed file missing original marker"
|
||||
);
|
||||
assert!(
|
||||
resumed_content.contains(&marker2),
|
||||
"resumed file missing resumed marker"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::InputMessageKind;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::UserMessageEvent;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -96,7 +101,7 @@ fn write_auth_json(
|
||||
"OPENAI_API_KEY": openai_api_key,
|
||||
"tokens": tokens,
|
||||
// RFC3339 datetime; value doesn't matter for these tests
|
||||
"last_refresh": "2025-08-06T20:41:36.232376Z",
|
||||
"last_refresh": chrono::Utc::now(),
|
||||
});
|
||||
|
||||
std::fs::write(
|
||||
@@ -109,7 +114,198 @@ fn write_auth_json(
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a fake rollout session file with prior user + system + assistant messages.
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
let session_path = tmpdir.path().join("resume-session.jsonl");
|
||||
let mut f = std::fs::File::create(&session_path).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"record_type": "session_meta",
|
||||
"id": Uuid::new_v4(),
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"cwd": tmpdir.path().to_string_lossy(),
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0-test"
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: user message (should be delivered)
|
||||
let prior_user = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::InputText {
|
||||
text: "resumed user message".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_user_obj = serde_json::to_value(&prior_user)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_user_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_user_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_user_obj)).unwrap();
|
||||
|
||||
// Also include a matching user message event to preserve ordering at resume
|
||||
let prior_user_event = EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "resumed user message".to_string(),
|
||||
kind: Some(InputMessageKind::Plain),
|
||||
});
|
||||
let prior_user_event_line = serde_json::json!({
|
||||
"timestamp": "2025-01-01T00:00:00Z",
|
||||
"record_type": "event",
|
||||
"id": "resume-0",
|
||||
"msg": prior_user_event,
|
||||
});
|
||||
writeln!(f, "{prior_user_event_line}").unwrap();
|
||||
|
||||
// Prior item: system message (excluded from API history)
|
||||
let prior_system = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "system".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: "resumed system instruction".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_system_obj = serde_json::to_value(&prior_system)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_system_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_system_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_system_obj)).unwrap();
|
||||
|
||||
// Prior item: assistant message
|
||||
let prior_item = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: "resumed assistant message".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_item_obj = serde_json::to_value(&prior_item)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_item_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_item_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_item_obj)).unwrap();
|
||||
let prior_item_event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "resumed assistant message".to_string(),
|
||||
});
|
||||
let prior_event_line = serde_json::json!({
|
||||
"timestamp": "2025-01-01T00:00:00Z",
|
||||
"record_type": "event",
|
||||
"id": "resume-1",
|
||||
"msg": prior_item_event,
|
||||
});
|
||||
writeln!(f, "{prior_event_line}").unwrap();
|
||||
drop(f);
|
||||
|
||||
// Mock server that will receive the resumed request
|
||||
let server = MockServer::start().await;
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure Codex to resume from our file
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.experimental_resume = Some(session_path.clone());
|
||||
// Also configure user instructions to ensure they are NOT delivered on resume.
|
||||
config.user_instructions = Some("be nice".to_string());
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
// 1) Assert initial_messages contains the prior user + assistant messages as EventMsg entries
|
||||
let initial_msgs = session_configured
|
||||
.initial_messages
|
||||
.clone()
|
||||
.expect("expected initial messages for resumed session");
|
||||
let initial_json = serde_json::to_value(&initial_msgs).unwrap();
|
||||
let expected_initial_json = json!([
|
||||
{ "type": "user_message", "message": "resumed user message", "kind": "plain" },
|
||||
{ "type": "agent_message", "message": "resumed assistant message" }
|
||||
]);
|
||||
assert_eq!(initial_json, expected_initial_json);
|
||||
|
||||
// 2) Submit new input; the request body must include the prior item followed by the new user input.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "resumed user message" }]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "output_text", "text": "resumed assistant message" }]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "hello" }]
|
||||
}
|
||||
]);
|
||||
assert_eq!(request_body["input"], expected_input);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
@@ -166,12 +362,12 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_conversation_id = request.headers.get("conversation_id").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request_session_id.to_str().unwrap(),
|
||||
request_conversation_id.to_str().unwrap(),
|
||||
conversation_id.to_string()
|
||||
);
|
||||
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||
@@ -344,14 +540,14 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_conversation_id = request.headers.get("conversation_id").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request_session_id.to_str().unwrap(),
|
||||
request_conversation_id.to_str().unwrap(),
|
||||
conversation_id.to_string()
|
||||
);
|
||||
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||
@@ -360,7 +556,6 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
"Bearer Access Token"
|
||||
);
|
||||
assert_eq!(request_chatgpt_account_id.to_str().unwrap(), "account_id");
|
||||
assert!(!request_body["store"].as_bool().unwrap());
|
||||
assert!(request_body["stream"].as_bool().unwrap());
|
||||
assert_eq!(
|
||||
request_body["include"][0].as_str().unwrap(),
|
||||
@@ -414,12 +609,15 @@ async fn prefers_chatgpt_token_when_config_prefers_chatgpt() {
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ChatGPT;
|
||||
|
||||
let auth_manager =
|
||||
match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) {
|
||||
Ok(Some(auth)) => codex_login::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
@@ -439,14 +637,6 @@ async fn prefers_chatgpt_token_when_config_prefers_chatgpt() {
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// verify request body flags
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
assert!(
|
||||
!request_body["store"].as_bool().unwrap(),
|
||||
"store should be false for ChatGPT auth"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -495,12 +685,15 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ApiKey;
|
||||
|
||||
let auth_manager =
|
||||
match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) {
|
||||
Ok(Some(auth)) => codex_login::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
@@ -520,14 +713,6 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// verify request body flags
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
assert!(
|
||||
request_body["store"].as_bool().unwrap(),
|
||||
"store should be true for API key auth"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -845,34 +1030,29 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
|
||||
|
||||
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
|
||||
let r3_tail_expected = serde_json::json!([
|
||||
let r3_tail_expected = json!([
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U1"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U2"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U3"}]
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#![expect(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
@@ -7,7 +8,6 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
|
||||
176
codex-rs/core/tests/suite/fork_conversation.rs
Normal file
176
codex-rs/core/tests/suite/fork_conversation.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
core_test_support::load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn fork_conversation_twice_drops_to_first_message() {
|
||||
// Start a mock server that completes three turns.
|
||||
let server = MockServer::start().await;
|
||||
let sse = sse_completed("resp");
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse.clone(), "text/event-stream");
|
||||
|
||||
// Expect three calls to /v1/responses – one per user input.
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(3)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure Codex to use the mock server.
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider.clone();
|
||||
let config_for_fork = config.clone();
|
||||
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation");
|
||||
|
||||
// Send three user messages; wait for three completed turns.
|
||||
for text in ["first", "second", "third"] {
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Request history from the base conversation.
|
||||
codex.submit(Op::GetConversationPath).await.unwrap();
|
||||
let base_history =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
|
||||
|
||||
// Capture path/id from the base history and compute expected prefixes after each fork.
|
||||
let (base_conv_id, base_path) = match &base_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
_ => panic!("expected ConversationHistory event"),
|
||||
};
|
||||
|
||||
// Read entries from rollout file.
|
||||
async fn read_response_entries(path: &std::path::Path) -> Vec<ResponseItem> {
|
||||
let text = tokio::fs::read_to_string(path).await.unwrap_or_default();
|
||||
let mut out = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_str::<ResponseItem>(line) {
|
||||
out.push(item);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
let entries_after_three: Vec<ResponseItem> = read_response_entries(&base_path).await;
|
||||
// History layout for this test:
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
// [3] "second" user message,
|
||||
// [4] "third" user message.
|
||||
|
||||
// Fork 1: drops the last user message and everything after.
|
||||
let expected_after_first = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
entries_after_three[3].clone(),
|
||||
];
|
||||
|
||||
// Fork 2: drops the last user message and everything after.
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
let expected_after_second = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
];
|
||||
|
||||
// Fork once with n=1 → drops the last user message and everything after.
|
||||
let NewConversation {
|
||||
conversation: codex_fork1,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(base_path.clone(), base_conv_id, 1, config_for_fork.clone())
|
||||
.await
|
||||
.expect("fork 1");
|
||||
|
||||
codex_fork1.submit(Op::GetConversationPath).await.unwrap();
|
||||
let fork1_history = wait_for_event(&codex_fork1, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
})
|
||||
.await;
|
||||
let (fork1_id, fork1_path) = match &fork1_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
_ => panic!("expected ConversationHistory event after first fork"),
|
||||
};
|
||||
let entries_after_first_fork: Vec<ResponseItem> = read_response_entries(&fork1_path).await;
|
||||
assert_eq!(entries_after_first_fork, expected_after_first);
|
||||
|
||||
// Fork again with n=1 → drops the (new) last user message, leaving only the first.
|
||||
let NewConversation {
|
||||
conversation: codex_fork2,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(fork1_path.clone(), fork1_id, 1, config_for_fork.clone())
|
||||
.await
|
||||
.expect("fork 2");
|
||||
|
||||
codex_fork2.submit(Op::GetConversationPath).await.unwrap();
|
||||
let fork2_history = wait_for_event(&codex_fork2, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
})
|
||||
.await;
|
||||
let (_fork2_id, fork2_path) = match &fork2_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
_ => panic!("expected ConversationHistory event after second fork"),
|
||||
};
|
||||
let entries_after_second_fork: Vec<ResponseItem> = read_response_entries(&fork2_path).await;
|
||||
assert_eq!(entries_after_second_fork, expected_after_second);
|
||||
}
|
||||
@@ -5,6 +5,7 @@ mod client;
|
||||
mod compact;
|
||||
mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod live_cli;
|
||||
mod prompt_caching;
|
||||
mod seatbelt;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
@@ -12,7 +13,6 @@ use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_core::protocol_config_types::ReasoningSummary;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -289,20 +289,17 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
|
||||
let expected_env_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
});
|
||||
let expected_ui_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
|
||||
let expected_user_message_1 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 1" } ]
|
||||
});
|
||||
@@ -314,7 +311,6 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
@@ -424,7 +420,6 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
// as the prefix of the second request, ensuring cache hit potential.
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
@@ -438,7 +433,6 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
</environment_context>"#;
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
@@ -543,7 +537,6 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
// as the prefix of the second request.
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
@@ -7,7 +8,6 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
|
||||
@@ -26,6 +26,7 @@ use codex_core::protocol::TurnAbortReason;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use codex_core::protocol::WebSearchBeginEvent;
|
||||
use codex_core::protocol::WebSearchEndEvent;
|
||||
use codex_protocol::num_format::format_with_separators;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
@@ -189,8 +190,14 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
return CodexStatus::InitiateShutdown;
|
||||
}
|
||||
EventMsg::TokenCount(token_usage) => {
|
||||
ts_println!(self, "tokens used: {}", token_usage.blended_total());
|
||||
EventMsg::TokenCount(ev) => {
|
||||
if let Some(usage_info) = ev.info {
|
||||
ts_println!(
|
||||
self,
|
||||
"tokens used: {}",
|
||||
format_with_separators(usage_info.total_token_usage.blended_total())
|
||||
);
|
||||
}
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if !self.answer_started {
|
||||
@@ -511,17 +518,18 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
initial_messages: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
conversation_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
@@ -551,6 +559,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
},
|
||||
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
|
||||
EventMsg::ConversationHistory(_) => {}
|
||||
EventMsg::UserMessage(_) => {}
|
||||
}
|
||||
CodexStatus::Running
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub use cli::Cli;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
@@ -20,7 +21,6 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_login::AuthManager;
|
||||
use codex_ollama::DEFAULT_OSS_MODEL;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
@@ -149,7 +149,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
include_plan_tool: None,
|
||||
include_apply_patch_tool: None,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: oss.then_some(true),
|
||||
show_raw_agent_reasoning: oss.then_some(true),
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -191,6 +190,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
|
||||
@@ -151,7 +151,13 @@ pub fn run(
|
||||
// Use the same tree-walker library that ripgrep uses. We use it directly so
|
||||
// that we can leverage the parallelism it provides.
|
||||
let mut walk_builder = WalkBuilder::new(search_directory);
|
||||
walk_builder.threads(num_walk_builder_threads);
|
||||
walk_builder
|
||||
.threads(num_walk_builder_threads)
|
||||
// Allow hidden entries.
|
||||
.hidden(false)
|
||||
// Don't require git to be present to apply to apply git-related ignore rules.
|
||||
.require_git(false);
|
||||
|
||||
if !exclude.is_empty() {
|
||||
let mut override_builder = OverrideBuilder::new(search_directory);
|
||||
for exclude in exclude {
|
||||
|
||||
@@ -30,3 +30,7 @@ fix *args:
|
||||
install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
# Run the MCP server
|
||||
mcp-server-run *args:
|
||||
cargo run -p codex-mcp-server -- "$@"
|
||||
|
||||
@@ -9,6 +9,7 @@ workspace = true
|
||||
[dependencies]
|
||||
base64 = "0.22"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
rand = "0.8"
|
||||
reqwest = { version = "0.12", features = ["json", "blocking"] }
|
||||
@@ -16,7 +17,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
tempfile = "3"
|
||||
thiserror = "2.0.12"
|
||||
thiserror = "2.0.16"
|
||||
tiny_http = "0.12"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::AuthMode;
|
||||
use crate::CodexAuth;
|
||||
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
preferred_auth_mode: AuthMode,
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
/// Central manager providing a single source of truth for auth.json derived
|
||||
/// authentication data. It loads once (or on preference change) and then
|
||||
/// hands out cloned `CodexAuth` values so the rest of the program has a
|
||||
/// consistent snapshot.
|
||||
///
|
||||
/// External modifications to `auth.json` will NOT be observed until
|
||||
/// `reload()` is called explicitly. This matches the design goal of avoiding
|
||||
/// different parts of the program seeing inconsistent auth data mid‑run.
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
inner: RwLock<CachedAuth>,
|
||||
}
|
||||
|
||||
impl AuthManager {
|
||||
/// Create a new manager loading the initial auth using the provided
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Self {
|
||||
let auth = crate::CodexAuth::from_codex_home(&codex_home, preferred_auth_mode)
|
||||
.ok()
|
||||
.flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let preferred_auth_mode = auth.mode;
|
||||
let cached = CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth: Some(auth),
|
||||
};
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
inner: RwLock::new(cached),
|
||||
})
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
pub fn auth(&self) -> Option<CodexAuth> {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Preferred auth method used when (re)loading.
|
||||
pub fn preferred_auth_method(&self) -> AuthMode {
|
||||
self.inner
|
||||
.read()
|
||||
.map(|c| c.preferred_auth_mode)
|
||||
.unwrap_or(AuthMode::ApiKey)
|
||||
}
|
||||
|
||||
/// Force a reload using the existing preferred auth method. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let preferred = self.preferred_auth_method();
|
||||
let new_auth = crate::CodexAuth::from_codex_home(&self.codex_home, preferred)
|
||||
.ok()
|
||||
.flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn auths_equal(a: &Option<CodexAuth>, b: &Option<CodexAuth>) -> bool {
|
||||
match (a, b) {
|
||||
(None, None) => true,
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home, preferred_auth_mode))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
/// the auth state from disk so other components observe refreshed token.
|
||||
pub async fn refresh_token(&self) -> std::io::Result<Option<String>> {
|
||||
let auth = match self.auth() {
|
||||
Some(a) => a,
|
||||
None => return Ok(None),
|
||||
};
|
||||
match auth.refresh_token().await {
|
||||
Ok(token) => {
|
||||
// Reload to pick up persisted changes.
|
||||
self.reload();
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Log out by deleting the on‑disk auth.json (if present). Returns Ok(true)
|
||||
/// if a file was removed, Ok(false) if no auth file existed. On success,
|
||||
/// reloads the in‑memory auth cache so callers immediately observe the
|
||||
/// unauthenticated state.
|
||||
pub fn logout(&self) -> std::io::Result<bool> {
|
||||
let removed = crate::logout(&self.codex_home)?;
|
||||
// Always reload to clear any cached auth (even if file absent).
|
||||
self.reload();
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
@@ -1,693 +1,21 @@
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::fs::remove_file;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use crate::server::LoginServer;
|
||||
pub use crate::server::ServerOptions;
|
||||
pub use crate::server::ShutdownHandle;
|
||||
pub use crate::server::run_login_server;
|
||||
pub use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
mod auth_manager;
|
||||
mod pkce;
|
||||
mod server;
|
||||
mod token_data;
|
||||
|
||||
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
pub use auth_manager::AuthManager;
|
||||
pub use server::LoginServer;
|
||||
pub use server::ServerOptions;
|
||||
pub use server::ShutdownHandle;
|
||||
pub use server::run_login_server;
|
||||
|
||||
// Re-export commonly used auth types and helpers from codex-core for compatibility
|
||||
pub use codex_core::AuthManager;
|
||||
pub use codex_core::CodexAuth;
|
||||
pub use codex_core::auth::AuthDotJson;
|
||||
pub use codex_core::auth::CLIENT_ID;
|
||||
pub use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
pub use codex_core::auth::get_auth_file;
|
||||
pub use codex_core::auth::login_with_api_key;
|
||||
pub use codex_core::auth::logout;
|
||||
pub use codex_core::auth::try_read_auth_json;
|
||||
pub use codex_core::auth::write_auth_json;
|
||||
pub use codex_core::token_data::TokenData;
|
||||
pub use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodexAuth {
|
||||
pub mode: AuthMode,
|
||||
|
||||
api_key: Option<String>,
|
||||
auth_dot_json: Arc<Mutex<Option<AuthDotJson>>>,
|
||||
auth_file: PathBuf,
|
||||
}
|
||||
|
||||
impl PartialEq for CodexAuth {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.mode == other.mode
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexAuth {
|
||||
pub fn from_api_key(api_key: &str) -> Self {
|
||||
Self {
|
||||
api_key: Some(api_key.to_owned()),
|
||||
mode: AuthMode::ApiKey,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self) -> Result<String, std::io::Error> {
|
||||
let token_data = self
|
||||
.get_current_token_data()
|
||||
.ok_or(std::io::Error::other("Token data is not available."))?;
|
||||
let token = token_data.refresh_token;
|
||||
|
||||
let refresh_response = try_refresh_token(token)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Ok(mut auth_lock) = self.auth_dot_json.lock() {
|
||||
*auth_lock = Some(updated.clone());
|
||||
}
|
||||
|
||||
let access = match updated.tokens {
|
||||
Some(t) => t.access_token,
|
||||
None => {
|
||||
return Err(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json or
|
||||
/// OPENAI_API_KEY environment variable.
|
||||
pub fn from_codex_home(
|
||||
codex_home: &Path,
|
||||
preferred_auth_method: AuthMode,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home, true, preferred_auth_method)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
let auth_dot_json: Option<AuthDotJson> = self.get_current_auth_json();
|
||||
match auth_dot_json {
|
||||
Some(AuthDotJson {
|
||||
tokens: Some(mut tokens),
|
||||
last_refresh: Some(last_refresh),
|
||||
..
|
||||
}) => {
|
||||
if last_refresh < Utc::now() - chrono::Duration::days(28) {
|
||||
let refresh_response = tokio::time::timeout(
|
||||
Duration::from_secs(60),
|
||||
try_refresh_token(tokens.refresh_token.clone()),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
std::io::Error::other("timed out while refreshing OpenAI API key")
|
||||
})?
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated_auth_dot_json = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokens = updated_auth_dot_json
|
||||
.tokens
|
||||
.clone()
|
||||
.ok_or(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
))?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let mut auth_lock = self.auth_dot_json.lock().unwrap();
|
||||
*auth_lock = Some(updated_auth_dot_json);
|
||||
}
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
_ => Err(std::io::Error::other("Token data is not available.")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_token(&self) -> Result<String, std::io::Error> {
|
||||
match self.mode {
|
||||
AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()),
|
||||
AuthMode::ChatGPT => {
|
||||
let id_token = self.get_token_data().await?.access_token;
|
||||
|
||||
Ok(id_token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.account_id.clone())
|
||||
}
|
||||
|
||||
pub fn get_plan_type(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type.as_ref().map(|p| p.as_string()))
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
self.auth_dot_json.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens.clone())
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
pub fn create_dummy_chatgpt_auth_for_testing() -> Self {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: Default::default(),
|
||||
access_token: "Access Token".to_string(),
|
||||
refresh_token: "test".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
}),
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
let auth_dot_json = Arc::new(Mutex::new(Some(auth_dot_json)));
|
||||
Self {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_auth(
|
||||
codex_home: &Path,
|
||||
include_env_var: bool,
|
||||
preferred_auth_method: AuthMode,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
// First, check to see if there is a valid auth.json file. If not, we fall
|
||||
// back to AuthMode::ApiKey using the OPENAI_API_KEY environment variable
|
||||
// (if it is set).
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
// If auth.json does not exist, try to read the OPENAI_API_KEY from the
|
||||
// environment variable.
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound && include_env_var => {
|
||||
return match read_openai_api_key_from_env() {
|
||||
Some(api_key) => Ok(Some(CodexAuth::from_api_key(&api_key))),
|
||||
None => Ok(None),
|
||||
};
|
||||
}
|
||||
// Though if auth.json exists but is malformed, do not fall back to the
|
||||
// env var because the user may be expecting to use AuthMode::ChatGPT.
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let AuthDotJson {
|
||||
openai_api_key: auth_json_api_key,
|
||||
tokens,
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// If the auth.json has an API key AND does not appear to be on a plan that
|
||||
// should prefer AuthMode::ChatGPT, use AuthMode::ApiKey.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
// Should any of these be AuthMode::ChatGPT with the api_key set?
|
||||
// Does AuthMode::ChatGPT indicate that there is an auth.json that is
|
||||
// "refreshable" even if we are using the API key for auth?
|
||||
match &tokens {
|
||||
Some(tokens) => {
|
||||
if tokens.should_use_api_key(preferred_auth_method, tokens.is_openai_email()) {
|
||||
return Ok(Some(CodexAuth::from_api_key(api_key)));
|
||||
} else {
|
||||
// Ignore the API key and fall through to ChatGPT auth.
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// We have an API key but no tokens in the auth.json file.
|
||||
// Perhaps the user ran `codex login --api-key <KEY>` or updated
|
||||
// auth.json by hand. Either way, let's assume they are trying
|
||||
// to use their API key.
|
||||
return Ok(Some(CodexAuth::from_api_key(api_key)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For the AuthMode::ChatGPT variant, perhaps neither api_key nor
|
||||
// openai_api_key should exist?
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file,
|
||||
auth_dot_json: Arc::new(Mutex::new(Some(AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens,
|
||||
last_refresh,
|
||||
}))),
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join("auth.json")
|
||||
}
|
||||
|
||||
/// Delete the auth.json file inside `codex_home` if it exists. Returns `Ok(true)`
|
||||
/// if a file was removed, `Ok(false)` if no auth file was present.
|
||||
pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
match remove_file(&auth_file) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
||||
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
||||
pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||
let mut file = File::open(auth_file)?;
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents)?;
|
||||
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
||||
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
fn write_auth_json(auth_file: &Path, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
|
||||
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
|
||||
let mut options = OpenOptions::new();
|
||||
options.truncate(true).write(true).create(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
options.mode(0o600);
|
||||
}
|
||||
let mut file = options.open(auth_file)?;
|
||||
file.write_all(json_data.as_bytes())?;
|
||||
file.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_tokens(
|
||||
auth_file: &Path,
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
) -> std::io::Result<AuthDotJson> {
|
||||
let mut auth_dot_json = try_read_auth_json(auth_file)?;
|
||||
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token.to_string();
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token.to_string();
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
async fn try_refresh_token(refresh_token: String) -> std::io::Result<RefreshResponse> {
|
||||
let refresh_request = RefreshRequest {
|
||||
client_id: CLIENT_ID,
|
||||
grant_type: "refresh_token",
|
||||
refresh_token,
|
||||
scope: "openid profile email",
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post("https://auth.openai.com/oauth/token")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&refresh_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let refresh_response = response
|
||||
.json::<RefreshResponse>()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(refresh_response)
|
||||
} else {
|
||||
Err(std::io::Error::other(format!(
|
||||
"Failed to refresh token: {}",
|
||||
response.status()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest {
|
||||
client_id: &'static str,
|
||||
grant_type: &'static str,
|
||||
refresh_token: String,
|
||||
scope: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct RefreshResponse {
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Expected structure for $CODEX_HOME/auth.json.
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
pub struct AuthDotJson {
|
||||
#[serde(rename = "OPENAI_API_KEY")]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<TokenData>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::token_data::IdTokenInfo;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use base64::Engine;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
|
||||
const LAST_REFRESH: &str = "2025-08-06T20:41:36.232376Z";
|
||||
|
||||
#[test]
|
||||
fn writes_api_key_and_loads_auth() {
|
||||
let dir = tempdir().unwrap();
|
||||
login_with_api_key(dir.path(), "sk-test-key").unwrap();
|
||||
let auth = load_auth(dir.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key.as_deref(), Some("sk-test-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loads_from_env_var_if_env_var_exists() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let env_var = std::env::var(OPENAI_API_KEY_ENV_VAR);
|
||||
|
||||
if let Ok(env_var) = env_var {
|
||||
let auth = load_auth(dir.path(), true, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some(env_var));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_auth_dot_json() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let file = get_auth_file(codex_home.path());
|
||||
let auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
write_auth_json(&file, &auth_dot_json).unwrap();
|
||||
|
||||
let same_auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
assert_eq!(auth_dot_json, same_auth_dot_json);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_no_api_key_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
} = load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// Even if the OPENAI_API_KEY is set in auth.json, if the plan is not in
|
||||
/// [`TokenData::is_plan_that_should_use_api_key`], it should use
|
||||
/// [`AuthMode::ChatGPT`].
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_api_key_still_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
} = load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// If the OPENAI_API_KEY is set in auth.json and it is an enterprise
|
||||
/// account, then it should use [`AuthMode::ApiKey`].
|
||||
#[tokio::test]
|
||||
async fn enterprise_account_with_api_key_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "enterprise".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
} = load_auth(codex_home.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(Some("sk-test-key".to_string()), api_key);
|
||||
assert_eq!(AuthMode::ApiKey, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().expect("should unwrap");
|
||||
assert!(guard.is_none(), "auth_dot_json should be None");
|
||||
}
|
||||
|
||||
struct AuthFileParams {
|
||||
openai_api_key: Option<String>,
|
||||
chatgpt_plan_type: String,
|
||||
}
|
||||
|
||||
fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result<String> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
// Create a minimal valid JWT for the id_token field.
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_account_id": "bc3618e3-489d-4d49-9362-1561dc53ba53",
|
||||
"chatgpt_plan_type": params.chatgpt_plan_type,
|
||||
"chatgpt_user_id": "user-12345",
|
||||
"user_id": "user-12345",
|
||||
}
|
||||
});
|
||||
let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b);
|
||||
let header_b64 = b64(&serde_json::to_vec(&header)?);
|
||||
let payload_b64 = b64(&serde_json::to_vec(&payload)?);
|
||||
let signature_b64 = b64(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let auth_json_data = json!({
|
||||
"OPENAI_API_KEY": params.openai_api_key,
|
||||
"tokens": {
|
||||
"id_token": fake_jwt,
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token"
|
||||
},
|
||||
"last_refresh": LAST_REFRESH,
|
||||
});
|
||||
let auth_json = serde_json::to_string_pretty(&auth_json_data)?;
|
||||
std::fs::write(auth_file, auth_json)?;
|
||||
|
||||
Ok(fake_jwt)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
// Payload without email or plan should yield None values.
|
||||
let header = serde_json::json!({"alg": "none", "typ": "JWT"});
|
||||
let payload = serde_json::json!({"sub": "123"});
|
||||
let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
||||
.encode(serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
||||
.encode(serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"sig");
|
||||
let jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_id_token(&jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.chatgpt_plan_type.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let auth_file = dir.path().join("auth.json");
|
||||
std::fs::write(
|
||||
auth_file,
|
||||
r#"
|
||||
{
|
||||
"OPENAI_API_KEY": "sk-test-key",
|
||||
"tokens": null,
|
||||
"last_refresh": null
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = load_auth(dir.path(), false, AuthMode::ChatGPT)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
assert!(auth.get_token_data().await.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logout_removes_auth_file() -> Result<(), std::io::Error> {
|
||||
let dir = tempdir()?;
|
||||
login_with_api_key(dir.path(), "sk-test-key")?;
|
||||
assert!(dir.path().join("auth.json").exists());
|
||||
let removed = logout(dir.path())?;
|
||||
assert!(removed);
|
||||
assert!(!dir.path().join("auth.json").exists());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::AuthDotJson;
|
||||
use crate::get_auth_file;
|
||||
use crate::pkce::PkceCodes;
|
||||
use crate::pkce::generate_pkce;
|
||||
use base64::Engine;
|
||||
use chrono::Utc;
|
||||
use codex_core::auth::AuthDotJson;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::token_data::TokenData;
|
||||
use codex_core::token_data::parse_id_token;
|
||||
use rand::RngCore;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Request;
|
||||
@@ -28,10 +35,11 @@ pub struct ServerOptions {
|
||||
pub port: u16,
|
||||
pub open_browser: bool,
|
||||
pub force_state: Option<String>,
|
||||
pub originator: String,
|
||||
}
|
||||
|
||||
impl ServerOptions {
|
||||
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
|
||||
pub fn new(codex_home: PathBuf, client_id: String, originator: String) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
client_id: client_id.to_string(),
|
||||
@@ -39,6 +47,7 @@ impl ServerOptions {
|
||||
port: DEFAULT_PORT,
|
||||
open_browser: true,
|
||||
force_state: None,
|
||||
originator,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,7 +90,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let pkce = generate_pkce();
|
||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||
|
||||
let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?;
|
||||
let server = bind_server(opts.port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
@@ -94,7 +103,14 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let server = Arc::new(server);
|
||||
|
||||
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
|
||||
let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
|
||||
let auth_url = build_authorize_url(
|
||||
&opts.issuer,
|
||||
&opts.client_id,
|
||||
&redirect_uri,
|
||||
&pkce,
|
||||
&state,
|
||||
&opts.originator,
|
||||
);
|
||||
|
||||
if opts.open_browser {
|
||||
let _ = webbrowser::open(&auth_url);
|
||||
@@ -134,19 +150,24 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||
match response {
|
||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
||||
let exit_result = match response {
|
||||
HandledRequest::Response(response) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
None
|
||||
}
|
||||
HandledRequest::ResponseAndExit { response, result } => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
Some(result)
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if is_login_complete {
|
||||
break Ok(());
|
||||
if let Some(result) = exit_result {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,7 +191,10 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
enum HandledRequest {
|
||||
Response(Response<Cursor<Vec<u8>>>),
|
||||
RedirectWithHeader(Header),
|
||||
ResponseAndExit(Response<Cursor<Vec<u8>>>),
|
||||
ResponseAndExit {
|
||||
response: Response<Cursor<Vec<u8>>>,
|
||||
result: io::Result<()>,
|
||||
},
|
||||
}
|
||||
|
||||
async fn process_request(
|
||||
@@ -265,8 +289,18 @@ async fn process_request(
|
||||
) {
|
||||
resp.add_header(h);
|
||||
}
|
||||
HandledRequest::ResponseAndExit(resp)
|
||||
HandledRequest::ResponseAndExit {
|
||||
response: resp,
|
||||
result: Ok(()),
|
||||
}
|
||||
}
|
||||
"/cancel" => HandledRequest::ResponseAndExit {
|
||||
response: Response::from_string("Login cancelled"),
|
||||
result: Err(io::Error::new(
|
||||
io::ErrorKind::Interrupted,
|
||||
"Login cancelled",
|
||||
)),
|
||||
},
|
||||
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
||||
}
|
||||
}
|
||||
@@ -277,6 +311,7 @@ fn build_authorize_url(
|
||||
redirect_uri: &str,
|
||||
pkce: &PkceCodes,
|
||||
state: &str,
|
||||
originator: &str,
|
||||
) -> String {
|
||||
let query = vec![
|
||||
("response_type", "code"),
|
||||
@@ -288,6 +323,7 @@ fn build_authorize_url(
|
||||
("id_token_add_organizations", "true"),
|
||||
("codex_cli_simplified_flow", "true"),
|
||||
("state", state),
|
||||
("originator", originator),
|
||||
];
|
||||
let qs = query
|
||||
.into_iter()
|
||||
@@ -303,6 +339,68 @@ fn generate_state() -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
fn send_cancel_request(port: u16) -> io::Result<()> {
|
||||
let addr: SocketAddr = format!("127.0.0.1:{port}")
|
||||
.parse()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(2)))?;
|
||||
|
||||
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
|
||||
stream.write_all(format!("Host: 127.0.0.1:{port}\r\n").as_bytes())?;
|
||||
stream.write_all(b"Connection: close\r\n\r\n")?;
|
||||
|
||||
let mut buf = [0u8; 64];
|
||||
let _ = stream.read(&mut buf);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_server(port: u16) -> io::Result<Server> {
|
||||
let bind_address = format!("127.0.0.1:{port}");
|
||||
let mut cancel_attempted = false;
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: u32 = 10;
|
||||
const RETRY_DELAY: Duration = Duration::from_millis(200);
|
||||
|
||||
loop {
|
||||
match Server::http(&bind_address) {
|
||||
Ok(server) => return Ok(server),
|
||||
Err(err) => {
|
||||
attempts += 1;
|
||||
let is_addr_in_use = err
|
||||
.downcast_ref::<io::Error>()
|
||||
.map(|io_err| io_err.kind() == io::ErrorKind::AddrInUse)
|
||||
.unwrap_or(false);
|
||||
|
||||
// If the address is in use, there is probably another instance of the login server
|
||||
// running. Attempt to cancel it and retry.
|
||||
if is_addr_in_use {
|
||||
if !cancel_attempted {
|
||||
cancel_attempted = true;
|
||||
if let Err(cancel_err) = send_cancel_request(port) {
|
||||
eprintln!("Failed to cancel previous login server: {cancel_err}");
|
||||
}
|
||||
}
|
||||
|
||||
thread::sleep(RETRY_DELAY);
|
||||
|
||||
if attempts >= MAX_ATTEMPTS {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
format!("Port {bind_address} is already in use"),
|
||||
));
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExchangedTokens {
|
||||
id_token: String,
|
||||
access_token: String,
|
||||
@@ -374,10 +472,8 @@ async fn persist_tokens_async(
|
||||
if let Some(key) = api_key {
|
||||
auth.openai_api_key = Some(key);
|
||||
}
|
||||
let tokens = auth
|
||||
.tokens
|
||||
.get_or_insert_with(crate::token_data::TokenData::default);
|
||||
tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?;
|
||||
let tokens = auth.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(io::Error::other)?;
|
||||
// Persist chatgpt_account_id if present in claims
|
||||
if let Some(acc) = jwt_auth_claims(&id_token)
|
||||
.get("chatgpt_account_id")
|
||||
@@ -392,14 +488,14 @@ async fn persist_tokens_async(
|
||||
tokens.refresh_token = rt;
|
||||
}
|
||||
auth.last_refresh = Some(Utc::now());
|
||||
super::write_auth_json(&auth_file, &auth)
|
||||
codex_core::auth::write_auth_json(&auth_file, &auth)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| io::Error::other(format!("persist task failed: {e}")))?
|
||||
}
|
||||
|
||||
fn read_or_default(path: &Path) -> AuthDotJson {
|
||||
match super::try_read_auth_json(path) {
|
||||
match codex_core::auth::try_read_auth_json(path) {
|
||||
Ok(auth) => auth,
|
||||
Err(_) => AuthDotJson {
|
||||
openai_api_key: None,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpListener;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use codex_login::ServerOptions;
|
||||
@@ -100,6 +102,7 @@ async fn end_to_end_login_flow_persists_auth_json() {
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some(state),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
let server = run_login_server(opts).unwrap();
|
||||
let login_port = server.actual_port;
|
||||
@@ -158,6 +161,7 @@ async fn creates_missing_codex_home_dir() {
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some(state),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
let server = run_login_server(opts).unwrap();
|
||||
let login_port = server.actual_port;
|
||||
@@ -175,3 +179,67 @@ async fn creates_missing_codex_home_dir() {
|
||||
"auth.json should be created even if parent dir was missing"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn cancels_previous_login_server_when_port_is_in_use() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let (issuer_addr, _issuer_handle) = start_mock_issuer();
|
||||
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
|
||||
|
||||
let first_tmp = tempdir().unwrap();
|
||||
let first_codex_home = first_tmp.path().to_path_buf();
|
||||
|
||||
let first_opts = ServerOptions {
|
||||
codex_home: first_codex_home,
|
||||
client_id: codex_login::CLIENT_ID.to_string(),
|
||||
issuer: issuer.clone(),
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state".to_string()),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
|
||||
let first_server = run_login_server(first_opts).unwrap();
|
||||
let login_port = first_server.actual_port;
|
||||
let first_server_task = tokio::spawn(async move { first_server.block_until_done().await });
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let second_tmp = tempdir().unwrap();
|
||||
let second_codex_home = second_tmp.path().to_path_buf();
|
||||
|
||||
let second_opts = ServerOptions {
|
||||
codex_home: second_codex_home,
|
||||
client_id: codex_login::CLIENT_ID.to_string(),
|
||||
issuer,
|
||||
port: login_port,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state_2".to_string()),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
|
||||
let second_server = run_login_server(second_opts).unwrap();
|
||||
assert_eq!(second_server.actual_port, login_port);
|
||||
|
||||
let cancel_result = first_server_task
|
||||
.await
|
||||
.expect("first login server task panicked")
|
||||
.expect_err("login server should report cancellation");
|
||||
assert_eq!(cancel_result.kind(), io::ErrorKind::Interrupted);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let cancel_url = format!("http://127.0.0.1:{login_port}/cancel");
|
||||
let resp = client.get(cancel_url).send().await.unwrap();
|
||||
assert!(resp.status().is_success());
|
||||
|
||||
second_server
|
||||
.block_until_done()
|
||||
.await
|
||||
.expect_err("second login server should report cancellation");
|
||||
}
|
||||
|
||||
@@ -1,39 +1,32 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::git_info::git_diff_to_remote;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_login::AuthManager;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GitDiffToRemoteResponse;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotification;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::Cursor as RolloutCursor;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::SessionMeta;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::git_info::git_diff_to_remote;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem as CoreInputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_login::CLIENT_ID;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_login::ServerOptions as LoginServerOptions;
|
||||
use codex_login::ShutdownHandle;
|
||||
use codex_login::run_login_server;
|
||||
@@ -42,27 +35,51 @@ use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::ConversationSummary;
|
||||
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
|
||||
use codex_protocol::mcp_protocol::ExecArbitraryCommandResponse;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::GetConfigTomlResponse;
|
||||
use codex_protocol::mcp_protocol::ExecOneOffCommandParams;
|
||||
use codex_protocol::mcp_protocol::GetUserAgentResponse;
|
||||
use codex_protocol::mcp_protocol::GetUserSavedConfigResponse;
|
||||
use codex_protocol::mcp_protocol::GitDiffToRemoteResponse;
|
||||
use codex_protocol::mcp_protocol::InputItem as WireInputItem;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationResponse;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsResponse;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationResponse;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationSubscriptionResponse;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
||||
use codex_protocol::mcp_protocol::ServerNotification;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InputMessageKind;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Duration before a ChatGPT login attempt is abandoned.
|
||||
const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
@@ -88,7 +105,7 @@ pub(crate) struct CodexMessageProcessor {
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
pending_interrupts: Arc<Mutex<HashMap<Uuid, Vec<RequestId>>>>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<ConversationId, Vec<RequestId>>>>,
|
||||
}
|
||||
|
||||
impl CodexMessageProcessor {
|
||||
@@ -119,6 +136,12 @@ impl CodexMessageProcessor {
|
||||
// created before processing any subsequent messages.
|
||||
self.process_new_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ListConversations { request_id, params } => {
|
||||
self.handle_list_conversations(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ResumeConversation { request_id, params } => {
|
||||
self.handle_resume_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::SendUserMessage { request_id, params } => {
|
||||
self.send_user_message(request_id, params).await;
|
||||
}
|
||||
@@ -149,8 +172,14 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GetAuthStatus { request_id, params } => {
|
||||
self.get_auth_status(request_id, params).await;
|
||||
}
|
||||
ClientRequest::GetConfigToml { request_id } => {
|
||||
self.get_config_toml(request_id).await;
|
||||
ClientRequest::GetUserSavedConfig { request_id } => {
|
||||
self.get_user_saved_config(request_id).await;
|
||||
}
|
||||
ClientRequest::GetUserAgent { request_id } => {
|
||||
self.get_user_agent(request_id).await;
|
||||
}
|
||||
ClientRequest::ExecOneOffCommand { request_id, params } => {
|
||||
self.exec_one_off_command(request_id, params).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -160,7 +189,11 @@ impl CodexMessageProcessor {
|
||||
|
||||
let opts = LoginServerOptions {
|
||||
open_browser: false,
|
||||
..LoginServerOptions::new(config.codex_home.clone(), CLIENT_ID.to_string())
|
||||
..LoginServerOptions::new(
|
||||
config.codex_home.clone(),
|
||||
CLIENT_ID.to_string(),
|
||||
config.responses_originator_header.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
enum LoginChatGptReply {
|
||||
@@ -360,7 +393,13 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_config_toml(&self, request_id: RequestId) {
|
||||
async fn get_user_agent(&self, request_id: RequestId) {
|
||||
let user_agent = get_codex_user_agent(Some(&self.config.responses_originator_header));
|
||||
let response = GetUserAgentResponse { user_agent };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_user_saved_config(&self, request_id: RequestId) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
@@ -387,33 +426,82 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
let profiles: HashMap<String, codex_protocol::config_types::ConfigProfile> = cfg
|
||||
.profiles
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
k,
|
||||
// Define this explicitly here to avoid the need to
|
||||
// implement `From<codex_core::config_profile::ConfigProfile>`
|
||||
// for the `ConfigProfile` type and introduce a dependency on codex_core
|
||||
codex_protocol::config_types::ConfigProfile {
|
||||
model: v.model,
|
||||
approval_policy: v.approval_policy,
|
||||
model_reasoning_effort: v.model_reasoning_effort,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let user_saved_config: UserSavedConfig = cfg.into();
|
||||
|
||||
let response = GetConfigTomlResponse {
|
||||
approval_policy: cfg.approval_policy,
|
||||
sandbox_mode: cfg.sandbox_mode,
|
||||
model_reasoning_effort: cfg.model_reasoning_effort,
|
||||
profile: cfg.profile,
|
||||
profiles: Some(profiles),
|
||||
let response = GetUserSavedConfigResponse {
|
||||
config: user_saved_config,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
||||
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
||||
|
||||
if params.command.is_empty() {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "command must not be empty".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let cwd = params.cwd.unwrap_or_else(|| self.config.cwd.clone());
|
||||
let env = create_env(&self.config.shell_environment_policy);
|
||||
let timeout_ms = params.timeout_ms;
|
||||
let exec_params = ExecParams {
|
||||
command: params.command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
let effective_policy = params
|
||||
.sandbox_policy
|
||||
.unwrap_or_else(|| self.config.sandbox_policy.clone());
|
||||
|
||||
let sandbox_type = match &effective_policy {
|
||||
codex_core::protocol::SandboxPolicy::DangerFullAccess => {
|
||||
codex_core::exec::SandboxType::None
|
||||
}
|
||||
_ => get_platform_sandbox().unwrap_or(codex_core::exec::SandboxType::None),
|
||||
};
|
||||
tracing::debug!("Sandbox type: {sandbox_type:?}");
|
||||
let codex_linux_sandbox_exe = self.config.codex_linux_sandbox_exe.clone();
|
||||
let outgoing = self.outgoing.clone();
|
||||
let req_id = request_id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
match codex_core::exec::process_exec_tool_call(
|
||||
exec_params,
|
||||
sandbox_type,
|
||||
&effective_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
let response = ExecArbitraryCommandResponse {
|
||||
exit_code: output.exit_code,
|
||||
stdout: output.stdout.text,
|
||||
stderr: output.stderr.text,
|
||||
};
|
||||
outgoing.send_response(req_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("exec failed: {err}"),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(req_id, error).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn process_new_conversation(&self, request_id: RequestId, params: NewConversationParams) {
|
||||
@@ -438,7 +526,7 @@ impl CodexMessageProcessor {
|
||||
..
|
||||
} = conversation_id;
|
||||
let response = NewConversationResponse {
|
||||
conversation_id: ConversationId(conversation_id),
|
||||
conversation_id,
|
||||
model: session_configured.model,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
@@ -454,6 +542,133 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_conversations(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ListConversationsParams,
|
||||
) {
|
||||
let page_size = params.page_size.unwrap_or(25);
|
||||
// Decode the optional cursor string to a Cursor via serde (Cursor implements Deserialize from string)
|
||||
let cursor_obj: Option<RolloutCursor> = match params.cursor {
|
||||
Some(s) => serde_json::from_str::<RolloutCursor>(&format!("\"{s}\"")).ok(),
|
||||
None => None,
|
||||
};
|
||||
let cursor_ref = cursor_obj.as_ref();
|
||||
|
||||
let page = match RolloutRecorder::list_conversations(
|
||||
&self.config.codex_home,
|
||||
page_size,
|
||||
cursor_ref,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to list conversations: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let items = page
|
||||
.items
|
||||
.into_iter()
|
||||
.filter_map(|it| extract_conversation_summary(it.path, &it.head))
|
||||
.collect();
|
||||
|
||||
// Encode next_cursor as a plain string
|
||||
let next_cursor = match page.next_cursor {
|
||||
Some(c) => match serde_json::to_value(&c) {
|
||||
Ok(serde_json::Value::String(s)) => Some(s),
|
||||
_ => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
let response = ListConversationsResponse { items, next_cursor };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn handle_resume_conversation(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ResumeConversationParams,
|
||||
) {
|
||||
// Derive a Config using the same logic as new conversation, honoring overrides if provided.
|
||||
let config = match params.overrides {
|
||||
Some(overrides) => {
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone())
|
||||
}
|
||||
None => Ok(self.config.as_ref().clone()),
|
||||
};
|
||||
let config = match config {
|
||||
Ok(cfg) => cfg,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match self
|
||||
.conversation_manager
|
||||
.resume_conversation_from_rollout(
|
||||
config,
|
||||
params.path.clone(),
|
||||
self.auth_manager.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(NewConversation {
|
||||
conversation_id,
|
||||
session_configured,
|
||||
..
|
||||
}) => {
|
||||
let event = Event {
|
||||
id: "".to_string(),
|
||||
msg: EventMsg::SessionConfigured(session_configured.clone()),
|
||||
};
|
||||
self.outgoing.send_event_as_notification(&event, None).await;
|
||||
let initial_messages = session_configured.initial_messages.map(|msgs| {
|
||||
msgs.into_iter()
|
||||
.filter(|event| {
|
||||
// Don't send non-plain user messages (like user instructions
|
||||
// or environment context) back so they don't get rendered.
|
||||
if let EventMsg::UserMessage(user_message) = event {
|
||||
return matches!(user_message.kind, Some(InputMessageKind::Plain));
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Reply with conversation id + model and initial messages (when present)
|
||||
let response = codex_protocol::mcp_protocol::ResumeConversationResponse {
|
||||
conversation_id,
|
||||
model: session_configured.model.clone(),
|
||||
initial_messages,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("error resuming conversation: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) {
|
||||
let SendUserMessageParams {
|
||||
conversation_id,
|
||||
@@ -461,7 +676,7 @@ impl CodexMessageProcessor {
|
||||
} = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -509,7 +724,7 @@ impl CodexMessageProcessor {
|
||||
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -555,7 +770,7 @@ impl CodexMessageProcessor {
|
||||
let InterruptConversationParams { conversation_id } = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -570,7 +785,7 @@ impl CodexMessageProcessor {
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(conversation_id.0).or_default().push(request_id);
|
||||
map.entry(conversation_id).or_default().push(request_id);
|
||||
}
|
||||
|
||||
// Submit the interrupt; we'll respond upon TurnAborted.
|
||||
@@ -585,12 +800,12 @@ impl CodexMessageProcessor {
|
||||
let AddConversationListenerParams { conversation_id } = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("conversation not found: {}", conversation_id.0),
|
||||
message: format!("conversation not found: {conversation_id}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -620,18 +835,18 @@ impl CodexMessageProcessor {
|
||||
};
|
||||
|
||||
// For now, we send a notification for every event,
|
||||
// JSON-serializing the `Event` as-is, but we will move
|
||||
// to creating a special enum for notifications with a
|
||||
// stable wire format.
|
||||
// JSON-serializing the `Event` as-is, but these should
|
||||
// be migrated to be variants of `ServerNotification`
|
||||
// instead.
|
||||
let method = format!("codex/event/{}", event.msg);
|
||||
let mut params = match serde_json::to_value(event.clone()) {
|
||||
Ok(serde_json::Value::Object(map)) => map,
|
||||
Ok(_) => {
|
||||
tracing::error!("event did not serialize to an object");
|
||||
error!("event did not serialize to an object");
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("failed to serialize event: {err}");
|
||||
error!("failed to serialize event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -703,7 +918,7 @@ async fn apply_bespoke_event_handling(
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<Uuid, Vec<RequestId>>>>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<ConversationId, Vec<RequestId>>>>,
|
||||
) {
|
||||
let Event { id: event_id, msg } = event;
|
||||
match msg {
|
||||
@@ -756,7 +971,7 @@ async fn apply_bespoke_event_handling(
|
||||
EventMsg::TurnAborted(turn_aborted_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_interrupts.lock().await;
|
||||
map.remove(&conversation_id.0).unwrap_or_default()
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
};
|
||||
if !pending.is_empty() {
|
||||
let response = InterruptConversationResponse {
|
||||
@@ -799,7 +1014,6 @@ fn derive_config_from_params(
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -815,7 +1029,7 @@ fn derive_config_from_params(
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
receiver: oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<CodexConversation>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
@@ -857,14 +1071,14 @@ async fn on_patch_approval_response(
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
receiver: oneshot::Receiver<mcp_types::Result>,
|
||||
conversation: Arc<CodexConversation>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::error!("request failed: {err:?}");
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -890,3 +1104,100 @@ async fn on_exec_approval_response(
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_conversation_summary(
|
||||
path: PathBuf,
|
||||
head: &[serde_json::Value],
|
||||
) -> Option<ConversationSummary> {
|
||||
let session_meta = match head.first() {
|
||||
Some(first_line) => match serde_json::from_value::<SessionMeta>(first_line.clone()) {
|
||||
Ok(session_meta) => session_meta,
|
||||
Err(..) => return None,
|
||||
},
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let preview = head
|
||||
.iter()
|
||||
.filter_map(|value| serde_json::from_value::<ResponseItem>(value.clone()).ok())
|
||||
.find_map(|item| match item {
|
||||
ResponseItem::Message { content, .. } => {
|
||||
content.into_iter().find_map(|content| match content {
|
||||
ContentItem::InputText { text } => {
|
||||
match InputMessageKind::from(("user", &text)) {
|
||||
InputMessageKind::Plain => Some(text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
})?;
|
||||
|
||||
let preview = match preview.find(USER_MESSAGE_BEGIN) {
|
||||
Some(idx) => preview[idx + USER_MESSAGE_BEGIN.len()..].trim(),
|
||||
None => preview.as_str(),
|
||||
};
|
||||
|
||||
let timestamp = if session_meta.timestamp.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(session_meta.timestamp.clone())
|
||||
};
|
||||
|
||||
Some(ConversationSummary {
|
||||
conversation_id: session_meta.id,
|
||||
timestamp,
|
||||
path,
|
||||
preview: preview.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn extract_conversation_summary_prefers_plain_user_messages() {
|
||||
let conversation_id =
|
||||
ConversationId(Uuid::parse_str("3f941c35-29b3-493b-b0a4-e25800d9aeb0").unwrap());
|
||||
let timestamp = Some("2025-09-05T16:53:11.850Z".to_string());
|
||||
let path = PathBuf::from("rollout.jsonl");
|
||||
|
||||
let head = vec![
|
||||
json!({
|
||||
"id": conversation_id.0,
|
||||
"timestamp": timestamp,
|
||||
}),
|
||||
json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_text",
|
||||
"text": "<user_instructions>\n<AGENTS.md contents>\n</user_instructions>".to_string(),
|
||||
}],
|
||||
}),
|
||||
json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_text",
|
||||
"text": format!("<prior context> {USER_MESSAGE_BEGIN}Count to 5"),
|
||||
}],
|
||||
}),
|
||||
];
|
||||
|
||||
let summary = extract_conversation_summary(path.clone(), &head).expect("summary");
|
||||
|
||||
assert_eq!(summary.conversation_id, conversation_id);
|
||||
assert_eq!(
|
||||
summary.timestamp,
|
||||
Some("2025-09-05T16:53:11.850Z".to_string())
|
||||
);
|
||||
assert_eq!(summary.path, path);
|
||||
assert_eq!(summary.preview, "Count to 5");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,6 @@ impl CodexToolCallParam {
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool: None,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -182,8 +181,8 @@ impl CodexToolCallParam {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CodexToolCallReplyParam {
|
||||
/// The *session id* for this conversation.
|
||||
pub session_id: String,
|
||||
/// The conversation id for this Codex session.
|
||||
pub conversation_id: String,
|
||||
|
||||
/// The *next user prompt* to continue the Codex conversation.
|
||||
pub prompt: String,
|
||||
@@ -214,7 +213,8 @@ pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool {
|
||||
input_schema: tool_input_schema,
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Continue a Codex session by providing the session id and prompt.".to_string(),
|
||||
"Continue a Codex conversation by providing the conversation id and prompt."
|
||||
.to_string(),
|
||||
),
|
||||
annotations: None,
|
||||
}
|
||||
@@ -309,21 +309,21 @@ mod tests {
|
||||
let tool = create_tool_for_codex_tool_call_reply_param();
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"description": "Continue a Codex session by providing the session id and prompt.",
|
||||
"description": "Continue a Codex conversation by providing the conversation id and prompt.",
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"conversationId": {
|
||||
"description": "The conversation id for this Codex session.",
|
||||
"type": "string"
|
||||
},
|
||||
"prompt": {
|
||||
"description": "The *next user prompt* to continue the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"sessionId": {
|
||||
"description": "The *session id* for this conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"conversationId",
|
||||
"prompt",
|
||||
"sessionId",
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
@@ -18,18 +22,13 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
|
||||
pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
|
||||
@@ -43,7 +42,7 @@ pub async fn run_codex_tool_session(
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
conversation_manager: Arc<ConversationManager>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
) {
|
||||
let NewConversation {
|
||||
conversation_id,
|
||||
@@ -119,13 +118,13 @@ pub async fn run_codex_tool_session_reply(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
prompt: String,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
session_id: Uuid,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
conversation_id: ConversationId,
|
||||
) {
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.clone(), session_id);
|
||||
.insert(request_id.clone(), conversation_id);
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text { text: prompt }],
|
||||
@@ -154,7 +153,7 @@ async fn run_codex_tool_session_inner(
|
||||
codex: Arc<CodexConversation>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
@@ -279,6 +278,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::UserMessage(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
|
||||
@@ -9,11 +9,12 @@ use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_login::AuthManager;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ClientRequest as McpClientRequest;
|
||||
@@ -41,7 +42,7 @@ pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
conversation_manager: Arc<ConversationManager>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -53,8 +54,11 @@ impl MessageProcessor {
|
||||
config: Arc<Config>,
|
||||
) -> Self {
|
||||
let outgoing = Arc::new(outgoing);
|
||||
let auth_manager =
|
||||
AuthManager::shared(config.codex_home.clone(), config.preferred_auth_method);
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
);
|
||||
let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone()));
|
||||
let codex_message_processor = CodexMessageProcessor::new(
|
||||
auth_manager,
|
||||
@@ -433,7 +437,10 @@ impl MessageProcessor {
|
||||
tracing::info!("tools/call -> params: {:?}", arguments);
|
||||
|
||||
// parse arguments
|
||||
let CodexToolCallReplyParam { session_id, prompt } = match arguments {
|
||||
let CodexToolCallReplyParam {
|
||||
conversation_id,
|
||||
prompt,
|
||||
} = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallReplyParam>(json_val) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
@@ -454,12 +461,12 @@ impl MessageProcessor {
|
||||
},
|
||||
None => {
|
||||
tracing::error!(
|
||||
"Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required."
|
||||
"Missing arguments for codex-reply tool-call; the `conversation_id` and `prompt` fields are required."
|
||||
);
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `conversation_id` and `prompt` fields are required.".to_owned(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -470,14 +477,14 @@ impl MessageProcessor {
|
||||
return;
|
||||
}
|
||||
};
|
||||
let session_id = match Uuid::parse_str(&session_id) {
|
||||
Ok(id) => id,
|
||||
let conversation_id = match Uuid::parse_str(&conversation_id) {
|
||||
Ok(id) => ConversationId::from(id),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse session_id: {e}");
|
||||
tracing::error!("Failed to parse conversation_id: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse session_id: {e}"),
|
||||
text: format!("Failed to parse conversation_id: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -493,14 +500,18 @@ impl MessageProcessor {
|
||||
let outgoing = self.outgoing.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
let codex = match self.conversation_manager.get_conversation(session_id).await {
|
||||
let codex = match self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
tracing::warn!("Session not found for conversation_id: {conversation_id}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Session not found for session_id: {session_id}"),
|
||||
text: format!("Session not found for conversation_id: {conversation_id}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -525,7 +536,7 @@ impl MessageProcessor {
|
||||
request_id,
|
||||
prompt,
|
||||
running_requests_id_to_codex_uuid,
|
||||
session_id,
|
||||
conversation_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -561,24 +572,28 @@ impl MessageProcessor {
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
// Obtain the session_id while holding the first lock, then release.
|
||||
let session_id = {
|
||||
// Obtain the conversation id while holding the first lock, then release.
|
||||
let conversation_id = {
|
||||
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||
match map_guard.get(&request_id) {
|
||||
Some(id) => *id, // Uuid is Copy
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!("session_id: {session_id}");
|
||||
tracing::info!("conversation_id: {conversation_id}");
|
||||
|
||||
// Obtain the Codex conversation from the server.
|
||||
let codex_arc = match self.conversation_manager.get_conversation(session_id).await {
|
||||
let codex_arc = match self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
tracing::warn!("Session not found for conversation_id: {conversation_id}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -97,6 +97,9 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
}
|
||||
|
||||
/// This is used with the MCP server, but not the more general JSON-RPC app
|
||||
/// server. Prefer [`OutgoingMessageSender::send_server_notification`] where
|
||||
/// possible.
|
||||
pub(crate) async fn send_event_as_notification(
|
||||
&self,
|
||||
event: &Event,
|
||||
@@ -123,14 +126,9 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
let method = format!("codex/event/{notification}");
|
||||
let params = match serde_json::to_value(¬ification) {
|
||||
Ok(serde_json::Value::Object(mut map)) => map.remove("data"),
|
||||
_ => None,
|
||||
};
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Notification(OutgoingNotification { method, params });
|
||||
let _ = self.sender.send(outgoing_message);
|
||||
let _ = self
|
||||
.sender
|
||||
.send(OutgoingMessage::AppServerNotification(notification));
|
||||
}
|
||||
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
@@ -148,6 +146,9 @@ impl OutgoingMessageSender {
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(OutgoingRequest),
|
||||
Notification(OutgoingNotification),
|
||||
/// AppServerNotification is specific to the case where this is run as an
|
||||
/// "app server" as opposed to an MCP server.
|
||||
AppServerNotification(ServerNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
@@ -171,6 +172,21 @@ impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
params,
|
||||
})
|
||||
}
|
||||
AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
let params = match notification.to_params() {
|
||||
Ok(params) => Some(params),
|
||||
Err(err) => {
|
||||
warn!("failed to serialize notification params: {err}");
|
||||
None
|
||||
}
|
||||
};
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
@@ -242,6 +258,8 @@ pub(crate) struct OutgoingError {
|
||||
mod tests {
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
@@ -253,13 +271,15 @@ mod tests {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -284,11 +304,13 @@ mod tests {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let session_configured_event = SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
};
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
@@ -322,4 +344,29 @@ mod tests {
|
||||
});
|
||||
assert_eq!(params.unwrap(), expected_params);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_server_notification_serialization() {
|
||||
let notification =
|
||||
ServerNotification::LoginChatGptComplete(LoginChatGptCompleteNotification {
|
||||
login_id: Uuid::nil(),
|
||||
success: true,
|
||||
error: None,
|
||||
});
|
||||
|
||||
let jsonrpc_notification: JSONRPCMessage =
|
||||
OutgoingMessage::AppServerNotification(notification).into();
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: "2.0".into(),
|
||||
method: "loginChatGptComplete".into(),
|
||||
params: Some(json!({
|
||||
"loginId": Uuid::nil(),
|
||||
"success": true,
|
||||
})),
|
||||
}),
|
||||
jsonrpc_notification,
|
||||
"ensure the strum macros serialize the method field correctly"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,10 @@ use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
|
||||
@@ -240,9 +242,32 @@ impl McpProcess {
|
||||
self.send_request("getAuthStatus", params).await
|
||||
}
|
||||
|
||||
/// Send a `getConfigToml` JSON-RPC request.
|
||||
pub async fn send_get_config_toml_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getConfigToml", None).await
|
||||
/// Send a `getUserSavedConfig` JSON-RPC request.
|
||||
pub async fn send_get_user_saved_config_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getUserSavedConfig", None).await
|
||||
}
|
||||
|
||||
/// Send a `getUserAgent` JSON-RPC request.
|
||||
pub async fn send_get_user_agent_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getUserAgent", None).await
|
||||
}
|
||||
|
||||
/// Send a `listConversations` JSON-RPC request.
|
||||
pub async fn send_list_conversations_request(
|
||||
&mut self,
|
||||
params: ListConversationsParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("listConversations", params).await
|
||||
}
|
||||
|
||||
/// Send a `resumeConversation` JSON-RPC request.
|
||||
pub async fn send_resume_conversation_request(
|
||||
&mut self,
|
||||
params: ResumeConversationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("resumeConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginChatGpt` JSON-RPC request.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusResponse;
|
||||
|
||||
@@ -2,10 +2,15 @@ use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_protocol::config_types::ConfigProfile;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::mcp_protocol::GetConfigTomlResponse;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::mcp_protocol::GetUserSavedConfigResponse;
|
||||
use codex_protocol::mcp_protocol::Profile;
|
||||
use codex_protocol::mcp_protocol::SandboxSettings;
|
||||
use codex_protocol::mcp_protocol::Tools;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -21,22 +26,38 @@ fn create_config_toml(codex_home: &Path) -> std::io::Result<()> {
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
r#"
|
||||
model = "gpt-5"
|
||||
approval_policy = "on-request"
|
||||
sandbox_mode = "workspace-write"
|
||||
model_reasoning_summary = "detailed"
|
||||
model_reasoning_effort = "high"
|
||||
model_verbosity = "medium"
|
||||
profile = "test"
|
||||
|
||||
[sandbox_workspace_write]
|
||||
writable_roots = ["/tmp"]
|
||||
network_access = true
|
||||
exclude_tmpdir_env_var = true
|
||||
exclude_slash_tmp = true
|
||||
|
||||
[tools]
|
||||
web_search = false
|
||||
view_image = true
|
||||
|
||||
[profiles.test]
|
||||
model = "gpt-4o"
|
||||
approval_policy = "on-request"
|
||||
model_reasoning_effort = "high"
|
||||
model_reasoning_summary = "detailed"
|
||||
model_verbosity = "medium"
|
||||
model_provider = "openai"
|
||||
chatgpt_base_url = "https://api.chatgpt.com"
|
||||
"#,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_config_toml_returns_subset() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn get_config_toml_parses_all_fields() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
|
||||
@@ -49,32 +70,94 @@ async fn get_config_toml_returns_subset() {
|
||||
.expect("init failed");
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_config_toml_request()
|
||||
.send_get_user_saved_config_request()
|
||||
.await
|
||||
.expect("send getConfigToml");
|
||||
.expect("send getUserSavedConfig");
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("getConfigToml timeout")
|
||||
.expect("getConfigToml response");
|
||||
.expect("getUserSavedConfig timeout")
|
||||
.expect("getUserSavedConfig response");
|
||||
|
||||
let config: GetConfigTomlResponse = to_response(resp).expect("deserialize config");
|
||||
let expected = GetConfigTomlResponse {
|
||||
approval_policy: Some(AskForApproval::OnRequest),
|
||||
sandbox_mode: Some(SandboxMode::WorkspaceWrite),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
profile: Some("test".to_string()),
|
||||
profiles: Some(HashMap::from([(
|
||||
"test".into(),
|
||||
ConfigProfile {
|
||||
model: Some("gpt-4o".into()),
|
||||
approval_policy: Some(AskForApproval::OnRequest),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
},
|
||||
)])),
|
||||
let config: GetUserSavedConfigResponse = to_response(resp).expect("deserialize config");
|
||||
let expected = GetUserSavedConfigResponse {
|
||||
config: UserSavedConfig {
|
||||
approval_policy: Some(AskForApproval::OnRequest),
|
||||
sandbox_mode: Some(SandboxMode::WorkspaceWrite),
|
||||
sandbox_settings: Some(SandboxSettings {
|
||||
writable_roots: vec!["/tmp".into()],
|
||||
network_access: Some(true),
|
||||
exclude_tmpdir_env_var: Some(true),
|
||||
exclude_slash_tmp: Some(true),
|
||||
}),
|
||||
model: Some("gpt-5".into()),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
model_reasoning_summary: Some(ReasoningSummary::Detailed),
|
||||
model_verbosity: Some(Verbosity::Medium),
|
||||
tools: Some(Tools {
|
||||
web_search: Some(false),
|
||||
view_image: Some(true),
|
||||
}),
|
||||
profile: Some("test".to_string()),
|
||||
profiles: HashMap::from([(
|
||||
"test".into(),
|
||||
Profile {
|
||||
model: Some("gpt-4o".into()),
|
||||
approval_policy: Some(AskForApproval::OnRequest),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
model_reasoning_summary: Some(ReasoningSummary::Detailed),
|
||||
model_verbosity: Some(Verbosity::Medium),
|
||||
model_provider: Some("openai".into()),
|
||||
chatgpt_base_url: Some("https://api.chatgpt.com".into()),
|
||||
},
|
||||
)]),
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(expected, config);
|
||||
assert_eq!(config, expected);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_config_toml_empty() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_user_saved_config_request()
|
||||
.await
|
||||
.expect("send getUserSavedConfig");
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("getUserSavedConfig timeout")
|
||||
.expect("getUserSavedConfig response");
|
||||
|
||||
let config: GetUserSavedConfigResponse = to_response(resp).expect("deserialize config");
|
||||
let expected = GetUserSavedConfigResponse {
|
||||
config: UserSavedConfig {
|
||||
approval_policy: None,
|
||||
sandbox_mode: None,
|
||||
sandbox_settings: None,
|
||||
model: None,
|
||||
model_reasoning_effort: None,
|
||||
model_reasoning_summary: None,
|
||||
model_verbosity: None,
|
||||
tools: None,
|
||||
profile: None,
|
||||
profiles: HashMap::new(),
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(config, expected);
|
||||
}
|
||||
|
||||
181
codex-rs/mcp-server/tests/suite/list_resume.rs
Normal file
181
codex-rs/mcp-server/tests/suite/list_resume.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams; // reused for overrides shape
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
use uuid::Uuid;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_list_and_resume_conversations() {
|
||||
// Prepare a temporary CODEX_HOME with a few fake rollout files.
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_fake_rollout(
|
||||
codex_home.path(),
|
||||
"2025-01-02T12-00-00",
|
||||
"2025-01-02T12:00:00Z",
|
||||
"Hello A",
|
||||
);
|
||||
create_fake_rollout(
|
||||
codex_home.path(),
|
||||
"2025-01-01T13-00-00",
|
||||
"2025-01-01T13:00:00Z",
|
||||
"Hello B",
|
||||
);
|
||||
create_fake_rollout(
|
||||
codex_home.path(),
|
||||
"2025-01-01T12-00-00",
|
||||
"2025-01-01T12:00:00Z",
|
||||
"Hello C",
|
||||
);
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Request first page with size 2
|
||||
let req_id = mcp
|
||||
.send_list_conversations_request(ListConversationsParams {
|
||||
page_size: Some(2),
|
||||
cursor: None,
|
||||
})
|
||||
.await
|
||||
.expect("send listConversations");
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("listConversations timeout")
|
||||
.expect("listConversations resp");
|
||||
let ListConversationsResponse { items, next_cursor } =
|
||||
to_response::<ListConversationsResponse>(resp).expect("deserialize response");
|
||||
|
||||
assert_eq!(items.len(), 2);
|
||||
// Newest first; preview text should match
|
||||
assert_eq!(items[0].preview, "Hello A");
|
||||
assert_eq!(items[1].preview, "Hello B");
|
||||
assert!(items[0].path.is_absolute());
|
||||
assert!(next_cursor.is_some());
|
||||
|
||||
// Request the next page using the cursor
|
||||
let req_id2 = mcp
|
||||
.send_list_conversations_request(ListConversationsParams {
|
||||
page_size: Some(2),
|
||||
cursor: next_cursor,
|
||||
})
|
||||
.await
|
||||
.expect("send listConversations page 2");
|
||||
let resp2: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id2)),
|
||||
)
|
||||
.await
|
||||
.expect("listConversations page 2 timeout")
|
||||
.expect("listConversations page 2 resp");
|
||||
let ListConversationsResponse {
|
||||
items: items2,
|
||||
next_cursor: next2,
|
||||
..
|
||||
} = to_response::<ListConversationsResponse>(resp2).expect("deserialize response");
|
||||
assert_eq!(items2.len(), 1);
|
||||
assert_eq!(items2[0].preview, "Hello C");
|
||||
assert!(next2.is_some());
|
||||
|
||||
// Now resume one of the sessions and expect a SessionConfigured notification and response.
|
||||
let resume_req_id = mcp
|
||||
.send_resume_conversation_request(ResumeConversationParams {
|
||||
path: items[0].path.clone(),
|
||||
overrides: Some(NewConversationParams {
|
||||
model: Some("o3".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.expect("send resumeConversation");
|
||||
|
||||
// Expect a codex/event notification with msg.type == session_configured
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event"),
|
||||
)
|
||||
.await
|
||||
.expect("session_configured notification timeout")
|
||||
.expect("session_configured notification");
|
||||
// Basic shape assertion: ensure event type is session_configured
|
||||
let msg_type = notification
|
||||
.params
|
||||
.as_ref()
|
||||
.and_then(|p| p.get("msg"))
|
||||
.and_then(|m| m.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
assert_eq!(msg_type, "session_configured");
|
||||
|
||||
// Then the response for resumeConversation
|
||||
let resume_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(resume_req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("resumeConversation timeout")
|
||||
.expect("resumeConversation resp");
|
||||
let ResumeConversationResponse {
|
||||
conversation_id, ..
|
||||
} = to_response::<ResumeConversationResponse>(resume_resp)
|
||||
.expect("deserialize resumeConversation response");
|
||||
// conversation id should be a valid UUID
|
||||
let _: uuid::Uuid = conversation_id.into();
|
||||
}
|
||||
|
||||
fn create_fake_rollout(codex_home: &Path, filename_ts: &str, meta_rfc3339: &str, preview: &str) {
|
||||
let uuid = Uuid::new_v4();
|
||||
// sessions/YYYY/MM/DD/ derived from filename_ts (YYYY-MM-DDThh-mm-ss)
|
||||
let year = &filename_ts[0..4];
|
||||
let month = &filename_ts[5..7];
|
||||
let day = &filename_ts[8..10];
|
||||
let dir = codex_home.join("sessions").join(year).join(month).join(day);
|
||||
fs::create_dir_all(&dir).unwrap_or_else(|e| panic!("create sessions dir: {e}"));
|
||||
|
||||
let file_path = dir.join(format!("rollout-{filename_ts}-{uuid}.jsonl"));
|
||||
let mut lines = Vec::new();
|
||||
lines.push(
|
||||
json!({
|
||||
"record_type": "session_meta",
|
||||
"id": uuid,
|
||||
"timestamp": meta_rfc3339,
|
||||
"cwd": codex_home.to_string_lossy(),
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0-test"
|
||||
})
|
||||
.to_string(),
|
||||
);
|
||||
// Minimal user message entry as a persisted response item
|
||||
lines.push(
|
||||
json!({
|
||||
"type":"message",
|
||||
"role":"user",
|
||||
"content":[{"type":"input_text","text": preview}]
|
||||
})
|
||||
.to_string(),
|
||||
);
|
||||
fs::write(file_path, lines.join("\n") + "\n")
|
||||
.unwrap_or_else(|e| panic!("write rollout file: {e}"));
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
|
||||
@@ -5,5 +5,7 @@ mod codex_tool;
|
||||
mod config;
|
||||
mod create_conversation;
|
||||
mod interrupt;
|
||||
mod list_resume;
|
||||
mod login;
|
||||
mod send_message;
|
||||
mod user_agent;
|
||||
|
||||
45
codex-rs/mcp-server/tests/suite/user_agent.rs
Normal file
45
codex-rs/mcp-server/tests/suite/user_agent.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use codex_core::default_client::DEFAULT_ORIGINATOR;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_protocol::mcp_protocol::GetUserAgentResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_user_agent_returns_current_codex_user_agent() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|err| panic!("create tempdir: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("initialize timeout")
|
||||
.expect("initialize request");
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_user_agent_request()
|
||||
.await
|
||||
.expect("send getUserAgent");
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("getUserAgent timeout")
|
||||
.expect("getUserAgent response");
|
||||
|
||||
let received: GetUserAgentResponse =
|
||||
to_response(response).expect("deserialize getUserAgent response");
|
||||
let expected = GetUserAgentResponse {
|
||||
user_agent: get_codex_user_agent(Some(DEFAULT_ORIGINATOR)),
|
||||
};
|
||||
|
||||
assert_eq!(received, expected);
|
||||
}
|
||||
@@ -9,4 +9,4 @@ workspace = true
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
ts-rs = { version = "11", features = ["serde-json-impl"] }
|
||||
ts-rs = { version = "11", features = ["serde-json-impl", "no-serde-warnings"] }
|
||||
|
||||
@@ -20,33 +20,25 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
||||
codex_protocol::mcp_protocol::InputItem::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ClientRequest::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ServerRequest::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::NewConversationParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::NewConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::AddConversationListenerParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::AddConversationSubscriptionResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::RemoveConversationListenerParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::RemoveConversationSubscriptionResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserMessageParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserMessageResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserTurnParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserTurnResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::InterruptConversationParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::InterruptConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GitDiffToRemoteParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GitDiffToRemoteResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginChatGptCompleteNotification::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::CancelLoginChatGptParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::CancelLoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LogoutChatGptParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LogoutChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetAuthStatusParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetAuthStatusResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ApplyPatchApprovalParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ExecCommandApprovalParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ServerNotification::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ListConversationsResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ResumeConversationResponse::export_all_to(out_dir)?;
|
||||
|
||||
generate_index_ts(out_dir)?;
|
||||
|
||||
|
||||
@@ -12,15 +12,19 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.22.1"
|
||||
icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0.5"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
serde_with = { version = "3.14.0", features = ["macros", "base64"] }
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
sys-locale = "0.3.2"
|
||||
tracing = "0.1.41"
|
||||
ts-rs = { version = "11", features = ["uuid-impl", "serde-json-impl"] }
|
||||
ts-rs = { version = "11", features = ["uuid-impl", "serde-json-impl", "no-serde-warnings"] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -4,8 +4,6 @@ use strum_macros::Display;
|
||||
use strum_macros::EnumIter;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::protocol::AskForApproval;
|
||||
|
||||
/// See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning
|
||||
#[derive(
|
||||
Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display, TS, EnumIter,
|
||||
@@ -35,6 +33,18 @@ pub enum ReasoningSummary {
|
||||
None,
|
||||
}
|
||||
|
||||
/// Controls output length/detail on GPT-5 models via the Responses API.
|
||||
/// Serialized with lowercase values to match the OpenAI API.
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display, TS)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum Verbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize, Display, TS)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
@@ -49,13 +59,3 @@ pub enum SandboxMode {
|
||||
#[serde(rename = "danger-full-access")]
|
||||
DangerFullAccess,
|
||||
}
|
||||
|
||||
/// Collection of common configuration options that a user can define as a unit
|
||||
/// in `config.toml`. Currently only a subset of the fields are supported.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConfigProfile {
|
||||
pub model: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::PathBuf;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, TS)]
|
||||
pub struct CustomPrompt {
|
||||
pub name: String,
|
||||
pub path: PathBuf,
|
||||
|
||||
@@ -3,6 +3,7 @@ pub mod custom_prompts;
|
||||
pub mod mcp_protocol;
|
||||
pub mod message_history;
|
||||
pub mod models;
|
||||
pub mod num_format;
|
||||
pub mod parse_command;
|
||||
pub mod plan_tool;
|
||||
pub mod protocol;
|
||||
|
||||
@@ -2,11 +2,12 @@ use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config_types::ConfigProfile;
|
||||
use crate::config_types::ReasoningEffort;
|
||||
use crate::config_types::ReasoningSummary;
|
||||
use crate::config_types::SandboxMode;
|
||||
use crate::config_types::Verbosity;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::FileChange;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
@@ -18,16 +19,40 @@ use strum_macros::Display;
|
||||
use ts_rs::TS;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TS)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TS, Hash)]
|
||||
#[ts(type = "string")]
|
||||
pub struct ConversationId(pub Uuid);
|
||||
|
||||
impl ConversationId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConversationId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ConversationId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Uuid> for ConversationId {
|
||||
fn from(value: Uuid) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConversationId> for Uuid {
|
||||
fn from(value: ConversationId) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, TS)]
|
||||
#[ts(type = "string")]
|
||||
pub struct GitSha(pub String);
|
||||
@@ -54,6 +79,18 @@ pub enum ClientRequest {
|
||||
request_id: RequestId,
|
||||
params: NewConversationParams,
|
||||
},
|
||||
/// List recorded Codex conversations (rollouts) with optional pagination and search.
|
||||
ListConversations {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: ListConversationsParams,
|
||||
},
|
||||
/// Resume a recorded Codex conversation from a rollout file.
|
||||
ResumeConversation {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: ResumeConversationParams,
|
||||
},
|
||||
SendUserMessage {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
@@ -102,10 +139,20 @@ pub enum ClientRequest {
|
||||
request_id: RequestId,
|
||||
params: GetAuthStatusParams,
|
||||
},
|
||||
GetConfigToml {
|
||||
GetUserSavedConfig {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
},
|
||||
GetUserAgent {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
},
|
||||
/// Execute a command (argv vector) under the server's sandbox.
|
||||
ExecOneOffCommand {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: ExecOneOffCommandParams,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, TS)]
|
||||
@@ -158,6 +205,57 @@ pub struct NewConversationResponse {
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResumeConversationResponse {
|
||||
pub conversation_id: ConversationId,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub initial_messages: Option<Vec<EventMsg>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ListConversationsParams {
|
||||
/// Optional page size; defaults to a reasonable server-side value.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub page_size: Option<usize>,
|
||||
/// Opaque pagination cursor returned by a previous call.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConversationSummary {
|
||||
pub conversation_id: ConversationId,
|
||||
pub path: PathBuf,
|
||||
pub preview: String,
|
||||
/// RFC3339 timestamp string for the session start, if available.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timestamp: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ListConversationsResponse {
|
||||
pub items: Vec<ConversationSummary>,
|
||||
/// Opaque cursor to pass to the next call to continue after the last item.
|
||||
/// if None, there are no more items to return.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResumeConversationParams {
|
||||
/// Absolute path to the rollout JSONL file.
|
||||
pub path: PathBuf,
|
||||
/// Optional overrides to apply when spawning the resumed session.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub overrides: Option<NewConversationParams>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AddConversationSubscriptionResponse {
|
||||
@@ -218,6 +316,30 @@ pub struct GetAuthStatusParams {
|
||||
pub refresh_token: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ExecOneOffCommandParams {
|
||||
/// Command argv to execute.
|
||||
pub command: Vec<String>,
|
||||
/// Timeout of the command in milliseconds.
|
||||
/// If not specified, a sensible default is used server-side.
|
||||
pub timeout_ms: Option<u64>,
|
||||
/// Optional working directory for the process. Defaults to server config cwd.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<PathBuf>,
|
||||
/// Optional explicit sandbox policy overriding the server default.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sandbox_policy: Option<SandboxPolicy>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ExecArbitraryCommandResponse {
|
||||
pub exit_code: i32,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetAuthStatusResponse {
|
||||
@@ -230,22 +352,87 @@ pub struct GetAuthStatusResponse {
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetConfigTomlResponse {
|
||||
pub struct GetUserAgentResponse {
|
||||
pub user_agent: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetUserSavedConfigResponse {
|
||||
pub config: UserSavedConfig,
|
||||
}
|
||||
|
||||
/// UserSavedConfig contains a subset of the config. It is meant to expose mcp
|
||||
/// client-configurable settings that can be specified in the NewConversation
|
||||
/// and SendUserTurn requests.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UserSavedConfig {
|
||||
/// Approvals
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sandbox_settings: Option<SandboxSettings>,
|
||||
|
||||
/// Relevant model configuration
|
||||
/// Model-specific configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
|
||||
/// Tools
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Tools>,
|
||||
|
||||
/// Profiles
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profile: Option<String>,
|
||||
#[serde(default)]
|
||||
pub profiles: HashMap<String, Profile>,
|
||||
}
|
||||
|
||||
/// MCP representation of a [`codex_core::config_profile::ConfigProfile`].
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Profile {
|
||||
pub model: Option<String>,
|
||||
/// The key in the `model_providers` map identifying the
|
||||
/// [`ModelProviderInfo`] to use.
|
||||
pub model_provider: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
}
|
||||
/// MCP representation of a [`codex_core::config::ToolsToml`].
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tools {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profiles: Option<HashMap<String, ConfigProfile>>,
|
||||
pub web_search: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub view_image: Option<bool>,
|
||||
}
|
||||
|
||||
/// MCP representation of a [`codex_core::config_types::SandboxWorkspaceWrite`].
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SandboxSettings {
|
||||
#[serde(default)]
|
||||
pub writable_roots: Vec<PathBuf>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub network_access: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_tmpdir_env_var: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_slash_tmp: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
@@ -398,8 +585,8 @@ pub struct AuthStatusChangeNotification {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS, Display)]
|
||||
#[serde(tag = "type", content = "data", rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
|
||||
#[strum(serialize_all = "camelCase")]
|
||||
pub enum ServerNotification {
|
||||
/// Authentication status changed
|
||||
AuthStatusChange(AuthStatusChangeNotification),
|
||||
@@ -408,6 +595,15 @@ pub enum ServerNotification {
|
||||
LoginChatGptComplete(LoginChatGptCompleteNotification),
|
||||
}
|
||||
|
||||
impl ServerNotification {
|
||||
pub fn to_params(self) -> Result<serde_json::Value, serde_json::Error> {
|
||||
match self {
|
||||
ServerNotification::AuthStatusChange(params) => serde_json::to_value(params),
|
||||
ServerNotification::LoginChatGptComplete(params) => serde_json::to_value(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -442,4 +638,10 @@ mod tests {
|
||||
serde_json::to_value(&request).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_id_default_is_not_zeroes() {
|
||||
let id = ConversationId::default();
|
||||
assert_ne!(id.0, Uuid::nil());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, TS)]
|
||||
pub struct HistoryEntry {
|
||||
pub session_id: String,
|
||||
pub conversation_id: String,
|
||||
pub ts: u64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
@@ -6,10 +6,11 @@ use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde::ser::Serializer;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::protocol::InputItem;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseInputItem {
|
||||
Message {
|
||||
@@ -30,7 +31,7 @@ pub enum ResponseInputItem {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentItem {
|
||||
InputText { text: String },
|
||||
@@ -38,15 +39,17 @@ pub enum ContentItem {
|
||||
OutputText { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseItem {
|
||||
Message {
|
||||
#[serde(skip_serializing)]
|
||||
id: Option<String>,
|
||||
role: String,
|
||||
content: Vec<ContentItem>,
|
||||
},
|
||||
Reasoning {
|
||||
#[serde(default)]
|
||||
id: String,
|
||||
summary: Vec<ReasoningItemReasoningSummary>,
|
||||
#[serde(default, skip_serializing_if = "should_serialize_reasoning_content")]
|
||||
@@ -55,6 +58,7 @@ pub enum ResponseItem {
|
||||
},
|
||||
LocalShellCall {
|
||||
/// Set when using the chat completions API.
|
||||
#[serde(skip_serializing)]
|
||||
id: Option<String>,
|
||||
/// Set when using the Responses API.
|
||||
call_id: Option<String>,
|
||||
@@ -62,6 +66,7 @@ pub enum ResponseItem {
|
||||
action: LocalShellAction,
|
||||
},
|
||||
FunctionCall {
|
||||
#[serde(skip_serializing)]
|
||||
id: Option<String>,
|
||||
name: String,
|
||||
// The Responses API returns the function call arguments as a *string* that contains
|
||||
@@ -82,7 +87,7 @@ pub enum ResponseItem {
|
||||
output: FunctionCallOutputPayload,
|
||||
},
|
||||
CustomToolCall {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[serde(skip_serializing)]
|
||||
id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
@@ -104,7 +109,7 @@ pub enum ResponseItem {
|
||||
// "action": {"type":"search","query":"weather: San Francisco, CA"}
|
||||
// }
|
||||
WebSearchCall {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[serde(skip_serializing)]
|
||||
id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
@@ -155,7 +160,7 @@ impl From<ResponseInputItem> for ResponseItem {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LocalShellStatus {
|
||||
Completed,
|
||||
@@ -163,13 +168,13 @@ pub enum LocalShellStatus {
|
||||
Incomplete,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum LocalShellAction {
|
||||
Exec(LocalShellExecAction),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
pub struct LocalShellExecAction {
|
||||
pub command: Vec<String>,
|
||||
pub timeout_ms: Option<u64>,
|
||||
@@ -178,7 +183,7 @@ pub struct LocalShellExecAction {
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum WebSearchAction {
|
||||
Search {
|
||||
@@ -188,13 +193,13 @@ pub enum WebSearchAction {
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ReasoningItemReasoningSummary {
|
||||
SummaryText { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ReasoningItemContent {
|
||||
ReasoningText { text: String },
|
||||
@@ -238,7 +243,7 @@ impl From<Vec<InputItem>> for ResponseInputItem {
|
||||
|
||||
/// If the `name` of a `ResponseItem::FunctionCall` is either `container.exec`
|
||||
/// or shell`, the `arguments` field should deserialize to this struct.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
pub struct ShellToolCallParams {
|
||||
pub command: Vec<String>,
|
||||
pub workdir: Option<String>,
|
||||
@@ -252,7 +257,7 @@ pub struct ShellToolCallParams {
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, TS)]
|
||||
pub struct FunctionCallOutputPayload {
|
||||
pub content: String,
|
||||
pub success: Option<bool>,
|
||||
@@ -309,6 +314,8 @@ impl std::ops::Deref for FunctionCallOutputPayload {
|
||||
}
|
||||
}
|
||||
|
||||
// (Moved event mapping logic into codex-core to avoid coupling protocol to UI-facing events.)
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
98
codex-rs/protocol/src/num_format.rs
Normal file
98
codex-rs/protocol/src/num_format.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use icu_decimal::DecimalFormatter;
|
||||
use icu_decimal::input::Decimal;
|
||||
use icu_decimal::options::DecimalFormatterOptions;
|
||||
use icu_locale_core::Locale;
|
||||
|
||||
fn make_local_formatter() -> Option<DecimalFormatter> {
|
||||
let loc: Locale = sys_locale::get_locale()?.parse().ok()?;
|
||||
DecimalFormatter::try_new(loc.into(), DecimalFormatterOptions::default()).ok()
|
||||
}
|
||||
|
||||
fn make_en_us_formatter() -> DecimalFormatter {
|
||||
#![allow(clippy::expect_used)]
|
||||
let loc: Locale = "en-US".parse().expect("en-US wasn't a valid locale");
|
||||
DecimalFormatter::try_new(loc.into(), DecimalFormatterOptions::default())
|
||||
.expect("en-US wasn't a valid locale")
|
||||
}
|
||||
|
||||
fn formatter() -> &'static DecimalFormatter {
|
||||
static FORMATTER: OnceLock<DecimalFormatter> = OnceLock::new();
|
||||
FORMATTER.get_or_init(|| make_local_formatter().unwrap_or_else(make_en_us_formatter))
|
||||
}
|
||||
|
||||
/// Format a u64 with locale-aware digit separators (e.g. "12345" -> "12,345"
|
||||
/// for en-US).
|
||||
pub fn format_with_separators(n: u64) -> String {
|
||||
formatter().format(&Decimal::from(n)).to_string()
|
||||
}
|
||||
|
||||
fn format_si_suffix_with_formatter(n: u64, formatter: &DecimalFormatter) -> String {
|
||||
if n < 1000 {
|
||||
return formatter.format(&Decimal::from(n)).to_string();
|
||||
}
|
||||
|
||||
// Format `n / scale` with the requested number of fractional digits.
|
||||
let format_scaled = |n: u64, scale: u64, frac_digits: u32| -> String {
|
||||
let value = n as f64 / scale as f64;
|
||||
let scaled: u64 = (value * 10f64.powi(frac_digits as i32)).round() as u64;
|
||||
let mut dec = Decimal::from(scaled);
|
||||
dec.multiply_pow10(-(frac_digits as i16));
|
||||
formatter.format(&dec).to_string()
|
||||
};
|
||||
|
||||
const UNITS: [(u64, &str); 3] = [(1_000, "K"), (1_000_000, "M"), (1_000_000_000, "G")];
|
||||
let f = n as f64;
|
||||
for &(scale, suffix) in &UNITS {
|
||||
if (100.0 * f / scale as f64).round() < 1000.0 {
|
||||
return format!("{}{}", format_scaled(n, scale, 2), suffix);
|
||||
} else if (10.0 * f / scale as f64).round() < 1000.0 {
|
||||
return format!("{}{}", format_scaled(n, scale, 1), suffix);
|
||||
} else if (f / scale as f64).round() < 1000.0 {
|
||||
return format!("{}{}", format_scaled(n, scale, 0), suffix);
|
||||
}
|
||||
}
|
||||
|
||||
// Above 1000G, keep whole‑G precision.
|
||||
format!(
|
||||
"{}G",
|
||||
format_with_separators(((n as f64) / 1e9).round() as u64)
|
||||
)
|
||||
}
|
||||
|
||||
/// Format token counts to 3 significant figures, using base-10 SI suffixes.
|
||||
///
|
||||
/// Examples (en-US):
|
||||
/// - 999 -> "999"
|
||||
/// - 1200 -> "1.20K"
|
||||
/// - 123456789 -> "123M"
|
||||
pub fn format_si_suffix(n: u64) -> String {
|
||||
format_si_suffix_with_formatter(n, formatter())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn kmg() {
|
||||
let formatter = make_en_us_formatter();
|
||||
let fmt = |n: u64| format_si_suffix_with_formatter(n, &formatter);
|
||||
assert_eq!(fmt(0), "0");
|
||||
assert_eq!(fmt(999), "999");
|
||||
assert_eq!(fmt(1_000), "1.00K");
|
||||
assert_eq!(fmt(1_200), "1.20K");
|
||||
assert_eq!(fmt(10_000), "10.0K");
|
||||
assert_eq!(fmt(100_000), "100K");
|
||||
assert_eq!(fmt(999_500), "1.00M");
|
||||
assert_eq!(fmt(1_000_000), "1.00M");
|
||||
assert_eq!(fmt(1_234_000), "1.23M");
|
||||
assert_eq!(fmt(12_345_678), "12.3M");
|
||||
assert_eq!(fmt(999_950_000), "1.00G");
|
||||
assert_eq!(fmt(1_000_000_000), "1.00G");
|
||||
assert_eq!(fmt(1_234_000_000), "1.23G");
|
||||
// Above 1000G we keep whole‑G precision (no higher unit supported here).
|
||||
assert_eq!(fmt(1_234_000_000_000), "1,234G");
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ParsedCommand {
|
||||
Read {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use ts_rs::TS;
|
||||
|
||||
// Types for the TODO tool arguments matching codex-vscode/todo-mcp/src/main.rs
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StepStatus {
|
||||
Pending,
|
||||
@@ -10,14 +11,14 @@ pub enum StepStatus {
|
||||
Completed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct PlanItemArg {
|
||||
pub step: String,
|
||||
pub status: StepStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct UpdatePlanArgs {
|
||||
#[serde(default)]
|
||||
|
||||
@@ -10,22 +10,29 @@ use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::custom_prompts::CustomPrompt;
|
||||
use crate::mcp_protocol::ConversationId;
|
||||
use crate::message_history::HistoryEntry;
|
||||
use crate::num_format::format_with_separators;
|
||||
use crate::parse_command::ParsedCommand;
|
||||
use crate::plan_tool::UpdatePlanArgs;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::Tool as McpTool;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_bytes::ByteBuf;
|
||||
use serde_with::serde_as;
|
||||
use strum_macros::Display;
|
||||
use ts_rs::TS;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::message_history::HistoryEntry;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::parse_command::ParsedCommand;
|
||||
use crate::plan_tool::UpdatePlanArgs;
|
||||
/// Open/close tags for special user-input blocks. Used across crates to avoid
|
||||
/// duplicated hardcoded strings.
|
||||
pub const USER_INSTRUCTIONS_OPEN_TAG: &str = "<user_instructions>";
|
||||
pub const USER_INSTRUCTIONS_CLOSE_TAG: &str = "</user_instructions>";
|
||||
pub const ENVIRONMENT_CONTEXT_OPEN_TAG: &str = "<environment_context>";
|
||||
pub const ENVIRONMENT_CONTEXT_CLOSE_TAG: &str = "</environment_context>";
|
||||
pub const USER_MESSAGE_BEGIN: &str = "## My request for Codex:";
|
||||
|
||||
/// Submission Queue Entry - requests from user
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -141,7 +148,7 @@ pub enum Op {
|
||||
|
||||
/// Request the full in-memory conversation transcript for the current session.
|
||||
/// Reply is delivered via `EventMsg::ConversationHistory`.
|
||||
GetHistory,
|
||||
GetConversationPath,
|
||||
|
||||
/// Request the list of MCP tools available across all configured servers.
|
||||
/// Reply is delivered via `EventMsg::McpListToolsResponse`.
|
||||
@@ -397,7 +404,7 @@ pub struct Event {
|
||||
}
|
||||
|
||||
/// Response event from the agent
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Display)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Display, TS)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum EventMsg {
|
||||
@@ -410,13 +417,16 @@ pub enum EventMsg {
|
||||
/// Agent has completed all actions
|
||||
TaskComplete(TaskCompleteEvent),
|
||||
|
||||
/// Token count event, sent periodically to report the number of tokens
|
||||
/// used in the current session.
|
||||
TokenCount(TokenUsage),
|
||||
/// Usage update for the current session, including totals and last turn.
|
||||
/// Optional means unknown — UIs should not display when `None`.
|
||||
TokenCount(TokenCountEvent),
|
||||
|
||||
/// Agent text output message
|
||||
AgentMessage(AgentMessageEvent),
|
||||
|
||||
/// User/system input message (what was sent to the model).
|
||||
UserMessage(UserMessageEvent),
|
||||
|
||||
/// Agent text output delta message
|
||||
AgentMessageDelta(AgentMessageDeltaEvent),
|
||||
|
||||
@@ -488,42 +498,87 @@ pub enum EventMsg {
|
||||
/// Notification that the agent is shutting down.
|
||||
ShutdownComplete,
|
||||
|
||||
ConversationHistory(ConversationHistoryResponseEvent),
|
||||
ConversationHistory(ConversationPathResponseEvent),
|
||||
}
|
||||
|
||||
// Individual event payload types matching each `EventMsg` variant.
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ErrorEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TaskCompleteEvent {
|
||||
pub last_agent_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TaskStartedEvent {
|
||||
pub model_context_window: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default, TS)]
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: u64,
|
||||
pub cached_input_tokens: Option<u64>,
|
||||
pub cached_input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub reasoning_output_tokens: Option<u64>,
|
||||
pub reasoning_output_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TokenUsageInfo {
|
||||
pub total_token_usage: TokenUsage,
|
||||
pub last_token_usage: TokenUsage,
|
||||
pub model_context_window: Option<u64>,
|
||||
}
|
||||
|
||||
impl TokenUsageInfo {
|
||||
pub fn new_or_append(
|
||||
info: &Option<TokenUsageInfo>,
|
||||
last: &Option<TokenUsage>,
|
||||
model_context_window: Option<u64>,
|
||||
) -> Option<Self> {
|
||||
if info.is_none() && last.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut info = match info {
|
||||
Some(info) => info.clone(),
|
||||
None => Self {
|
||||
total_token_usage: TokenUsage::default(),
|
||||
last_token_usage: TokenUsage::default(),
|
||||
model_context_window,
|
||||
},
|
||||
};
|
||||
if let Some(last) = last {
|
||||
info.append_last_usage(last);
|
||||
}
|
||||
Some(info)
|
||||
}
|
||||
|
||||
pub fn append_last_usage(&mut self, last: &TokenUsage) {
|
||||
self.total_token_usage.add_assign(last);
|
||||
self.last_token_usage = last.clone();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TokenCountEvent {
|
||||
pub info: Option<TokenUsageInfo>,
|
||||
}
|
||||
|
||||
// Includes prompts, tools and space to call compact.
|
||||
const BASELINE_TOKENS: u64 = 12000;
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.total_tokens == 0
|
||||
}
|
||||
|
||||
pub fn cached_input(&self) -> u64 {
|
||||
self.cached_input_tokens.unwrap_or(0)
|
||||
self.cached_input_tokens
|
||||
}
|
||||
|
||||
pub fn non_cached_input(&self) -> u64 {
|
||||
@@ -541,35 +596,40 @@ impl TokenUsage {
|
||||
/// This will be off for the current turn and pending function calls.
|
||||
pub fn tokens_in_context_window(&self) -> u64 {
|
||||
self.total_tokens
|
||||
.saturating_sub(self.reasoning_output_tokens.unwrap_or(0))
|
||||
.saturating_sub(self.reasoning_output_tokens)
|
||||
}
|
||||
|
||||
/// Estimate the remaining user-controllable percentage of the model's context window.
|
||||
///
|
||||
/// `context_window` is the total size of the model's context window.
|
||||
/// `baseline_used_tokens` should capture tokens that are always present in
|
||||
/// `BASELINE_TOKENS` should capture tokens that are always present in
|
||||
/// the context (e.g., system prompt and fixed tool instructions) so that
|
||||
/// the percentage reflects the portion the user can influence.
|
||||
///
|
||||
/// This normalizes both the numerator and denominator by subtracting the
|
||||
/// baseline, so immediately after the first prompt the UI shows 100% left
|
||||
/// and trends toward 0% as the user fills the effective window.
|
||||
pub fn percent_of_context_window_remaining(
|
||||
&self,
|
||||
context_window: u64,
|
||||
baseline_used_tokens: u64,
|
||||
) -> u8 {
|
||||
if context_window <= baseline_used_tokens {
|
||||
pub fn percent_of_context_window_remaining(&self, context_window: u64) -> u8 {
|
||||
if context_window <= BASELINE_TOKENS {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let effective_window = context_window - baseline_used_tokens;
|
||||
let effective_window = context_window - BASELINE_TOKENS;
|
||||
let used = self
|
||||
.tokens_in_context_window()
|
||||
.saturating_sub(baseline_used_tokens);
|
||||
.saturating_sub(BASELINE_TOKENS);
|
||||
let remaining = effective_window.saturating_sub(used);
|
||||
((remaining as f32 / effective_window as f32) * 100.0).clamp(0.0, 100.0) as u8
|
||||
}
|
||||
|
||||
/// In-place element-wise sum of token counts.
|
||||
pub fn add_assign(&mut self, other: &TokenUsage) {
|
||||
self.input_tokens += other.input_tokens;
|
||||
self.cached_input_tokens += other.cached_input_tokens;
|
||||
self.output_tokens += other.output_tokens;
|
||||
self.reasoning_output_tokens += other.reasoning_output_tokens;
|
||||
self.total_tokens += other.total_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -586,59 +646,108 @@ impl From<TokenUsage> for FinalOutput {
|
||||
impl fmt::Display for FinalOutput {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let token_usage = &self.token_usage;
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Token usage: total={} input={}{} output={}{}",
|
||||
token_usage.blended_total(),
|
||||
token_usage.non_cached_input(),
|
||||
format_with_separators(token_usage.blended_total()),
|
||||
format_with_separators(token_usage.non_cached_input()),
|
||||
if token_usage.cached_input() > 0 {
|
||||
format!(" (+ {} cached)", token_usage.cached_input())
|
||||
format!(
|
||||
" (+ {} cached)",
|
||||
format_with_separators(token_usage.cached_input())
|
||||
)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
token_usage.output_tokens,
|
||||
token_usage
|
||||
.reasoning_output_tokens
|
||||
.map(|r| format!(" (reasoning {r})"))
|
||||
.unwrap_or_default()
|
||||
format_with_separators(token_usage.output_tokens),
|
||||
if token_usage.reasoning_output_tokens > 0 {
|
||||
format!(
|
||||
" (reasoning {})",
|
||||
format_with_separators(token_usage.reasoning_output_tokens)
|
||||
)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentMessageEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum InputMessageKind {
|
||||
/// Plain user text (default)
|
||||
Plain,
|
||||
/// XML-wrapped user instructions (<user_instructions>...)
|
||||
UserInstructions,
|
||||
/// XML-wrapped environment context (<environment_context>...)
|
||||
EnvironmentContext,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct UserMessageEvent {
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub kind: Option<InputMessageKind>,
|
||||
}
|
||||
|
||||
impl<T, U> From<(T, U)> for InputMessageKind
|
||||
where
|
||||
T: AsRef<str>,
|
||||
U: AsRef<str>,
|
||||
{
|
||||
fn from(value: (T, U)) -> Self {
|
||||
let (_role, message) = value;
|
||||
let message = message.as_ref();
|
||||
let trimmed = message.trim();
|
||||
if trimmed.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG)
|
||||
&& trimmed.ends_with(ENVIRONMENT_CONTEXT_CLOSE_TAG)
|
||||
{
|
||||
InputMessageKind::EnvironmentContext
|
||||
} else if trimmed.starts_with(USER_INSTRUCTIONS_OPEN_TAG)
|
||||
&& trimmed.ends_with(USER_INSTRUCTIONS_CLOSE_TAG)
|
||||
{
|
||||
InputMessageKind::UserInstructions
|
||||
} else {
|
||||
InputMessageKind::Plain
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentMessageDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentReasoningEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentReasoningRawContentEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentReasoningRawContentDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentReasoningSectionBreakEvent {}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct AgentReasoningDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct McpInvocation {
|
||||
/// Name of the MCP server as defined in the config.
|
||||
pub server: String,
|
||||
@@ -648,18 +757,19 @@ pub struct McpInvocation {
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct McpToolCallBeginEvent {
|
||||
/// Identifier so this can be paired with the McpToolCallEnd event.
|
||||
pub call_id: String,
|
||||
pub invocation: McpInvocation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct McpToolCallEndEvent {
|
||||
/// Identifier for the corresponding McpToolCallBegin that finished.
|
||||
pub call_id: String,
|
||||
pub invocation: McpInvocation,
|
||||
#[ts(type = "string")]
|
||||
pub duration: Duration,
|
||||
/// Result of the tool call. Note this could be an error.
|
||||
pub result: Result<CallToolResult, String>,
|
||||
@@ -674,12 +784,12 @@ impl McpToolCallEndEvent {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct WebSearchBeginEvent {
|
||||
pub call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct WebSearchEndEvent {
|
||||
pub call_id: String,
|
||||
pub query: String,
|
||||
@@ -687,13 +797,13 @@ pub struct WebSearchEndEvent {
|
||||
|
||||
/// Response payload for `Op::GetHistory` containing the current session's
|
||||
/// in-memory transcript.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ConversationHistoryResponseEvent {
|
||||
pub conversation_id: Uuid,
|
||||
pub entries: Vec<ResponseItem>,
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ConversationPathResponseEvent {
|
||||
pub conversation_id: ConversationId,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ExecCommandBeginEvent {
|
||||
/// Identifier so this can be paired with the ExecCommandEnd event.
|
||||
pub call_id: String,
|
||||
@@ -704,7 +814,7 @@ pub struct ExecCommandBeginEvent {
|
||||
pub parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ExecCommandEndEvent {
|
||||
/// Identifier for the ExecCommandBegin that finished.
|
||||
pub call_id: String,
|
||||
@@ -718,30 +828,33 @@ pub struct ExecCommandEndEvent {
|
||||
/// The command's exit code.
|
||||
pub exit_code: i32,
|
||||
/// The duration of the command execution.
|
||||
#[ts(type = "string")]
|
||||
pub duration: Duration,
|
||||
/// Formatted output from the command, as seen by the model.
|
||||
pub formatted_output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ExecOutputStream {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, TS)]
|
||||
pub struct ExecCommandOutputDeltaEvent {
|
||||
/// Identifier for the ExecCommandBegin that produced this chunk.
|
||||
pub call_id: String,
|
||||
/// Which stream produced this chunk.
|
||||
pub stream: ExecOutputStream,
|
||||
/// Raw bytes from the stream (may not be valid UTF-8).
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub chunk: ByteBuf,
|
||||
#[serde_as(as = "serde_with::base64::Base64")]
|
||||
#[ts(type = "string")]
|
||||
pub chunk: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ExecApprovalRequestEvent {
|
||||
/// Identifier for the associated exec call, if available.
|
||||
pub call_id: String,
|
||||
@@ -754,7 +867,7 @@ pub struct ExecApprovalRequestEvent {
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ApplyPatchApprovalRequestEvent {
|
||||
/// Responses API call id for the associated patch apply call, if available.
|
||||
pub call_id: String,
|
||||
@@ -767,17 +880,17 @@ pub struct ApplyPatchApprovalRequestEvent {
|
||||
pub grant_root: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct BackgroundEventEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct StreamErrorEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct PatchApplyBeginEvent {
|
||||
/// Identifier so this can be paired with the PatchApplyEnd event.
|
||||
pub call_id: String,
|
||||
@@ -787,7 +900,7 @@ pub struct PatchApplyBeginEvent {
|
||||
pub changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct PatchApplyEndEvent {
|
||||
/// Identifier for the PatchApplyBegin that finished.
|
||||
pub call_id: String,
|
||||
@@ -799,12 +912,12 @@ pub struct PatchApplyEndEvent {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TurnDiffEvent {
|
||||
pub unified_diff: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct GetHistoryEntryResponseEvent {
|
||||
pub offset: usize,
|
||||
pub log_id: u64,
|
||||
@@ -814,22 +927,22 @@ pub struct GetHistoryEntryResponseEvent {
|
||||
}
|
||||
|
||||
/// Response payload for `Op::ListMcpTools`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct McpListToolsResponseEvent {
|
||||
/// Fully qualified tool name -> tool definition.
|
||||
pub tools: std::collections::HashMap<String, McpTool>,
|
||||
}
|
||||
|
||||
/// Response payload for `Op::ListCustomPrompts`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct ListCustomPromptsResponseEvent {
|
||||
pub custom_prompts: Vec<CustomPrompt>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Default, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct SessionConfiguredEvent {
|
||||
/// Unique id for this session.
|
||||
pub session_id: Uuid,
|
||||
/// Name left as session_id instead of conversation_id for backwards compatibility.
|
||||
pub session_id: ConversationId,
|
||||
|
||||
/// Tell the client what model is being queried.
|
||||
pub model: String,
|
||||
@@ -839,6 +952,11 @@ pub struct SessionConfiguredEvent {
|
||||
|
||||
/// Current number of entries in the history log.
|
||||
pub history_entry_count: usize,
|
||||
|
||||
/// Optional initial messages (as events) for resumed sessions.
|
||||
/// When present, UIs can use these to seed the history.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub initial_messages: Option<Vec<EventMsg>>,
|
||||
}
|
||||
|
||||
/// User's decision in response to an ExecApprovalRequest.
|
||||
@@ -878,7 +996,7 @@ pub enum FileChange {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct Chunk {
|
||||
/// 1-based line index of the first line in the original file
|
||||
pub orig_index: u32,
|
||||
@@ -886,7 +1004,7 @@ pub struct Chunk {
|
||||
pub inserted_lines: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct TurnAbortedEvent {
|
||||
pub reason: TurnAbortReason,
|
||||
}
|
||||
@@ -906,14 +1024,15 @@ mod tests {
|
||||
/// amount of nesting.
|
||||
#[test]
|
||||
fn serialize_event() {
|
||||
let session_id: Uuid = uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8");
|
||||
let conversation_id = ConversationId(uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"));
|
||||
let event = Event {
|
||||
id: "1234".to_string(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id,
|
||||
session_id: conversation_id,
|
||||
model: "codex-mini-latest".to_string(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
initial_messages: None,
|
||||
}),
|
||||
};
|
||||
let serialized = serde_json::to_string(&event).unwrap();
|
||||
@@ -922,4 +1041,21 @@ mod tests {
|
||||
r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"codex-mini-latest","history_log_id":0,"history_entry_count":0}}"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_u8_as_base64_serialization_and_deserialization() {
|
||||
let event = ExecCommandOutputDeltaEvent {
|
||||
call_id: "call21".to_string(),
|
||||
stream: ExecOutputStream::Stdout,
|
||||
chunk: vec![1, 2, 3, 4, 5],
|
||||
};
|
||||
let serialized = serde_json::to_string(&event).unwrap();
|
||||
assert_eq!(
|
||||
r#"{"call_id":"call21","stream":"stdout","chunk":"AQIDBAU="}"#,
|
||||
serialized,
|
||||
);
|
||||
|
||||
let deserialized: ExecCommandOutputDeltaEvent = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized, event);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user