mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
281 Commits
pr2883
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ed3c92ed1 | ||
|
|
a2cfb125dc | ||
|
|
002d877c02 | ||
|
|
c415827ac2 | ||
|
|
4e0550b995 | ||
|
|
f54a49157b | ||
|
|
dd56750612 | ||
|
|
8bc73a2bfd | ||
|
|
be366a31ab | ||
|
|
c75920a071 | ||
|
|
8daba53808 | ||
|
|
d2940bd4c3 | ||
|
|
76a9b11678 | ||
|
|
fa80bbb587 | ||
|
|
434eb4fd49 | ||
|
|
19f46439ae | ||
|
|
e258ca61b4 | ||
|
|
e5fe50d3ce | ||
|
|
14a115d488 | ||
|
|
5996ee0e5f | ||
|
|
a4ebd069e5 | ||
|
|
04504d8218 | ||
|
|
42d335deb8 | ||
|
|
ad0c2b4db3 | ||
|
|
ff389dc52f | ||
|
|
9b18875a42 | ||
|
|
881c7978f1 | ||
|
|
a7fda70053 | ||
|
|
de64f5f007 | ||
|
|
8595237505 | ||
|
|
62258df92f | ||
|
|
b34e906396 | ||
|
|
71038381aa | ||
|
|
277fc6254e | ||
|
|
992b531180 | ||
|
|
84a0ba9bf5 | ||
|
|
4a5d6f7c71 | ||
|
|
1b3c8b8e94 | ||
|
|
d4aba772cb | ||
|
|
4c97eeb32a | ||
|
|
c9505488a1 | ||
|
|
530382db05 | ||
|
|
208089e58e | ||
|
|
e5fdb5b0fd | ||
|
|
5332f6e215 | ||
|
|
5d87f5d24a | ||
|
|
791d7b125f | ||
|
|
72733e34c4 | ||
|
|
b8d2b1a576 | ||
|
|
7fe4021f95 | ||
|
|
11285655c4 | ||
|
|
244687303b | ||
|
|
5e2c4f7e35 | ||
|
|
a8026d3846 | ||
|
|
45bccd36b0 | ||
|
|
404c126fc3 | ||
|
|
88027552dd | ||
|
|
ca8bd09d56 | ||
|
|
39ed8a7d26 | ||
|
|
2df7f7efe5 | ||
|
|
0560079c41 | ||
|
|
0de154194d | ||
|
|
5c583fe89b | ||
|
|
cf63cbf153 | ||
|
|
b1c291e2bb | ||
|
|
934d728946 | ||
|
|
f037b2fd56 | ||
|
|
d60cbed691 | ||
|
|
6aafe37752 | ||
|
|
d555b68469 | ||
|
|
9baa5c33da | ||
|
|
fdf4a68646 | ||
|
|
adc9e1526b | ||
|
|
b9af1d2b16 | ||
|
|
2d52e3b40a | ||
|
|
6039f8a126 | ||
|
|
50262a44ce | ||
|
|
839b2ae7cf | ||
|
|
6a8e743d57 | ||
|
|
a797051921 | ||
|
|
d7d9d96d6c | ||
|
|
26f1246a89 | ||
|
|
6581da9b57 | ||
|
|
900bb01486 | ||
|
|
2ad6a37192 | ||
|
|
e5dd7f0934 | ||
|
|
b6673838e8 | ||
|
|
1823906215 | ||
|
|
5185d69f13 | ||
|
|
4dffa496ac | ||
|
|
ce984b2c71 | ||
|
|
c47febf221 | ||
|
|
76c37c5493 | ||
|
|
2aa84b8891 | ||
|
|
9177bdae5e | ||
|
|
a30e5e40ee | ||
|
|
99e1d33bd1 | ||
|
|
b2f6fc3b9a | ||
|
|
51f88fd04a | ||
|
|
916fdc2a37 | ||
|
|
863d9c237e | ||
|
|
7e1543f5d8 | ||
|
|
d701eb32d7 | ||
|
|
9baae77533 | ||
|
|
e932722292 | ||
|
|
bbea6bbf7e | ||
|
|
4891ee29c5 | ||
|
|
bac8a427f3 | ||
|
|
14ab1063a7 | ||
|
|
a77364bbaa | ||
|
|
19b4ed3c96 | ||
|
|
3d4acbaea0 | ||
|
|
414b8be8b6 | ||
|
|
90a0fd342f | ||
|
|
8d56d2f655 | ||
|
|
8408f3e8ed | ||
|
|
b8ccfe9b65 | ||
|
|
e3c6903199 | ||
|
|
5f6e95b592 | ||
|
|
a2e9cc5530 | ||
|
|
ea225df22e | ||
|
|
d4848e558b | ||
|
|
1a6a95fb2a | ||
|
|
c6fd056aa6 | ||
|
|
abdcb40f4c | ||
|
|
4ae6b9787a | ||
|
|
bba567cee9 | ||
|
|
ba6af23cb6 | ||
|
|
f805d17930 | ||
|
|
90965fbc84 | ||
|
|
c172e8e997 | ||
|
|
9bbeb75361 | ||
|
|
6ccd32c601 | ||
|
|
3b5a5412bb | ||
|
|
44bb53df1e | ||
|
|
8453915e02 | ||
|
|
44587c2443 | ||
|
|
8f7b22b652 | ||
|
|
027944c64e | ||
|
|
bec51f6c05 | ||
|
|
66967500bb | ||
|
|
167b4f0e25 | ||
|
|
167154178b | ||
|
|
674e3d3c90 | ||
|
|
114ce9ff4d | ||
|
|
e13b35ecb0 | ||
|
|
377af75730 | ||
|
|
86e0f31a7e | ||
|
|
8f837f1093 | ||
|
|
162e1235a8 | ||
|
|
c09ed74a16 | ||
|
|
65f3528cad | ||
|
|
44262d8fd8 | ||
|
|
95a9938d3a | ||
|
|
f69f07b028 | ||
|
|
8d766088e6 | ||
|
|
87654ec0b7 | ||
|
|
51d9e05de7 | ||
|
|
8068cc75f8 | ||
|
|
acb28bf914 | ||
|
|
97338de578 | ||
|
|
5200b7a95d | ||
|
|
64e6c4afbb | ||
|
|
39db113cc9 | ||
|
|
45bd5ca4b9 | ||
|
|
c13c3dadbf | ||
|
|
8636bff46d | ||
|
|
43809a454e | ||
|
|
5c48600bb3 | ||
|
|
de6559f2ab | ||
|
|
5bcc9d8b77 | ||
|
|
5eab4c7ab4 | ||
|
|
f656e192bf | ||
|
|
ee5ecae7c0 | ||
|
|
58bb2048ac | ||
|
|
ac8a3155d6 | ||
|
|
ace14e8d36 | ||
|
|
2a76a08a9e | ||
|
|
16309d6b68 | ||
|
|
62bd0e3d9d | ||
|
|
a9c68ea270 | ||
|
|
ac58749bd3 | ||
|
|
79cbd2ab1b | ||
|
|
5eaaf307e1 | ||
|
|
18330c2362 | ||
|
|
4c46490e53 | ||
|
|
5c1416d99b | ||
|
|
0525b48baa | ||
|
|
1f4f9cde8e | ||
|
|
cad37009e1 | ||
|
|
e2b3053b2b | ||
|
|
e47bd33689 | ||
|
|
6b878bea01 | ||
|
|
ca46510fd3 | ||
|
|
6efb52e545 | ||
|
|
d84a799ec0 | ||
|
|
c8fab51372 | ||
|
|
58d77ca4e7 | ||
|
|
0269096229 | ||
|
|
70a6d4b1b4 | ||
|
|
b1d5f7c0bd | ||
|
|
066c6cce02 | ||
|
|
bd65f81e54 | ||
|
|
ba9620aea7 | ||
|
|
45c3b20041 | ||
|
|
6cfc012e9d | ||
|
|
17a80d43c8 | ||
|
|
c11696f6b1 | ||
|
|
5775174ec2 | ||
|
|
ba631e7928 | ||
|
|
db3834733a | ||
|
|
d6182becbe | ||
|
|
323a5cb7e7 | ||
|
|
3f40fbc0a8 | ||
|
|
742feaf40f | ||
|
|
907d3dd348 | ||
|
|
7df9e9c664 | ||
|
|
b795fbe244 | ||
|
|
82ed7bd285 | ||
|
|
1c04e1314d | ||
|
|
bef7ed0ccc | ||
|
|
be23fe1353 | ||
|
|
2073fa7139 | ||
|
|
e60a44cbab | ||
|
|
075e385969 | ||
|
|
aa083b795d | ||
|
|
91708bb031 | ||
|
|
82dfec5b10 | ||
|
|
1e82bf9d98 | ||
|
|
0a83db5512 | ||
|
|
bd4fa85507 | ||
|
|
234c0a0469 | ||
|
|
0f4ae1b5b0 | ||
|
|
2b96f9f569 | ||
|
|
f2036572b6 | ||
|
|
bea64569c1 | ||
|
|
e83c5f429c | ||
|
|
ed0d23d560 | ||
|
|
4ae45a6c8d | ||
|
|
6b83c1c3f3 | ||
|
|
db5276f8e6 | ||
|
|
77fb9f3465 | ||
|
|
0e827b6598 | ||
|
|
daaadfb260 | ||
|
|
c636f821ae | ||
|
|
af338cc505 | ||
|
|
97000c6e6d | ||
|
|
fb5dfe3396 | ||
|
|
a56eb48195 | ||
|
|
d77b33ded7 | ||
|
|
9ad2e726fc | ||
|
|
6aa306c584 | ||
|
|
44dce748b6 | ||
|
|
d489690efe | ||
|
|
3f76220055 | ||
|
|
90725fe3d5 | ||
|
|
53413c728e | ||
|
|
b127a3643f | ||
|
|
a93a907c7e | ||
|
|
03e2796ca4 | ||
|
|
051f185ce3 | ||
|
|
6f75114695 | ||
|
|
3baccba0ac | ||
|
|
578ff09e17 | ||
|
|
0d5ffb000e | ||
|
|
431a10fc50 | ||
|
|
8b993b557d | ||
|
|
60fdfc5f14 | ||
|
|
13e5b567f5 | ||
|
|
46e35a2345 | ||
|
|
7bcdc5cc7c | ||
|
|
4b426f7e1e | ||
|
|
fcb62a0fa5 | ||
|
|
eb40fe3451 | ||
|
|
b32c79e371 | ||
|
|
e442ecedab | ||
|
|
3f8d6021ac | ||
|
|
7ac6194c22 | ||
|
|
619436c58f | ||
|
|
1cc6b97227 | ||
|
|
7eee69d821 |
4
.github/dotslash-config.json
vendored
4
.github/dotslash-config.json
vendored
@@ -21,6 +21,10 @@
|
||||
"windows-x86_64": {
|
||||
"regex": "^codex-x86_64-pc-windows-msvc\\.exe\\.zst$",
|
||||
"path": "codex.exe"
|
||||
},
|
||||
"windows-aarch64": {
|
||||
"regex": "^codex-aarch64-pc-windows-msvc\\.exe\\.zst$",
|
||||
"path": "codex.exe"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -1,6 +1,6 @@
|
||||
# External (non-OpenAI) Pull Request Requirements
|
||||
|
||||
Before opening this Pull Request, please read the "Contributing" section of the README or your PR may be closed:
|
||||
https://github.com/openai/codex#contributing
|
||||
Before opening this Pull Request, please read the dedicated "Contributing" markdown file or your PR may be closed:
|
||||
https://github.com/openai/codex/blob/main/docs/contributing.md
|
||||
|
||||
If your PR conforms to our contribution guidelines, replace this text with a detailed and high quality description of your changes.
|
||||
|
||||
23
.github/workflows/ci.yml
vendored
23
.github/workflows/ci.yml
vendored
@@ -14,33 +14,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.8.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "store_path=$(pnpm store path --silent)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v4
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.store_path }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
node-version: 22
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Run all tasks using workspace filters
|
||||
|
||||
|
||||
1
.github/workflows/codespell.yml
vendored
1
.github/workflows/codespell.yml
vendored
@@ -25,3 +25,4 @@ jobs:
|
||||
uses: codespell-project/actions-codespell@406322ec52dd7b488e48c1c4b82e2a8b3a1bf630 # v2
|
||||
with:
|
||||
ignore_words_file: .codespellignore
|
||||
skip: frame*.txt
|
||||
|
||||
46
.github/workflows/rust-ci.yml
vendored
46
.github/workflows/rust-ci.yml
vendored
@@ -62,6 +62,26 @@ jobs:
|
||||
components: rustfmt
|
||||
- name: cargo fmt
|
||||
run: cargo fmt -- --config imports_granularity=Item --check
|
||||
- name: Verify codegen for mcp-types
|
||||
run: ./mcp-types/check_lib_rs.py
|
||||
|
||||
cargo_shear:
|
||||
name: cargo shear
|
||||
runs-on: ubuntu-24.04
|
||||
needs: changed
|
||||
if: ${{ needs.changed.outputs.codex == 'true' || needs.changed.outputs.workflows == 'true' || github.event_name == 'push' }}
|
||||
defaults:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: cargo-shear
|
||||
version: 1.5.1
|
||||
- name: cargo shear
|
||||
run: cargo shear
|
||||
|
||||
# --- CI to validate on different os/targets --------------------------------
|
||||
lint_build_test:
|
||||
@@ -100,15 +120,26 @@ jobs:
|
||||
- runner: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
profile: dev
|
||||
- runner: windows-11-arm
|
||||
target: aarch64-pc-windows-msvc
|
||||
profile: dev
|
||||
|
||||
# Also run representative release builds on Mac and Linux because
|
||||
# there could be release-only build errors we want to catch.
|
||||
# Hopefully this also pre-populates the build cache to speed up
|
||||
# releases.
|
||||
- runner: macos-14
|
||||
target: aarch64-apple-darwin
|
||||
profile: release
|
||||
- runner: ubuntu-24.04
|
||||
target: x86_64-unknown-linux-musl
|
||||
profile: release
|
||||
- runner: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
profile: release
|
||||
- runner: windows-11-arm
|
||||
target: aarch64-pc-windows-msvc
|
||||
profile: release
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -149,12 +180,17 @@ jobs:
|
||||
find . -name Cargo.toml -mindepth 2 -maxdepth 2 -print0 \
|
||||
| xargs -0 -n1 -I{} bash -c 'cd "$(dirname "{}")" && cargo check --profile ${{ matrix.profile }}'
|
||||
|
||||
- name: cargo test
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: nextest
|
||||
version: 0.9.103
|
||||
|
||||
- name: tests
|
||||
id: test
|
||||
# `cargo test` takes too long for release builds to run them on every PR
|
||||
# Tests take too long for release builds to run them on every PR.
|
||||
if: ${{ matrix.profile != 'release' }}
|
||||
continue-on-error: true
|
||||
run: cargo test --all-features --target ${{ matrix.target }} --profile ${{ matrix.profile }}
|
||||
run: cargo nextest run --all-features --no-fail-fast --target ${{ matrix.target }}
|
||||
env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
@@ -171,7 +207,7 @@ jobs:
|
||||
# --- Gatherer job that you mark as the ONLY required status -----------------
|
||||
results:
|
||||
name: CI results (required)
|
||||
needs: [changed, general, lint_build_test]
|
||||
needs: [changed, general, cargo_shear, lint_build_test]
|
||||
if: always()
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
@@ -179,6 +215,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "general: ${{ needs.general.result }}"
|
||||
echo "shear : ${{ needs.cargo_shear.result }}"
|
||||
echo "matrix : ${{ needs.lint_build_test.result }}"
|
||||
|
||||
# If nothing relevant changed (PR touching only root README, etc.),
|
||||
@@ -190,4 +227,5 @@ jobs:
|
||||
|
||||
# Otherwise require the jobs to have succeeded
|
||||
[[ '${{ needs.general.result }}' == 'success' ]] || { echo 'general failed'; exit 1; }
|
||||
[[ '${{ needs.cargo_shear.result }}' == 'success' ]] || { echo 'cargo_shear failed'; exit 1; }
|
||||
[[ '${{ needs.lint_build_test.result }}' == 'success' ]] || { echo 'matrix failed'; exit 1; }
|
||||
|
||||
76
.github/workflows/rust-release.yml
vendored
76
.github/workflows/rust-release.yml
vendored
@@ -72,6 +72,8 @@ jobs:
|
||||
target: aarch64-unknown-linux-gnu
|
||||
- runner: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
- runner: windows-11-arm
|
||||
target: aarch64-pc-windows-msvc
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -87,7 +89,7 @@ jobs:
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-release-${{ matrix.runner }}-${{ matrix.target }}-release-${{ hashFiles('**/Cargo.lock') }}
|
||||
key: cargo-${{ matrix.runner }}-${{ matrix.target }}-release-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
@@ -109,6 +111,11 @@ jobs:
|
||||
cp target/${{ matrix.target }}/release/codex "$dest/codex-${{ matrix.target }}"
|
||||
fi
|
||||
|
||||
- if: ${{ matrix.runner == 'windows-11-arm' }}
|
||||
name: Install zstd
|
||||
shell: powershell
|
||||
run: choco install -y zstandard
|
||||
|
||||
- name: Compress artifacts
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -160,6 +167,12 @@ jobs:
|
||||
needs: build
|
||||
name: release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
actions: read
|
||||
outputs:
|
||||
version: ${{ steps.release_name.outputs.name }}
|
||||
tag: ${{ github.ref_name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -212,3 +225,64 @@ jobs:
|
||||
with:
|
||||
tag: ${{ github.ref_name }}
|
||||
config: .github/dotslash-config.json
|
||||
|
||||
# Publish to npm using OIDC authentication.
|
||||
# July 31, 2025: https://github.blog/changelog/2025-07-31-npm-trusted-publishing-with-oidc-is-generally-available/
|
||||
# npm docs: https://docs.npmjs.com/trusted-publishers
|
||||
publish-npm:
|
||||
# Skip this step for pre-releases (alpha/beta).
|
||||
if: ${{ !contains(needs.release.outputs.version, '-') }}
|
||||
name: publish-npm
|
||||
needs: release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
id-token: write # Required for OIDC
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
with:
|
||||
node-version: 22
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
scope: "@openai"
|
||||
|
||||
# Trusted publishing requires npm CLI version 11.5.1 or later.
|
||||
- name: Update npm
|
||||
run: npm install -g npm@latest
|
||||
|
||||
- name: Download npm tarball from release
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
version="${{ needs.release.outputs.version }}"
|
||||
tag="${{ needs.release.outputs.tag }}"
|
||||
mkdir -p dist/npm
|
||||
gh release download "$tag" \
|
||||
--repo "${GITHUB_REPOSITORY}" \
|
||||
--pattern "codex-npm-${version}.tgz" \
|
||||
--dir dist/npm
|
||||
|
||||
# No NODE_AUTH_TOKEN needed because we use OIDC.
|
||||
- name: Publish to npm
|
||||
run: npm publish "${GITHUB_WORKSPACE}/dist/npm/codex-npm-${{ needs.release.outputs.version }}.tgz"
|
||||
|
||||
update-branch:
|
||||
name: Update latest-alpha-cli branch
|
||||
permissions:
|
||||
contents: write
|
||||
needs: release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Update latest-alpha-cli branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh api \
|
||||
repos/${GITHUB_REPOSITORY}/git/refs/heads/latest-alpha-cli \
|
||||
-X PATCH \
|
||||
-f sha="${GITHUB_SHA}" \
|
||||
-F force=true
|
||||
|
||||
6
.vscode/extensions.json
vendored
6
.vscode/extensions.json
vendored
@@ -1,5 +1,11 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"rust-lang.rust-analyzer",
|
||||
"tamasfe.even-better-toml",
|
||||
"vadimcn.vscode-lldb",
|
||||
|
||||
// Useful if touching files in .github/workflows, though most
|
||||
// contributors will not be doing that?
|
||||
// "github.vscode-github-actions",
|
||||
]
|
||||
}
|
||||
|
||||
30
AGENTS.md
30
AGENTS.md
@@ -4,14 +4,15 @@ In the codex-rs folder where the rust code lives:
|
||||
|
||||
- Crate names are prefixed with `codex-`. For example, the `core` folder's crate is named `codex-core`
|
||||
- When using format! and you can inline variables into {}, always do that.
|
||||
- Install any commands the repo relies on (for example `just`, `rg`, or `cargo-insta`) if they aren't already available before running instructions here.
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` or `CODEX_SANDBOX_ENV_VAR`.
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
|
||||
Before finalizing a change to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
1. Run the test for the specific project that was changed. For example, if changes were made in `codex-rs/tui`, run `cargo test -p codex-tui`.
|
||||
2. Once those pass, if any changes were made in common, core, or protocol, run the complete test suite with `cargo test --all-features`.
|
||||
When running interactively, ask the user before running these commands to finalize.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
|
||||
## TUI style conventions
|
||||
|
||||
@@ -26,7 +27,26 @@ See `codex-rs/tui/styles.md`.
|
||||
- Example: patch summary file lines
|
||||
- Desired: vec![" └ ".into(), "M".red(), " ".dim(), "tui/src/app.rs".dim()]
|
||||
|
||||
## Snapshot tests
|
||||
### TUI Styling (ratatui)
|
||||
- Prefer Stylize helpers: use "text".dim(), .bold(), .cyan(), .italic(), .underlined() instead of manual Style where possible.
|
||||
- Prefer simple conversions: use "text".into() for spans and vec![…].into() for lines; when inference is ambiguous (e.g., Paragraph::new/Cell::from), use Line::from(spans) or Span::from(text).
|
||||
- Computed styles: if the Style is computed at runtime, using `Span::styled` is OK (`Span::from(text).set_style(style)` is also acceptable).
|
||||
- Avoid hardcoded white: do not use `.white()`; prefer the default foreground (no color).
|
||||
- Chaining: combine helpers by chaining for readability (e.g., url.cyan().underlined()).
|
||||
- Single items: prefer "text".into(); use Line::from(text) or Span::from(text) only when the target type isn’t obvious from context, or when using .into() would require extra type annotations.
|
||||
- Building lines: use vec![…].into() to construct a Line when the target type is obvious and no extra type annotations are needed; otherwise use Line::from(vec![…]).
|
||||
- Avoid churn: don’t refactor between equivalent forms (Span::styled ↔ set_style, Line::from ↔ .into()) without a clear readability or functional gain; follow file‑local conventions and do not introduce type annotations solely to satisfy .into().
|
||||
- Compactness: prefer the form that stays on one line after rustfmt; if only one of Line::from(vec![…]) or vec![…].into() avoids wrapping, choose that. If both wrap, pick the one with fewer wrapped lines.
|
||||
|
||||
### Text wrapping
|
||||
- Always use textwrap::wrap to wrap plain strings.
|
||||
- If you have a ratatui Line and you want to wrap it, use the helpers in tui/src/wrapping.rs, e.g. word_wrap_lines / word_wrap_line.
|
||||
- If you need to indent wrapped lines, use the initial_indent / subsequent_indent options from RtOptions if you can, rather than writing custom logic.
|
||||
- If you have a list of lines and you need to prefix them all with some prefix (optionally different on the first vs subsequent lines), use the `prefix_lines` helper from line_utils.
|
||||
|
||||
## Tests
|
||||
|
||||
### Snapshot tests
|
||||
|
||||
This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to validate rendered output. When UI or text output changes intentionally, update the snapshots as follows:
|
||||
|
||||
@@ -41,3 +61,7 @@ This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to va
|
||||
|
||||
If you don’t have the tool:
|
||||
- `cargo install cargo-insta`
|
||||
|
||||
### Test assertions
|
||||
|
||||
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, see <a href="https://chatgpt.com/codex">chatgpt.com/codex</a>.</p>
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
</br>
|
||||
</br>If you want Codex in your code editor (VS Code, Cursor, Windsurf), <a href="https://developers.openai.com/codex/ide">install in your IDE</a>
|
||||
</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, go to <a href="https://chatgpt.com/codex">chatgpt.com/codex</a></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="./.github/codex-cli-splash.png" alt="Codex CLI splash" width="80%" />
|
||||
@@ -75,7 +78,7 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [CLI usage](./docs/getting-started.md#cli-usage)
|
||||
- [Running with a prompt as input](./docs/getting-started.md#running-with-a-prompt-as-input)
|
||||
- [Example prompts](./docs/getting-started.md#example-prompts)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory--project-docs)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory-with-agentsmd)
|
||||
- [Configuration](./docs/config.md)
|
||||
- [**Sandbox & approvals**](./docs/sandbox.md)
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
|
||||
@@ -43,7 +43,8 @@ switch (platform) {
|
||||
targetTriple = "x86_64-pc-windows-msvc.exe";
|
||||
break;
|
||||
case "arm64":
|
||||
// We do not build this today, fall through...
|
||||
targetTriple = "aarch64-pc-windows-msvc.exe";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
],
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/openai/codex.git"
|
||||
"url": "git+https://github.com/openai/codex.git",
|
||||
"directory": "codex-cli"
|
||||
},
|
||||
"dependencies": {
|
||||
"@vscode/ripgrep": "^1.15.14"
|
||||
|
||||
@@ -20,7 +20,7 @@ CODEX_CLI_ROOT=""
|
||||
# Until we start publishing stable GitHub releases, we have to grab the binaries
|
||||
# from the GitHub Action that created them. Update the URL below to point to the
|
||||
# appropriate workflow run:
|
||||
WORKFLOW_URL="https://github.com/openai/codex/actions/runs/16840150768" # rust-v0.20.0-alpha.2
|
||||
WORKFLOW_URL="https://github.com/openai/codex/actions/runs/17417194663" # rust-v0.28.0
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
@@ -87,5 +87,8 @@ zstd -d "$ARTIFACTS_DIR/aarch64-apple-darwin/codex-aarch64-apple-darwin.zst" \
|
||||
# x64 Windows
|
||||
zstd -d "$ARTIFACTS_DIR/x86_64-pc-windows-msvc/codex-x86_64-pc-windows-msvc.exe.zst" \
|
||||
-o "$BIN_DIR/codex-x86_64-pc-windows-msvc.exe"
|
||||
# ARM64 Windows
|
||||
zstd -d "$ARTIFACTS_DIR/aarch64-pc-windows-msvc/codex-aarch64-pc-windows-msvc.exe.zst" \
|
||||
-o "$BIN_DIR/codex-aarch64-pc-windows-msvc.exe"
|
||||
|
||||
echo "Installed native dependencies into $BIN_DIR"
|
||||
|
||||
1216
codex-rs/Cargo.lock
generated
1216
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -29,14 +29,167 @@ version = "0.0.0"
|
||||
# edition.
|
||||
edition = "2024"
|
||||
|
||||
[workspace.dependencies]
|
||||
# Internal
|
||||
codex-ansi-escape = { path = "ansi-escape" }
|
||||
codex-apply-patch = { path = "apply-patch" }
|
||||
codex-arg0 = { path = "arg0" }
|
||||
codex-chatgpt = { path = "chatgpt" }
|
||||
codex-common = { path = "common" }
|
||||
codex-core = { path = "core" }
|
||||
codex-exec = { path = "exec" }
|
||||
codex-file-search = { path = "file-search" }
|
||||
codex-linux-sandbox = { path = "linux-sandbox" }
|
||||
codex-login = { path = "login" }
|
||||
codex-mcp-client = { path = "mcp-client" }
|
||||
codex-mcp-server = { path = "mcp-server" }
|
||||
codex-ollama = { path = "ollama" }
|
||||
codex-protocol = { path = "protocol" }
|
||||
codex-protocol-ts = { path = "protocol-ts" }
|
||||
codex-tui = { path = "tui" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
mcp-types = { path = "mcp-types" }
|
||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||
|
||||
# External
|
||||
allocative = "0.3.3"
|
||||
ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_cmd = "2"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
chrono = "0.4.40"
|
||||
clap = "4"
|
||||
clap_complete = "4"
|
||||
color-eyre = "0.6.3"
|
||||
crossterm = "0.28.1"
|
||||
derive_more = "2"
|
||||
diffy = "0.4.2"
|
||||
dirs = "6"
|
||||
dotenvy = "0.15.7"
|
||||
env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
ignore = "0.4.23"
|
||||
image = { version = "^0.25.8", default-features = false }
|
||||
insta = "1.43.2"
|
||||
itertools = "0.14.0"
|
||||
landlock = "0.4.1"
|
||||
lazy_static = "1"
|
||||
libc = "0.2.175"
|
||||
log = "0.4"
|
||||
maplit = "1.0.2"
|
||||
mime_guess = "2.0.5"
|
||||
multimap = "0.10.0"
|
||||
nucleo-matcher = "0.3.1"
|
||||
once_cell = "1"
|
||||
openssl-sys = "*"
|
||||
os_info = "3.12.0"
|
||||
owo-colors = "4.2.0"
|
||||
path-absolutize = "3.1.1"
|
||||
path-clean = "1.0.1"
|
||||
pathdiff = "0.2"
|
||||
portable-pty = "0.9.0"
|
||||
predicates = "3"
|
||||
pretty_assertions = "1.4.1"
|
||||
pulldown-cmark = "0.10"
|
||||
rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = "0.12"
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
sha1 = "0.10.6"
|
||||
sha2 = "0.10"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
starlark = "0.13.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
supports-color = "3.0.2"
|
||||
sys-locale = "0.3.2"
|
||||
tempfile = "3.13.0"
|
||||
textwrap = "0.16.2"
|
||||
thiserror = "2.0.16"
|
||||
time = "0.3"
|
||||
tiny_http = "0.12"
|
||||
tokio = "1"
|
||||
tokio-stream = "0.1.17"
|
||||
tokio-test = "0.4"
|
||||
tokio-util = "0.7.16"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.23.4"
|
||||
tracing = "0.1.41"
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = "0.3.20"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
ts-rs = "11"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
url = "2"
|
||||
urlencoding = "2.1"
|
||||
uuid = "1"
|
||||
vt100 = "0.16.2"
|
||||
walkdir = "2.5.0"
|
||||
webbrowser = "1.0"
|
||||
which = "6"
|
||||
wildmatch = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
|
||||
[workspace.lints]
|
||||
rust = {}
|
||||
|
||||
[workspace.lints.clippy]
|
||||
expect_used = "deny"
|
||||
identity_op = "deny"
|
||||
manual_clamp = "deny"
|
||||
manual_filter = "deny"
|
||||
manual_find = "deny"
|
||||
manual_flatten = "deny"
|
||||
manual_map = "deny"
|
||||
manual_memcpy = "deny"
|
||||
manual_non_exhaustive = "deny"
|
||||
manual_ok_or = "deny"
|
||||
manual_range_contains = "deny"
|
||||
manual_retain = "deny"
|
||||
manual_strip = "deny"
|
||||
manual_try_fold = "deny"
|
||||
manual_unwrap_or = "deny"
|
||||
needless_borrow = "deny"
|
||||
needless_borrowed_reference = "deny"
|
||||
needless_collect = "deny"
|
||||
needless_late_init = "deny"
|
||||
needless_option_as_deref = "deny"
|
||||
needless_question_mark = "deny"
|
||||
needless_update = "deny"
|
||||
redundant_clone = "deny"
|
||||
redundant_closure = "deny"
|
||||
redundant_closure_for_method_calls = "deny"
|
||||
redundant_static_lifetimes = "deny"
|
||||
trivially_copy_pass_by_ref = "deny"
|
||||
uninlined_format_args = "deny"
|
||||
unnecessary_filter_map = "deny"
|
||||
unnecessary_lazy_evaluations = "deny"
|
||||
unnecessary_sort_by = "deny"
|
||||
unnecessary_to_owned = "deny"
|
||||
unwrap_used = "deny"
|
||||
|
||||
# cargo-shear cannot see the platform-specific openssl-sys usage, so we
|
||||
# silence the false positive here instead of deleting a real dependency.
|
||||
[workspace.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
# Because we bundle some of these executables with the TypeScript CLI, we
|
||||
|
||||
@@ -35,7 +35,7 @@ npx @modelcontextprotocol/inspector codex mcp
|
||||
|
||||
You can enable notifications by configuring a script that is run whenever the agent finishes a turn. The [notify documentation](../docs/config.md#notify) includes a detailed example that explains how to get desktop notifications via [terminal-notifier](https://github.com/julienXX/terminal-notifier) on macOS.
|
||||
|
||||
### `codex exec` to run Codex programmatially/non-interactively
|
||||
### `codex exec` to run Codex programmatically/non-interactively
|
||||
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ name = "codex_ansi_escape"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
ansi-to-tui = "7.0.0"
|
||||
ratatui = { version = "0.29.0", features = [
|
||||
ansi-to-tui = { workspace = true }
|
||||
ratatui = { workspace = true, features = [
|
||||
"unstable-rendered-line-info",
|
||||
"unstable-widget-ref",
|
||||
] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
|
||||
@@ -9,7 +9,7 @@ use ratatui::text::Text;
|
||||
pub fn ansi_escape_line(s: &str) -> Line<'static> {
|
||||
let text = ansi_escape(s);
|
||||
match text.lines.as_slice() {
|
||||
[] => Line::from(""),
|
||||
[] => "".into(),
|
||||
[only] => only.clone(),
|
||||
[first, rest @ ..] => {
|
||||
tracing::warn!("ansi_escape_line: expected a single line, got {first:?} and {rest:?}");
|
||||
|
||||
@@ -15,13 +15,14 @@ path = "src/main.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
similar = "2.7.0"
|
||||
thiserror = "2.0.12"
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
anyhow = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tree-sitter = { workspace = true }
|
||||
tree-sitter-bash = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3.13.0"
|
||||
assert_cmd = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::str::Utf8Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use once_cell::sync::Lazy;
|
||||
pub use parser::Hunk;
|
||||
pub use parser::ParseError;
|
||||
use parser::ParseError::*;
|
||||
@@ -18,6 +19,9 @@ use similar::TextDiff;
|
||||
use thiserror::Error;
|
||||
use tree_sitter::LanguageError;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Query;
|
||||
use tree_sitter::QueryCursor;
|
||||
use tree_sitter::StreamingIterator;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
pub use standalone_executable::main;
|
||||
@@ -36,6 +40,11 @@ pub enum ApplyPatchError {
|
||||
/// Error that occurs while computing replacements when applying patch chunks
|
||||
#[error("{0}")]
|
||||
ComputeReplacements(String),
|
||||
/// A raw patch body was provided without an explicit `apply_patch` invocation.
|
||||
#[error(
|
||||
"patch detected without explicit call to apply_patch. Rerun as [\"apply_patch\", \"<patch>\"]"
|
||||
)]
|
||||
ImplicitInvocation,
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ApplyPatchError {
|
||||
@@ -84,26 +93,29 @@ pub enum MaybeApplyPatch {
|
||||
pub struct ApplyPatchArgs {
|
||||
pub patch: String,
|
||||
pub hunks: Vec<Hunk>,
|
||||
pub workdir: Option<String>,
|
||||
}
|
||||
|
||||
pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
match argv {
|
||||
// Direct invocation: apply_patch <patch>
|
||||
[cmd, body] if APPLY_PATCH_COMMANDS.contains(&cmd.as_str()) => match parse_patch(body) {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
[bash, flag, script]
|
||||
if bash == "bash"
|
||||
&& flag == "-lc"
|
||||
&& APPLY_PATCH_COMMANDS
|
||||
.iter()
|
||||
.any(|cmd| script.trim_start().starts_with(cmd)) =>
|
||||
{
|
||||
match extract_heredoc_body_from_apply_patch_command(script) {
|
||||
Ok(body) => match parse_patch(&body) {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
// Bash heredoc form: (optional `cd <path> &&`) apply_patch <<'EOF' ...
|
||||
[bash, flag, script] if bash == "bash" && flag == "-lc" => {
|
||||
match extract_apply_patch_from_bash(script) {
|
||||
Ok((body, workdir)) => match parse_patch(&body) {
|
||||
Ok(mut source) => {
|
||||
source.workdir = workdir;
|
||||
MaybeApplyPatch::Body(source)
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch) => {
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::ShellParseError(e),
|
||||
}
|
||||
}
|
||||
@@ -116,7 +128,9 @@ pub enum ApplyPatchFileChange {
|
||||
Add {
|
||||
content: String,
|
||||
},
|
||||
Delete,
|
||||
Delete {
|
||||
content: String,
|
||||
},
|
||||
Update {
|
||||
unified_diff: String,
|
||||
move_path: Option<PathBuf>,
|
||||
@@ -200,17 +214,63 @@ impl ApplyPatchAction {
|
||||
/// cwd must be an absolute path so that we can resolve relative paths in the
|
||||
/// patch.
|
||||
pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified {
|
||||
// Detect a raw patch body passed directly as the command or as the body of a bash -lc
|
||||
// script. In these cases, report an explicit error rather than applying the patch.
|
||||
match argv {
|
||||
[body] => {
|
||||
if parse_patch(body).is_ok() {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::ImplicitInvocation,
|
||||
);
|
||||
}
|
||||
}
|
||||
[bash, flag, script] if bash == "bash" && flag == "-lc" => {
|
||||
if parse_patch(script).is_ok() {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::ImplicitInvocation,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
match maybe_parse_apply_patch(argv) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { patch, hunks }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs {
|
||||
patch,
|
||||
hunks,
|
||||
workdir,
|
||||
}) => {
|
||||
let effective_cwd = workdir
|
||||
.as_ref()
|
||||
.map(|dir| {
|
||||
let path = Path::new(dir);
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
cwd.join(path)
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| cwd.to_path_buf());
|
||||
let mut changes = HashMap::new();
|
||||
for hunk in hunks {
|
||||
let path = hunk.resolve_path(cwd);
|
||||
let path = hunk.resolve_path(&effective_cwd);
|
||||
match hunk {
|
||||
Hunk::AddFile { contents, .. } => {
|
||||
changes.insert(path, ApplyPatchFileChange::Add { content: contents });
|
||||
}
|
||||
Hunk::DeleteFile { .. } => {
|
||||
changes.insert(path, ApplyPatchFileChange::Delete);
|
||||
let content = match std::fs::read_to_string(&path) {
|
||||
Ok(content) => content,
|
||||
Err(e) => {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::IoError(IoError {
|
||||
context: format!("Failed to read {}", path.display()),
|
||||
source: e,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
changes.insert(path, ApplyPatchFileChange::Delete { content });
|
||||
}
|
||||
Hunk::UpdateFile {
|
||||
move_path, chunks, ..
|
||||
@@ -238,7 +298,7 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp
|
||||
MaybeApplyPatchVerified::Body(ApplyPatchAction {
|
||||
changes,
|
||||
patch,
|
||||
cwd: cwd.to_path_buf(),
|
||||
cwd: effective_cwd,
|
||||
})
|
||||
}
|
||||
MaybeApplyPatch::ShellParseError(e) => MaybeApplyPatchVerified::ShellParseError(e),
|
||||
@@ -247,33 +307,96 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to extract a heredoc_body object from a string bash command like:
|
||||
/// Optimistically
|
||||
/// Extract the heredoc body (and optional `cd` workdir) from a `bash -lc` script
|
||||
/// that invokes the apply_patch tool using a heredoc.
|
||||
///
|
||||
/// ```bash
|
||||
/// bash -lc 'apply_patch <<EOF\n***Begin Patch\n...EOF'
|
||||
/// ```
|
||||
/// Supported top‑level forms (must be the only top‑level statement):
|
||||
/// - `apply_patch <<'EOF'\n...\nEOF`
|
||||
/// - `cd <path> && apply_patch <<'EOF'\n...\nEOF`
|
||||
///
|
||||
/// # Arguments
|
||||
/// Notes about matching:
|
||||
/// - Parsed with Tree‑sitter Bash and a strict query that uses anchors so the
|
||||
/// heredoc‑redirected statement is the only top‑level statement.
|
||||
/// - The connector between `cd` and `apply_patch` must be `&&` (not `|` or `||`).
|
||||
/// - Exactly one positional `word` argument is allowed for `cd` (no flags, no quoted
|
||||
/// strings, no second argument).
|
||||
/// - The apply command is validated in‑query via `#any-of?` to allow `apply_patch`
|
||||
/// or `applypatch`.
|
||||
/// - Preceding or trailing commands (e.g., `echo ...;` or `... && echo done`) do not match.
|
||||
///
|
||||
/// * `src` - A string slice that holds the full command
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// This function returns a `Result` which is:
|
||||
///
|
||||
/// * `Ok(String)` - The heredoc body if the extraction is successful.
|
||||
/// * `Err(anyhow::Error)` - An error if the extraction fails.
|
||||
///
|
||||
fn extract_heredoc_body_from_apply_patch_command(
|
||||
/// Returns `(heredoc_body, Some(path))` when the `cd` variant matches, or
|
||||
/// `(heredoc_body, None)` for the direct form. Errors are returned if the script
|
||||
/// cannot be parsed or does not match the allowed patterns.
|
||||
fn extract_apply_patch_from_bash(
|
||||
src: &str,
|
||||
) -> std::result::Result<String, ExtractHeredocError> {
|
||||
if !APPLY_PATCH_COMMANDS
|
||||
.iter()
|
||||
.any(|cmd| src.trim_start().starts_with(cmd))
|
||||
{
|
||||
return Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch);
|
||||
}
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
// This function uses a Tree-sitter query to recognize one of two
|
||||
// whole-script forms, each expressed as a single top-level statement:
|
||||
//
|
||||
// 1. apply_patch <<'EOF'\n...\nEOF
|
||||
// 2. cd <path> && apply_patch <<'EOF'\n...\nEOF
|
||||
//
|
||||
// Key ideas when reading the query:
|
||||
// - dots (`.`) between named nodes enforces adjacency among named children and
|
||||
// anchor to the start/end of the expression.
|
||||
// - we match a single redirected_statement directly under program with leading
|
||||
// and trailing anchors (`.`). This ensures it is the only top-level statement
|
||||
// (so prefixes like `echo ...;` or suffixes like `... && echo done` do not match).
|
||||
//
|
||||
// Overall, we want to be conservative and only match the intended forms, as other
|
||||
// forms are likely to be model errors, or incorrectly interpreted by later code.
|
||||
//
|
||||
// If you're editing this query, it's helpful to start by creating a debugging binary
|
||||
// which will let you see the AST of an arbitrary bash script passed in, and optionally
|
||||
// also run an arbitrary query against the AST. This is useful for understanding
|
||||
// how tree-sitter parses the script and whether the query syntax is correct. Be sure
|
||||
// to test both positive and negative cases.
|
||||
static APPLY_PATCH_QUERY: Lazy<Query> = Lazy::new(|| {
|
||||
let language = BASH.into();
|
||||
#[expect(clippy::expect_used)]
|
||||
Query::new(
|
||||
&language,
|
||||
r#"
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (command
|
||||
name: (command_name (word) @apply_name) .)
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (list
|
||||
. (command
|
||||
name: (command_name (word) @cd_name) .
|
||||
argument: [
|
||||
(word) @cd_path
|
||||
(string (string_content) @cd_path)
|
||||
(raw_string) @cd_raw_string
|
||||
] .)
|
||||
"&&"
|
||||
. (command
|
||||
name: (command_name (word) @apply_name))
|
||||
.)
|
||||
(#eq? @cd_name "cd")
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
"#,
|
||||
)
|
||||
.expect("valid bash query")
|
||||
});
|
||||
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
@@ -285,26 +408,55 @@ fn extract_heredoc_body_from_apply_patch_command(
|
||||
.ok_or(ExtractHeredocError::FailedToParsePatchIntoAst)?;
|
||||
|
||||
let bytes = src.as_bytes();
|
||||
let mut c = tree.root_node().walk();
|
||||
let root = tree.root_node();
|
||||
|
||||
loop {
|
||||
let node = c.node();
|
||||
if node.kind() == "heredoc_body" {
|
||||
let text = node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
return Ok(text.trim_end_matches('\n').to_owned());
|
||||
}
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut matches = cursor.matches(&APPLY_PATCH_QUERY, root, bytes);
|
||||
while let Some(m) = matches.next() {
|
||||
let mut heredoc_text: Option<String> = None;
|
||||
let mut cd_path: Option<String> = None;
|
||||
|
||||
if c.goto_first_child() {
|
||||
continue;
|
||||
}
|
||||
while !c.goto_next_sibling() {
|
||||
if !c.goto_parent() {
|
||||
return Err(ExtractHeredocError::FailedToFindHeredocBody);
|
||||
for capture in m.captures.iter() {
|
||||
let name = APPLY_PATCH_QUERY.capture_names()[capture.index as usize];
|
||||
match name {
|
||||
"heredoc" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.trim_end_matches('\n')
|
||||
.to_string();
|
||||
heredoc_text = Some(text);
|
||||
}
|
||||
"cd_path" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.to_string();
|
||||
cd_path = Some(text);
|
||||
}
|
||||
"cd_raw_string" => {
|
||||
let raw = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
let trimmed = raw
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''))
|
||||
.unwrap_or(raw);
|
||||
cd_path = Some(trimmed.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(heredoc) = heredoc_text {
|
||||
return Ok((heredoc, cd_path));
|
||||
}
|
||||
}
|
||||
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -496,21 +648,18 @@ fn derive_new_contents_from_chunks(
|
||||
}
|
||||
};
|
||||
|
||||
let mut original_lines: Vec<String> = original_contents
|
||||
.split('\n')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
let mut original_lines: Vec<String> = original_contents.split('\n').map(String::from).collect();
|
||||
|
||||
// Drop the trailing empty element that results from the final newline so
|
||||
// that line counts match the behaviour of standard `diff`.
|
||||
if original_lines.last().is_some_and(|s| s.is_empty()) {
|
||||
if original_lines.last().is_some_and(String::is_empty) {
|
||||
original_lines.pop();
|
||||
}
|
||||
|
||||
let replacements = compute_replacements(&original_lines, path, chunks)?;
|
||||
let new_lines = apply_replacements(original_lines, &replacements);
|
||||
let mut new_lines = new_lines;
|
||||
if !new_lines.last().is_some_and(|s| s.is_empty()) {
|
||||
if !new_lines.last().is_some_and(String::is_empty) {
|
||||
new_lines.push(String::new());
|
||||
}
|
||||
let new_contents = new_lines.join("\n");
|
||||
@@ -554,7 +703,7 @@ fn compute_replacements(
|
||||
if chunk.old_lines.is_empty() {
|
||||
// Pure addition (no old lines). We'll add them at the end or just
|
||||
// before the final empty line if one exists.
|
||||
let insertion_idx = if original_lines.last().is_some_and(|s| s.is_empty()) {
|
||||
let insertion_idx = if original_lines.last().is_some_and(String::is_empty) {
|
||||
original_lines.len() - 1
|
||||
} else {
|
||||
original_lines.len()
|
||||
@@ -580,11 +729,11 @@ fn compute_replacements(
|
||||
|
||||
let mut new_slice: &[String] = &chunk.new_lines;
|
||||
|
||||
if found.is_none() && pattern.last().is_some_and(|s| s.is_empty()) {
|
||||
if found.is_none() && pattern.last().is_some_and(String::is_empty) {
|
||||
// Retry without the trailing empty line which represents the final
|
||||
// newline in the file.
|
||||
pattern = &pattern[..pattern.len() - 1];
|
||||
if new_slice.last().is_some_and(|s| s.is_empty()) {
|
||||
if new_slice.last().is_some_and(String::is_empty) {
|
||||
new_slice = &new_slice[..new_slice.len() - 1];
|
||||
}
|
||||
|
||||
@@ -601,13 +750,15 @@ fn compute_replacements(
|
||||
line_index = start_idx + pattern.len();
|
||||
} else {
|
||||
return Err(ApplyPatchError::ComputeReplacements(format!(
|
||||
"Failed to find expected lines {:?} in {}",
|
||||
chunk.old_lines,
|
||||
path.display()
|
||||
"Failed to find expected lines in {}:\n{}",
|
||||
path.display(),
|
||||
chunk.old_lines.join("\n"),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
replacements.sort_by(|(lhs_idx, _, _), (rhs_idx, _, _)| lhs_idx.cmp(rhs_idx));
|
||||
|
||||
Ok(replacements)
|
||||
}
|
||||
|
||||
@@ -694,6 +845,7 @@ mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::string::ToString;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Helper to construct a patch with the given body.
|
||||
@@ -702,7 +854,72 @@ mod tests {
|
||||
}
|
||||
|
||||
fn strs_to_strings(strs: &[&str]) -> Vec<String> {
|
||||
strs.iter().map(|s| s.to_string()).collect()
|
||||
strs.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
// Test helpers to reduce repetition when building bash -lc heredoc scripts
|
||||
fn args_bash(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["bash", "-lc", script])
|
||||
}
|
||||
|
||||
fn heredoc_script(prefix: &str) -> String {
|
||||
format!(
|
||||
"{prefix}apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH"
|
||||
)
|
||||
}
|
||||
|
||||
fn heredoc_script_ps(prefix: &str, suffix: &str) -> String {
|
||||
format!(
|
||||
"{prefix}apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH{suffix}"
|
||||
)
|
||||
}
|
||||
|
||||
fn expected_single_add() -> Vec<Hunk> {
|
||||
vec![Hunk::AddFile {
|
||||
path: PathBuf::from("foo"),
|
||||
contents: "hi\n".to_string(),
|
||||
}]
|
||||
}
|
||||
|
||||
fn assert_match(script: &str, expected_workdir: Option<&str>) {
|
||||
let args = args_bash(script);
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, workdir, .. }) => {
|
||||
assert_eq!(workdir.as_deref(), expected_workdir);
|
||||
assert_eq!(hunks, expected_single_add());
|
||||
}
|
||||
result => panic!("expected MaybeApplyPatch::Body got {result:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert!(matches!(
|
||||
maybe_parse_apply_patch(&args),
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_implicit_patch_single_arg_is_error() {
|
||||
let patch = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch".to_string();
|
||||
let args = vec![patch];
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_implicit_patch_bash_script_is_error() {
|
||||
let script = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch";
|
||||
let args = args_bash(script);
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -717,7 +934,7 @@ mod tests {
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, .. }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -742,7 +959,7 @@ mod tests {
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, .. }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -757,29 +974,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_heredoc() {
|
||||
let args = strs_to_strings(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
r#"apply_patch <<'PATCH'
|
||||
*** Begin Patch
|
||||
*** Add File: foo
|
||||
+hi
|
||||
*** End Patch
|
||||
PATCH"#,
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
path: PathBuf::from("foo"),
|
||||
contents: "hi\n".to_string()
|
||||
}]
|
||||
);
|
||||
}
|
||||
result => panic!("expected MaybeApplyPatch::Body got {result:?}"),
|
||||
}
|
||||
assert_match(&heredoc_script(""), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -796,7 +991,8 @@ PATCH"#,
|
||||
]);
|
||||
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, workdir, .. }) => {
|
||||
assert_eq!(workdir, None);
|
||||
assert_eq!(
|
||||
hunks,
|
||||
vec![Hunk::AddFile {
|
||||
@@ -809,6 +1005,69 @@ PATCH"#,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heredoc_with_leading_cd() {
|
||||
assert_match(&heredoc_script("cd foo && "), Some("foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_with_semicolon_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo; "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_or_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd bar || "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_pipe_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd bar | "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_single_quoted_path_with_spaces() {
|
||||
assert_match(&heredoc_script("cd 'foo bar' && "), Some("foo bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_double_quoted_path_with_spaces() {
|
||||
assert_match(&heredoc_script("cd \"foo bar\" && "), Some("foo bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_and_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("echo foo && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_patch_with_arg_is_ignored() {
|
||||
let script = "apply_patch foo <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH";
|
||||
assert_not_match(script);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_cd_then_apply_patch_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo && cd bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_two_args_is_ignored() {
|
||||
assert_not_match(&heredoc_script("cd foo bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cd_then_apply_patch_then_extra_is_ignored() {
|
||||
let script = heredoc_script_ps("cd bar && ", " && echo done");
|
||||
assert_not_match(&script);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_then_cd_and_apply_patch_is_ignored() {
|
||||
// Ensure preceding commands before the `cd && apply_patch <<...` sequence do not match.
|
||||
assert_not_match(&heredoc_script("echo foo; cd bar && "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_file_hunk_creates_file_with_contents() {
|
||||
let dir = tempdir().unwrap();
|
||||
@@ -1006,6 +1265,33 @@ PATCH"#,
|
||||
assert_eq!(contents, "a\nB\nc\nd\nE\nf\ng\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pure_addition_chunk_followed_by_removal() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("panic.txt");
|
||||
fs::write(&path, "line1\nline2\nline3\n").unwrap();
|
||||
let patch = wrap_patch(&format!(
|
||||
r#"*** Update File: {}
|
||||
@@
|
||||
+after-context
|
||||
+second-line
|
||||
@@
|
||||
line1
|
||||
-line2
|
||||
-line3
|
||||
+line2-replacement"#,
|
||||
path.display()
|
||||
));
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
apply_patch(&patch, &mut stdout, &mut stderr).unwrap();
|
||||
let contents = fs::read_to_string(path).unwrap();
|
||||
assert_eq!(
|
||||
contents,
|
||||
"line1\nline2-replacement\nafter-context\nsecond-line\n"
|
||||
);
|
||||
}
|
||||
|
||||
/// Ensure that patches authored with ASCII characters can update lines that
|
||||
/// contain typographic Unicode punctuation (e.g. EN DASH, NON-BREAKING
|
||||
/// HYPHEN). Historically `git apply` succeeds in such scenarios but our
|
||||
|
||||
@@ -175,7 +175,11 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result<ApplyPatchArgs, Pars
|
||||
remaining_lines = &remaining_lines[hunk_lines..]
|
||||
}
|
||||
let patch = lines.join("\n");
|
||||
Ok(ApplyPatchArgs { hunks, patch })
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks,
|
||||
patch,
|
||||
workdir: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Checks the start and end lines of the patch text for `apply_patch`,
|
||||
@@ -586,7 +590,8 @@ fn test_parse_patch_lenient() {
|
||||
parse_patch_text(&patch_text_in_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -599,7 +604,8 @@ fn test_parse_patch_lenient() {
|
||||
parse_patch_text(&patch_text_in_single_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -611,8 +617,9 @@ fn test_parse_patch_lenient() {
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
patch: patch_text.to_string()
|
||||
hunks: expected_patch,
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -630,7 +637,7 @@ fn test_parse_patch_lenient() {
|
||||
"<<EOF\n*** Begin Patch\n*** Update File: file2.py\nEOF\n".to_string();
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Strict),
|
||||
Err(expected_error.clone())
|
||||
Err(expected_error)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Lenient),
|
||||
|
||||
@@ -112,9 +112,10 @@ pub(crate) fn seek_sequence(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::seek_sequence;
|
||||
use std::string::ToString;
|
||||
|
||||
fn to_vec(strings: &[&str]) -> Vec<String> {
|
||||
strings.iter().map(|s| s.to_string()).collect()
|
||||
strings.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -11,10 +11,10 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
codex-apply-patch = { path = "../apply-patch" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
dotenvy = "0.15.7"
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
anyhow = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-linux-sandbox = { workspace = true }
|
||||
dotenvy = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
|
||||
@@ -21,8 +21,7 @@ const MISSPELLED_APPLY_PATCH_ARG0: &str = "applypatch";
|
||||
/// `codex-linux-sandbox` we *directly* execute
|
||||
/// [`codex_linux_sandbox::run_main`] (which never returns). Otherwise we:
|
||||
///
|
||||
/// 1. Use [`dotenvy::from_path`] and [`dotenvy::dotenv`] to modify the
|
||||
/// environment before creating any threads.
|
||||
/// 1. Load `.env` values from `~/.codex/.env` before creating any threads.
|
||||
/// 2. Construct a Tokio multi-thread runtime.
|
||||
/// 3. Derive the path to the current executable (so children can re-invoke the
|
||||
/// sandbox) when running on Linux.
|
||||
@@ -55,7 +54,7 @@ where
|
||||
|
||||
let argv1 = args.next().unwrap_or_default();
|
||||
if argv1 == CODEX_APPLY_PATCH_ARG1 {
|
||||
let patch_arg = args.next().and_then(|s| s.to_str().map(|s| s.to_owned()));
|
||||
let patch_arg = args.next().and_then(|s| s.to_str().map(str::to_owned));
|
||||
let exit_code = match patch_arg {
|
||||
Some(patch_arg) => {
|
||||
let mut stdout = std::io::stdout();
|
||||
@@ -106,7 +105,7 @@ where
|
||||
|
||||
const ILLEGAL_ENV_VAR_PREFIX: &str = "CODEX_";
|
||||
|
||||
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
|
||||
/// Load env vars from ~/.codex/.env.
|
||||
///
|
||||
/// Security: Do not allow `.env` files to create or modify any variables
|
||||
/// with names starting with `CODEX_`.
|
||||
@@ -116,10 +115,6 @@ fn load_dotenv() {
|
||||
{
|
||||
set_filtered(iter);
|
||||
}
|
||||
|
||||
if let Ok(iter) = dotenvy::dotenv_iter() {
|
||||
set_filtered(iter);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to set vars from a dotenvy iterator while filtering out `CODEX_` keys.
|
||||
|
||||
@@ -7,15 +7,13 @@ version = { workspace = true }
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-login = { path = "../login" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-common = { workspace = true, features = ["cli"] }
|
||||
codex-core = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use codex_core::config::Config;
|
||||
use codex_core::user_agent::get_codex_user_agent;
|
||||
use codex_core::default_client::create_client;
|
||||
|
||||
use crate::chatgpt_token::get_chatgpt_token_data;
|
||||
use crate::chatgpt_token::init_chatgpt_token_from_auth;
|
||||
@@ -16,7 +16,7 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
// Make direct HTTP request to ChatGPT backend API with the token
|
||||
let client = reqwest::Client::new();
|
||||
let client = create_client();
|
||||
let url = format!("{chatgpt_base_url}{path}");
|
||||
|
||||
let token =
|
||||
@@ -31,7 +31,6 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
.bearer_auth(&token.access_token)
|
||||
.header("chatgpt-account-id", account_id?)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", get_codex_user_agent(None))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request")?;
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_core::CodexAuth;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use codex_login::TokenData;
|
||||
use codex_core::token_data::TokenData;
|
||||
|
||||
static CHATGPT_TOKEN: LazyLock<RwLock<Option<TokenData>>> = LazyLock::new(|| RwLock::new(None));
|
||||
|
||||
@@ -20,7 +19,7 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
||||
|
||||
/// Initialize the ChatGPT token from auth.json file
|
||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT)?;
|
||||
let auth = CodexAuth::from_codex_home(codex_home)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
|
||||
@@ -15,26 +15,34 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
clap_complete = "4"
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
codex-chatgpt = { path = "../chatgpt" }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-exec = { path = "../exec" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-mcp-server = { path = "../mcp-server" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
codex-tui = { path = "../tui" }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = [
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
clap_complete = { workspace = true }
|
||||
codex-arg0 = { workspace = true }
|
||||
codex-chatgpt = { workspace = true }
|
||||
codex-common = { workspace = true, features = ["cli"] }
|
||||
codex-core = { workspace = true }
|
||||
codex-exec = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-mcp-server = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-protocol-ts = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
supports-color = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.19"
|
||||
codex-protocol-ts = { path = "../protocol-ts" }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -64,7 +64,6 @@ async fn run_command_under_sandbox(
|
||||
sandbox_type: SandboxType,
|
||||
) -> anyhow::Result<()> {
|
||||
let sandbox_mode = create_sandbox_mode(full_auto);
|
||||
let cwd = std::env::current_dir()?;
|
||||
let config = Config::load_with_cli_overrides(
|
||||
config_overrides
|
||||
.parse_overrides()
|
||||
@@ -75,13 +74,29 @@ async fn run_command_under_sandbox(
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
// In practice, this should be `std::env::current_dir()` because this CLI
|
||||
// does not support `--cwd`, but let's use the config value for consistency.
|
||||
let cwd = config.cwd.clone();
|
||||
// For now, we always use the same cwd for both the command and the
|
||||
// sandbox policy. In the future, we could add a CLI option to set them
|
||||
// separately.
|
||||
let sandbox_policy_cwd = cwd.clone();
|
||||
|
||||
let stdio_policy = StdioPolicy::Inherit;
|
||||
let env = create_env(&config.shell_environment_policy);
|
||||
|
||||
let mut child = match sandbox_type {
|
||||
SandboxType::Seatbelt => {
|
||||
spawn_command_under_seatbelt(command, &config.sandbox_policy, cwd, stdio_policy, env)
|
||||
.await?
|
||||
spawn_command_under_seatbelt(
|
||||
command,
|
||||
cwd,
|
||||
&config.sandbox_policy,
|
||||
sandbox_policy_cwd.as_path(),
|
||||
stdio_policy,
|
||||
env,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
SandboxType::Landlock => {
|
||||
#[expect(clippy::expect_used)]
|
||||
@@ -91,8 +106,9 @@ async fn run_command_under_sandbox(
|
||||
spawn_command_under_linux_sandbox(
|
||||
codex_linux_sandbox_exe,
|
||||
command,
|
||||
&config.sandbox_policy,
|
||||
cwd,
|
||||
&config.sandbox_policy,
|
||||
sandbox_policy_cwd.as_path(),
|
||||
stdio_policy,
|
||||
env,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::logout;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CLIENT_ID;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_login::logout;
|
||||
use codex_login::run_login_server;
|
||||
use std::env;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
@@ -60,19 +58,11 @@ pub async fn run_login_with_api_key(
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method) {
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
Ok(api_key) => {
|
||||
eprintln!("Logged in using an API key - {}", safe_format_key(&api_key));
|
||||
|
||||
if let Ok(env_api_key) = env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
&& env_api_key == api_key
|
||||
{
|
||||
eprintln!(
|
||||
" API loaded from OPENAI_API_KEY environment variable or .env file"
|
||||
);
|
||||
}
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -14,9 +14,15 @@ use codex_cli::login::run_logout;
|
||||
use codex_cli::proto;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_exec::Cli as ExecCli;
|
||||
use codex_tui::AppExitInfo;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
use owo_colors::OwoColorize;
|
||||
use std::path::PathBuf;
|
||||
use supports_color::Stream;
|
||||
|
||||
mod mcp_cmd;
|
||||
|
||||
use crate::mcp_cmd::McpCli;
|
||||
use crate::proto::ProtoCli;
|
||||
|
||||
/// Codex CLI
|
||||
@@ -56,8 +62,8 @@ enum Subcommand {
|
||||
/// Remove stored authentication credentials.
|
||||
Logout(LogoutCommand),
|
||||
|
||||
/// Experimental: run Codex as an MCP server.
|
||||
Mcp,
|
||||
/// [experimental] Run Codex as an MCP server and manage MCP servers.
|
||||
Mcp(McpCli),
|
||||
|
||||
/// Run the Protocol stream via stdin/stdout
|
||||
#[clap(visible_alias = "p")]
|
||||
@@ -73,6 +79,9 @@ enum Subcommand {
|
||||
#[clap(visible_alias = "a")]
|
||||
Apply(ApplyCommand),
|
||||
|
||||
/// Resume a previous interactive session (picker by default; use --last to continue the most recent).
|
||||
Resume(ResumeCommand),
|
||||
|
||||
/// Internal: generate TypeScript protocol bindings.
|
||||
#[clap(hide = true)]
|
||||
GenerateTs(GenerateTsCommand),
|
||||
@@ -85,6 +94,21 @@ struct CompletionCommand {
|
||||
shell: Shell,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct ResumeCommand {
|
||||
/// Conversation/session id (UUID). When provided, resumes this session.
|
||||
/// If omitted, use --last to pick the most recent recorded session.
|
||||
#[arg(value_name = "SESSION_ID")]
|
||||
session_id: Option<String>,
|
||||
|
||||
/// Continue the most recent session without showing the picker.
|
||||
#[arg(long = "last", default_value_t = false, conflicts_with = "session_id")]
|
||||
last: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
config_overrides: TuiCli,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct DebugArgs {
|
||||
#[command(subcommand)]
|
||||
@@ -135,6 +159,41 @@ struct GenerateTsCommand {
|
||||
prettier: Option<PathBuf>,
|
||||
}
|
||||
|
||||
fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec<String> {
|
||||
let AppExitInfo {
|
||||
token_usage,
|
||||
conversation_id,
|
||||
} = exit_info;
|
||||
|
||||
if token_usage.is_zero() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut lines = vec![format!(
|
||||
"{}",
|
||||
codex_core::protocol::FinalOutput::from(token_usage)
|
||||
)];
|
||||
|
||||
if let Some(session_id) = conversation_id {
|
||||
let resume_cmd = format!("codex resume {session_id}");
|
||||
let command = if color_enabled {
|
||||
resume_cmd.cyan().to_string()
|
||||
} else {
|
||||
resume_cmd
|
||||
};
|
||||
lines.push(format!("To continue this session, run {command}."));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn print_exit_messages(exit_info: AppExitInfo) {
|
||||
let color_enabled = supports_color::on(Stream::Stdout).is_some();
|
||||
for line in format_exit_messages(exit_info, color_enabled) {
|
||||
println!("{line}");
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
cli_main(codex_linux_sandbox_exe).await?;
|
||||
@@ -143,26 +202,52 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let cli = MultitoolCli::parse();
|
||||
let MultitoolCli {
|
||||
config_overrides: root_config_overrides,
|
||||
mut interactive,
|
||||
subcommand,
|
||||
} = MultitoolCli::parse();
|
||||
|
||||
match cli.subcommand {
|
||||
match subcommand {
|
||||
None => {
|
||||
let mut tui_cli = cli.interactive;
|
||||
prepend_config_flags(&mut tui_cli.config_overrides, cli.config_overrides);
|
||||
let usage = codex_tui::run_main(tui_cli, codex_linux_sandbox_exe).await?;
|
||||
if !usage.is_zero() {
|
||||
println!("{}", codex_core::protocol::FinalOutput::from(usage));
|
||||
}
|
||||
prepend_config_flags(
|
||||
&mut interactive.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
prepend_config_flags(&mut exec_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut exec_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
codex_exec::run_main(exec_cli, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Mcp) => {
|
||||
codex_mcp_server::run_main(codex_linux_sandbox_exe, cli.config_overrides).await?;
|
||||
Some(Subcommand::Mcp(mut mcp_cli)) => {
|
||||
// Propagate any root-level config overrides (e.g. `-c key=value`).
|
||||
prepend_config_flags(&mut mcp_cli.config_overrides, root_config_overrides.clone());
|
||||
mcp_cli.run(codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Resume(ResumeCommand {
|
||||
session_id,
|
||||
last,
|
||||
config_overrides,
|
||||
})) => {
|
||||
interactive = finalize_resume_interactive(
|
||||
interactive,
|
||||
root_config_overrides.clone(),
|
||||
session_id,
|
||||
last,
|
||||
config_overrides,
|
||||
);
|
||||
codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Login(mut login_cli)) => {
|
||||
prepend_config_flags(&mut login_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut login_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
match login_cli.action {
|
||||
Some(LoginSubcommand::Status) => {
|
||||
run_login_status(login_cli.config_overrides).await;
|
||||
@@ -177,11 +262,17 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
}
|
||||
}
|
||||
Some(Subcommand::Logout(mut logout_cli)) => {
|
||||
prepend_config_flags(&mut logout_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut logout_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
run_logout(logout_cli.config_overrides).await;
|
||||
}
|
||||
Some(Subcommand::Proto(mut proto_cli)) => {
|
||||
prepend_config_flags(&mut proto_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut proto_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
proto::run_main(proto_cli).await?;
|
||||
}
|
||||
Some(Subcommand::Completion(completion_cli)) => {
|
||||
@@ -189,7 +280,10 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
}
|
||||
Some(Subcommand::Debug(debug_args)) => match debug_args.cmd {
|
||||
DebugCommand::Seatbelt(mut seatbelt_cli) => {
|
||||
prepend_config_flags(&mut seatbelt_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut seatbelt_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
codex_cli::debug_sandbox::run_command_under_seatbelt(
|
||||
seatbelt_cli,
|
||||
codex_linux_sandbox_exe,
|
||||
@@ -197,7 +291,10 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
.await?;
|
||||
}
|
||||
DebugCommand::Landlock(mut landlock_cli) => {
|
||||
prepend_config_flags(&mut landlock_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut landlock_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
codex_cli::debug_sandbox::run_command_under_landlock(
|
||||
landlock_cli,
|
||||
codex_linux_sandbox_exe,
|
||||
@@ -206,7 +303,10 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
}
|
||||
},
|
||||
Some(Subcommand::Apply(mut apply_cli)) => {
|
||||
prepend_config_flags(&mut apply_cli.config_overrides, cli.config_overrides);
|
||||
prepend_config_flags(
|
||||
&mut apply_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
run_apply_command(apply_cli, None).await?;
|
||||
}
|
||||
Some(Subcommand::GenerateTs(gen_cli)) => {
|
||||
@@ -228,8 +328,256 @@ fn prepend_config_flags(
|
||||
.splice(0..0, cli_config_overrides.raw_overrides);
|
||||
}
|
||||
|
||||
/// Build the final `TuiCli` for a `codex resume` invocation.
|
||||
fn finalize_resume_interactive(
|
||||
mut interactive: TuiCli,
|
||||
root_config_overrides: CliConfigOverrides,
|
||||
session_id: Option<String>,
|
||||
last: bool,
|
||||
resume_cli: TuiCli,
|
||||
) -> TuiCli {
|
||||
// Start with the parsed interactive CLI so resume shares the same
|
||||
// configuration surface area as `codex` without additional flags.
|
||||
let resume_session_id = session_id;
|
||||
interactive.resume_picker = resume_session_id.is_none() && !last;
|
||||
interactive.resume_last = last;
|
||||
interactive.resume_session_id = resume_session_id;
|
||||
|
||||
// Merge resume-scoped flags and overrides with highest precedence.
|
||||
merge_resume_cli_flags(&mut interactive, resume_cli);
|
||||
|
||||
// Propagate any root-level config overrides (e.g. `-c key=value`).
|
||||
prepend_config_flags(&mut interactive.config_overrides, root_config_overrides);
|
||||
|
||||
interactive
|
||||
}
|
||||
|
||||
/// Merge flags provided to `codex resume` so they take precedence over any
|
||||
/// root-level flags. Only overrides fields explicitly set on the resume-scoped
|
||||
/// CLI. Also appends `-c key=value` overrides with highest precedence.
|
||||
fn merge_resume_cli_flags(interactive: &mut TuiCli, resume_cli: TuiCli) {
|
||||
if let Some(model) = resume_cli.model {
|
||||
interactive.model = Some(model);
|
||||
}
|
||||
if resume_cli.oss {
|
||||
interactive.oss = true;
|
||||
}
|
||||
if let Some(profile) = resume_cli.config_profile {
|
||||
interactive.config_profile = Some(profile);
|
||||
}
|
||||
if let Some(sandbox) = resume_cli.sandbox_mode {
|
||||
interactive.sandbox_mode = Some(sandbox);
|
||||
}
|
||||
if let Some(approval) = resume_cli.approval_policy {
|
||||
interactive.approval_policy = Some(approval);
|
||||
}
|
||||
if resume_cli.full_auto {
|
||||
interactive.full_auto = true;
|
||||
}
|
||||
if resume_cli.dangerously_bypass_approvals_and_sandbox {
|
||||
interactive.dangerously_bypass_approvals_and_sandbox = true;
|
||||
}
|
||||
if let Some(cwd) = resume_cli.cwd {
|
||||
interactive.cwd = Some(cwd);
|
||||
}
|
||||
if resume_cli.web_search {
|
||||
interactive.web_search = true;
|
||||
}
|
||||
if !resume_cli.images.is_empty() {
|
||||
interactive.images = resume_cli.images;
|
||||
}
|
||||
if let Some(prompt) = resume_cli.prompt {
|
||||
interactive.prompt = Some(prompt);
|
||||
}
|
||||
|
||||
interactive
|
||||
.config_overrides
|
||||
.raw_overrides
|
||||
.extend(resume_cli.config_overrides.raw_overrides);
|
||||
}
|
||||
|
||||
fn print_completion(cmd: CompletionCommand) {
|
||||
let mut app = MultitoolCli::command();
|
||||
let name = "codex";
|
||||
generate(cmd.shell, &mut app, name, &mut std::io::stdout());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
|
||||
fn finalize_from_args(args: &[&str]) -> TuiCli {
|
||||
let cli = MultitoolCli::try_parse_from(args).expect("parse");
|
||||
let MultitoolCli {
|
||||
interactive,
|
||||
config_overrides: root_overrides,
|
||||
subcommand,
|
||||
} = cli;
|
||||
|
||||
let Subcommand::Resume(ResumeCommand {
|
||||
session_id,
|
||||
last,
|
||||
config_overrides: resume_cli,
|
||||
}) = subcommand.expect("resume present")
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
finalize_resume_interactive(interactive, root_overrides, session_id, last, resume_cli)
|
||||
}
|
||||
|
||||
fn sample_exit_info(conversation: Option<&str>) -> AppExitInfo {
|
||||
let token_usage = TokenUsage {
|
||||
output_tokens: 2,
|
||||
total_tokens: 2,
|
||||
..Default::default()
|
||||
};
|
||||
AppExitInfo {
|
||||
token_usage,
|
||||
conversation_id: conversation
|
||||
.map(ConversationId::from_string)
|
||||
.map(Result::unwrap),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exit_messages_skips_zero_usage() {
|
||||
let exit_info = AppExitInfo {
|
||||
token_usage: TokenUsage::default(),
|
||||
conversation_id: None,
|
||||
};
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert!(lines.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exit_messages_includes_resume_hint_without_color() {
|
||||
let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"));
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"Token usage: total=2 input=0 output=2".to_string(),
|
||||
"To continue this session, run codex resume 123e4567-e89b-12d3-a456-426614174000."
|
||||
.to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exit_messages_applies_color_when_enabled() {
|
||||
let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"));
|
||||
let lines = format_exit_messages(exit_info, true);
|
||||
assert_eq!(lines.len(), 2);
|
||||
assert!(lines[1].contains("\u{1b}[36m"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_model_flag_applies_when_no_root_flags() {
|
||||
let interactive = finalize_from_args(["codex", "resume", "-m", "gpt-5-test"].as_ref());
|
||||
|
||||
assert_eq!(interactive.model.as_deref(), Some("gpt-5-test"));
|
||||
assert!(interactive.resume_picker);
|
||||
assert!(!interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_picker_logic_none_and_not_last() {
|
||||
let interactive = finalize_from_args(["codex", "resume"].as_ref());
|
||||
assert!(interactive.resume_picker);
|
||||
assert!(!interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_picker_logic_last() {
|
||||
let interactive = finalize_from_args(["codex", "resume", "--last"].as_ref());
|
||||
assert!(!interactive.resume_picker);
|
||||
assert!(interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_picker_logic_with_session_id() {
|
||||
let interactive = finalize_from_args(["codex", "resume", "1234"].as_ref());
|
||||
assert!(!interactive.resume_picker);
|
||||
assert!(!interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id.as_deref(), Some("1234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_merges_option_flags_and_full_auto() {
|
||||
let interactive = finalize_from_args(
|
||||
[
|
||||
"codex",
|
||||
"resume",
|
||||
"sid",
|
||||
"--oss",
|
||||
"--full-auto",
|
||||
"--search",
|
||||
"--sandbox",
|
||||
"workspace-write",
|
||||
"--ask-for-approval",
|
||||
"on-request",
|
||||
"-m",
|
||||
"gpt-5-test",
|
||||
"-p",
|
||||
"my-profile",
|
||||
"-C",
|
||||
"/tmp",
|
||||
"-i",
|
||||
"/tmp/a.png,/tmp/b.png",
|
||||
]
|
||||
.as_ref(),
|
||||
);
|
||||
|
||||
assert_eq!(interactive.model.as_deref(), Some("gpt-5-test"));
|
||||
assert!(interactive.oss);
|
||||
assert_eq!(interactive.config_profile.as_deref(), Some("my-profile"));
|
||||
assert!(matches!(
|
||||
interactive.sandbox_mode,
|
||||
Some(codex_common::SandboxModeCliArg::WorkspaceWrite)
|
||||
));
|
||||
assert!(matches!(
|
||||
interactive.approval_policy,
|
||||
Some(codex_common::ApprovalModeCliArg::OnRequest)
|
||||
));
|
||||
assert!(interactive.full_auto);
|
||||
assert_eq!(
|
||||
interactive.cwd.as_deref(),
|
||||
Some(std::path::Path::new("/tmp"))
|
||||
);
|
||||
assert!(interactive.web_search);
|
||||
let has_a = interactive
|
||||
.images
|
||||
.iter()
|
||||
.any(|p| p == std::path::Path::new("/tmp/a.png"));
|
||||
let has_b = interactive
|
||||
.images
|
||||
.iter()
|
||||
.any(|p| p == std::path::Path::new("/tmp/b.png"));
|
||||
assert!(has_a && has_b);
|
||||
assert!(!interactive.resume_picker);
|
||||
assert!(!interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id.as_deref(), Some("sid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_merges_dangerously_bypass_flag() {
|
||||
let interactive = finalize_from_args(
|
||||
[
|
||||
"codex",
|
||||
"resume",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
]
|
||||
.as_ref(),
|
||||
);
|
||||
assert!(interactive.dangerously_bypass_approvals_and_sandbox);
|
||||
assert!(interactive.resume_picker);
|
||||
assert!(!interactive.resume_last);
|
||||
assert_eq!(interactive.resume_session_id, None);
|
||||
}
|
||||
}
|
||||
|
||||
384
codex-rs/cli/src/mcp_cmd.rs
Normal file
384
codex-rs/cli/src/mcp_cmd.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::find_codex_home;
|
||||
use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
|
||||
/// [experimental] Launch Codex as an MCP server or manage configured MCP servers.
|
||||
///
|
||||
/// Subcommands:
|
||||
/// - `serve` — run the MCP server on stdio
|
||||
/// - `list` — list configured servers (with `--json`)
|
||||
/// - `get` — show a single server (with `--json`)
|
||||
/// - `add` — add a server launcher entry to `~/.codex/config.toml`
|
||||
/// - `remove` — delete a server entry
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct McpCli {
|
||||
#[clap(flatten)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
#[command(subcommand)]
|
||||
pub cmd: Option<McpSubcommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum McpSubcommand {
|
||||
/// [experimental] Run the Codex MCP server (stdio transport).
|
||||
Serve,
|
||||
|
||||
/// [experimental] List configured MCP servers.
|
||||
List(ListArgs),
|
||||
|
||||
/// [experimental] Show details for a configured MCP server.
|
||||
Get(GetArgs),
|
||||
|
||||
/// [experimental] Add a global MCP server entry.
|
||||
Add(AddArgs),
|
||||
|
||||
/// [experimental] Remove a global MCP server entry.
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct ListArgs {
|
||||
/// Output the configured servers as JSON.
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct GetArgs {
|
||||
/// Name of the MCP server to display.
|
||||
pub name: String,
|
||||
|
||||
/// Output the server configuration as JSON.
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct AddArgs {
|
||||
/// Name for the MCP server configuration.
|
||||
pub name: String,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
#[arg(long, value_parser = parse_env_pair, value_name = "KEY=VALUE")]
|
||||
pub env: Vec<(String, String)>,
|
||||
|
||||
/// Command to launch the MCP server.
|
||||
#[arg(trailing_var_arg = true, num_args = 1..)]
|
||||
pub command: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the MCP server configuration to remove.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl McpCli {
|
||||
pub async fn run(self, codex_linux_sandbox_exe: Option<PathBuf>) -> Result<()> {
|
||||
let McpCli {
|
||||
config_overrides,
|
||||
cmd,
|
||||
} = self;
|
||||
let subcommand = cmd.unwrap_or(McpSubcommand::Serve);
|
||||
|
||||
match subcommand {
|
||||
McpSubcommand::Serve => {
|
||||
codex_mcp_server::run_main(codex_linux_sandbox_exe, config_overrides).await?;
|
||||
}
|
||||
McpSubcommand::List(args) => {
|
||||
run_list(&config_overrides, args)?;
|
||||
}
|
||||
McpSubcommand::Get(args) => {
|
||||
run_get(&config_overrides, args)?;
|
||||
}
|
||||
McpSubcommand::Add(args) => {
|
||||
run_add(&config_overrides, args)?;
|
||||
}
|
||||
McpSubcommand::Remove(args) => {
|
||||
run_remove(&config_overrides, args)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let AddArgs { name, env, command } = add_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let mut command_parts = command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in env {
|
||||
map.insert(key, value);
|
||||
}
|
||||
Some(map)
|
||||
};
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
};
|
||||
|
||||
servers.insert(name.clone(), new_entry);
|
||||
|
||||
write_global_mcp_servers(&codex_home, &servers)
|
||||
.with_context(|| format!("failed to write MCP servers to {}", codex_home.display()))?;
|
||||
|
||||
println!("Added global MCP server '{name}'.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) -> Result<()> {
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let RemoveArgs { name } = remove_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let removed = servers.remove(&name).is_some();
|
||||
|
||||
if removed {
|
||||
write_global_mcp_servers(&codex_home, &servers)
|
||||
.with_context(|| format!("failed to write MCP servers to {}", codex_home.display()))?;
|
||||
}
|
||||
|
||||
if removed {
|
||||
println!("Removed global MCP server '{name}'.");
|
||||
} else {
|
||||
println!("No MCP server named '{name}' found.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let mut entries: Vec<_> = config.mcp_servers.iter().collect();
|
||||
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
|
||||
if list_args.json {
|
||||
let json_entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
.map(|(name, cfg)| {
|
||||
let env = cfg.env.as_ref().map(|env| {
|
||||
env.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect::<BTreeMap<_, _>>()
|
||||
});
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"command": cfg.command,
|
||||
"args": cfg.args,
|
||||
"env": env,
|
||||
"startup_timeout_sec": cfg
|
||||
.startup_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
"tool_timeout_sec": cfg
|
||||
.tool_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let output = serde_json::to_string_pretty(&json_entries)?;
|
||||
println!("{output}");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
println!("No MCP servers configured yet. Try `codex mcp add my-tool -- my-command`.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut rows: Vec<[String; 4]> = Vec::new();
|
||||
for (name, cfg) in entries {
|
||||
let args = if cfg.args.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
cfg.args.join(" ")
|
||||
};
|
||||
|
||||
let env = match cfg.env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
};
|
||||
|
||||
rows.push([name.clone(), cfg.command.clone(), args, env]);
|
||||
}
|
||||
|
||||
let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()];
|
||||
for row in &rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
"Name",
|
||||
"Command",
|
||||
"Args",
|
||||
"Env",
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&get_args.name) else {
|
||||
bail!("No MCP server named '{name}' found.", name = get_args.name);
|
||||
};
|
||||
|
||||
if get_args.json {
|
||||
let env = server.env.as_ref().map(|env| {
|
||||
env.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect::<BTreeMap<_, _>>()
|
||||
});
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
"name": get_args.name,
|
||||
"command": server.command,
|
||||
"args": server.args,
|
||||
"env": env,
|
||||
"startup_timeout_sec": server
|
||||
.startup_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
"tool_timeout_sec": server
|
||||
.tool_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
}))?;
|
||||
println!("{output}");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("{}", get_args.name);
|
||||
println!(" command: {}", server.command);
|
||||
let args = if server.args.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
server.args.join(" ")
|
||||
};
|
||||
println!(" args: {args}");
|
||||
let env_display = match server.env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
println!(" startup_timeout_sec: {}", timeout.as_secs_f64());
|
||||
}
|
||||
if let Some(timeout) = server.tool_timeout_sec {
|
||||
println!(" tool_timeout_sec: {}", timeout.as_secs_f64());
|
||||
}
|
||||
println!(" remove: codex mcp remove {}", get_args.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_env_pair(raw: &str) -> Result<(String, String), String> {
|
||||
let mut parts = raw.splitn(2, '=');
|
||||
let key = parts
|
||||
.next()
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| "environment entries must be in KEY=VALUE form".to_string())?;
|
||||
let value = parts
|
||||
.next()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| "environment entries must be in KEY=VALUE form".to_string())?;
|
||||
|
||||
Ok((key.to_string(), value))
|
||||
}
|
||||
|
||||
fn validate_server_name(name: &str) -> Result<()> {
|
||||
let is_valid = !name.is_empty()
|
||||
&& name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_');
|
||||
|
||||
if is_valid {
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("invalid server name '{name}' (use letters, numbers, '-', '_')");
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use std::io::IsTerminal;
|
||||
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
@@ -9,7 +10,6 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_login::AuthManager;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tracing::error;
|
||||
@@ -37,10 +37,8 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
// Use conversation_manager API to start a conversation
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
86
codex-rs/cli/tests/mcp_add_remove.rs
Normal file
86
codex-rs/cli/tests/mcp_add_remove.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::config::load_global_mcp_servers;
|
||||
use predicates::str::contains;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn codex_command(codex_home: &Path) -> Result<assert_cmd::Command> {
|
||||
let mut cmd = assert_cmd::Command::cargo_bin("codex")?;
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
Ok(cmd)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args(["mcp", "add", "docs", "--", "echo", "hello"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("Added global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
assert_eq!(servers.len(), 1);
|
||||
let docs = servers.get("docs").expect("server should exist");
|
||||
assert_eq!(docs.command, "echo");
|
||||
assert_eq!(docs.args, vec!["hello".to_string()]);
|
||||
assert!(docs.env.is_none());
|
||||
|
||||
let mut remove_cmd = codex_command(codex_home.path())?;
|
||||
remove_cmd
|
||||
.args(["mcp", "remove", "docs"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("Removed global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
let mut remove_again_cmd = codex_command(codex_home.path())?;
|
||||
remove_again_cmd
|
||||
.args(["mcp", "remove", "docs"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("No MCP server named 'docs' found."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"envy",
|
||||
"--env",
|
||||
"FOO=bar",
|
||||
"--env",
|
||||
"ALPHA=beta",
|
||||
"--",
|
||||
"python",
|
||||
"server.py",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let envy = servers.get("envy").expect("server should exist");
|
||||
let env = envy.env.as_ref().expect("env should be present");
|
||||
|
||||
assert_eq!(env.len(), 2);
|
||||
assert_eq!(env.get("FOO"), Some(&"bar".to_string()));
|
||||
assert_eq!(env.get("ALPHA"), Some(&"beta".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
106
codex-rs/cli/tests/mcp_list.rs
Normal file
106
codex-rs/cli/tests/mcp_list.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use predicates::str::contains;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn codex_command(codex_home: &Path) -> Result<assert_cmd::Command> {
|
||||
let mut cmd = assert_cmd::Command::cargo_bin("codex")?;
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
Ok(cmd)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_shows_empty_state() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut cmd = codex_command(codex_home.path())?;
|
||||
let output = cmd.args(["mcp", "list"]).output()?;
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8(output.stdout)?;
|
||||
assert!(stdout.contains("No MCP servers configured yet."));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_and_get_render_expected_output() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add = codex_command(codex_home.path())?;
|
||||
add.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"docs",
|
||||
"--env",
|
||||
"TOKEN=secret",
|
||||
"--",
|
||||
"docs-server",
|
||||
"--port",
|
||||
"4000",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let mut list_cmd = codex_command(codex_home.path())?;
|
||||
let list_output = list_cmd.args(["mcp", "list"]).output()?;
|
||||
assert!(list_output.status.success());
|
||||
let stdout = String::from_utf8(list_output.stdout)?;
|
||||
assert!(stdout.contains("Name"));
|
||||
assert!(stdout.contains("docs"));
|
||||
assert!(stdout.contains("docs-server"));
|
||||
assert!(stdout.contains("TOKEN=secret"));
|
||||
|
||||
let mut list_json_cmd = codex_command(codex_home.path())?;
|
||||
let json_output = list_json_cmd.args(["mcp", "list", "--json"]).output()?;
|
||||
assert!(json_output.status.success());
|
||||
let stdout = String::from_utf8(json_output.stdout)?;
|
||||
let parsed: JsonValue = serde_json::from_str(&stdout)?;
|
||||
let array = parsed.as_array().expect("expected array");
|
||||
assert_eq!(array.len(), 1);
|
||||
let entry = &array[0];
|
||||
assert_eq!(entry.get("name"), Some(&JsonValue::String("docs".into())));
|
||||
assert_eq!(
|
||||
entry.get("command"),
|
||||
Some(&JsonValue::String("docs-server".into()))
|
||||
);
|
||||
|
||||
let args = entry
|
||||
.get("args")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("args array");
|
||||
assert_eq!(
|
||||
args,
|
||||
&vec![
|
||||
JsonValue::String("--port".into()),
|
||||
JsonValue::String("4000".into())
|
||||
]
|
||||
);
|
||||
|
||||
let env = entry
|
||||
.get("env")
|
||||
.and_then(|v| v.as_object())
|
||||
.expect("env map");
|
||||
assert_eq!(env.get("TOKEN"), Some(&JsonValue::String("secret".into())));
|
||||
|
||||
let mut get_cmd = codex_command(codex_home.path())?;
|
||||
let get_output = get_cmd.args(["mcp", "get", "docs"]).output()?;
|
||||
assert!(get_output.status.success());
|
||||
let stdout = String::from_utf8(get_output.stdout)?;
|
||||
assert!(stdout.contains("docs"));
|
||||
assert!(stdout.contains("command: docs-server"));
|
||||
assert!(stdout.contains("args: --port 4000"));
|
||||
assert!(stdout.contains("env: TOKEN=secret"));
|
||||
assert!(stdout.contains("remove: codex mcp remove docs"));
|
||||
|
||||
let mut get_json_cmd = codex_command(codex_home.path())?;
|
||||
get_json_cmd
|
||||
.args(["mcp", "get", "docs", "--json"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("\"name\": \"docs\""));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -7,11 +7,11 @@ version = { workspace = true }
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4", features = ["derive", "wrap_help"], optional = true }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
serde = { version = "1", optional = true }
|
||||
toml = { version = "0.9", optional = true }
|
||||
clap = { workspace = true, features = ["derive", "wrap_help"], optional = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
serde = { workspace = true, optional = true }
|
||||
toml = { workspace = true, optional = true }
|
||||
|
||||
[features]
|
||||
# Separate feature so that `clap` is not a mandatory dependency.
|
||||
|
||||
@@ -17,7 +17,10 @@ pub fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, Stri
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
config
|
||||
.model_reasoning_effort
|
||||
.map(|effort| effort.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Returns a string representing the elapsed time since `start_time` like
|
||||
/// "1m15s" or "1.50s".
|
||||
/// "1m 15s" or "1.50s".
|
||||
pub fn format_elapsed(start_time: Instant) -> String {
|
||||
format_duration(start_time.elapsed())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ pub fn format_elapsed(start_time: Instant) -> String {
|
||||
/// Formatting rules:
|
||||
/// * < 1 s -> "{milli}ms"
|
||||
/// * < 60 s -> "{sec:.2}s" (two decimal places)
|
||||
/// * >= 60 s -> "{min}m{sec:02}s"
|
||||
/// * >= 60 s -> "{min}m {sec:02}s"
|
||||
pub fn format_duration(duration: Duration) -> String {
|
||||
let millis = duration.as_millis() as i64;
|
||||
format_elapsed_millis(millis)
|
||||
@@ -26,7 +26,7 @@ fn format_elapsed_millis(millis: i64) -> String {
|
||||
} else {
|
||||
let minutes = millis / 60_000;
|
||||
let seconds = (millis % 60_000) / 1000;
|
||||
format!("{minutes}m{seconds:02}s")
|
||||
format!("{minutes}m {seconds:02}s")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,12 +61,18 @@ mod tests {
|
||||
fn test_format_duration_minutes() {
|
||||
// Durations ≥ 1 minute should be printed mmss.
|
||||
let dur = Duration::from_millis(75_000); // 1m15s
|
||||
assert_eq!(format_duration(dur), "1m15s");
|
||||
assert_eq!(format_duration(dur), "1m 15s");
|
||||
|
||||
let dur_exact = Duration::from_millis(60_000); // 1m0s
|
||||
assert_eq!(format_duration(dur_exact), "1m00s");
|
||||
assert_eq!(format_duration(dur_exact), "1m 00s");
|
||||
|
||||
let dur_long = Duration::from_millis(3_601_000);
|
||||
assert_eq!(format_duration(dur_long), "60m01s");
|
||||
assert_eq!(format_duration(dur_long), "60m 01s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration_one_hour_has_space() {
|
||||
let dur_hour = Duration::from_millis(3_600_000);
|
||||
assert_eq!(format_duration(dur_hour), "60m 00s");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use codex_core::config::GPT_5_CODEX_MEDIUM_MODEL;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
/// A simple preset pairing a model slug with a reasoning effort.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -12,43 +14,68 @@ pub struct ModelPreset {
|
||||
/// Model slug (e.g., "gpt-5").
|
||||
pub model: &'static str,
|
||||
/// Reasoning effort to apply for this preset.
|
||||
pub effort: ReasoningEffort,
|
||||
pub effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
/// Built-in list of model presets that pair a model with a reasoning effort.
|
||||
///
|
||||
/// Keep this UI-agnostic so it can be reused by both TUI and MCP server.
|
||||
pub fn builtin_model_presets() -> &'static [ModelPreset] {
|
||||
// Order reflects effort from minimal to high.
|
||||
const PRESETS: &[ModelPreset] = &[
|
||||
ModelPreset {
|
||||
id: "gpt-5-minimal",
|
||||
label: "gpt-5 minimal",
|
||||
description: "— fastest responses with limited reasoning; ideal for coding, instructions, or lightweight tasks",
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::Minimal,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-low",
|
||||
label: "gpt-5 low",
|
||||
description: "— balances speed with some reasoning; useful for straightforward queries and short explanations",
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::Low,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-medium",
|
||||
label: "gpt-5 medium",
|
||||
description: "— default setting; provides a solid balance of reasoning depth and latency for general-purpose tasks",
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::Medium,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-high",
|
||||
label: "gpt-5 high",
|
||||
description: "— maximizes reasoning depth for complex or ambiguous problems",
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::High,
|
||||
},
|
||||
];
|
||||
PRESETS
|
||||
const PRESETS: &[ModelPreset] = &[
|
||||
ModelPreset {
|
||||
id: "gpt-5-codex-low",
|
||||
label: "gpt-5-codex low",
|
||||
description: "",
|
||||
model: "gpt-5-codex",
|
||||
effort: Some(ReasoningEffort::Low),
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-codex-medium",
|
||||
label: "gpt-5-codex medium",
|
||||
description: "",
|
||||
model: "gpt-5-codex",
|
||||
effort: None,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-codex-high",
|
||||
label: "gpt-5-codex high",
|
||||
description: "",
|
||||
model: "gpt-5-codex",
|
||||
effort: Some(ReasoningEffort::High),
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-minimal",
|
||||
label: "gpt-5 minimal",
|
||||
description: "— fastest responses with limited reasoning; ideal for coding, instructions, or lightweight tasks",
|
||||
model: "gpt-5",
|
||||
effort: Some(ReasoningEffort::Minimal),
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-low",
|
||||
label: "gpt-5 low",
|
||||
description: "— balances speed with some reasoning; useful for straightforward queries and short explanations",
|
||||
model: "gpt-5",
|
||||
effort: Some(ReasoningEffort::Low),
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-medium",
|
||||
label: "gpt-5 medium",
|
||||
description: "— default setting; provides a solid balance of reasoning depth and latency for general-purpose tasks",
|
||||
model: "gpt-5",
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-high",
|
||||
label: "gpt-5 high",
|
||||
description: "— maximizes reasoning depth for complex or ambiguous problems",
|
||||
model: "gpt-5",
|
||||
effort: Some(ReasoningEffort::High),
|
||||
},
|
||||
];
|
||||
|
||||
pub fn builtin_model_presets(auth_mode: Option<AuthMode>) -> Vec<ModelPreset> {
|
||||
match auth_mode {
|
||||
Some(AuthMode::ApiKey) => PRESETS
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|p| p.model != GPT_5_CODEX_MEDIUM_MODEL)
|
||||
.collect(),
|
||||
_ => PRESETS.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,85 +4,89 @@ name = "codex-core"
|
||||
version = { workspace = true }
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_core"
|
||||
path = "src/lib.rs"
|
||||
doctest = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-channel = "2.3.1"
|
||||
base64 = "0.22"
|
||||
bytes = "1.10.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-apply-patch = { path = "../apply-patch" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-mcp-client = { path = "../mcp-client" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
dirs = "6"
|
||||
env-flags = "0.1.1"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
libc = "0.2.175"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0"
|
||||
os_info = "3.12.0"
|
||||
portable-pty = "0.9.0"
|
||||
rand = "0.9"
|
||||
regex-lite = "0.1.6"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
strum_macros = "0.27.2"
|
||||
tempfile = "3"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
tokio = { version = "1", features = [
|
||||
anyhow = { workspace = true }
|
||||
askama = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-mcp-client = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
portable-pty = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha1 = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
time = { workspace = true, features = [
|
||||
"formatting",
|
||||
"parsing",
|
||||
"local-offset",
|
||||
"macros",
|
||||
] }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = "0.7.16"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.23.4"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
whoami = "1.6.1"
|
||||
wildmatch = "2.4.0"
|
||||
tokio-util = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
toml_edit = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tree-sitter = { workspace = true }
|
||||
tree-sitter-bash = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v4"] }
|
||||
which = { workspace = true }
|
||||
wildmatch = { workspace = true }
|
||||
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
landlock = "0.4.1"
|
||||
seccompiler = "0.5.0"
|
||||
landlock = { workspace = true }
|
||||
seccompiler = { workspace = true }
|
||||
|
||||
# Build OpenSSL from source for musl builds.
|
||||
[target.x86_64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { version = "*", features = ["vendored"] }
|
||||
openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
|
||||
# Build OpenSSL from source for musl builds.
|
||||
[target.aarch64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { version = "*", features = ["vendored"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
which = "6"
|
||||
openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
core_test_support = { path = "tests/common" }
|
||||
maplit = "1.0.2"
|
||||
predicates = "3"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
walkdir = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
assert_cmd = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
maplit = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
104
codex-rs/core/gpt_5_codex_prompt.md
Normal file
104
codex-rs/core/gpt_5_codex_prompt.md
Normal file
@@ -0,0 +1,104 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- The arguments to `shell` will be passed to execvp(). Most terminal commands should be prefixed with ["bash", "-lc"].
|
||||
- Always set the `workdir` param when using the shell function. Do not use `cd` unless absolutely necessary.
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `with_escalated_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `with_escalated_permissions` parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why you need to enable `with_escalated_permissions` in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No "save/copy this file" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; add a language hint whenever obvious.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
@@ -14,6 +14,18 @@ Within this context, Codex refers to the open-source agentic coding interface (n
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
@@ -228,7 +240,6 @@ You are producing plain text that will later be styled by the CLI. Follow these
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Bold the keyword, then colon + concise description.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
@@ -240,6 +251,16 @@ You are producing plain text that will later be styled by the CLI. Follow these
|
||||
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
|
||||
- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).
|
||||
|
||||
**File References**
|
||||
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
|
||||
**Structure**
|
||||
|
||||
- Place related bullets together; don’t mix unrelated concepts in the same section.
|
||||
|
||||
87
codex-rs/core/review_prompt.md
Normal file
87
codex-rs/core/review_prompt.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# Review guidelines:
|
||||
|
||||
You are acting as a reviewer for a proposed code change made by another engineer.
|
||||
|
||||
Below are some default guidelines for determining whether the original author would appreciate the issue being flagged.
|
||||
|
||||
These are not the final word in determining whether an issue is a bug. In many cases, you will encounter other, more specific guidelines. These may be present elsewhere in a developer message, a user message, a file, or even elsewhere in this system message.
|
||||
Those guidelines should be considered to override these general instructions.
|
||||
|
||||
Here are the general guidelines for determining whether something is a bug and should be flagged.
|
||||
|
||||
1. It meaningfully impacts the accuracy, performance, security, or maintainability of the code.
|
||||
2. The bug is discrete and actionable (i.e. not a general issue with the codebase or a combination of multiple issues).
|
||||
3. Fixing the bug does not demand a level of rigor that is not present in the rest of the codebase (e.g. one doesn't need very detailed comments and input validation in a repository of one-off scripts in personal projects)
|
||||
4. The bug was introduced in the commit (pre-existing bugs should not be flagged).
|
||||
5. The author of the original PR would likely fix the issue if they were made aware of it.
|
||||
6. The bug does not rely on unstated assumptions about the codebase or author's intent.
|
||||
7. It is not enough to speculate that a change may disrupt another part of the codebase, to be considered a bug, one must identify the other parts of the code that are provably affected.
|
||||
8. The bug is clearly not just an intentional change by the original author.
|
||||
|
||||
When flagging a bug, you will also provide an accompanying comment. Once again, these guidelines are not the final word on how to construct a comment -- defer to any subsequent guidelines that you encounter.
|
||||
|
||||
1. The comment should be clear about why the issue is a bug.
|
||||
2. The comment should appropriately communicate the severity of the issue. It should not claim that an issue is more severe than it actually is.
|
||||
3. The comment should be brief. The body should be at most 1 paragraph. It should not introduce line breaks within the natural language flow unless it is necessary for the code fragment.
|
||||
4. The comment should not include any chunks of code longer than 3 lines. Any code chunks should be wrapped in markdown inline code tags or a code block.
|
||||
5. The comment should clearly and explicitly communicate the scenarios, environments, or inputs that are necessary for the bug to arise. The comment should immediately indicate that the issue's severity depends on these factors.
|
||||
6. The comment's tone should be matter-of-fact and not accusatory or overly positive. It should read as a helpful AI assistant suggestion without sounding too much like a human reviewer.
|
||||
7. The comment should be written such that the original author can immediately grasp the idea without close reading.
|
||||
8. The comment should avoid excessive flattery and comments that are not helpful to the original author. The comment should avoid phrasing like "Great job ...", "Thanks for ...".
|
||||
|
||||
Below are some more detailed guidelines that you should apply to this specific review.
|
||||
|
||||
HOW MANY FINDINGS TO RETURN:
|
||||
|
||||
Output all findings that the original author would fix if they knew about it. If there is no finding that a person would definitely love to see and fix, prefer outputting no findings. Do not stop at the first qualifying finding. Continue until you've listed every qualifying finding.
|
||||
|
||||
GUIDELINES:
|
||||
|
||||
- Ignore trivial style unless it obscures meaning or violates documented standards.
|
||||
- Use one comment per distinct issue (or a multi-line range if necessary).
|
||||
- Use ```suggestion blocks ONLY for concrete replacement code (minimal lines; no commentary inside the block).
|
||||
- In every ```suggestion block, preserve the exact leading whitespace of the replaced lines (spaces vs tabs, number of spaces).
|
||||
- Do NOT introduce or remove outer indentation levels unless that is the actual fix.
|
||||
|
||||
The comments will be presented in the code review as inline comments. You should avoid providing unnecessary location details in the comment body. Always keep the line range as short as possible for interpreting the issue. Avoid ranges longer than 5–10 lines; instead, choose the most suitable subrange that pinpoints the problem.
|
||||
|
||||
At the beginning of the finding title, tag the bug with priority level. For example "[P1] Un-padding slices along wrong tensor dimensions". [P0] – Drop everything to fix. Blocking release, operations, or major usage. Only use for universal issues that do not depend on any assumptions about the inputs. · [P1] – Urgent. Should be addressed in the next cycle · [P2] – Normal. To be fixed eventually · [P3] – Low. Nice to have.
|
||||
|
||||
Additionally, include a numeric priority field in the JSON output for each finding: set "priority" to 0 for P0, 1 for P1, 2 for P2, or 3 for P3. If a priority cannot be determined, omit the field or use null.
|
||||
|
||||
At the end of your findings, output an "overall correctness" verdict of whether or not the patch should be considered "correct".
|
||||
Correct implies that existing code and tests will not break, and the patch is free of bugs and other blocking issues.
|
||||
Ignore non-blocking issues such as style, formatting, typos, documentation, and other nits.
|
||||
|
||||
FORMATTING GUIDELINES:
|
||||
The finding description should be one paragraph.
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
## Output schema — MUST MATCH *exactly*
|
||||
|
||||
```json
|
||||
{
|
||||
"findings": [
|
||||
{
|
||||
"title": "<≤ 80 chars, imperative>",
|
||||
"body": "<valid Markdown explaining *why* this is a problem; cite files/lines/functions>",
|
||||
"confidence_score": <float 0.0-1.0>,
|
||||
"priority": <int 0-3, optional>,
|
||||
"code_location": {
|
||||
"absolute_file_path": "<file path>",
|
||||
"line_range": {"start": <int>, "end": <int>}
|
||||
}
|
||||
}
|
||||
],
|
||||
"overall_correctness": "patch is correct" | "patch is incorrect",
|
||||
"overall_explanation": "<1-3 sentence explanation justifying the overall_correctness verdict>",
|
||||
"overall_confidence_score": <float 0.0-1.0>
|
||||
}
|
||||
```
|
||||
|
||||
* **Do not** wrap the JSON in markdown fences or extra prose.
|
||||
* The code_location field is required and must include absolute_file_path and line_range.
|
||||
*Line ranges must be as short as possible for interpreting the issue (avoid ranges over 5–10 lines; pick the most suitable subrange).
|
||||
* The code_location should overlap with the diff.
|
||||
* Do not generate a PR fix.
|
||||
@@ -109,7 +109,9 @@ pub(crate) fn convert_apply_patch_to_protocol(
|
||||
ApplyPatchFileChange::Add { content } => FileChange::Add {
|
||||
content: content.clone(),
|
||||
},
|
||||
ApplyPatchFileChange::Delete => FileChange::Delete,
|
||||
ApplyPatchFileChange::Delete { content } => FileChange::Delete {
|
||||
content: content.clone(),
|
||||
},
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
|
||||
659
codex-rs/core/src/auth.rs
Normal file
659
codex-rs/core/src/auth.rs
Normal file
@@ -0,0 +1,659 @@
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
use crate::token_data::PlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodexAuth {
|
||||
pub mode: AuthMode,
|
||||
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) auth_dot_json: Arc<Mutex<Option<AuthDotJson>>>,
|
||||
pub(crate) auth_file: PathBuf,
|
||||
pub(crate) client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl PartialEq for CodexAuth {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.mode == other.mode
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexAuth {
|
||||
pub async fn refresh_token(&self) -> Result<String, std::io::Error> {
|
||||
let token_data = self
|
||||
.get_current_token_data()
|
||||
.ok_or(std::io::Error::other("Token data is not available."))?;
|
||||
let token = token_data.refresh_token;
|
||||
|
||||
let refresh_response = try_refresh_token(token, &self.client)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Ok(mut auth_lock) = self.auth_dot_json.lock() {
|
||||
*auth_lock = Some(updated.clone());
|
||||
}
|
||||
|
||||
let access = match updated.tokens {
|
||||
Some(t) => t.access_token,
|
||||
None => {
|
||||
return Err(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json.
|
||||
pub fn from_codex_home(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
let auth_dot_json: Option<AuthDotJson> = self.get_current_auth_json();
|
||||
match auth_dot_json {
|
||||
Some(AuthDotJson {
|
||||
tokens: Some(mut tokens),
|
||||
last_refresh: Some(last_refresh),
|
||||
..
|
||||
}) => {
|
||||
if last_refresh < Utc::now() - chrono::Duration::days(28) {
|
||||
let refresh_response = tokio::time::timeout(
|
||||
Duration::from_secs(60),
|
||||
try_refresh_token(tokens.refresh_token.clone(), &self.client),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
std::io::Error::other("timed out while refreshing OpenAI API key")
|
||||
})?
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
let updated_auth_dot_json = update_tokens(
|
||||
&self.auth_file,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokens = updated_auth_dot_json
|
||||
.tokens
|
||||
.clone()
|
||||
.ok_or(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
))?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let mut auth_lock = self.auth_dot_json.lock().unwrap();
|
||||
*auth_lock = Some(updated_auth_dot_json);
|
||||
}
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
_ => Err(std::io::Error::other("Token data is not available.")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_token(&self) -> Result<String, std::io::Error> {
|
||||
match self.mode {
|
||||
AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()),
|
||||
AuthMode::ChatGPT => {
|
||||
let id_token = self.get_token_data().await?.access_token;
|
||||
Ok(id_token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
self.auth_dot_json.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens)
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
pub fn create_dummy_chatgpt_auth_for_testing() -> Self {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: Default::default(),
|
||||
access_token: "Access Token".to_string(),
|
||||
refresh_token: "test".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
}),
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
let auth_dot_json = Arc::new(Mutex::new(Some(auth_dot_json)));
|
||||
Self {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json,
|
||||
client: crate::default_client::create_client(),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_api_key_with_client(api_key: &str, client: reqwest::Client) -> Self {
|
||||
Self {
|
||||
api_key: Some(api_key.to_owned()),
|
||||
mode: AuthMode::ApiKey,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json: Arc::new(Mutex::new(None)),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_api_key(api_key: &str) -> Self {
|
||||
Self::from_api_key_with_client(api_key, crate::default_client::create_client())
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
pub fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join("auth.json")
|
||||
}
|
||||
|
||||
/// Delete the auth.json file inside `codex_home` if it exists. Returns `Ok(true)`
|
||||
/// if a file was removed, `Ok(false)` if no auth file was present.
|
||||
pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
match std::fs::remove_file(&auth_file) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an `auth.json` that contains only the API key.
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
fn load_auth(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let client = crate::default_client::create_client();
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let AuthDotJson {
|
||||
openai_api_key: auth_json_api_key,
|
||||
tokens,
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// Prefer AuthMode.ApiKey if it's set in the auth.json.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file,
|
||||
auth_dot_json: Arc::new(Mutex::new(Some(AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens,
|
||||
last_refresh,
|
||||
}))),
|
||||
client,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
||||
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
||||
pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||
let mut file = File::open(auth_file)?;
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents)?;
|
||||
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
||||
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
pub fn write_auth_json(auth_file: &Path, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
|
||||
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
|
||||
let mut options = OpenOptions::new();
|
||||
options.truncate(true).write(true).create(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
options.mode(0o600);
|
||||
}
|
||||
let mut file = options.open(auth_file)?;
|
||||
file.write_all(json_data.as_bytes())?;
|
||||
file.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_tokens(
|
||||
auth_file: &Path,
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
) -> std::io::Result<AuthDotJson> {
|
||||
let mut auth_dot_json = try_read_auth_json(auth_file)?;
|
||||
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token;
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token;
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
|
||||
async fn try_refresh_token(
|
||||
refresh_token: String,
|
||||
client: &reqwest::Client,
|
||||
) -> std::io::Result<RefreshResponse> {
|
||||
let refresh_request = RefreshRequest {
|
||||
client_id: CLIENT_ID,
|
||||
grant_type: "refresh_token",
|
||||
refresh_token,
|
||||
scope: "openid profile email",
|
||||
};
|
||||
|
||||
// Use shared client factory to include standard headers
|
||||
let response = client
|
||||
.post("https://auth.openai.com/oauth/token")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&refresh_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let refresh_response = response
|
||||
.json::<RefreshResponse>()
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(refresh_response)
|
||||
} else {
|
||||
Err(std::io::Error::other(format!(
|
||||
"Failed to refresh token: {}",
|
||||
response.status()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest {
|
||||
client_id: &'static str,
|
||||
grant_type: &'static str,
|
||||
refresh_token: String,
|
||||
scope: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct RefreshResponse {
|
||||
id_token: String,
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Expected structure for $CODEX_HOME/auth.json.
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
pub struct AuthDotJson {
|
||||
#[serde(rename = "OPENAI_API_KEY")]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<TokenData>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// Shared constant for token refresh (client id used for oauth token refresh flow)
|
||||
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
|
||||
use std::sync::RwLock;
|
||||
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::token_data::IdTokenInfo;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use base64::Engine;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
|
||||
const LAST_REFRESH: &str = "2025-08-06T20:41:36.232376Z";
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_auth_dot_json() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let _ = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let file = get_auth_file(codex_home.path());
|
||||
let auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
write_auth_json(&file, &auth_dot_json).unwrap();
|
||||
|
||||
let same_auth_dot_json = try_read_auth_json(&file).unwrap();
|
||||
assert_eq!(auth_dot_json, same_auth_dot_json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn login_with_api_key_overwrites_existing_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let auth_path = dir.path().join("auth.json");
|
||||
let stale_auth = json!({
|
||||
"OPENAI_API_KEY": "sk-old",
|
||||
"tokens": {
|
||||
"id_token": "stale.header.payload",
|
||||
"access_token": "stale-access",
|
||||
"refresh_token": "stale-refresh",
|
||||
"account_id": "stale-acc"
|
||||
}
|
||||
});
|
||||
std::fs::write(
|
||||
&auth_path,
|
||||
serde_json::to_string_pretty(&stale_auth).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
super::login_with_api_key(dir.path(), "sk-new").expect("login_with_api_key should succeed");
|
||||
|
||||
let auth = super::try_read_auth_json(&auth_path).expect("auth.json should parse");
|
||||
assert_eq!(auth.openai_api_key.as_deref(), Some("sk-new"));
|
||||
assert!(auth.tokens.is_none(), "tokens should be cleared");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_no_api_key_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: None,
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path()).unwrap().unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let auth_file = dir.path().join("auth.json");
|
||||
std::fs::write(
|
||||
auth_file,
|
||||
r#"{"OPENAI_API_KEY":"sk-test-key","tokens":null,"last_refresh":null}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = super::load_auth(dir.path()).unwrap().unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
assert!(auth.get_token_data().await.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logout_removes_auth_file() -> Result<(), std::io::Error> {
|
||||
let dir = tempdir()?;
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&get_auth_file(dir.path()), &auth_dot_json)?;
|
||||
assert!(dir.path().join("auth.json").exists());
|
||||
let removed = logout(dir.path())?;
|
||||
assert!(removed);
|
||||
assert!(!dir.path().join("auth.json").exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct AuthFileParams {
|
||||
openai_api_key: Option<String>,
|
||||
chatgpt_plan_type: String,
|
||||
}
|
||||
|
||||
fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result<String> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
// Create a minimal valid JWT for the id_token field.
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_account_id": "bc3618e3-489d-4d49-9362-1561dc53ba53",
|
||||
"chatgpt_plan_type": params.chatgpt_plan_type,
|
||||
"chatgpt_user_id": "user-12345",
|
||||
"user_id": "user-12345",
|
||||
}
|
||||
});
|
||||
let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b);
|
||||
let header_b64 = b64(&serde_json::to_vec(&header)?);
|
||||
let payload_b64 = b64(&serde_json::to_vec(&payload)?);
|
||||
let signature_b64 = b64(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let auth_json_data = json!({
|
||||
"OPENAI_API_KEY": params.openai_api_key,
|
||||
"tokens": {
|
||||
"id_token": fake_jwt,
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token"
|
||||
},
|
||||
"last_refresh": LAST_REFRESH,
|
||||
});
|
||||
let auth_json = serde_json::to_string_pretty(&auth_json_data)?;
|
||||
std::fs::write(auth_file, auth_json)?;
|
||||
Ok(fake_jwt)
|
||||
}
|
||||
}
|
||||
|
||||
/// Central manager providing a single source of truth for auth.json derived
|
||||
/// authentication data. It loads once (or on preference change) and then
|
||||
/// hands out cloned `CodexAuth` values so the rest of the program has a
|
||||
/// consistent snapshot.
|
||||
///
|
||||
/// External modifications to `auth.json` will NOT be observed until
|
||||
/// `reload()` is called explicitly. This matches the design goal of avoiding
|
||||
/// different parts of the program seeing inconsistent auth data mid‑run.
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
inner: RwLock<CachedAuth>,
|
||||
}
|
||||
|
||||
impl AuthManager {
|
||||
/// Create a new manager loading the initial auth using the provided
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home).ok().flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
inner: RwLock::new(CachedAuth { auth }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let cached = CachedAuth { auth: Some(auth) };
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
inner: RwLock::new(cached),
|
||||
})
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
pub fn auth(&self) -> Option<CodexAuth> {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Force a reload of the auth information from auth.json. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home).ok().flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn auths_equal(a: &Option<CodexAuth>, b: &Option<CodexAuth>) -> bool {
|
||||
match (a, b) {
|
||||
(None, None) => true,
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(codex_home: PathBuf) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
/// the auth state from disk so other components observe refreshed token.
|
||||
pub async fn refresh_token(&self) -> std::io::Result<Option<String>> {
|
||||
let auth = match self.auth() {
|
||||
Some(a) => a,
|
||||
None => return Ok(None),
|
||||
};
|
||||
match auth.refresh_token().await {
|
||||
Ok(token) => {
|
||||
// Reload to pick up persisted changes.
|
||||
self.reload();
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Log out by deleting the on‑disk auth.json (if present). Returns Ok(true)
|
||||
/// if a file was removed, Ok(false) if no auth file existed. On success,
|
||||
/// reloads the in‑memory auth cache so callers immediately observe the
|
||||
/// unauthenticated state.
|
||||
pub fn logout(&self) -> std::io::Result<bool> {
|
||||
let removed = super::auth::logout(&self.codex_home)?;
|
||||
// Always reload to clear any cached auth (even if file absent).
|
||||
self.reload();
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use tree_sitter::Node;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Tree;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
@@ -73,6 +74,9 @@ pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option<V
|
||||
}
|
||||
}
|
||||
|
||||
// Walk uses a stack (LIFO), so re-sort by position to restore source order.
|
||||
command_nodes.sort_by_key(Node::start_byte);
|
||||
|
||||
let mut commands = Vec::new();
|
||||
for node in command_nodes {
|
||||
if let Some(words) = parse_plain_command_from_node(node, src) {
|
||||
@@ -150,10 +154,10 @@ mod tests {
|
||||
let src = "ls && pwd; echo 'hi there' | wc -l";
|
||||
let cmds = parse_seq(src).unwrap();
|
||||
let expected: Vec<Vec<String>> = vec![
|
||||
vec!["wc".to_string(), "-l".to_string()],
|
||||
vec!["echo".to_string(), "hi there".to_string()],
|
||||
vec!["pwd".to_string()],
|
||||
vec!["ls".to_string()],
|
||||
vec!["pwd".to_string()],
|
||||
vec!["echo".to_string(), "hi there".to_string()],
|
||||
vec!["wc".to_string(), "-l".to_string()],
|
||||
];
|
||||
assert_eq!(cmds, expected);
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error_codes::CONTEXT_LENGTH_EXCEEDED;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
@@ -28,6 +29,19 @@ use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
// Minimal error body used to parse structured provider error codes on
|
||||
// Chat Completions non‑2xx responses.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ChatErrorBody {
|
||||
error: ChatErrorInner,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ChatErrorInner {
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
@@ -43,7 +57,107 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for c in items {
|
||||
match c {
|
||||
ReasoningItemContent::ReasoningText { text: t }
|
||||
| ReasoningItemContent::Text { text: t } => text.push_str(t),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
@@ -56,7 +170,24 @@ pub(crate) async fn stream_chat_completions(
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
messages.push(json!({"role": role, "content": text}));
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let mut msg = json!({"role": role, "content": text});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
@@ -64,7 +195,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
@@ -75,7 +206,13 @@ pub(crate) async fn stream_chat_completions(
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
@@ -84,7 +221,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
messages.push(json!({
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
@@ -93,7 +230,13 @@ pub(crate) async fn stream_chat_completions(
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
}));
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
@@ -180,6 +323,16 @@ pub(crate) async fn stream_chat_completions(
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
|
||||
// Attempt to parse a structured error and map known codes.
|
||||
if let Ok(parsed) = serde_json::from_str::<ChatErrorBody>(&body)
|
||||
&& parsed.error.code.as_deref() == Some(CONTEXT_LENGTH_EXCEEDED)
|
||||
{
|
||||
return Err(CodexErr::ContextLengthExceeded(
|
||||
parsed.error.message.unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
@@ -331,7 +484,10 @@ async fn process_chat_sse<S>(
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string());
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
@@ -350,12 +506,39 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
reasoning_text.push_str(&reasoning);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
reasoning_text.push_str(s);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
reasoning_text.push_str(s);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
@@ -511,32 +694,55 @@ where
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_delta = matches!(&item, codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant");
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_delta {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message { content, .. } =
|
||||
&item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText { text } => {
|
||||
Some(text)
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
@@ -563,6 +769,11 @@ where
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::error_codes::CONTEXT_LENGTH_EXCEEDED;
|
||||
use bytes::Bytes;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::AuthMode;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::prelude::*;
|
||||
use regex_lite::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
@@ -17,7 +23,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -28,6 +33,7 @@ use crate::client_common::ResponsesApiRequest;
|
||||
use crate::client_common::create_reasoning_param_for_request;
|
||||
use crate::client_common::create_text_param_for_request;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::create_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::UsageLimitReachedError;
|
||||
@@ -37,8 +43,9 @@ use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::RateLimitSnapshotEvent;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::user_agent::get_codex_user_agent;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -53,10 +60,12 @@ struct ErrorResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
|
||||
// Optional fields available on "usage_limit_reached" and "usage_not_included" errors
|
||||
plan_type: Option<String>,
|
||||
plan_type: Option<PlanType>,
|
||||
resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -66,8 +75,8 @@ pub struct ModelClient {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
effort: ReasoningEffortConfig,
|
||||
conversation_id: ConversationId,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
|
||||
@@ -76,16 +85,18 @@ impl ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> Self {
|
||||
let client = create_client();
|
||||
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
provider,
|
||||
session_id,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -97,6 +108,12 @@ impl ModelClient {
|
||||
.or_else(|| get_model_info(&self.config.model_family).map(|info| info.context_window))
|
||||
}
|
||||
|
||||
pub fn get_auto_compact_token_limit(&self) -> Option<i64> {
|
||||
self.config.model_auto_compact_token_limit.or_else(|| {
|
||||
get_model_info(&self.config.model_family).and_then(|info| info.auto_compact_token_limit)
|
||||
})
|
||||
}
|
||||
|
||||
/// Dispatches to either the Responses or Chat implementation depending on
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
@@ -151,14 +168,6 @@ impl ModelClient {
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
|
||||
let auth_mode = auth_manager
|
||||
.as_ref()
|
||||
.and_then(|m| m.auth())
|
||||
.as_ref()
|
||||
.map(|a| a.mode);
|
||||
|
||||
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
|
||||
let tools_json = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
let reasoning = create_reasoning_param_for_request(
|
||||
@@ -167,9 +176,7 @@ impl ModelClient {
|
||||
self.summary,
|
||||
);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
let include: Vec<String> = if !store && reasoning.is_some() {
|
||||
let include: Vec<String> = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
@@ -190,6 +197,15 @@ impl ModelClient {
|
||||
None
|
||||
};
|
||||
|
||||
// In general, we want to explicitly send `store: false` when using the Responses API,
|
||||
// but in practice, the Azure Responses API rejects `store: false`:
|
||||
//
|
||||
// - If store = false and id is sent an error is thrown that ID is not found
|
||||
// - If store = false and id is not sent an error is thrown that ID is required
|
||||
//
|
||||
// For Azure, we send `store: true` and preserve reasoning item IDs.
|
||||
let azure_workaround = self.provider.is_azure_responses_endpoint();
|
||||
|
||||
let payload = ResponsesApiRequest {
|
||||
model: &self.config.model,
|
||||
instructions: &full_instructions,
|
||||
@@ -198,13 +214,19 @@ impl ModelClient {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
store,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
include,
|
||||
prompt_cache_key: Some(self.session_id.to_string()),
|
||||
prompt_cache_key: Some(self.conversation_id.to_string()),
|
||||
text,
|
||||
};
|
||||
|
||||
let mut payload_json = serde_json::to_value(&payload)?;
|
||||
if azure_workaround {
|
||||
attach_item_ids(&mut payload_json, &input_with_instructions);
|
||||
}
|
||||
let payload_body = serde_json::to_string(&payload_json)?;
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
|
||||
@@ -217,7 +239,7 @@ impl ModelClient {
|
||||
trace!(
|
||||
"POST to {}: {}",
|
||||
self.provider.get_full_url(&auth),
|
||||
serde_json::to_string(&payload)?
|
||||
payload_body.as_str()
|
||||
);
|
||||
|
||||
let mut req_builder = self
|
||||
@@ -227,9 +249,11 @@ impl ModelClient {
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
// Send session_id for compatibility.
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
.json(&payload_json);
|
||||
|
||||
if let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
@@ -238,17 +262,13 @@ impl ModelClient {
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let originator = &self.config.responses_originator_header;
|
||||
req_builder = req_builder.header("originator", originator);
|
||||
req_builder = req_builder.header("User-Agent", get_codex_user_agent(Some(originator)));
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
@@ -258,6 +278,15 @@ impl ModelClient {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
debug!("receiver dropped rate limit snapshot event");
|
||||
}
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_sse(
|
||||
@@ -310,7 +339,7 @@ impl ModelClient {
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.and_then(|a| a.get_plan_type()));
|
||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
@@ -361,7 +390,7 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
pub fn get_reasoning_effort(&self) -> ReasoningEffortConfig {
|
||||
pub fn get_reasoning_effort(&self) -> Option<ReasoningEffortConfig> {
|
||||
self.effort
|
||||
}
|
||||
|
||||
@@ -406,9 +435,15 @@ impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val.input_tokens_details.map(|d| d.cached_tokens),
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val.output_tokens_details.map(|d| d.reasoning_tokens),
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
@@ -424,6 +459,65 @@ struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: u64,
|
||||
}
|
||||
|
||||
fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
|
||||
let Some(input_value) = payload_json.get_mut("input") else {
|
||||
return;
|
||||
};
|
||||
let serde_json::Value::Array(items) = input_value else {
|
||||
return;
|
||||
};
|
||||
|
||||
for (value, item) in items.iter_mut().zip(original_items.iter()) {
|
||||
if let ResponseItem::Reasoning { id, .. }
|
||||
| ResponseItem::Message { id: Some(id), .. }
|
||||
| ResponseItem::WebSearchCall { id: Some(id), .. }
|
||||
| ResponseItem::FunctionCall { id: Some(id), .. }
|
||||
| ResponseItem::LocalShellCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCall { id: Some(id), .. } = item
|
||||
{
|
||||
if id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(obj) = value.as_object_mut() {
|
||||
obj.insert("id".to_string(), Value::String(id.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshotEvent> {
|
||||
let primary_used_percent = parse_header_f64(headers, "x-codex-primary-used-percent")?;
|
||||
let secondary_used_percent = parse_header_f64(headers, "x-codex-secondary-used-percent")?;
|
||||
let primary_to_secondary_ratio_percent =
|
||||
parse_header_f64(headers, "x-codex-primary-over-secondary-limit-percent")?;
|
||||
let primary_window_minutes = parse_header_u64(headers, "x-codex-primary-window-minutes")?;
|
||||
let secondary_window_minutes = parse_header_u64(headers, "x-codex-secondary-window-minutes")?;
|
||||
|
||||
Some(RateLimitSnapshotEvent {
|
||||
primary_used_percent,
|
||||
secondary_used_percent,
|
||||
primary_to_secondary_ratio_percent,
|
||||
primary_window_minutes,
|
||||
secondary_window_minutes,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
parse_header_str(headers, name)?
|
||||
.parse::<f64>()
|
||||
.ok()
|
||||
.filter(|v| v.is_finite())
|
||||
}
|
||||
|
||||
fn parse_header_u64(headers: &HeaderMap, name: &str) -> Option<u64> {
|
||||
parse_header_str(headers, name)?.parse::<u64>().ok()
|
||||
}
|
||||
|
||||
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
|
||||
headers.get(name)?.to_str().ok()
|
||||
}
|
||||
|
||||
async fn process_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
@@ -564,8 +658,13 @@ async fn process_sse<S>(
|
||||
if let Some(error) = error {
|
||||
match serde_json::from_value::<Error>(error.clone()) {
|
||||
Ok(error) => {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, None));
|
||||
if error.code.as_deref() == Some(CONTEXT_LENGTH_EXCEEDED) {
|
||||
response_error = Some(CodexErr::ContextLengthExceeded(message));
|
||||
} else {
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("failed to parse ErrorResponse: {e}");
|
||||
@@ -651,6 +750,40 @@ async fn stream_from_fixture(
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static Regex {
|
||||
static RE: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").unwrap())
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
if err.code != Some("rate_limit_exceeded".to_string()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// parse the Please try again in 1.898s format using regex
|
||||
let re = rate_limit_regex();
|
||||
if let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str();
|
||||
|
||||
if unit == "s" {
|
||||
return Some(Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -871,7 +1004,7 @@ mod tests {
|
||||
msg,
|
||||
"Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
|
||||
);
|
||||
assert_eq!(*delay, None);
|
||||
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
@@ -975,4 +1108,64 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_in_seconds: None
|
||||
};
|
||||
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_millis(28)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_no_delay() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_in_seconds: None
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_known_plan_type_and_serializes_back() {
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json = r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_in_seconds":3600}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_unknown_plan_type_and_serializes_back() {
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json =
|
||||
r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_in_seconds":60}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
use crate::config_types::Verbosity as VerbosityConfig;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::protocol::RateLimitSnapshotEvent;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use std::borrow::Cow;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// The `instructions` field in the payload sent to a model should always start
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
|
||||
/// wraps user instructions message in a tag for the model to parse more easily.
|
||||
const USER_INSTRUCTIONS_START: &str = "<user_instructions>\n\n";
|
||||
const USER_INSTRUCTIONS_END: &str = "\n\n</user_instructions>";
|
||||
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
|
||||
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -30,27 +26,23 @@ pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub tools: Vec<OpenAiTool>,
|
||||
pub(crate) tools: Vec<OpenAiTool>,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions(&self, model: &ModelFamily) -> Cow<'_, str> {
|
||||
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(BASE_INSTRUCTIONS);
|
||||
let mut sections: Vec<&str> = vec![base];
|
||||
|
||||
// When there are no custom instructions, add apply_patch_tool_instructions if either:
|
||||
// - the model needs special instructions (4.1), or
|
||||
.unwrap_or(model.base_instructions.deref());
|
||||
// When there are no custom instructions, add apply_patch_tool_instructions if:
|
||||
// - the model needs special instructions (4.1)
|
||||
// AND
|
||||
// - there is no apply_patch tool present
|
||||
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
OpenAiTool::Function(f) => f.name == "apply_patch",
|
||||
@@ -58,27 +50,18 @@ impl Prompt {
|
||||
_ => false,
|
||||
});
|
||||
if self.base_instructions_override.is_none()
|
||||
&& (model.needs_special_apply_patch_instructions || !is_apply_patch_tool_present)
|
||||
&& model.needs_special_apply_patch_instructions
|
||||
&& !is_apply_patch_tool_present
|
||||
{
|
||||
sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS);
|
||||
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
|
||||
} else {
|
||||
Cow::Borrowed(base)
|
||||
}
|
||||
Cow::Owned(sections.join("\n"))
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
}
|
||||
|
||||
/// Creates a formatted user instructions message from a string
|
||||
pub(crate) fn format_user_instructions_message(ui: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -96,12 +79,15 @@ pub enum ResponseEvent {
|
||||
WebSearchCallBegin {
|
||||
call_id: String,
|
||||
},
|
||||
RateLimits(RateLimitSnapshotEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
pub(crate) effort: ReasoningEffortConfig,
|
||||
pub(crate) summary: ReasoningSummaryConfig,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
/// Controls under the `text` field in the Responses API for GPT-5.
|
||||
@@ -144,7 +130,6 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
@@ -156,14 +141,17 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
|
||||
pub(crate) fn create_reasoning_param_for_request(
|
||||
model_family: &ModelFamily,
|
||||
effort: ReasoningEffortConfig,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
) -> Option<Reasoning> {
|
||||
if model_family.supports_reasoning_summaries {
|
||||
Some(Reasoning { effort, summary })
|
||||
} else {
|
||||
None
|
||||
if !model_family.supports_reasoning_summaries {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Reasoning {
|
||||
effort,
|
||||
summary: Some(summary),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn create_text_param_for_request(
|
||||
@@ -174,7 +162,7 @@ pub(crate) fn create_text_param_for_request(
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) struct ResponseStream {
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
@@ -189,18 +177,64 @@ impl Stream for ResponseStream {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct InstructionsTestCase {
|
||||
pub slug: &'static str,
|
||||
pub expects_apply_patch_instructions: bool,
|
||||
}
|
||||
#[test]
|
||||
fn get_full_instructions_no_user_content() {
|
||||
let prompt = Prompt {
|
||||
..Default::default()
|
||||
};
|
||||
let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}");
|
||||
let model_family = find_family_for_model("gpt-4.1").expect("known model slug");
|
||||
let full = prompt.get_full_instructions(&model_family);
|
||||
assert_eq!(full, expected);
|
||||
let test_cases = vec![
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-3.5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4.1",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4o",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "codex-mini-latest",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-oss:120b",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5-codex",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
];
|
||||
for test_case in test_cases {
|
||||
let model_family = find_family_for_model(test_case.slug).expect("known model slug");
|
||||
let expected = if test_case.expects_apply_patch_instructions {
|
||||
format!(
|
||||
"{}\n{}",
|
||||
model_family.clone().base_instructions,
|
||||
APPLY_PATCH_TOOL_INSTRUCTIONS
|
||||
)
|
||||
} else {
|
||||
model_family.clone().base_instructions
|
||||
};
|
||||
|
||||
let full = prompt.get_full_instructions(&model_family);
|
||||
assert_eq!(full, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -215,7 +249,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
@@ -245,7 +279,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
473
codex-rs/core/src/codex/compact.rs
Normal file
473
codex-rs/core/src/codex/compact.rs
Normal file
@@ -0,0 +1,473 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::AgentTask;
|
||||
use super::Session;
|
||||
use super::TurnContext;
|
||||
use super::get_last_assistant_message_from_turn;
|
||||
use crate::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::ErrorEvent;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
use crate::truncate::truncate_middle;
|
||||
use crate::util::backoff;
|
||||
use askama::Template;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use futures::prelude::*;
|
||||
|
||||
pub(super) const COMPACT_TRIGGER_TEXT: &str = "Start Summarization";
|
||||
const SUMMARIZATION_PROMPT: &str = include_str!("../../templates/compact/prompt.md");
|
||||
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "compact/history_bridge.md", escape = "none")]
|
||||
struct HistoryBridgeTemplate<'a> {
|
||||
user_messages_text: &'a str,
|
||||
summary_text: &'a str,
|
||||
}
|
||||
|
||||
pub(super) async fn spawn_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
let task = AgentTask::compact(
|
||||
sess.clone(),
|
||||
turn_context,
|
||||
sub_id,
|
||||
input,
|
||||
SUMMARIZATION_PROMPT.to_string(),
|
||||
);
|
||||
sess.set_task(task).await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_inline_auto_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) {
|
||||
let sub_id = sess.next_internal_sub_id();
|
||||
let input = vec![InputItem::Text {
|
||||
text: COMPACT_TRIGGER_TEXT.to_string(),
|
||||
}];
|
||||
run_compact_task_inner(
|
||||
sess,
|
||||
turn_context,
|
||||
sub_id,
|
||||
input,
|
||||
SUMMARIZATION_PROMPT.to_string(),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
) {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(start_event).await;
|
||||
run_compact_task_inner(
|
||||
sess.clone(),
|
||||
turn_context,
|
||||
sub_id.clone(),
|
||||
input,
|
||||
compact_instructions,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn run_compact_task_inner(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
remove_task_on_completion: bool,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let instructions_override = compact_instructions;
|
||||
|
||||
// Build an in-memory snapshot of the current history; attempts will pop
|
||||
// from this vector on context-length errors without modifying the session
|
||||
// transcript.
|
||||
let mut working_history = {
|
||||
let state = sess.state.lock().await;
|
||||
state.history.contents()
|
||||
};
|
||||
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
// Build prompt input = history + compact trigger
|
||||
let mut turn_input: Vec<ResponseItem> = Vec::with_capacity(working_history.len() + 1);
|
||||
turn_input.extend_from_slice(&working_history);
|
||||
turn_input.push(initial_input_for_turn.clone().into());
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
tools: Vec::new(),
|
||||
base_instructions_override: Some(instructions_override.clone()),
|
||||
};
|
||||
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
break;
|
||||
}
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Special-case compaction overflows: trim and retry immediately with no backoff.
|
||||
if matches!(e, CodexErr::ContextLengthExceeded(_)) {
|
||||
if working_history.pop().is_some() {
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
"compact input exceeds context window; retrying with 1 fewer item…",
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
} else {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message:
|
||||
"Unable to compact: context window too small for any history"
|
||||
.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
continue;
|
||||
} else {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if remove_task_on_completion {
|
||||
sess.remove_task(&sub_id).await;
|
||||
}
|
||||
let history_snapshot = {
|
||||
let state = sess.state.lock().await;
|
||||
state.history.contents()
|
||||
};
|
||||
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||
let user_messages = collect_user_messages(&history_snapshot);
|
||||
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
||||
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.history.replace(new_history);
|
||||
}
|
||||
|
||||
let rollout_item = RolloutItem::Compacted(CompactedItem {
|
||||
message: summary_text.clone(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
||||
let mut pieces = Vec::new();
|
||||
for item in content {
|
||||
match item {
|
||||
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
|
||||
if !text.is_empty() {
|
||||
pieces.push(text.as_str());
|
||||
}
|
||||
}
|
||||
ContentItem::InputImage { .. } => {}
|
||||
}
|
||||
}
|
||||
if pieces.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(pieces.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content_items_to_text(content)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.filter(|text| !is_session_prefix_message(text))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn is_session_prefix_message(text: &str) -> bool {
|
||||
matches!(
|
||||
InputMessageKind::from(("user", text)),
|
||||
InputMessageKind::UserInstructions | InputMessageKind::EnvironmentContext
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn build_compacted_history(
|
||||
initial_context: Vec<ResponseItem>,
|
||||
user_messages: &[String],
|
||||
summary_text: &str,
|
||||
) -> Vec<ResponseItem> {
|
||||
let mut history = initial_context;
|
||||
let mut user_messages_text = if user_messages.is_empty() {
|
||||
"(none)".to_string()
|
||||
} else {
|
||||
user_messages.join("\n\n")
|
||||
};
|
||||
// Truncate the concatenated prior user messages so the bridge message
|
||||
// stays well under the context window (approx. 4 bytes/token).
|
||||
let max_bytes = COMPACT_USER_MESSAGE_MAX_TOKENS * 4;
|
||||
if user_messages_text.len() > max_bytes {
|
||||
user_messages_text = truncate_middle(&user_messages_text, max_bytes).0;
|
||||
}
|
||||
let summary_text = if summary_text.is_empty() {
|
||||
"(no summary available)".to_string()
|
||||
} else {
|
||||
summary_text.to_string()
|
||||
};
|
||||
let Ok(bridge) = HistoryBridgeTemplate {
|
||||
user_messages_text: &user_messages_text,
|
||||
summary_text: &summary_text,
|
||||
}
|
||||
.render() else {
|
||||
return vec![];
|
||||
};
|
||||
history.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: bridge }],
|
||||
});
|
||||
history
|
||||
}
|
||||
|
||||
async fn drain_to_completed(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
let mut state = sess.state.lock().await;
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
}
|
||||
Ok(ResponseEvent::Completed { .. }) => {
|
||||
return Ok(());
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn content_items_to_text_joins_non_empty_segments() {
|
||||
let items = vec![
|
||||
ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
},
|
||||
ContentItem::OutputText {
|
||||
text: String::new(),
|
||||
},
|
||||
ContentItem::OutputText {
|
||||
text: "world".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let joined = content_items_to_text(&items);
|
||||
|
||||
assert_eq!(Some("hello\nworld".to_string()), joined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_items_to_text_ignores_image_only_content() {
|
||||
let items = vec![ContentItem::InputImage {
|
||||
image_url: "file://image.png".to_string(),
|
||||
}];
|
||||
|
||||
let joined = content_items_to_text(&items);
|
||||
|
||||
assert_eq!(None, joined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_extracts_user_text_only() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: Some("assistant".to_string()),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "ignored".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: Some("user".to_string()),
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "first".to_string(),
|
||||
},
|
||||
ContentItem::OutputText {
|
||||
text: "second".to_string(),
|
||||
},
|
||||
],
|
||||
},
|
||||
ResponseItem::Other,
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(vec!["first\nsecond".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_filters_session_prefix_entries() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<user_instructions>do things</user_instructions>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<ENVIRONMENT_CONTEXT>cwd=/tmp</ENVIRONMENT_CONTEXT>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "real user message".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(vec!["real user message".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_compacted_history_truncates_overlong_user_messages() {
|
||||
// Prepare a very large prior user message so the aggregated
|
||||
// `user_messages_text` exceeds the truncation threshold used by
|
||||
// `build_compacted_history` (80k bytes).
|
||||
let big = "X".repeat(200_000);
|
||||
let history = build_compacted_history(Vec::new(), std::slice::from_ref(&big), "SUMMARY");
|
||||
|
||||
// Expect exactly one bridge message added to history (plus any initial context we provided, which is none).
|
||||
assert_eq!(history.len(), 1);
|
||||
|
||||
// Extract the text content of the bridge message.
|
||||
let bridge_text = match &history[0] {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content_items_to_text(content).unwrap_or_default()
|
||||
}
|
||||
other => panic!("unexpected item in history: {other:?}"),
|
||||
};
|
||||
|
||||
// The bridge should contain the truncation marker and not the full original payload.
|
||||
assert!(
|
||||
bridge_text.contains("tokens truncated"),
|
||||
"expected truncation marker in bridge message"
|
||||
);
|
||||
assert!(
|
||||
!bridge_text.contains(&big),
|
||||
"bridge should not include the full oversized user text"
|
||||
);
|
||||
assert!(
|
||||
bridge_text.contains("SUMMARY"),
|
||||
"bridge should include the provided summary text"
|
||||
);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
748
codex-rs/core/src/config_edit.rs
Normal file
748
codex-rs/core/src/config_edit.rs
Normal file
@@ -0,0 +1,748 @@
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use tempfile::NamedTempFile;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
pub const CONFIG_KEY_MODEL: &str = "model";
|
||||
pub const CONFIG_KEY_EFFORT: &str = "model_reasoning_effort";
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum NoneBehavior {
|
||||
Skip,
|
||||
Remove,
|
||||
}
|
||||
|
||||
/// Persist overrides into `config.toml` using explicit key segments per
|
||||
/// override. This avoids ambiguity with keys that contain dots or spaces.
|
||||
pub async fn persist_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], &str)],
|
||||
) -> Result<()> {
|
||||
let with_options: Vec<(&[&str], Option<&str>)> = overrides
|
||||
.iter()
|
||||
.map(|(segments, value)| (*segments, Some(*value)))
|
||||
.collect();
|
||||
|
||||
persist_overrides_with_behavior(codex_home, profile, &with_options, NoneBehavior::Skip).await
|
||||
}
|
||||
|
||||
/// Persist overrides where values may be optional. Any entries with `None`
|
||||
/// values are skipped. If all values are `None`, this becomes a no-op and
|
||||
/// returns `Ok(())` without touching the file.
|
||||
pub async fn persist_non_null_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
) -> Result<()> {
|
||||
persist_overrides_with_behavior(codex_home, profile, overrides, NoneBehavior::Skip).await
|
||||
}
|
||||
|
||||
/// Persist overrides where `None` values clear any existing values from the
|
||||
/// configuration file.
|
||||
pub async fn persist_overrides_and_clear_if_none(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
) -> Result<()> {
|
||||
persist_overrides_with_behavior(codex_home, profile, overrides, NoneBehavior::Remove).await
|
||||
}
|
||||
|
||||
/// Apply a single override onto a `toml_edit` document while preserving
|
||||
/// existing formatting/comments.
|
||||
/// The key is expressed as explicit segments to correctly handle keys that
|
||||
/// contain dots or spaces.
|
||||
fn apply_toml_edit_override_segments(
|
||||
doc: &mut DocumentMut,
|
||||
segments: &[&str],
|
||||
value: toml_edit::Item,
|
||||
) {
|
||||
use toml_edit::Item;
|
||||
|
||||
if segments.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut current = doc.as_table_mut();
|
||||
for seg in &segments[..segments.len() - 1] {
|
||||
if !current.contains_key(seg) {
|
||||
current[*seg] = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = current[*seg].as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let maybe_item = current.get_mut(seg);
|
||||
let Some(item) = maybe_item else { return };
|
||||
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = item.as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tbl) = item.as_table_mut() else {
|
||||
return;
|
||||
};
|
||||
current = tbl;
|
||||
}
|
||||
|
||||
let last = segments[segments.len() - 1];
|
||||
current[last] = value;
|
||||
}
|
||||
|
||||
async fn persist_overrides_with_behavior(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
none_behavior: NoneBehavior,
|
||||
) -> Result<()> {
|
||||
if overrides.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let should_skip = match none_behavior {
|
||||
NoneBehavior::Skip => overrides.iter().all(|(_, value)| value.is_none()),
|
||||
NoneBehavior::Remove => false,
|
||||
};
|
||||
|
||||
if should_skip {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
let read_result = tokio::fs::read_to_string(&config_path).await;
|
||||
let mut doc = match read_result {
|
||||
Ok(contents) => contents.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
if overrides
|
||||
.iter()
|
||||
.all(|(_, value)| value.is_none() && matches!(none_behavior, NoneBehavior::Remove))
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tokio::fs::create_dir_all(codex_home).await?;
|
||||
DocumentMut::new()
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let effective_profile = if let Some(p) = profile {
|
||||
Some(p.to_owned())
|
||||
} else {
|
||||
doc.get("profile")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(str::to_string)
|
||||
};
|
||||
|
||||
let mut mutated = false;
|
||||
|
||||
for (segments, value) in overrides.iter().copied() {
|
||||
let mut seg_buf: Vec<&str> = Vec::new();
|
||||
let segments_to_apply: &[&str];
|
||||
|
||||
if let Some(ref name) = effective_profile {
|
||||
if segments.first().copied() == Some("profiles") {
|
||||
segments_to_apply = segments;
|
||||
} else {
|
||||
seg_buf.reserve(2 + segments.len());
|
||||
seg_buf.push("profiles");
|
||||
seg_buf.push(name.as_str());
|
||||
seg_buf.extend_from_slice(segments);
|
||||
segments_to_apply = seg_buf.as_slice();
|
||||
}
|
||||
} else {
|
||||
segments_to_apply = segments;
|
||||
}
|
||||
|
||||
match value {
|
||||
Some(v) => {
|
||||
let item_value = toml_edit::value(v);
|
||||
apply_toml_edit_override_segments(&mut doc, segments_to_apply, item_value);
|
||||
mutated = true;
|
||||
}
|
||||
None => {
|
||||
if matches!(none_behavior, NoneBehavior::Remove)
|
||||
&& remove_toml_edit_segments(&mut doc, segments_to_apply)
|
||||
{
|
||||
mutated = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !mutated {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
tokio::fs::write(tmp_file.path(), doc.to_string()).await?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_toml_edit_segments(doc: &mut DocumentMut, segments: &[&str]) -> bool {
|
||||
use toml_edit::Item;
|
||||
|
||||
if segments.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut current = doc.as_table_mut();
|
||||
for seg in &segments[..segments.len() - 1] {
|
||||
let Some(item) = current.get_mut(seg) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
match item {
|
||||
Item::Table(table) => {
|
||||
current = table;
|
||||
}
|
||||
_ => {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current.remove(segments[segments.len() - 1]).is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Verifies model and effort are written at top-level when no profile is set.
|
||||
#[tokio::test]
|
||||
async fn set_default_model_and_effort_top_level_when_no_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "high"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies values are written under the active profile when `profile` is set.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_when_profile_set() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile selection but without profiles table
|
||||
let seed = "profile = \"o3\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "o3"
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies profile names with dots/spaces are preserved via explicit segments.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_with_dot_and_space() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile name that contains a dot and a space
|
||||
let seed = "profile = \"my.team name\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "my.team name"
|
||||
|
||||
[profiles."my.team name"]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies explicit profile override writes under that profile even without active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_when_profile_override_supplied() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No profile key in config.toml
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), "")
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Persist with an explicit profile override
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("o3"),
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies nested tables are created as needed when applying overrides.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_creates_nested_tables() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&["a", "b", "c"], "v"),
|
||||
(&["x"], "y"),
|
||||
(&["profiles", "p1", CONFIG_KEY_MODEL], "gpt-5"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"x = "y"
|
||||
|
||||
[a.b]
|
||||
c = "v"
|
||||
|
||||
[profiles.p1]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies a scalar key becomes a table when nested keys are written.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_replaces_scalar_with_table() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
let seed = "foo = \"bar\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&["foo", "bar", "baz"], "ok")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[foo.bar]
|
||||
baz = "ok"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing under active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config with comments and spacing we expect to preserve
|
||||
let seed = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Apply defaults; since profile is set, it should write under [profiles.o3]
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing at top level.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_global_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config WITHOUT a profile, containing comments and spacing
|
||||
let seed = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Since there is no profile, the defaults should be written at top-level
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies errors on invalid TOML propagate and file is not clobbered.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_errors_on_parse_failure() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Write an intentionally invalid TOML file
|
||||
let invalid = "invalid = [unclosed";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), invalid)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Attempting to persist should return an error and must not clobber the file.
|
||||
let res = persist_overrides(codex_home, None, &[(&["x"], "y")]).await;
|
||||
assert!(res.is_err(), "expected parse error to propagate");
|
||||
|
||||
// File should be unchanged
|
||||
let contents = read_config(codex_home).await;
|
||||
assert_eq!(contents, invalid);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_existing_effort_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an effort value only
|
||||
let seed = "model_reasoning_effort = \"minimal\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the model
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o3")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model_reasoning_effort = "minimal"
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_existing_model_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with a model value only
|
||||
let seed = "model = \"gpt-5\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the effort
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_EFFORT], "high")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort in active profile.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_effort_in_active_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an active profile and an existing effort under that profile
|
||||
let seed = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o4-mini")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
model = "o4-mini"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model in a profile override.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_model_in_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No active profile key; we'll target an explicit override
|
||||
let seed = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[(&[CONFIG_KEY_EFFORT], "minimal")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies `persist_non_null_overrides` skips `None` entries and writes only present values at top-level.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_skips_none_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("gpt-5")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = "model = \"gpt-5\"\n";
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies no-op behavior when all provided overrides are `None` (no file created/modified).
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_noop_when_all_none() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&["a"], None), (&["profiles", "p", "x"], None)],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
// Should not create config.toml on a pure no-op
|
||||
assert!(!codex_home.join(CONFIG_TOML_FILE).exists());
|
||||
}
|
||||
|
||||
/// Verifies entries are written under the specified profile and `None` entries are skipped.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_respects_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("o3")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_clear_none_removes_top_level_value() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
let seed = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "medium"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides_and_clear_if_none(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], None),
|
||||
(&[CONFIG_KEY_EFFORT], Some("high")),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = "model_reasoning_effort = \"high\"\n";
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_clear_none_respects_active_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
let seed = r#"profile = "team"
|
||||
|
||||
[profiles.team]
|
||||
model = "gpt-4"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides_and_clear_if_none(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], None),
|
||||
(&[CONFIG_KEY_EFFORT], Some("high")),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "team"
|
||||
|
||||
[profiles.team]
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_clear_none_noop_when_file_missing() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides_and_clear_if_none(codex_home, None, &[(&[CONFIG_KEY_MODEL], None)])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
assert!(!codex_home.join(CONFIG_TOML_FILE).exists());
|
||||
}
|
||||
|
||||
// Test helper moved to bottom per review guidance.
|
||||
async fn read_config(codex_home: &Path) -> String {
|
||||
let p = codex_home.join(CONFIG_TOML_FILE);
|
||||
tokio::fs::read_to_string(p).await.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config_types::Verbosity;
|
||||
use crate::protocol::AskForApproval;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
|
||||
/// Collection of common configuration options that a user can define as a unit
|
||||
/// in `config.toml`.
|
||||
@@ -15,10 +15,23 @@ pub struct ConfigProfile {
|
||||
/// [`ModelProviderInfo`] to use.
|
||||
pub model_provider: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl From<ConfigProfile> for codex_protocol::mcp_protocol::Profile {
|
||||
fn from(config_profile: ConfigProfile) -> Self {
|
||||
Self {
|
||||
model: config_profile.model,
|
||||
model_provider: config_profile.model_provider,
|
||||
approval_policy: config_profile.approval_policy,
|
||||
model_reasoning_effort: config_profile.model_reasoning_effort,
|
||||
model_reasoning_summary: config_profile.model_reasoning_summary,
|
||||
model_verbosity: config_profile.model_verbosity,
|
||||
chatgpt_base_url: config_profile.chatgpt_base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,15 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use wildmatch::WildMatchPattern;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display;
|
||||
use serde::de::Error as SerdeError;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
pub struct McpServerConfig {
|
||||
pub command: String,
|
||||
|
||||
@@ -20,6 +22,85 @@ pub struct McpServerConfig {
|
||||
|
||||
#[serde(default)]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
|
||||
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
|
||||
#[serde(
|
||||
default,
|
||||
with = "option_duration_secs",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub startup_timeout_sec: Option<Duration>,
|
||||
|
||||
/// Default timeout for MCP tool calls initiated via this server.
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
pub tool_timeout_sec: Option<Duration>,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct RawMcpServerConfig {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
#[serde(default)]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
#[serde(default)]
|
||||
startup_timeout_ms: Option<u64>,
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
tool_timeout_sec: Option<Duration>,
|
||||
}
|
||||
|
||||
let raw = RawMcpServerConfig::deserialize(deserializer)?;
|
||||
|
||||
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
|
||||
(Some(sec), _) => {
|
||||
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
|
||||
Some(duration)
|
||||
}
|
||||
(None, Some(ms)) => Some(Duration::from_millis(ms)),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
command: raw.command,
|
||||
args: raw.args,
|
||||
env: raw.env,
|
||||
startup_timeout_sec,
|
||||
tool_timeout_sec: raw.tool_timeout_sec,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
mod option_duration_secs {
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serializer;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let secs = Option::<f64>::deserialize(deserializer)?;
|
||||
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
|
||||
.transpose()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||
@@ -74,9 +155,27 @@ pub enum HistoryPersistence {
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Notifications {
|
||||
Enabled(bool),
|
||||
Custom(Vec<String>),
|
||||
}
|
||||
|
||||
impl Default for Notifications {
|
||||
fn default() -> Self {
|
||||
Self::Enabled(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Tui {}
|
||||
pub struct Tui {
|
||||
/// Enable desktop notifications from the TUI when the terminal is unfocused.
|
||||
/// Defaults to `false`.
|
||||
#[serde(default)]
|
||||
pub notifications: Notifications,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct SandboxWorkspaceWrite {
|
||||
@@ -90,6 +189,17 @@ pub struct SandboxWorkspaceWrite {
|
||||
pub exclude_slash_tmp: bool,
|
||||
}
|
||||
|
||||
impl From<SandboxWorkspaceWrite> for codex_protocol::mcp_protocol::SandboxSettings {
|
||||
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
|
||||
Self {
|
||||
writable_roots: sandbox_workspace_write.writable_roots,
|
||||
network_access: Some(sandbox_workspace_write.network_access),
|
||||
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
|
||||
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ShellEnvironmentPolicyInherit {
|
||||
@@ -186,42 +296,10 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningEffort {
|
||||
Low,
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReasoningSummaryFormat {
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
/// Option to disable reasoning.
|
||||
None,
|
||||
}
|
||||
|
||||
/// A summary of the reasoning performed by the model. This can be useful for
|
||||
/// debugging and understanding the model's reasoning process.
|
||||
/// See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#reasoning-summaries
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningSummary {
|
||||
#[default]
|
||||
Auto,
|
||||
Concise,
|
||||
Detailed,
|
||||
/// Option to disable reasoning summaries.
|
||||
None,
|
||||
}
|
||||
|
||||
/// Controls output length/detail on GPT-5 models via the Responses API.
|
||||
/// Serialized with lowercase values to match the OpenAI API.
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq, Display)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum Verbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
Experimental,
|
||||
}
|
||||
|
||||
@@ -32,32 +32,8 @@ impl ConversationHistory {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn keep_last_messages(&mut self, n: usize) {
|
||||
if n == 0 {
|
||||
self.items.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect the last N message items (assistant/user), newest to oldest.
|
||||
let mut kept: Vec<ResponseItem> = Vec::with_capacity(n);
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
kept.push(ResponseItem::Message {
|
||||
// we need to remove the id or the model will complain that messages are sent without
|
||||
// their reasonings
|
||||
id: None,
|
||||
role: role.clone(),
|
||||
content: content.clone(),
|
||||
});
|
||||
if kept.len() == n {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve chronological order (oldest to newest) within the kept slice.
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
|
||||
self.items = items;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,8 +47,9 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. } => true,
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => false,
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::CodexSpawnOk;
|
||||
use crate::codex::INITIAL_SUBMIT_ID;
|
||||
use crate::codex::compact::content_items_to_text;
|
||||
use crate::codex::compact::is_session_prefix_message;
|
||||
use crate::codex_conversation::CodexConversation;
|
||||
use crate::config::Config;
|
||||
use crate::error::CodexErr;
|
||||
@@ -16,12 +12,20 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
pub conversation_id: Uuid,
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
@@ -29,7 +33,7 @@ pub struct NewConversation {
|
||||
/// [`ConversationManager`] is responsible for creating conversations and
|
||||
/// maintaining them in memory.
|
||||
pub struct ConversationManager {
|
||||
conversations: Arc<RwLock<HashMap<Uuid, Arc<CodexConversation>>>>,
|
||||
conversations: Arc<RwLock<HashMap<ConversationId, Arc<CodexConversation>>>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
@@ -44,7 +48,7 @@ impl ConversationManager {
|
||||
/// Construct with a dummy AuthManager containing the provided CodexAuth.
|
||||
/// Used for integration tests: should not be used by ordinary business logic.
|
||||
pub fn with_auth(auth: CodexAuth) -> Self {
|
||||
Self::new(codex_login::AuthManager::from_auth_for_testing(auth))
|
||||
Self::new(crate::AuthManager::from_auth_for_testing(auth))
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
@@ -59,18 +63,15 @@ impl ConversationManager {
|
||||
) -> CodexResult<NewConversation> {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = {
|
||||
let initial_history = None;
|
||||
Codex::spawn(config, auth_manager, initial_history).await?
|
||||
};
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, InitialHistory::New).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
async fn finalize_spawn(
|
||||
&self,
|
||||
codex: Codex,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
@@ -101,7 +102,7 @@ impl ConversationManager {
|
||||
|
||||
pub async fn get_conversation(
|
||||
&self,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<Arc<CodexConversation>> {
|
||||
let conversations = self.conversations.read().await;
|
||||
conversations
|
||||
@@ -110,71 +111,97 @@ impl ConversationManager {
|
||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
||||
self.conversations.write().await.remove(&conversation_id);
|
||||
pub async fn resume_conversation_from_rollout(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by dropping the last `drop_last_messages`
|
||||
/// user/assistant messages from its transcript and starting a new
|
||||
/// Removes the conversation from the manager's internal map, though the
|
||||
/// conversation is stored as `Arc<CodexConversation>`, it is possible that
|
||||
/// other references to it exist elsewhere. Returns the conversation if the
|
||||
/// conversation was found and removed.
|
||||
pub async fn remove_conversation(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
) -> Option<Arc<CodexConversation>> {
|
||||
self.conversations.write().await.remove(conversation_id)
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by taking messages up to the given position
|
||||
/// (not including the message at the given position) and starting a new
|
||||
/// conversation with identical configuration (unless overridden by the
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
conversation_history: Vec<ResponseItem>,
|
||||
num_messages_to_drop: usize,
|
||||
nth_user_message: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let truncated_history =
|
||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_before_nth_user_message(history, nth_user_message);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, Some(truncated_history)).await?;
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, history).await?;
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> Vec<ResponseItem> {
|
||||
if n == 0 || items.is_empty() {
|
||||
return items;
|
||||
}
|
||||
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
|
||||
/// (0-based) and all items that follow it.
|
||||
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
// Work directly on rollout items, and cut the vector at the nth user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
let mut count = 0usize;
|
||||
let mut cut_index = 0usize;
|
||||
for (idx, item) in items.iter().enumerate().rev() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = item
|
||||
&& role == "user"
|
||||
&& content_items_to_text(content).is_some_and(|text| !is_session_prefix_message(&text))
|
||||
{
|
||||
count += 1;
|
||||
if count == n {
|
||||
// Cut everything from this user message to the end.
|
||||
cut_index = idx;
|
||||
break;
|
||||
}
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
if count < n {
|
||||
// If fewer than n messages exist, drop everything.
|
||||
Vec::new()
|
||||
|
||||
// If fewer than or equal to n user messages exist, treat as empty (out of range).
|
||||
if user_positions.len() <= n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
} else {
|
||||
items.into_iter().take(cut_index).collect()
|
||||
InitialHistory::Forked(rolled)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
@@ -220,13 +247,60 @@ mod tests {
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
// Wrap as InitialHistory::Forked with response items only.
|
||||
let initial: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
let expected_items = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
assert_eq!(
|
||||
truncated,
|
||||
vec![items[0].clone(), items[1].clone(), items[2].clone()]
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected_items).unwrap()
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert!(truncated2.is_empty());
|
||||
let initial2: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_session_prefix_messages_when_truncating() {
|
||||
let (session, turn_context) = make_session_and_context();
|
||||
let mut items = session.build_initial_context(&turn_context);
|
||||
items.push(user_msg("feature request"));
|
||||
items.push(assistant_msg("ack"));
|
||||
items.push(user_msg("second question"));
|
||||
items.push(assistant_msg("answer"));
|
||||
|
||||
let rollout_items: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
|
||||
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
|
||||
let expected: Vec<RolloutItem> = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ pub async fn discover_prompts_in_excluding(
|
||||
let Some(name) = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.map(str::to_string)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
212
codex-rs/core/src/default_client.rs
Normal file
212
codex-rs/core/src/default_client.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
use reqwest::header::HeaderValue;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Set this to add a suffix to the User-Agent string.
|
||||
///
|
||||
/// It is not ideal that we're using a global singleton for this.
|
||||
/// This is primarily designed to differentiate MCP clients from each other.
|
||||
/// Because there can only be one MCP server per process, it should be safe for this to be a global static.
|
||||
/// However, future users of this should use this with caution as a result.
|
||||
/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a
|
||||
/// lot of wiring and it's easy to miss code paths by doing so.
|
||||
/// See https://github.com/openai/codex/pull/3388/files for an example of what that would look like.
|
||||
/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var
|
||||
/// or having to set data that they already specified in the mcp initialize request somewhere else.
|
||||
///
|
||||
/// A space is automatically added between the suffix and the rest of the User-Agent string.
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
pub header_value: HeaderValue,
|
||||
}
|
||||
|
||||
pub static ORIGINATOR: LazyLock<Originator> = LazyLock::new(|| {
|
||||
let default = "codex_cli_rs";
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.unwrap_or_else(|_| default.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
value,
|
||||
header_value,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: default.to_string(),
|
||||
header_value: HeaderValue::from_static(default),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
let prefix = format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
ORIGINATOR.value.as_str(),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
);
|
||||
let suffix = USER_AGENT_SUFFIX
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|guard| guard.clone());
|
||||
let suffix = suffix
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map_or_else(String::new, |value| format!(" ({value})"));
|
||||
|
||||
let candidate = format!("{prefix}{suffix}");
|
||||
sanitize_user_agent(candidate, &prefix)
|
||||
}
|
||||
|
||||
/// Sanitize the user agent string.
|
||||
///
|
||||
/// Invalid characters are replaced with an underscore.
|
||||
///
|
||||
/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR.
|
||||
fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
|
||||
if HeaderValue::from_str(candidate.as_str()).is_ok() {
|
||||
return candidate;
|
||||
}
|
||||
|
||||
let sanitized: String = candidate
|
||||
.chars()
|
||||
.map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' })
|
||||
.collect();
|
||||
if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() {
|
||||
tracing::warn!(
|
||||
"Sanitized Codex user agent because provided suffix contained invalid header characters"
|
||||
);
|
||||
sanitized
|
||||
} else if HeaderValue::from_str(fallback).is_ok() {
|
||||
tracing::warn!(
|
||||
"Falling back to base Codex user agent because provided suffix could not be sanitized"
|
||||
);
|
||||
fallback.to_string()
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Falling back to default Codex originator because base user agent string is invalid"
|
||||
);
|
||||
ORIGINATOR.value.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a reqwest client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client() -> reqwest::Client {
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("originator", ORIGINATOR.header_value.clone());
|
||||
let ua = get_codex_user_agent();
|
||||
|
||||
reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent();
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_client_sets_default_headers() {
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let client = create_client();
|
||||
|
||||
// Spin up a local mock server and capture a request.
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let resp = client
|
||||
.get(server.uri())
|
||||
.send()
|
||||
.await
|
||||
.expect("failed to send request");
|
||||
assert!(resp.status().is_success());
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch received requests");
|
||||
assert!(!requests.is_empty());
|
||||
let headers = &requests[0].headers;
|
||||
|
||||
// originator header is set to the provided value
|
||||
let originator_header = headers
|
||||
.get("originator")
|
||||
.expect("originator header missing");
|
||||
assert_eq!(originator_header.to_str().unwrap(), "codex_cli_rs");
|
||||
|
||||
// User-Agent matches the computed Codex UA for that originator
|
||||
let expected_ua = get_codex_user_agent();
|
||||
let ua_header = headers
|
||||
.get("user-agent")
|
||||
.expect("user-agent header missing");
|
||||
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\rsuffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized2() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\0suffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent();
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(re.is_match(&user_agent));
|
||||
}
|
||||
}
|
||||
@@ -2,18 +2,17 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display as DeriveDisplay;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::shell::Shell;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_CLOSE_TAG;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>";
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_END: &str = "</environment_context>";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
@@ -28,6 +27,7 @@ pub(crate) struct EnvironmentContext {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
}
|
||||
|
||||
@@ -59,9 +59,52 @@ impl EnvironmentContext {
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
writable_roots: match sandbox_policy {
|
||||
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
|
||||
if writable_roots.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(writable_roots)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compares two environment contexts, ignoring the shell. Useful when
|
||||
/// comparing turn to turn, since the initial environment_context will
|
||||
/// include the shell, and then it is not configurable from turn to turn.
|
||||
pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool {
|
||||
let EnvironmentContext {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
network_access,
|
||||
writable_roots,
|
||||
// should compare all fields except shell
|
||||
shell: _,
|
||||
} = other;
|
||||
|
||||
self.cwd == *cwd
|
||||
&& self.approval_policy == *approval_policy
|
||||
&& self.sandbox_mode == *sandbox_mode
|
||||
&& self.network_access == *network_access
|
||||
&& self.writable_roots == *writable_roots
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&TurnContext> for EnvironmentContext {
|
||||
fn from(turn_context: &TurnContext) -> Self {
|
||||
Self::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
@@ -74,12 +117,13 @@ impl EnvironmentContext {
|
||||
/// <cwd>...</cwd>
|
||||
/// <approval_policy>...</approval_policy>
|
||||
/// <sandbox_mode>...</sandbox_mode>
|
||||
/// <writable_roots>...</writable_roots>
|
||||
/// <network_access>...</network_access>
|
||||
/// <shell>...</shell>
|
||||
/// </environment_context>
|
||||
/// ```
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_START.to_string()];
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()];
|
||||
if let Some(cwd) = self.cwd {
|
||||
lines.push(format!(" <cwd>{}</cwd>", cwd.to_string_lossy()));
|
||||
}
|
||||
@@ -96,12 +140,22 @@ impl EnvironmentContext {
|
||||
" <network_access>{network_access}</network_access>"
|
||||
));
|
||||
}
|
||||
if let Some(writable_roots) = self.writable_roots {
|
||||
lines.push(" <writable_roots>".to_string());
|
||||
for writable_root in writable_roots {
|
||||
lines.push(format!(
|
||||
" <root>{}</root>",
|
||||
writable_root.to_string_lossy()
|
||||
));
|
||||
}
|
||||
lines.push(" </writable_roots>".to_string());
|
||||
}
|
||||
if let Some(shell) = self.shell
|
||||
&& let Some(shell_name) = shell.name()
|
||||
{
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
}
|
||||
lines.push(ENVIRONMENT_CONTEXT_END.to_string());
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
@@ -117,3 +171,158 @@ impl From<EnvironmentContext> for ResponseItem {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::shell::BashShell;
|
||||
use crate::shell::ZshShell;
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
|
||||
network_access,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_workspace_write_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<cwd>/repo</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
<writable_roots>
|
||||
<root>/repo</root>
|
||||
<root>/tmp</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_read_only_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_full_access_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>on-failure</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_approval_policy() {
|
||||
// Approval policy
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::Never),
|
||||
Some(workspace_write_policy(vec!["/repo"], true)),
|
||||
None,
|
||||
);
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_sandbox_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_read_only_policy()),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_workspace_write_policy()),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_workspace_write_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_ignores_shell() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".into(),
|
||||
bashrc_path: "/home/user/.bashrc".into(),
|
||||
})),
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Zsh(ZshShell {
|
||||
shell_path: "/bin/zsh".into(),
|
||||
zshrc_path: "/home/user/.zshrc".into(),
|
||||
})),
|
||||
);
|
||||
|
||||
assert!(context1.equals_except_shell(&context2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SandboxErr {
|
||||
/// Error from sandbox execution
|
||||
#[error("sandbox denied exec error, exit code: {0}, stdout: {1}, stderr: {2}")]
|
||||
Denied(i32, String, String),
|
||||
#[error(
|
||||
"sandbox denied exec error, exit code: {}, stdout: {}, stderr: {}",
|
||||
.output.exit_code, .output.stdout.text, .output.stderr.text
|
||||
)]
|
||||
Denied { output: Box<ExecToolCallOutput> },
|
||||
|
||||
/// Error from linux seccomp filter setup
|
||||
#[cfg(target_os = "linux")]
|
||||
@@ -26,7 +32,7 @@ pub enum SandboxErr {
|
||||
|
||||
/// Command timed out
|
||||
#[error("command timed out")]
|
||||
Timeout,
|
||||
Timeout { output: Box<ExecToolCallOutput> },
|
||||
|
||||
/// Command was killed by a signal
|
||||
#[error("command was killed by a signal")]
|
||||
@@ -49,7 +55,7 @@ pub enum CodexErr {
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(Uuid),
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
#[error("session configured event was not the first event in the stream")]
|
||||
SessionConfiguredNotFirstEvent,
|
||||
@@ -98,6 +104,10 @@ pub enum CodexErr {
|
||||
#[error("codex-linux-sandbox was required but not provided")]
|
||||
LandlockSandboxExecutableNotProvided,
|
||||
|
||||
/// Provider reported the input exceeds the model's context window.
|
||||
#[error("context_length_exceeded: {0}")]
|
||||
ContextLengthExceeded(String),
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Automatic conversions for common external error types
|
||||
// -----------------------------------------------------------------
|
||||
@@ -127,38 +137,58 @@ pub enum CodexErr {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UsageLimitReachedError {
|
||||
pub plan_type: Option<String>,
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
pub(crate) plan_type: Option<PlanType>,
|
||||
pub(crate) resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UsageLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Base message differs slightly for legacy ChatGPT Plus plan users.
|
||||
if let Some(plan_type) = &self.plan_type
|
||||
&& plan_type == "plus"
|
||||
{
|
||||
write!(
|
||||
f,
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing) or try again"
|
||||
)?;
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " later.")?;
|
||||
let message = match self.plan_type.as_ref() {
|
||||
Some(PlanType::Known(KnownPlan::Plus)) => format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing){}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
|
||||
format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin{}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
write!(f, "You've hit your usage limit.")?;
|
||||
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " Try again in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " Try again later.")?;
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
| Some(PlanType::Known(KnownPlan::Enterprise))
|
||||
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Unknown(_)) | None => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" Try again in {reset_duration}.")
|
||||
} else {
|
||||
" Try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix_after_or(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" or try again in {reset_duration}.")
|
||||
} else {
|
||||
" or try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,9 +253,12 @@ impl CodexErr {
|
||||
|
||||
pub fn get_error_message_ui(e: &CodexErr) -> String {
|
||||
match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied(_, _, stderr)) => stderr.to_string(),
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => output.stderr.text.clone(),
|
||||
// Timeouts are not sandbox errors from a UX perspective; present them plainly
|
||||
CodexErr::Sandbox(SandboxErr::Timeout) => "error: command timed out".to_string(),
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
),
|
||||
_ => e.to_string(),
|
||||
}
|
||||
}
|
||||
@@ -237,7 +270,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_plus_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -246,6 +279,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Free)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_when_none() {
|
||||
let err = UsageLimitReachedError {
|
||||
@@ -258,10 +303,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_team_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Team)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again in 1 hour."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_business_plan_without_reset() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_for_other_plans() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("pro".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -285,7 +354,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_hours_and_minutes() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: Some(3 * 3600 + 32 * 60),
|
||||
};
|
||||
assert_eq!(
|
||||
|
||||
2
codex-rs/core/src/error_codes.rs
Normal file
2
codex-rs/core/src/error_codes.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
/// Known structured error codes returned by model providers.
|
||||
pub const CONTEXT_LENGTH_EXCEEDED: &str = "context_length_exceeded";
|
||||
167
codex-rs/core/src/event_mapping.rs
Normal file
167
codex-rs/core/src/event_mapping.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AgentReasoningEvent;
|
||||
use crate::protocol::AgentReasoningRawContentEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::UserMessageEvent;
|
||||
use crate::protocol::WebSearchEndEvent;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
|
||||
/// Convert a `ResponseItem` into zero or more `EventMsg` values that the UI can render.
|
||||
///
|
||||
/// When `show_raw_agent_reasoning` is false, raw reasoning content events are omitted.
|
||||
pub(crate) fn map_response_item_to_event_messages(
|
||||
item: &ResponseItem,
|
||||
show_raw_agent_reasoning: bool,
|
||||
) -> Vec<EventMsg> {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Do not surface system messages as user events.
|
||||
if role == "system" {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut events: Vec<EventMsg> = Vec::new();
|
||||
let mut message_parts: Vec<String> = Vec::new();
|
||||
let mut images: Vec<String> = Vec::new();
|
||||
let mut kind: Option<InputMessageKind> = None;
|
||||
|
||||
for content_item in content.iter() {
|
||||
match content_item {
|
||||
ContentItem::InputText { text } => {
|
||||
if kind.is_none() {
|
||||
let trimmed = text.trim_start();
|
||||
kind = if trimmed.starts_with("<environment_context>") {
|
||||
Some(InputMessageKind::EnvironmentContext)
|
||||
} else if trimmed.starts_with("<user_instructions>") {
|
||||
Some(InputMessageKind::UserInstructions)
|
||||
} else {
|
||||
Some(InputMessageKind::Plain)
|
||||
};
|
||||
}
|
||||
message_parts.push(text.clone());
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
images.push(image_url.clone());
|
||||
}
|
||||
ContentItem::OutputText { text } => {
|
||||
events.push(EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: text.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !message_parts.is_empty() || !images.is_empty() {
|
||||
let message = if message_parts.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
message_parts.join("")
|
||||
};
|
||||
let images = if images.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(images)
|
||||
};
|
||||
|
||||
events.push(EventMsg::UserMessage(UserMessageEvent {
|
||||
message,
|
||||
kind,
|
||||
images,
|
||||
}));
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::Reasoning {
|
||||
summary, content, ..
|
||||
} => {
|
||||
let mut events = Vec::new();
|
||||
for ReasoningItemReasoningSummary::SummaryText { text } in summary {
|
||||
events.push(EventMsg::AgentReasoning(AgentReasoningEvent {
|
||||
text: text.clone(),
|
||||
}));
|
||||
}
|
||||
if let Some(items) = content.as_ref().filter(|_| show_raw_agent_reasoning) {
|
||||
for c in items {
|
||||
let text = match c {
|
||||
ReasoningItemContent::ReasoningText { text }
|
||||
| ReasoningItemContent::Text { text } => text,
|
||||
};
|
||||
events.push(EventMsg::AgentReasoningRawContent(
|
||||
AgentReasoningRawContentEvent { text: text.clone() },
|
||||
));
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::WebSearchCall { id, action, .. } => match action {
|
||||
WebSearchAction::Search { query } => {
|
||||
let call_id = id.clone().unwrap_or_else(|| "".to_string());
|
||||
vec![EventMsg::WebSearchEnd(WebSearchEndEvent {
|
||||
call_id,
|
||||
query: query.clone(),
|
||||
})]
|
||||
}
|
||||
WebSearchAction::Other => Vec::new(),
|
||||
},
|
||||
|
||||
// Variants that require side effects are handled by higher layers and do not emit events here.
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Other => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn maps_user_message_with_text_and_two_images() {
|
||||
let img1 = "https://example.com/one.png".to_string();
|
||||
let img2 = "https://example.com/two.jpg".to_string();
|
||||
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img1.clone(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img2.clone(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let events = map_response_item_to_event_messages(&item, false);
|
||||
assert_eq!(events.len(), 1, "expected a single user message event");
|
||||
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ use std::os::unix::process::ExitStatusExt;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitStatus;
|
||||
use std::time::Duration;
|
||||
@@ -26,7 +27,6 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use serde_bytes::ByteBuf;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
@@ -35,6 +35,7 @@ const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
const SIGKILL_CODE: i32 = 9;
|
||||
const TIMEOUT_CODE: i32 = 64;
|
||||
const EXIT_CODE_SIGNAL_BASE: i32 = 128; // conventional shell: 128 + signal
|
||||
const EXEC_TIMEOUT_EXIT_CODE: i32 = 124; // conventional timeout exit code
|
||||
|
||||
// I/O buffer sizing
|
||||
const READ_CHUNK_SIZE: usize = 8192; // bytes per read
|
||||
@@ -44,7 +45,7 @@ const AGGREGATE_BUFFER_INITIAL_CAPACITY: usize = 8 * 1024; // 8 KiB
|
||||
/// Aggregation still collects full output; only the live event stream is capped.
|
||||
pub(crate) const MAX_EXEC_OUTPUT_DELTAS_PER_CALL: usize = 10_000;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExecParams {
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
@@ -82,33 +83,41 @@ pub async fn process_exec_tool_call(
|
||||
params: ExecParams,
|
||||
sandbox_type: SandboxType,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_cwd: &Path,
|
||||
codex_linux_sandbox_exe: &Option<PathBuf>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
let timeout_duration = params.timeout_duration();
|
||||
|
||||
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||
{
|
||||
SandboxType::None => exec(params, sandbox_policy, stdout_stream.clone()).await,
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command, cwd, env, ..
|
||||
command,
|
||||
cwd: command_cwd,
|
||||
env,
|
||||
..
|
||||
} = params;
|
||||
let child = spawn_command_under_seatbelt(
|
||||
command,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
cwd,
|
||||
sandbox_cwd,
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, timeout, stdout_stream.clone()).await
|
||||
consume_truncated_output(child, timeout_duration, stdout_stream.clone()).await
|
||||
}
|
||||
SandboxType::LinuxSeccomp => {
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command, cwd, env, ..
|
||||
command,
|
||||
cwd: command_cwd,
|
||||
env,
|
||||
..
|
||||
} = params;
|
||||
|
||||
let codex_linux_sandbox_exe = codex_linux_sandbox_exe
|
||||
@@ -117,48 +126,64 @@ pub async fn process_exec_tool_call(
|
||||
let child = spawn_command_under_linux_sandbox(
|
||||
codex_linux_sandbox_exe,
|
||||
command,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
cwd,
|
||||
sandbox_cwd,
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_truncated_output(child, timeout, stdout_stream).await
|
||||
consume_truncated_output(child, timeout_duration, stdout_stream).await
|
||||
}
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
match raw_output_result {
|
||||
Ok(raw_output) => {
|
||||
let stdout = raw_output.stdout.from_utf8_lossy();
|
||||
let stderr = raw_output.stderr.from_utf8_lossy();
|
||||
#[allow(unused_mut)]
|
||||
let mut timed_out = raw_output.timed_out;
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
match raw_output.exit_status.signal() {
|
||||
Some(TIMEOUT_CODE) => return Err(CodexErr::Sandbox(SandboxErr::Timeout)),
|
||||
Some(signal) => {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Signal(signal)));
|
||||
{
|
||||
if let Some(signal) = raw_output.exit_status.signal() {
|
||||
if signal == TIMEOUT_CODE {
|
||||
timed_out = true;
|
||||
} else {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Signal(signal)));
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
let exit_code = raw_output.exit_status.code().unwrap_or(-1);
|
||||
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied(
|
||||
exit_code,
|
||||
stdout.text,
|
||||
stderr.text,
|
||||
)));
|
||||
let mut exit_code = raw_output.exit_status.code().unwrap_or(-1);
|
||||
if timed_out {
|
||||
exit_code = EXEC_TIMEOUT_EXIT_CODE;
|
||||
}
|
||||
|
||||
Ok(ExecToolCallOutput {
|
||||
let stdout = raw_output.stdout.from_utf8_lossy();
|
||||
let stderr = raw_output.stderr.from_utf8_lossy();
|
||||
let aggregated_output = raw_output.aggregated_output.from_utf8_lossy();
|
||||
let exec_output = ExecToolCallOutput {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output: raw_output.aggregated_output.from_utf8_lossy(),
|
||||
aggregated_output,
|
||||
duration,
|
||||
})
|
||||
timed_out,
|
||||
};
|
||||
|
||||
if timed_out {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Timeout {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
}
|
||||
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(exec_output)
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("exec error: {err}");
|
||||
@@ -198,6 +223,7 @@ struct RawExecToolCallOutput {
|
||||
pub stdout: StreamOutput<Vec<u8>>,
|
||||
pub stderr: StreamOutput<Vec<u8>>,
|
||||
pub aggregated_output: StreamOutput<Vec<u8>>,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl StreamOutput<String> {
|
||||
@@ -230,6 +256,7 @@ pub struct ExecToolCallOutput {
|
||||
pub stderr: StreamOutput<String>,
|
||||
pub aggregated_output: StreamOutput<String>,
|
||||
pub duration: Duration,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
async fn exec(
|
||||
@@ -299,22 +326,24 @@ async fn consume_truncated_output(
|
||||
Some(agg_tx.clone()),
|
||||
));
|
||||
|
||||
let exit_status = tokio::select! {
|
||||
let (exit_status, timed_out) = tokio::select! {
|
||||
result = tokio::time::timeout(timeout, child.wait()) => {
|
||||
match result {
|
||||
Ok(Ok(exit_status)) => exit_status,
|
||||
Ok(e) => e?,
|
||||
Ok(status_result) => {
|
||||
let exit_status = status_result?;
|
||||
(exit_status, false)
|
||||
}
|
||||
Err(_) => {
|
||||
// timeout
|
||||
child.start_kill()?;
|
||||
// Debatable whether `child.wait().await` should be called here.
|
||||
synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE)
|
||||
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
child.start_kill()?;
|
||||
synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + SIGKILL_CODE)
|
||||
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + SIGKILL_CODE), false)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -337,6 +366,7 @@ async fn consume_truncated_output(
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output,
|
||||
timed_out,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -369,7 +399,7 @@ async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk: ByteBuf::from(chunk),
|
||||
chunk,
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
|
||||
@@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession {
|
||||
|
||||
/// JoinHandle for the child wait task.
|
||||
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||
|
||||
/// Tracks whether the underlying process has exited.
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
impl ExecCommandSession {
|
||||
@@ -34,15 +37,21 @@ impl ExecCommandSession {
|
||||
reader_handle: JoinHandle<()>,
|
||||
writer_handle: JoinHandle<()>,
|
||||
wait_handle: JoinHandle<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer: StdMutex::new(Some(killer)),
|
||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||
}
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
) -> (Self, broadcast::Receiver<Vec<u8>>) {
|
||||
let initial_output_rx = output_tx.subscribe();
|
||||
(
|
||||
Self {
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer: StdMutex::new(Some(killer)),
|
||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||
exit_status,
|
||||
},
|
||||
initial_output_rx,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
@@ -52,6 +61,10 @@ impl ExecCommandSession {
|
||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||
self.output_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn has_exited(&self) -> bool {
|
||||
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ExecCommandSession {
|
||||
|
||||
@@ -6,6 +6,7 @@ mod session_manager;
|
||||
|
||||
pub use exec_command_params::ExecCommandParams;
|
||||
pub use exec_command_params::WriteStdinParams;
|
||||
pub(crate) use exec_command_session::ExecCommandSession;
|
||||
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use portable_pty::CommandBuilder;
|
||||
@@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||
use crate::exec_command::session_id::SessionId;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -91,18 +93,16 @@ impl SessionManager {
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||
);
|
||||
|
||||
let (session, mut exit_rx) =
|
||||
create_exec_command_session(params.clone())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"failed to create exec command session for session id {}: {err}",
|
||||
session_id.0
|
||||
)
|
||||
})?;
|
||||
let (session, mut output_rx, mut exit_rx) = create_exec_command_session(params.clone())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"failed to create exec command session for session id {}: {err}",
|
||||
session_id.0
|
||||
)
|
||||
})?;
|
||||
|
||||
// Insert into session map.
|
||||
let mut output_rx = session.output_receiver();
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
|
||||
// Collect output until either timeout expires or process exits.
|
||||
@@ -243,7 +243,11 @@ impl SessionManager {
|
||||
/// Spawn PTY and child process per spawn_exec_command_session logic.
|
||||
async fn create_exec_command_session(
|
||||
params: ExecCommandParams,
|
||||
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
|
||||
) -> anyhow::Result<(
|
||||
ExecCommandSession,
|
||||
tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
oneshot::Receiver<i32>,
|
||||
)> {
|
||||
let ExecCommandParams {
|
||||
cmd,
|
||||
yield_time_ms: _,
|
||||
@@ -277,7 +281,6 @@ async fn create_exec_command_session(
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
|
||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||
|
||||
// Reader task: drain PTY and forward chunks to output channel.
|
||||
let mut reader = pair.master.try_clone_reader()?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
@@ -327,130 +330,28 @@ async fn create_exec_command_session(
|
||||
|
||||
// Keep the child alive until it exits, then signal exit code.
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = exit_status.clone();
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.exit_code() as i32,
|
||||
Err(_) => -1,
|
||||
};
|
||||
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
// Create and store the session with channels.
|
||||
let session = ExecCommandSession::new(
|
||||
let (session, initial_output_rx) = ExecCommandSession::new(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer,
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
);
|
||||
Ok((session, exit_rx))
|
||||
}
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
// No truncation needed
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
// Cannot keep any content; still return a full marker (never truncated).
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
// Helper to truncate a string to a given byte length on a char boundary.
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
||||
// to UTF-8 char boundaries.
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1; // keep the newline so suffix starts on a fresh line
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1; // start after newline
|
||||
}
|
||||
// Fall back to a char boundary at or after start_tail.
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
// Refine marker length and budgets until stable. Marker is never truncated.
|
||||
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
// No room for any content within the cap; return a full, untruncated marker
|
||||
// that reflects the entire truncated content.
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
// Place marker on its own line for symmetry when we keep line boundaries.
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
// Fallback: use last guess to build output.
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
Ok((session, initial_output_rx, exit_rx))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -616,50 +517,4 @@ Output:
|
||||
abc"#;
|
||||
assert_eq!(expected, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
// A long string with no newlines that exceeds the cap.
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
let max_bytes = 16; // force truncation
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
// For very small caps, we return the full, untruncated marker,
|
||||
// even if it exceeds the cap.
|
||||
assert_eq!(out, "…16 tokens truncated…");
|
||||
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
||||
assert_eq!(original, Some(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::GitSha;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use futures::future::join_all;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -10,24 +11,39 @@ use tokio::process::Command;
|
||||
use tokio::time::Duration as TokioDuration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::util::is_inside_git_repo;
|
||||
/// Return `true` if the project folder specified by the `Config` is inside a
|
||||
/// Git repository.
|
||||
///
|
||||
/// The check walks up the directory hierarchy looking for a `.git` file or
|
||||
/// directory (note `.git` can be a file that contains a `gitdir` entry). This
|
||||
/// approach does **not** require the `git` binary or the `git2` crate and is
|
||||
/// therefore fairly lightweight.
|
||||
///
|
||||
/// Note that this does **not** detect *work‑trees* created with
|
||||
/// `git worktree add` where the checkout lives outside the main repository
|
||||
/// directory. If you need Codex to work from such a checkout simply pass the
|
||||
/// `--allow-no-git-exec` CLI flag that disables the repo requirement.
|
||||
pub fn get_git_repo_root(base_dir: &Path) -> Option<PathBuf> {
|
||||
let mut dir = base_dir.to_path_buf();
|
||||
|
||||
loop {
|
||||
if dir.join(".git").exists() {
|
||||
return Some(dir);
|
||||
}
|
||||
|
||||
// Pop one component (go up one directory). `pop` returns false when
|
||||
// we have reached the filesystem root.
|
||||
if !dir.pop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Timeout for git commands to prevent freezing on large repositories
|
||||
const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5);
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitInfo {
|
||||
/// Current commit hash (SHA)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub commit_hash: Option<String>,
|
||||
/// Current branch name
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub branch: Option<String>,
|
||||
/// Repository URL (if available from remote)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repository_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitDiffToRemote {
|
||||
pub sha: GitSha,
|
||||
@@ -92,11 +108,64 @@ pub async fn collect_git_info(cwd: &Path) -> Option<GitInfo> {
|
||||
Some(git_info)
|
||||
}
|
||||
|
||||
/// A minimal commit summary entry used for pickers (subject + timestamp + sha).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CommitLogEntry {
|
||||
pub sha: String,
|
||||
/// Unix timestamp (seconds since epoch) of the commit time (committer time).
|
||||
pub timestamp: i64,
|
||||
/// Single-line subject of the commit message.
|
||||
pub subject: String,
|
||||
}
|
||||
|
||||
/// Return the last `limit` commits reachable from HEAD for the current branch.
|
||||
/// Each entry contains the SHA, commit timestamp (seconds), and subject line.
|
||||
/// Returns an empty vector if not in a git repo or on error/timeout.
|
||||
pub async fn recent_commits(cwd: &Path, limit: usize) -> Vec<CommitLogEntry> {
|
||||
// Ensure we're in a git repo first to avoid noisy errors.
|
||||
let Some(out) = run_git_command_with_timeout(&["rev-parse", "--git-dir"], cwd).await else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !out.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let fmt = "%H%x1f%ct%x1f%s"; // <sha> <US> <commit_time> <US> <subject>
|
||||
let n = limit.max(1).to_string();
|
||||
let Some(log_out) =
|
||||
run_git_command_with_timeout(&["log", "-n", &n, &format!("--pretty=format:{fmt}")], cwd)
|
||||
.await
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !log_out.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&log_out.stdout);
|
||||
let mut entries: Vec<CommitLogEntry> = Vec::new();
|
||||
for line in text.lines() {
|
||||
let mut parts = line.split('\u{001f}');
|
||||
let sha = parts.next().unwrap_or("").trim();
|
||||
let ts_s = parts.next().unwrap_or("").trim();
|
||||
let subject = parts.next().unwrap_or("").trim();
|
||||
if sha.is_empty() || ts_s.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let timestamp = ts_s.parse::<i64>().unwrap_or(0);
|
||||
entries.push(CommitLogEntry {
|
||||
sha: sha.to_string(),
|
||||
timestamp,
|
||||
subject: subject.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
/// Returns the closest git sha to HEAD that is on a remote as well as the diff to that sha.
|
||||
pub async fn git_diff_to_remote(cwd: &Path) -> Option<GitDiffToRemote> {
|
||||
if !is_inside_git_repo(cwd) {
|
||||
return None;
|
||||
}
|
||||
get_git_repo_root(cwd)?;
|
||||
|
||||
let remotes = get_git_remotes(cwd).await?;
|
||||
let branches = branch_ancestry(cwd).await?;
|
||||
@@ -131,7 +200,7 @@ async fn get_git_remotes(cwd: &Path) -> Option<Vec<String>> {
|
||||
let mut remotes: Vec<String> = String::from_utf8(output.stdout)
|
||||
.ok()?
|
||||
.lines()
|
||||
.map(|s| s.to_string())
|
||||
.map(str::to_string)
|
||||
.collect();
|
||||
if let Some(pos) = remotes.iter().position(|r| r == "origin") {
|
||||
let origin = remotes.remove(pos);
|
||||
@@ -188,6 +257,11 @@ async fn get_default_branch(cwd: &Path) -> Option<String> {
|
||||
}
|
||||
|
||||
// No remote-derived default; try common local defaults if they exist
|
||||
get_default_branch_local(cwd).await
|
||||
}
|
||||
|
||||
/// Attempt to determine the repository's default branch name from local branches.
|
||||
async fn get_default_branch_local(cwd: &Path) -> Option<String> {
|
||||
for candidate in ["main", "master"] {
|
||||
if let Some(verify) = run_git_command_with_timeout(
|
||||
&[
|
||||
@@ -403,7 +477,7 @@ async fn diff_against_sha(cwd: &Path, sha: &GitSha) -> Option<String> {
|
||||
let untracked: Vec<String> = String::from_utf8(untracked_output.stdout)
|
||||
.ok()?
|
||||
.lines()
|
||||
.map(|s| s.to_string())
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
@@ -440,7 +514,7 @@ async fn diff_against_sha(cwd: &Path, sha: &GitSha) -> Option<String> {
|
||||
}
|
||||
|
||||
/// Resolve the path that should be used for trust checks. Similar to
|
||||
/// `[utils::is_inside_git_repo]`, but resolves to the root of the main
|
||||
/// `[get_git_repo_root]`, but resolves to the root of the main
|
||||
/// repository. Handles worktrees.
|
||||
pub fn resolve_root_git_project_for_trust(cwd: &Path) -> Option<PathBuf> {
|
||||
let base = if cwd.is_dir() { cwd } else { cwd.parent()? };
|
||||
@@ -471,6 +545,46 @@ pub fn resolve_root_git_project_for_trust(cwd: &Path) -> Option<PathBuf> {
|
||||
git_dir_path.parent().map(Path::to_path_buf)
|
||||
}
|
||||
|
||||
/// Returns a list of local git branches.
|
||||
/// Includes the default branch at the beginning of the list, if it exists.
|
||||
pub async fn local_git_branches(cwd: &Path) -> Vec<String> {
|
||||
let mut branches: Vec<String> = if let Some(out) =
|
||||
run_git_command_with_timeout(&["branch", "--format=%(refname:short)"], cwd).await
|
||||
&& out.status.success()
|
||||
{
|
||||
String::from_utf8_lossy(&out.stdout)
|
||||
.lines()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
branches.sort_unstable();
|
||||
|
||||
if let Some(base) = get_default_branch_local(cwd).await
|
||||
&& let Some(pos) = branches.iter().position(|name| name == &base)
|
||||
{
|
||||
let base_branch = branches.remove(pos);
|
||||
branches.insert(0, base_branch);
|
||||
}
|
||||
|
||||
branches
|
||||
}
|
||||
|
||||
/// Returns the current checked out branch name.
|
||||
pub async fn current_branch_name(cwd: &Path) -> Option<String> {
|
||||
let out = run_git_command_with_timeout(&["branch", "--show-current"], cwd).await?;
|
||||
if !out.status.success() {
|
||||
return None;
|
||||
}
|
||||
String::from_utf8(out.stdout)
|
||||
.ok()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|name| !name.is_empty())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -537,6 +651,80 @@ mod tests {
|
||||
repo_path
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recent_commits_non_git_directory_returns_empty() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let entries = recent_commits(temp_dir.path(), 10).await;
|
||||
assert!(entries.is_empty(), "expected no commits outside a git repo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recent_commits_orders_and_limits() {
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
|
||||
// Make three distinct commits with small delays to ensure ordering by timestamp.
|
||||
fs::write(repo_path.join("file.txt"), "one").unwrap();
|
||||
Command::new("git")
|
||||
.args(["add", "file.txt"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git add");
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "first change"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git commit 1");
|
||||
|
||||
sleep(Duration::from_millis(1100)).await;
|
||||
|
||||
fs::write(repo_path.join("file.txt"), "two").unwrap();
|
||||
Command::new("git")
|
||||
.args(["add", "file.txt"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git add 2");
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "second change"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git commit 2");
|
||||
|
||||
sleep(Duration::from_millis(1100)).await;
|
||||
|
||||
fs::write(repo_path.join("file.txt"), "three").unwrap();
|
||||
Command::new("git")
|
||||
.args(["add", "file.txt"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git add 3");
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "third change"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("git commit 3");
|
||||
|
||||
// Request the latest 3 commits; should be our three changes in reverse time order.
|
||||
let entries = recent_commits(&repo_path, 3).await;
|
||||
assert_eq!(entries.len(), 3);
|
||||
assert_eq!(entries[0].subject, "third change");
|
||||
assert_eq!(entries[1].subject, "second change");
|
||||
assert_eq!(entries[2].subject, "first change");
|
||||
// Basic sanity on SHA formatting
|
||||
for e in entries {
|
||||
assert!(e.sha.len() >= 7 && e.sha.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_test_git_repo_with_remote(temp_dir: &TempDir) -> (PathBuf, String) {
|
||||
let repo_path = create_test_git_repo(temp_dir).await;
|
||||
let remote_path = temp_dir.path().join("remote.git");
|
||||
@@ -788,7 +976,7 @@ mod tests {
|
||||
async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap().to_path_buf();
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&repo_path),
|
||||
@@ -796,10 +984,7 @@ mod tests {
|
||||
);
|
||||
let nested = repo_path.join("sub/dir");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&nested),
|
||||
Some(expected.clone())
|
||||
);
|
||||
assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
76
codex-rs/core/src/internal_storage.rs
Normal file
76
codex-rs/core/src/internal_storage.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub(crate) const INTERNAL_STORAGE_FILE: &str = "internal_storage.json";
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct InternalStorage {
|
||||
#[serde(skip)]
|
||||
storage_path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub gpt_5_codex_model_prompt_seen: bool,
|
||||
}
|
||||
|
||||
// TODO(jif) generalise all the file writers and build proper async channel inserters.
|
||||
impl InternalStorage {
|
||||
pub fn load(codex_home: &Path) -> Self {
|
||||
let storage_path = codex_home.join(INTERNAL_STORAGE_FILE);
|
||||
|
||||
match std::fs::read_to_string(&storage_path) {
|
||||
Ok(serialized) => match serde_json::from_str::<Self>(&serialized) {
|
||||
Ok(mut storage) => {
|
||||
storage.storage_path = storage_path;
|
||||
storage
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to parse internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
if error.kind() == ErrorKind::NotFound {
|
||||
tracing::debug!(
|
||||
"internal storage not found at {}; initializing defaults",
|
||||
storage_path.display()
|
||||
);
|
||||
} else {
|
||||
tracing::warn!("failed to read internal storage: {error:?}");
|
||||
}
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn empty(storage_path: PathBuf) -> Self {
|
||||
Self {
|
||||
storage_path,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> anyhow::Result<()> {
|
||||
let serialized = serde_json::to_string_pretty(self)?;
|
||||
|
||||
if let Some(parent) = self.storage_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!(
|
||||
"failed to create internal storage directory at {}",
|
||||
parent.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&self.storage_path, serialized)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to persist internal storage at {}",
|
||||
self.storage_path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -160,9 +160,10 @@ fn is_valid_sed_n_arg(arg: Option<&str>) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::string::ToString;
|
||||
|
||||
fn vec_str(args: &[&str]) -> Vec<String> {
|
||||
args.iter().map(|s| s.to_string()).collect()
|
||||
args.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -16,21 +16,22 @@ use tokio::process::Child;
|
||||
pub async fn spawn_command_under_linux_sandbox<P>(
|
||||
codex_linux_sandbox_exe: P,
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: PathBuf,
|
||||
sandbox_policy_cwd: &Path,
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let args = create_linux_sandbox_command_args(command, sandbox_policy, &cwd);
|
||||
let args = create_linux_sandbox_command_args(command, sandbox_policy, sandbox_policy_cwd);
|
||||
let arg0 = Some("codex-linux-sandbox");
|
||||
spawn_child_async(
|
||||
codex_linux_sandbox_exe.as_ref().to_path_buf(),
|
||||
args,
|
||||
arg0,
|
||||
cwd,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
stdio_policy,
|
||||
env,
|
||||
@@ -42,10 +43,13 @@ where
|
||||
fn create_linux_sandbox_command_args(
|
||||
command: Vec<String>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: &Path,
|
||||
sandbox_policy_cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
#[expect(clippy::expect_used)]
|
||||
let sandbox_policy_cwd = cwd.to_str().expect("cwd must be valid UTF-8").to_string();
|
||||
let sandbox_policy_cwd = sandbox_policy_cwd
|
||||
.to_str()
|
||||
.expect("cwd must be valid UTF-8")
|
||||
.to_string();
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
let sandbox_policy_json =
|
||||
|
||||
@@ -6,14 +6,17 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod apply_patch;
|
||||
mod bash;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -25,6 +28,7 @@ mod exec_command;
|
||||
pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod internal_storage;
|
||||
mod is_safe_command;
|
||||
pub mod landlock;
|
||||
mod mcp_connection_manager;
|
||||
@@ -32,14 +36,25 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub mod review_format;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
pub use auth::AuthManager;
|
||||
pub use auth::CodexAuth;
|
||||
pub mod default_client;
|
||||
pub mod error_codes;
|
||||
pub mod model_family;
|
||||
mod openai_model_info;
|
||||
mod openai_tools;
|
||||
@@ -53,9 +68,17 @@ pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tool_apply_patch;
|
||||
pub mod turn_diff_tracker;
|
||||
pub mod user_agent;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::RolloutRecorder;
|
||||
pub use rollout::SESSIONS_SUBDIR;
|
||||
pub use rollout::SessionMeta;
|
||||
pub use rollout::find_conversation_path_by_id_str;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use safety::get_platform_sandbox;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
@@ -64,3 +87,17 @@ pub use codex_protocol::protocol;
|
||||
// Re-export protocol config enums to ensure call sites can use the same types
|
||||
// as those in the protocol crate when constructing protocol messages.
|
||||
pub use codex_protocol::config_types as protocol_config_types;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
pub use client_common::ResponseEvent;
|
||||
pub use client_common::ResponseStream;
|
||||
pub use codex::compact::content_items_to_text;
|
||||
pub use codex::compact::is_session_prefix_message;
|
||||
pub use codex_protocol::models::ContentItem;
|
||||
pub use codex_protocol::models::LocalShellAction;
|
||||
pub use codex_protocol::models::LocalShellExecAction;
|
||||
pub use codex_protocol::models::LocalShellStatus;
|
||||
pub use codex_protocol::models::ReasoningItemContent;
|
||||
pub use codex_protocol::models::ResponseItem;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -36,8 +37,11 @@ use crate::config_types::McpServerConfig;
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default timeout for initializing MCP server & initially listing tools.
|
||||
const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Default timeout for individual tool calls.
|
||||
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
/// Map that holds a startup error for every MCP server that could **not** be
|
||||
/// spawned successfully.
|
||||
@@ -81,6 +85,12 @@ struct ToolInfo {
|
||||
tool: Tool,
|
||||
}
|
||||
|
||||
struct ManagedClient {
|
||||
client: Arc<McpClient>,
|
||||
startup_timeout: Duration,
|
||||
tool_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
@@ -88,7 +98,7 @@ pub(crate) struct McpConnectionManager {
|
||||
///
|
||||
/// The server name originates from the keys of the `mcp_servers` map in
|
||||
/// the user configuration.
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
clients: HashMap<String, ManagedClient>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
@@ -126,8 +136,14 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { command, args, env } = cfg;
|
||||
let McpServerConfig {
|
||||
command, args, env, ..
|
||||
} = cfg;
|
||||
let client_res = McpClient::new_stdio_client(
|
||||
command.into(),
|
||||
args.into_iter().map(OsString::from).collect(),
|
||||
@@ -150,33 +166,52 @@ impl McpConnectionManager {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
let initialize_notification_params = None;
|
||||
let timeout = Some(Duration::from_secs(10));
|
||||
match client
|
||||
.initialize(params, initialize_notification_params, timeout)
|
||||
.await
|
||||
{
|
||||
Ok(_response) => (server_name, Ok(client)),
|
||||
Err(e) => (server_name, Err(e)),
|
||||
}
|
||||
let init_result = client
|
||||
.initialize(
|
||||
params,
|
||||
initialize_notification_params,
|
||||
Some(startup_timeout),
|
||||
)
|
||||
.await;
|
||||
(
|
||||
(server_name, tool_timeout),
|
||||
init_result.map(|_| (client, startup_timeout)),
|
||||
)
|
||||
}
|
||||
Err(e) => (server_name, Err(e.into())),
|
||||
Err(e) => ((server_name, tool_timeout), Err(e.into())),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
|
||||
HashMap::with_capacity(join_set.len());
|
||||
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
let (server_name, client_res) = res?; // JoinError propagation
|
||||
let ((server_name, tool_timeout), client_res) = match res {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
warn!("Task panic when starting MCP server: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match client_res {
|
||||
Ok(client) => {
|
||||
clients.insert(server_name, std::sync::Arc::new(client));
|
||||
Ok((client, startup_timeout)) => {
|
||||
clients.insert(
|
||||
server_name,
|
||||
ManagedClient {
|
||||
client: Arc::new(client),
|
||||
startup_timeout,
|
||||
tool_timeout: Some(tool_timeout),
|
||||
},
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(server_name, e);
|
||||
@@ -184,7 +219,13 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
let all_tools = match list_all_tools(&clients).await {
|
||||
Ok(tools) => tools,
|
||||
Err(e) => {
|
||||
warn!("Failed to list tools from some MCP servers: {e:#}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
|
||||
@@ -206,13 +247,13 @@ impl McpConnectionManager {
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::CallToolResult> {
|
||||
let client = self
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.clone();
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.call_tool(tool.to_string(), arguments, timeout)
|
||||
@@ -229,21 +270,18 @@ impl McpConnectionManager {
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
// keeps the overall latency roughly at the slowest server instead of
|
||||
// the cumulative latency.
|
||||
for (server_name, client) in clients {
|
||||
for (server_name, managed_client) in clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = client.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let startup_timeout = managed_client.startup_timeout;
|
||||
join_set.spawn(async move {
|
||||
let res = client_clone
|
||||
.list_tools(None, Some(LIST_TOOLS_TIMEOUT))
|
||||
.await;
|
||||
let res = client_clone.list_tools(None, Some(startup_timeout)).await;
|
||||
(server_name_cloned, res)
|
||||
});
|
||||
}
|
||||
@@ -251,8 +289,19 @@ async fn list_all_tools(
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
let (server_name, list_result) = if let Ok(result) = join_res {
|
||||
result
|
||||
} else {
|
||||
warn!("Task panic when listing tools for MCP server: {join_res:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let list_result = if let Ok(result) = list_result {
|
||||
result
|
||||
} else {
|
||||
warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use tracing::error;
|
||||
@@ -21,7 +20,6 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
server: String,
|
||||
tool_name: String,
|
||||
arguments: String,
|
||||
timeout: Option<Duration>,
|
||||
) -> ResponseInputItem {
|
||||
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
|
||||
// is not.
|
||||
@@ -58,7 +56,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
let start = Instant::now();
|
||||
// Perform the tool call.
|
||||
let result = sess
|
||||
.call_tool(&server, &tool_name, arguments_value.clone(), timeout)
|
||||
.call_tool(&server, &tool_name, arguments_value.clone())
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e}"));
|
||||
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
//! JSON-Lines tooling. Each record has the following schema:
|
||||
//!
|
||||
//! ````text
|
||||
//! {"session_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! {"conversation_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! ````
|
||||
//!
|
||||
//! To minimise the chance of interleaved writes when multiple processes are
|
||||
@@ -22,14 +22,15 @@ use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config_types::HistoryPersistence;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
#[cfg(unix)]
|
||||
@@ -54,10 +55,14 @@ fn history_filepath(config: &Config) -> PathBuf {
|
||||
path
|
||||
}
|
||||
|
||||
/// Append a `text` entry associated with `session_id` to the history file. Uses
|
||||
/// Append a `text` entry associated with `conversation_id` to the history file. Uses
|
||||
/// advisory file locking to ensure that concurrent writes do not interleave,
|
||||
/// which entails a small amount of blocking I/O internally.
|
||||
pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config) -> Result<()> {
|
||||
pub(crate) async fn append_entry(
|
||||
text: &str,
|
||||
conversation_id: &ConversationId,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
match config.history.persistence {
|
||||
HistoryPersistence::SaveAll => {
|
||||
// Save everything: proceed.
|
||||
@@ -84,7 +89,7 @@ pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config)
|
||||
|
||||
// Construct the JSON line first so we can write it in a single syscall.
|
||||
let entry = HistoryEntry {
|
||||
session_id: session_id.to_string(),
|
||||
session_id: conversation_id.to_string(),
|
||||
ts,
|
||||
text: text.to_string(),
|
||||
};
|
||||
@@ -105,47 +110,34 @@ pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config)
|
||||
// Ensure permissions.
|
||||
ensure_owner_only_permissions(&history_file).await?;
|
||||
|
||||
// Lock file.
|
||||
acquire_exclusive_lock_with_retry(&history_file).await?;
|
||||
|
||||
// We use sync I/O with spawn_blocking() because we are using a
|
||||
// [`std::fs::File`] instead of a [`tokio::fs::File`] to leverage an
|
||||
// advisory file locking API that is not available in the async API.
|
||||
// Perform a blocking write under an advisory write lock using std::fs.
|
||||
tokio::task::spawn_blocking(move || -> Result<()> {
|
||||
history_file.write_all(line.as_bytes())?;
|
||||
history_file.flush()?;
|
||||
Ok(())
|
||||
// Retry a few times to avoid indefinite blocking when contended.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match history_file.try_lock() {
|
||||
Ok(()) => {
|
||||
// While holding the exclusive lock, write the full line.
|
||||
history_file.write_all(line.as_bytes())?;
|
||||
history_file.flush()?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire exclusive lock on history file after multiple attempts",
|
||||
))
|
||||
})
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Attempt to acquire an exclusive advisory lock on `file`, retrying up to 10
|
||||
/// times if the lock is currently held by another process. This prevents a
|
||||
/// potential indefinite wait while still giving other writers some time to
|
||||
/// finish their operation.
|
||||
async fn acquire_exclusive_lock_with_retry(file: &File) -> Result<()> {
|
||||
use tokio::time::sleep;
|
||||
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match file.try_lock() {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => match e {
|
||||
std::fs::TryLockError::WouldBlock => {
|
||||
sleep(RETRY_SLEEP).await;
|
||||
}
|
||||
other => return Err(other.into()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire exclusive lock on history file after multiple attempts",
|
||||
))
|
||||
}
|
||||
|
||||
/// Asynchronously fetch the history file's *identifier* (inode on Unix) and
|
||||
/// the current number of entries by counting newline characters.
|
||||
pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
|
||||
@@ -221,29 +213,42 @@ pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<Hist
|
||||
return None;
|
||||
}
|
||||
|
||||
// Open & lock file for reading.
|
||||
if let Err(e) = acquire_shared_lock_with_retry(&file) {
|
||||
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
|
||||
return None;
|
||||
}
|
||||
// Open & lock file for reading using a shared lock.
|
||||
// Retry a few times to avoid indefinite blocking.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
let lock_result = file.try_lock_shared();
|
||||
|
||||
let reader = BufReader::new(&file);
|
||||
for (idx, line_res) in reader.lines().enumerate() {
|
||||
let line = match line_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read line from history file");
|
||||
match lock_result {
|
||||
Ok(()) => {
|
||||
let reader = BufReader::new(&file);
|
||||
for (idx, line_res) in reader.lines().enumerate() {
|
||||
let line = match line_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read line from history file");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if idx == offset {
|
||||
match serde_json::from_str::<HistoryEntry>(&line) {
|
||||
Ok(entry) => return Some(entry),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to parse history entry");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not found at requested offset.
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if idx == offset {
|
||||
match serde_json::from_str::<HistoryEntry>(&line) {
|
||||
Ok(entry) => return Some(entry),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to parse history entry");
|
||||
return None;
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,26 +263,6 @@ pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<Hist
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn acquire_shared_lock_with_retry(file: &File) -> Result<()> {
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match file.try_lock_shared() {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => match e {
|
||||
std::fs::TryLockError::WouldBlock => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
other => return Err(other.into()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire shared lock on history file after multiple attempts",
|
||||
))
|
||||
}
|
||||
|
||||
/// On Unix systems ensure the file permissions are `0o600` (rw-------). If the
|
||||
/// permissions cannot be changed the error is propagated to the caller.
|
||||
#[cfg(unix)]
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
|
||||
/// The `instructions` field in the payload sent to a model should always start
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
const GPT_5_CODEX_INSTRUCTIONS: &str = include_str!("../gpt_5_codex_prompt.md");
|
||||
|
||||
/// A model family is a group of models that share certain characteristics.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ModelFamily {
|
||||
@@ -20,6 +26,9 @@ pub struct ModelFamily {
|
||||
// `summary` is optional).
|
||||
pub supports_reasoning_summaries: bool,
|
||||
|
||||
// Define if we need a special handling of reasoning summary
|
||||
pub reasoning_summary_format: ReasoningSummaryFormat,
|
||||
|
||||
// This should be set to true when the model expects a tool named
|
||||
// "local_shell" to be provided. Its contract must be understood natively by
|
||||
// the model such that its description can be omitted.
|
||||
@@ -29,6 +38,9 @@ pub struct ModelFamily {
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
|
||||
// Instructions to use for querying the model
|
||||
pub base_instructions: String,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
@@ -41,8 +53,10 @@ macro_rules! model_family {
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
@@ -52,21 +66,6 @@ macro_rules! model_family {
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! simple_model_family {
|
||||
(
|
||||
$slug:expr, $family:expr
|
||||
) => {{
|
||||
Some(ModelFamily {
|
||||
slug: $slug.to_string(),
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
@@ -74,40 +73,59 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("o4-mini") {
|
||||
model_family!(
|
||||
slug, "o4-mini",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("codex-mini-latest") {
|
||||
model_family!(
|
||||
slug, "codex-mini-latest",
|
||||
supports_reasoning_summaries: true,
|
||||
uses_local_shell_tool: true,
|
||||
)
|
||||
} else if slug.starts_with("codex-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4.1") {
|
||||
model_family!(
|
||||
slug, "gpt-4.1",
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-oss") {
|
||||
} else if slug.starts_with("gpt-oss") || slug.starts_with("openai/gpt-oss") {
|
||||
model_family!(slug, "gpt-oss", apply_patch_tool_type: Some(ApplyPatchToolType::Function))
|
||||
} else if slug.starts_with("gpt-4o") {
|
||||
simple_model_family!(slug, "gpt-4o")
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
simple_model_family!(slug, "gpt-3.5")
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
slug, "gpt-5",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
ModelFamily {
|
||||
slug: model.to_string(),
|
||||
family: model.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use crate::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
@@ -80,7 +80,10 @@ pub struct ModelProviderInfo {
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
@@ -159,6 +162,21 @@ impl ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire_api != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.base_url
|
||||
.as_ref()
|
||||
.map(|base| matches_azure_responses_base_url(base))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||
/// onto an existing `reqwest::RequestBuilder` and return the updated
|
||||
/// builder.
|
||||
@@ -326,6 +344,18 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
let base = base_url.to_ascii_lowercase();
|
||||
const AZURE_MARKERS: [&str; 5] = [
|
||||
"openai.azure.",
|
||||
"cognitiveservices.azure.",
|
||||
"aoai.azure.",
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
];
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -416,4 +446,69 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_azure_responses_base_urls() {
|
||||
fn provider_for(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
let positive_cases = [
|
||||
"https://foo.openai.azure.com/openai",
|
||||
"https://foo.openai.azure.us/openai/deployments/bar",
|
||||
"https://foo.cognitiveservices.azure.cn/openai",
|
||||
"https://foo.aoai.azure.com/openai",
|
||||
"https://foo.openai.azure-api.net/openai",
|
||||
"https://foo.z01.azurefd.net/",
|
||||
];
|
||||
for base_url in positive_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} to be detected as Azure"
|
||||
);
|
||||
}
|
||||
|
||||
let named_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
assert!(named_provider.is_azure_responses_endpoint());
|
||||
|
||||
let negative_cases = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://example.com/openai",
|
||||
"https://myproxy.azurewebsites.net/openai",
|
||||
];
|
||||
for base_url in negative_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
!provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} not to be detected as Azure"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,19 @@ pub(crate) struct ModelInfo {
|
||||
|
||||
/// Maximum number of output tokens that can be generated for the model.
|
||||
pub(crate) max_output_tokens: u64,
|
||||
|
||||
/// Token threshold where we should automatically compact conversation history.
|
||||
pub(crate) auto_compact_token_limit: Option<i64>,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
const fn new(context_window: u64, max_output_tokens: u64) -> Self {
|
||||
Self {
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
auto_compact_token_limit: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
@@ -20,73 +33,37 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
// OSS models have a 128k shared token pool.
|
||||
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
|
||||
// https://openai.com/index/gpt-oss-model-card/
|
||||
"gpt-oss-20b" => Some(ModelInfo {
|
||||
context_window: 96_000,
|
||||
max_output_tokens: 32_000,
|
||||
}),
|
||||
"gpt-oss-120b" => Some(ModelInfo {
|
||||
context_window: 96_000,
|
||||
max_output_tokens: 32_000,
|
||||
}),
|
||||
"gpt-oss-20b" => Some(ModelInfo::new(96_000, 32_000)),
|
||||
"gpt-oss-120b" => Some(ModelInfo::new(96_000, 32_000)),
|
||||
// https://platform.openai.com/docs/models/o3
|
||||
"o3" => Some(ModelInfo {
|
||||
context_window: 200_000,
|
||||
max_output_tokens: 100_000,
|
||||
}),
|
||||
"o3" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// https://platform.openai.com/docs/models/o4-mini
|
||||
"o4-mini" => Some(ModelInfo {
|
||||
context_window: 200_000,
|
||||
max_output_tokens: 100_000,
|
||||
}),
|
||||
"o4-mini" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// https://platform.openai.com/docs/models/codex-mini-latest
|
||||
"codex-mini-latest" => Some(ModelInfo {
|
||||
context_window: 200_000,
|
||||
max_output_tokens: 100_000,
|
||||
}),
|
||||
"codex-mini-latest" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// As of Jun 25, 2025, gpt-4.1 defaults to gpt-4.1-2025-04-14.
|
||||
// https://platform.openai.com/docs/models/gpt-4.1
|
||||
"gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo {
|
||||
context_window: 1_047_576,
|
||||
max_output_tokens: 32_768,
|
||||
}),
|
||||
"gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo::new(1_047_576, 32_768)),
|
||||
|
||||
// As of Jun 25, 2025, gpt-4o defaults to gpt-4o-2024-08-06.
|
||||
// https://platform.openai.com/docs/models/gpt-4o
|
||||
"gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo {
|
||||
context_window: 128_000,
|
||||
max_output_tokens: 16_384,
|
||||
}),
|
||||
"gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo::new(128_000, 16_384)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-05-13
|
||||
"gpt-4o-2024-05-13" => Some(ModelInfo {
|
||||
context_window: 128_000,
|
||||
max_output_tokens: 4_096,
|
||||
}),
|
||||
"gpt-4o-2024-05-13" => Some(ModelInfo::new(128_000, 4_096)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-11-20
|
||||
"gpt-4o-2024-11-20" => Some(ModelInfo {
|
||||
context_window: 128_000,
|
||||
max_output_tokens: 16_384,
|
||||
}),
|
||||
"gpt-4o-2024-11-20" => Some(ModelInfo::new(128_000, 16_384)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-3.5-turbo
|
||||
"gpt-3.5-turbo" => Some(ModelInfo {
|
||||
context_window: 16_385,
|
||||
max_output_tokens: 4_096,
|
||||
}),
|
||||
"gpt-3.5-turbo" => Some(ModelInfo::new(16_385, 4_096)),
|
||||
|
||||
"gpt-5" => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo::new(272_000, 128_000)),
|
||||
|
||||
_ if slug.starts_with("codex-") => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
_ if slug.starts_with("codex-") => Some(ModelInfo::new(272_000, 128_000)),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
use crate::tool_apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tool_apply_patch::create_apply_patch_json_tool;
|
||||
@@ -49,7 +47,7 @@ pub(crate) enum OpenAiTool {
|
||||
LocalShell {},
|
||||
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||
#[serde(rename = "web_search_preview")]
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
@@ -57,10 +55,9 @@ pub(crate) enum OpenAiTool {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
LocalShell,
|
||||
StreamableShell,
|
||||
Default,
|
||||
Local,
|
||||
Streamable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -70,43 +67,37 @@ pub(crate) struct ToolsConfig {
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) model_family: &'a ModelFamily,
|
||||
pub(crate) approval_policy: AskForApproval,
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) include_plan_tool: bool,
|
||||
pub(crate) include_apply_patch_tool: bool,
|
||||
pub(crate) include_web_search_request: bool,
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
pub fn new(params: &ToolsConfigParams) -> Self {
|
||||
let ToolsConfigParams {
|
||||
model_family,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_web_search_request,
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
} = params;
|
||||
let mut shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::StreamableShell
|
||||
let shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::Streamable
|
||||
} else if model_family.uses_local_shell_tool {
|
||||
ConfigShellToolType::LocalShell
|
||||
ConfigShellToolType::Local
|
||||
} else {
|
||||
ConfigShellToolType::DefaultShell
|
||||
ConfigShellToolType::Default
|
||||
};
|
||||
if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
|
||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
let apply_patch_tool_type = match model_family.apply_patch_tool_type {
|
||||
Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform),
|
||||
@@ -126,6 +117,7 @@ impl ToolsConfig {
|
||||
apply_patch_tool_type,
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,6 +158,53 @@ pub(crate) enum JsonSchema {
|
||||
},
|
||||
}
|
||||
|
||||
fn create_unified_exec_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some(
|
||||
"When no session_id is provided, treat the array as the command and arguments \
|
||||
to launch. When session_id is set, concatenate the strings (in order) and write \
|
||||
them to the session's stdin."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"session_id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier for an existing interactive session. If omitted, a new command \
|
||||
is spawned."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum time in milliseconds to wait for output after writing the input."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "unified_exec".to_string(),
|
||||
description:
|
||||
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["input".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_shell_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
@@ -188,108 +227,22 @@ fn create_shell_tool() -> OpenAiTool {
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
description: "Runs a shell command and returns its output".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["command".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some("The command to execute".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) {
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files outside the current directory, and protected folders like .git or .env{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
description,
|
||||
description: "Runs a shell command and returns its output.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
@@ -332,7 +285,7 @@ pub(crate) struct ApplyPatchToolArgs {
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
pub fn create_tools_json_for_responses_api(
|
||||
tools: &Vec<OpenAiTool>,
|
||||
tools: &[OpenAiTool],
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
let mut tools_json = Vec::new();
|
||||
|
||||
@@ -347,7 +300,7 @@ pub fn create_tools_json_for_responses_api(
|
||||
/// Chat Completions API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
|
||||
pub(crate) fn create_tools_json_for_chat_completions_api(
|
||||
tools: &Vec<OpenAiTool>,
|
||||
tools: &[OpenAiTool],
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// We start with the JSON for the Responses API and than rewrite it to match
|
||||
// the chat completions tool call format.
|
||||
@@ -447,10 +400,7 @@ fn sanitize_json_schema(value: &mut JsonValue) {
|
||||
}
|
||||
|
||||
// Normalize/ensure type
|
||||
let mut ty = map
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let mut ty = map.get("type").and_then(|v| v.as_str()).map(str::to_string);
|
||||
|
||||
// If type is an array (union), pick first supported; else leave to inference
|
||||
if ty.is_none()
|
||||
@@ -532,23 +482,24 @@ pub(crate) fn get_openai_tools(
|
||||
) -> Vec<OpenAiTool> {
|
||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
if config.experimental_unified_exec_tool {
|
||||
tools.push(create_unified_exec_tool());
|
||||
} else {
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::Default => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::Local => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::Streamable => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -575,10 +526,8 @@ pub(crate) fn get_openai_tools(
|
||||
if config.include_view_image_tool {
|
||||
tools.push(create_view_image_tool());
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
// Ensure deterministic ordering to maximize prompt cache hits.
|
||||
// HashMap iteration order is non-deterministic, so sort by fully-qualified tool name.
|
||||
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
@@ -633,19 +582,18 @@ mod tests {
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["local_shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -654,19 +602,18 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -675,13 +622,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
@@ -724,7 +670,7 @@ mod tests {
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -780,13 +726,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -839,11 +784,11 @@ mod tests {
|
||||
]);
|
||||
|
||||
let tools = get_openai_tools(&config, Some(tools_map));
|
||||
// Expect shell first, followed by MCP tools sorted by fully-qualified name.
|
||||
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -857,13 +802,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -891,7 +835,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/search"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/search"],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -919,13 +863,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -951,7 +894,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/paginate"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/paginate"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
@@ -976,13 +919,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1006,7 +948,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/tags"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1033,13 +978,12 @@ mod tests {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1063,7 +1007,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/value"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1081,4 +1028,19 @@ mod tests {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool() {
|
||||
let tool = super::create_shell_tool();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
else {
|
||||
panic!("expected function tool");
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = "Runs a shell command and returns its output.";
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,22 +20,6 @@ pub enum ParsedCommand {
|
||||
query: Option<String>,
|
||||
path: Option<String>,
|
||||
},
|
||||
Format {
|
||||
cmd: String,
|
||||
tool: Option<String>,
|
||||
targets: Option<Vec<String>>,
|
||||
},
|
||||
Test {
|
||||
cmd: String,
|
||||
},
|
||||
Lint {
|
||||
cmd: String,
|
||||
tool: Option<String>,
|
||||
targets: Option<Vec<String>>,
|
||||
},
|
||||
Noop {
|
||||
cmd: String,
|
||||
},
|
||||
Unknown {
|
||||
cmd: String,
|
||||
},
|
||||
@@ -50,17 +34,13 @@ impl From<ParsedCommand> for codex_protocol::parse_command::ParsedCommand {
|
||||
ParsedCommand::Read { cmd, name } => P::Read { cmd, name },
|
||||
ParsedCommand::ListFiles { cmd, path } => P::ListFiles { cmd, path },
|
||||
ParsedCommand::Search { cmd, query, path } => P::Search { cmd, query, path },
|
||||
ParsedCommand::Format { cmd, tool, targets } => P::Format { cmd, tool, targets },
|
||||
ParsedCommand::Test { cmd } => P::Test { cmd },
|
||||
ParsedCommand::Lint { cmd, tool, targets } => P::Lint { cmd, tool, targets },
|
||||
ParsedCommand::Noop { cmd } => P::Noop { cmd },
|
||||
ParsedCommand::Unknown { cmd } => P::Unknown { cmd },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn shlex_join(tokens: &[String]) -> String {
|
||||
shlex_try_join(tokens.iter().map(|s| s.as_str()))
|
||||
shlex_try_join(tokens.iter().map(String::as_str))
|
||||
.unwrap_or_else(|_| "<command included NUL byte>".to_string())
|
||||
}
|
||||
|
||||
@@ -92,13 +72,14 @@ pub fn parse_command(command: &[String]) -> Vec<ParsedCommand> {
|
||||
/// Tests are at the top to encourage using TDD + Codex to fix the implementation.
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::string::ToString;
|
||||
|
||||
fn shlex_split_safe(s: &str) -> Vec<String> {
|
||||
shlex_split(s).unwrap_or_else(|| s.split_whitespace().map(|s| s.to_string()).collect())
|
||||
shlex_split(s).unwrap_or_else(|| s.split_whitespace().map(ToString::to_string).collect())
|
||||
}
|
||||
|
||||
fn vec_str(args: &[&str]) -> Vec<String> {
|
||||
args.iter().map(|s| s.to_string()).collect()
|
||||
args.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
fn assert_parsed(args: &[String], expected: Vec<ParsedCommand>) {
|
||||
@@ -122,7 +103,7 @@ mod tests {
|
||||
assert_parsed(
|
||||
&vec_str(&["bash", "-lc", inner]),
|
||||
vec![ParsedCommand::Unknown {
|
||||
cmd: "git status | wc -l".to_string(),
|
||||
cmd: "git status".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -244,6 +225,39 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cd_then_cat_is_single_read() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cd foo && cat foo.txt"),
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_cd_then_bar_is_same_as_bar() {
|
||||
// Ensure a leading `cd` inside bash -lc is dropped when followed by another command.
|
||||
assert_parsed(
|
||||
&shlex_split_safe("bash -lc 'cd foo && bar'"),
|
||||
vec![ParsedCommand::Unknown {
|
||||
cmd: "bar".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_cd_then_cat_is_read() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("bash -lc 'cd foo && cat foo.txt'"),
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_ls_with_pipe() {
|
||||
let inner = "ls -la | sed -n '1,120p'";
|
||||
@@ -315,27 +329,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_npm_run_with_forwarded_args() {
|
||||
assert_parsed(
|
||||
&vec_str(&[
|
||||
"npm",
|
||||
"run",
|
||||
"lint",
|
||||
"--",
|
||||
"--max-warnings",
|
||||
"0",
|
||||
"--format",
|
||||
"json",
|
||||
]),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "npm run lint -- --max-warnings 0 --format json".to_string(),
|
||||
tool: Some("npm-script:lint".to_string()),
|
||||
targets: None,
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_grep_recursive_current_dir() {
|
||||
assert_parsed(
|
||||
@@ -396,173 +389,10 @@ mod tests {
|
||||
fn supports_cd_and_rg_files() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cd codex-rs && rg --files"),
|
||||
vec![
|
||||
ParsedCommand::Unknown {
|
||||
cmd: "cd codex-rs".to_string(),
|
||||
},
|
||||
ParsedCommand::Search {
|
||||
cmd: "rg --files".to_string(),
|
||||
query: None,
|
||||
path: None,
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn echo_then_cargo_test_sequence() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("echo Running tests... && cargo test --all-features --quiet"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "cargo test --all-features --quiet".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_cargo_fmt_and_test_with_config() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe(
|
||||
"cargo fmt -- --config imports_granularity=Item && cargo test -p core --all-features",
|
||||
),
|
||||
vec![
|
||||
ParsedCommand::Format {
|
||||
cmd: "cargo fmt -- --config 'imports_granularity=Item'".to_string(),
|
||||
tool: Some("cargo fmt".to_string()),
|
||||
targets: None,
|
||||
},
|
||||
ParsedCommand::Test {
|
||||
cmd: "cargo test -p core --all-features".to_string(),
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_rustfmt_and_clippy() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("rustfmt src/main.rs"),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "rustfmt src/main.rs".to_string(),
|
||||
tool: Some("rustfmt".to_string()),
|
||||
targets: Some(vec!["src/main.rs".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cargo clippy -p core --all-features -- -D warnings"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "cargo clippy -p core --all-features -- -D warnings".to_string(),
|
||||
tool: Some("cargo clippy".to_string()),
|
||||
targets: None,
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_pytest_go_and_tools() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe(
|
||||
"pytest -k 'Login and not slow' tests/test_login.py::TestLogin::test_ok",
|
||||
),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "pytest -k 'Login and not slow' tests/test_login.py::TestLogin::test_ok"
|
||||
.to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("go fmt ./..."),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "go fmt ./...".to_string(),
|
||||
tool: Some("go fmt".to_string()),
|
||||
targets: Some(vec!["./...".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("go test ./pkg -run TestThing"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "go test ./pkg -run TestThing".to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("eslint . --max-warnings 0"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "eslint . --max-warnings 0".to_string(),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets: Some(vec![".".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("prettier -w ."),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "prettier -w .".to_string(),
|
||||
tool: Some("prettier".to_string()),
|
||||
targets: Some(vec![".".to_string()]),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_jest_and_vitest_filters() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("jest -t 'should work' src/foo.test.ts"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "jest -t 'should work' src/foo.test.ts".to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("vitest -t 'runs' src/foo.test.tsx"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "vitest -t runs src/foo.test.tsx".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_npx_and_scripts() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("npx eslint src"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "npx eslint src".to_string(),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets: Some(vec!["src".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("npx prettier -c ."),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "npx prettier -c .".to_string(),
|
||||
tool: Some("prettier".to_string()),
|
||||
targets: Some(vec![".".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("pnpm run lint -- --max-warnings 0"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "pnpm run lint -- --max-warnings 0".to_string(),
|
||||
tool: Some("pnpm-script:lint".to_string()),
|
||||
targets: None,
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("npm test"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "npm test".to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
assert_parsed(
|
||||
&shlex_split_safe("yarn test"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "yarn test".to_string(),
|
||||
vec![ParsedCommand::Search {
|
||||
cmd: "rg --files".to_string(),
|
||||
query: None,
|
||||
path: None,
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -770,6 +600,51 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_mixed_sequence_with_pipes_semicolons_and_or() {
|
||||
// Provided long command sequence combining sequencing, pipelines, and ORs.
|
||||
let inner = "pwd; ls -la; rg --files -g '!target' | wc -l; rg -n '^\\[workspace\\]' -n Cargo.toml || true; rg -n '^\\[package\\]' -n */Cargo.toml || true; cargo --version; rustc --version; cargo clippy --workspace --all-targets --all-features -q";
|
||||
let args = vec_str(&["bash", "-lc", inner]);
|
||||
|
||||
let expected = vec![
|
||||
ParsedCommand::Unknown {
|
||||
cmd: "pwd".to_string(),
|
||||
},
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: shlex_join(&shlex_split_safe("ls -la")),
|
||||
path: None,
|
||||
},
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&shlex_split_safe("rg --files -g '!target'")),
|
||||
query: None,
|
||||
path: Some("!target".to_string()),
|
||||
},
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&shlex_split_safe("rg -n '^\\[workspace\\]' -n Cargo.toml")),
|
||||
query: Some("^\\[workspace\\]".to_string()),
|
||||
path: Some("Cargo.toml".to_string()),
|
||||
},
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&shlex_split_safe("rg -n '^\\[package\\]' -n */Cargo.toml")),
|
||||
query: Some("^\\[package\\]".to_string()),
|
||||
path: Some("Cargo.toml".to_string()),
|
||||
},
|
||||
ParsedCommand::Unknown {
|
||||
cmd: shlex_join(&shlex_split_safe("cargo --version")),
|
||||
},
|
||||
ParsedCommand::Unknown {
|
||||
cmd: shlex_join(&shlex_split_safe("rustc --version")),
|
||||
},
|
||||
ParsedCommand::Unknown {
|
||||
cmd: shlex_join(&shlex_split_safe(
|
||||
"cargo clippy --workspace --all-targets --all-features -q",
|
||||
)),
|
||||
},
|
||||
];
|
||||
|
||||
assert_parsed(&args, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_true_in_sequence() {
|
||||
// `true` should be dropped from parsed sequences
|
||||
@@ -867,159 +742,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pnpm_test_is_parsed_as_test() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("pnpm test"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "pnpm test".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pnpm_exec_vitest_is_unknown() {
|
||||
// From commands_combined: cd codex-cli && pnpm exec vitest run tests/... --threads=false --passWithNoTests
|
||||
let inner = "cd codex-cli && pnpm exec vitest run tests/file-tag-utils.test.ts --threads=false --passWithNoTests";
|
||||
assert_parsed(
|
||||
&shlex_split_safe(inner),
|
||||
vec![
|
||||
ParsedCommand::Unknown {
|
||||
cmd: "cd codex-cli".to_string(),
|
||||
},
|
||||
ParsedCommand::Unknown {
|
||||
cmd: "pnpm exec vitest run tests/file-tag-utils.test.ts '--threads=false' --passWithNoTests".to_string(),
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cargo_test_with_crate() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cargo test -p codex-core parse_command::"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "cargo test -p codex-core parse_command::".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cargo_test_with_crate_2() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe(
|
||||
"cd core && cargo test -q parse_command::tests::bash_dash_c_pipeline_parsing parse_command::tests::fd_file_finder_variants",
|
||||
),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "cargo test -q parse_command::tests::bash_dash_c_pipeline_parsing parse_command::tests::fd_file_finder_variants".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cargo_test_with_crate_3() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cd core && cargo test -q parse_command::tests"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "cargo test -q parse_command::tests".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cargo_test_with_crate_4() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("cd core && cargo test --all-features parse_command -- --nocapture"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "cargo test --all-features parse_command -- --nocapture".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
// Additional coverage for other common tools/frameworks
|
||||
#[test]
|
||||
fn recognizes_black_and_ruff() {
|
||||
// black formats Python code
|
||||
assert_parsed(
|
||||
&shlex_split_safe("black src"),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "black src".to_string(),
|
||||
tool: Some("black".to_string()),
|
||||
targets: Some(vec!["src".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
// ruff check is a linter; ensure we collect targets
|
||||
assert_parsed(
|
||||
&shlex_split_safe("ruff check ."),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "ruff check .".to_string(),
|
||||
tool: Some("ruff".to_string()),
|
||||
targets: Some(vec![".".to_string()]),
|
||||
}],
|
||||
);
|
||||
|
||||
// ruff format is a formatter
|
||||
assert_parsed(
|
||||
&shlex_split_safe("ruff format pkg/"),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "ruff format pkg/".to_string(),
|
||||
tool: Some("ruff".to_string()),
|
||||
targets: Some(vec!["pkg/".to_string()]),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_pnpm_monorepo_test_and_npm_format_script() {
|
||||
// pnpm -r test in a monorepo should still parse as a test action
|
||||
assert_parsed(
|
||||
&shlex_split_safe("pnpm -r test"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "pnpm -r test".to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
// npm run format should be recognized as a format action
|
||||
assert_parsed(
|
||||
&shlex_split_safe("npm run format -- -w ."),
|
||||
vec![ParsedCommand::Format {
|
||||
cmd: "npm run format -- -w .".to_string(),
|
||||
tool: Some("npm-script:format".to_string()),
|
||||
targets: None,
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn yarn_test_is_parsed_as_test() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("yarn test"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "yarn test".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pytest_file_only_and_go_run_regex() {
|
||||
// pytest invoked with a file path should be captured as a filter
|
||||
assert_parsed(
|
||||
&shlex_split_safe("pytest tests/test_example.py"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "pytest tests/test_example.py".to_string(),
|
||||
}],
|
||||
);
|
||||
|
||||
// go test with -run regex should capture the filter
|
||||
assert_parsed(
|
||||
&shlex_split_safe("go test ./... -run '^TestFoo$'"),
|
||||
vec![ParsedCommand::Test {
|
||||
cmd: "go test ./... -run '^TestFoo$'".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grep_with_query_and_path() {
|
||||
assert_parsed(
|
||||
@@ -1090,30 +812,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eslint_with_config_path_and_target() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("eslint -c .eslintrc.json src"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "eslint -c .eslintrc.json src".to_string(),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets: Some(vec!["src".to_string()]),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npx_eslint_with_config_path_and_target() {
|
||||
assert_parsed(
|
||||
&shlex_split_safe("npx eslint -c .eslintrc src"),
|
||||
vec![ParsedCommand::Lint {
|
||||
cmd: "npx eslint -c .eslintrc src".to_string(),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets: Some(vec!["src".to_string()]),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fd_file_finder_variants() {
|
||||
assert_parsed(
|
||||
@@ -1171,7 +869,7 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
let parts = if contains_connectors(&normalized) {
|
||||
split_on_connectors(&normalized)
|
||||
} else {
|
||||
vec![normalized.clone()]
|
||||
vec![normalized]
|
||||
};
|
||||
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
@@ -1197,21 +895,18 @@ fn simplify_once(commands: &[ParsedCommand]) -> Option<Vec<ParsedCommand>> {
|
||||
|
||||
// echo ... && ...rest => ...rest
|
||||
if let ParsedCommand::Unknown { cmd } = &commands[0]
|
||||
&& shlex_split(cmd).is_some_and(|t| t.first().map(|s| s.as_str()) == Some("echo"))
|
||||
&& shlex_split(cmd).is_some_and(|t| t.first().map(String::as_str) == Some("echo"))
|
||||
{
|
||||
return Some(commands[1..].to_vec());
|
||||
}
|
||||
|
||||
// cd foo && [any Test command] => [any Test command]
|
||||
// cd foo && [any command] => [any command] (keep non-cd when a cd is followed by something)
|
||||
if let Some(idx) = commands.iter().position(|pc| match pc {
|
||||
ParsedCommand::Unknown { cmd } => {
|
||||
shlex_split(cmd).is_some_and(|t| t.first().map(|s| s.as_str()) == Some("cd"))
|
||||
shlex_split(cmd).is_some_and(|t| t.first().map(String::as_str) == Some("cd"))
|
||||
}
|
||||
_ => false,
|
||||
}) && commands
|
||||
.iter()
|
||||
.skip(idx + 1)
|
||||
.any(|pc| matches!(pc, ParsedCommand::Test { .. }))
|
||||
}) && commands.len() > idx + 1
|
||||
{
|
||||
let mut out = Vec::with_capacity(commands.len() - 1);
|
||||
out.extend_from_slice(&commands[..idx]);
|
||||
@@ -1220,10 +915,10 @@ fn simplify_once(commands: &[ParsedCommand]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
|
||||
// cmd || true => cmd
|
||||
if let Some(idx) = commands.iter().position(|pc| match pc {
|
||||
ParsedCommand::Noop { cmd } => cmd == "true",
|
||||
_ => false,
|
||||
}) {
|
||||
if let Some(idx) = commands
|
||||
.iter()
|
||||
.position(|pc| matches!(pc, ParsedCommand::Unknown { cmd } if cmd == "true"))
|
||||
{
|
||||
let mut out = Vec::with_capacity(commands.len() - 1);
|
||||
out.extend_from_slice(&commands[..idx]);
|
||||
out.extend_from_slice(&commands[idx + 1..]);
|
||||
@@ -1341,7 +1036,7 @@ fn short_display_path(path: &str) -> String {
|
||||
});
|
||||
parts
|
||||
.next()
|
||||
.map(|s| s.to_string())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| trimmed.to_string())
|
||||
}
|
||||
|
||||
@@ -1377,75 +1072,6 @@ fn skip_flag_values<'a>(args: &'a [String], flags_with_vals: &[&str]) -> Vec<&'a
|
||||
out
|
||||
}
|
||||
|
||||
/// Common flags for ESLint that take a following value and should not be
|
||||
/// considered positional targets.
|
||||
const ESLINT_FLAGS_WITH_VALUES: &[&str] = &[
|
||||
"-c",
|
||||
"--config",
|
||||
"--parser",
|
||||
"--parser-options",
|
||||
"--rulesdir",
|
||||
"--plugin",
|
||||
"--max-warnings",
|
||||
"--format",
|
||||
];
|
||||
|
||||
fn collect_non_flag_targets(args: &[String]) -> Option<Vec<String>> {
|
||||
let mut targets = Vec::new();
|
||||
let mut skip_next = false;
|
||||
for (i, a) in args.iter().enumerate() {
|
||||
if a == "--" {
|
||||
break;
|
||||
}
|
||||
if skip_next {
|
||||
skip_next = false;
|
||||
continue;
|
||||
}
|
||||
if a == "-p"
|
||||
|| a == "--package"
|
||||
|| a == "--features"
|
||||
|| a == "-C"
|
||||
|| a == "--config"
|
||||
|| a == "--config-path"
|
||||
|| a == "--out-dir"
|
||||
|| a == "-o"
|
||||
|| a == "--run"
|
||||
|| a == "--max-warnings"
|
||||
|| a == "--format"
|
||||
{
|
||||
if i + 1 < args.len() {
|
||||
skip_next = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if a.starts_with('-') {
|
||||
continue;
|
||||
}
|
||||
targets.push(a.clone());
|
||||
}
|
||||
if targets.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(targets)
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_non_flag_targets_with_flags(
|
||||
args: &[String],
|
||||
flags_with_vals: &[&str],
|
||||
) -> Option<Vec<String>> {
|
||||
let targets: Vec<String> = skip_flag_values(args, flags_with_vals)
|
||||
.into_iter()
|
||||
.filter(|a| !a.starts_with('-'))
|
||||
.cloned()
|
||||
.collect();
|
||||
if targets.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(targets)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_pathish(s: &str) -> bool {
|
||||
s == "."
|
||||
|| s == ".."
|
||||
@@ -1514,47 +1140,6 @@ fn parse_find_query_and_path(tail: &[String]) -> (Option<String>, Option<String>
|
||||
(query, path)
|
||||
}
|
||||
|
||||
fn classify_npm_like(tool: &str, tail: &[String], full_cmd: &[String]) -> Option<ParsedCommand> {
|
||||
let mut r = tail;
|
||||
if tool == "pnpm" && r.first().map(|s| s.as_str()) == Some("-r") {
|
||||
r = &r[1..];
|
||||
}
|
||||
let mut script_name: Option<String> = None;
|
||||
if r.first().map(|s| s.as_str()) == Some("run") {
|
||||
script_name = r.get(1).cloned();
|
||||
} else {
|
||||
let is_test_cmd = (tool == "npm" && r.first().map(|s| s.as_str()) == Some("t"))
|
||||
|| ((tool == "npm" || tool == "pnpm" || tool == "yarn")
|
||||
&& r.first().map(|s| s.as_str()) == Some("test"));
|
||||
if is_test_cmd {
|
||||
script_name = Some("test".to_string());
|
||||
}
|
||||
}
|
||||
if let Some(name) = script_name {
|
||||
let lname = name.to_lowercase();
|
||||
if lname == "test" || lname == "unit" || lname == "jest" || lname == "vitest" {
|
||||
return Some(ParsedCommand::Test {
|
||||
cmd: shlex_join(full_cmd),
|
||||
});
|
||||
}
|
||||
if lname == "lint" || lname == "eslint" {
|
||||
return Some(ParsedCommand::Lint {
|
||||
cmd: shlex_join(full_cmd),
|
||||
tool: Some(format!("{tool}-script:{name}")),
|
||||
targets: None,
|
||||
});
|
||||
}
|
||||
if lname == "format" || lname == "fmt" || lname == "prettier" {
|
||||
return Some(ParsedCommand::Format {
|
||||
cmd: shlex_join(full_cmd),
|
||||
tool: Some(format!("{tool}-script:{name}")),
|
||||
targets: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
let [bash, flag, script] = original else {
|
||||
return None;
|
||||
@@ -1572,10 +1157,8 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
// bias toward the primary command when pipelines are present.
|
||||
// First, drop obvious small formatting helpers (e.g., wc/awk/etc).
|
||||
let had_multiple_commands = all_commands.len() > 1;
|
||||
// The bash AST walker yields commands in right-to-left order for
|
||||
// connector/pipeline sequences. Reverse to reflect actual execution order.
|
||||
let mut filtered_commands = drop_small_formatting_commands(all_commands);
|
||||
filtered_commands.reverse();
|
||||
// Commands arrive in source order; drop formatting helpers while preserving it.
|
||||
let filtered_commands = drop_small_formatting_commands(all_commands);
|
||||
if filtered_commands.is_empty() {
|
||||
return Some(vec![ParsedCommand::Unknown {
|
||||
cmd: script.clone(),
|
||||
@@ -1586,7 +1169,11 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
.map(|tokens| summarize_main_tokens(&tokens))
|
||||
.collect();
|
||||
if commands.len() > 1 {
|
||||
commands.retain(|pc| !matches!(pc, ParsedCommand::Noop { .. }));
|
||||
commands.retain(|pc| !matches!(pc, ParsedCommand::Unknown { cmd } if cmd == "true"));
|
||||
// Apply the same simplifications used for non-bash parsing, e.g., drop leading `cd`.
|
||||
while let Some(next) = simplify_once(&commands) {
|
||||
commands = next;
|
||||
}
|
||||
}
|
||||
if commands.len() == 1 {
|
||||
// If we reduced to a single command, attribute the full original script
|
||||
@@ -1604,8 +1191,8 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
if had_connectors {
|
||||
let has_pipe = script_tokens.iter().any(|t| t == "|");
|
||||
let has_sed_n = script_tokens.windows(2).any(|w| {
|
||||
w.first().map(|s| s.as_str()) == Some("sed")
|
||||
&& w.get(1).map(|s| s.as_str()) == Some("-n")
|
||||
w.first().map(String::as_str) == Some("sed")
|
||||
&& w.get(1).map(String::as_str) == Some("-n")
|
||||
});
|
||||
if has_pipe && has_sed_n {
|
||||
ParsedCommand::Read {
|
||||
@@ -1613,10 +1200,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: cmd.clone(),
|
||||
name,
|
||||
}
|
||||
ParsedCommand::Read { cmd, name }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
@@ -1627,10 +1211,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
ParsedCommand::ListFiles { path, cmd, .. } => {
|
||||
if had_connectors {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: cmd.clone(),
|
||||
path,
|
||||
}
|
||||
ParsedCommand::ListFiles { cmd, path }
|
||||
} else {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
@@ -1642,11 +1223,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
query, path, cmd, ..
|
||||
} => {
|
||||
if had_connectors {
|
||||
ParsedCommand::Search {
|
||||
cmd: cmd.clone(),
|
||||
query,
|
||||
path,
|
||||
}
|
||||
ParsedCommand::Search { cmd, query, path }
|
||||
} else {
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
@@ -1655,27 +1232,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
ParsedCommand::Format {
|
||||
tool, targets, cmd, ..
|
||||
} => ParsedCommand::Format {
|
||||
cmd: cmd.clone(),
|
||||
tool,
|
||||
targets,
|
||||
},
|
||||
ParsedCommand::Test { cmd, .. } => ParsedCommand::Test { cmd: cmd.clone() },
|
||||
ParsedCommand::Lint {
|
||||
tool, targets, cmd, ..
|
||||
} => ParsedCommand::Lint {
|
||||
cmd: cmd.clone(),
|
||||
tool,
|
||||
targets,
|
||||
},
|
||||
ParsedCommand::Unknown { .. } => ParsedCommand::Unknown {
|
||||
cmd: script.clone(),
|
||||
},
|
||||
ParsedCommand::Noop { .. } => ParsedCommand::Noop {
|
||||
cmd: script.clone(),
|
||||
},
|
||||
other => other,
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
@@ -1715,7 +1272,7 @@ fn is_small_formatting_command(tokens: &[String]) -> bool {
|
||||
// Keep `sed -n <range> file` (treated as a file read elsewhere);
|
||||
// otherwise consider it a formatting helper in a pipeline.
|
||||
tokens.len() < 4
|
||||
|| !(tokens[1] == "-n" && is_valid_sed_n_arg(tokens.get(2).map(|s| s.as_str())))
|
||||
|| !(tokens[1] == "-n" && is_valid_sed_n_arg(tokens.get(2).map(String::as_str)))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
@@ -1728,124 +1285,6 @@ fn drop_small_formatting_commands(mut commands: Vec<Vec<String>>) -> Vec<Vec<Str
|
||||
|
||||
fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
match main_cmd.split_first() {
|
||||
Some((head, tail)) if head == "true" && tail.is_empty() => ParsedCommand::Noop {
|
||||
cmd: shlex_join(main_cmd),
|
||||
},
|
||||
// (sed-specific logic handled below in dedicated arm returning Read)
|
||||
Some((head, tail))
|
||||
if head == "cargo" && tail.first().map(|s| s.as_str()) == Some("fmt") =>
|
||||
{
|
||||
ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("cargo fmt".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
Some((head, tail))
|
||||
if head == "cargo" && tail.first().map(|s| s.as_str()) == Some("clippy") =>
|
||||
{
|
||||
ParsedCommand::Lint {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("cargo clippy".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
Some((head, tail))
|
||||
if head == "cargo" && tail.first().map(|s| s.as_str()) == Some("test") =>
|
||||
{
|
||||
ParsedCommand::Test {
|
||||
cmd: shlex_join(main_cmd),
|
||||
}
|
||||
}
|
||||
Some((head, tail)) if head == "rustfmt" => ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("rustfmt".to_string()),
|
||||
targets: collect_non_flag_targets(tail),
|
||||
},
|
||||
Some((head, tail)) if head == "go" && tail.first().map(|s| s.as_str()) == Some("fmt") => {
|
||||
ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("go fmt".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
Some((head, tail)) if head == "go" && tail.first().map(|s| s.as_str()) == Some("test") => {
|
||||
ParsedCommand::Test {
|
||||
cmd: shlex_join(main_cmd),
|
||||
}
|
||||
}
|
||||
Some((head, _)) if head == "pytest" => ParsedCommand::Test {
|
||||
cmd: shlex_join(main_cmd),
|
||||
},
|
||||
Some((head, tail)) if head == "eslint" => {
|
||||
// Treat configuration flags with values (e.g. `-c .eslintrc`) as non-targets.
|
||||
let targets = collect_non_flag_targets_with_flags(tail, ESLINT_FLAGS_WITH_VALUES);
|
||||
ParsedCommand::Lint {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets,
|
||||
}
|
||||
}
|
||||
Some((head, tail)) if head == "prettier" => ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("prettier".to_string()),
|
||||
targets: collect_non_flag_targets(tail),
|
||||
},
|
||||
Some((head, tail)) if head == "black" => ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("black".to_string()),
|
||||
targets: collect_non_flag_targets(tail),
|
||||
},
|
||||
Some((head, tail))
|
||||
if head == "ruff" && tail.first().map(|s| s.as_str()) == Some("check") =>
|
||||
{
|
||||
ParsedCommand::Lint {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("ruff".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
Some((head, tail))
|
||||
if head == "ruff" && tail.first().map(|s| s.as_str()) == Some("format") =>
|
||||
{
|
||||
ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("ruff".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
Some((head, _)) if (head == "jest" || head == "vitest") => ParsedCommand::Test {
|
||||
cmd: shlex_join(main_cmd),
|
||||
},
|
||||
Some((head, tail))
|
||||
if head == "npx" && tail.first().map(|s| s.as_str()) == Some("eslint") =>
|
||||
{
|
||||
let targets = collect_non_flag_targets_with_flags(&tail[1..], ESLINT_FLAGS_WITH_VALUES);
|
||||
ParsedCommand::Lint {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("eslint".to_string()),
|
||||
targets,
|
||||
}
|
||||
}
|
||||
Some((head, tail))
|
||||
if head == "npx" && tail.first().map(|s| s.as_str()) == Some("prettier") =>
|
||||
{
|
||||
ParsedCommand::Format {
|
||||
cmd: shlex_join(main_cmd),
|
||||
tool: Some("prettier".to_string()),
|
||||
targets: collect_non_flag_targets(&tail[1..]),
|
||||
}
|
||||
}
|
||||
// NPM-like scripts including yarn
|
||||
Some((tool, tail)) if (tool == "pnpm" || tool == "npm" || tool == "yarn") => {
|
||||
if let Some(cmd) = classify_npm_like(tool, tail, main_cmd) {
|
||||
cmd
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
cmd: shlex_join(main_cmd),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((head, tail)) if head == "ls" => {
|
||||
// Avoid treating option values as paths (e.g., ls -I "*.test.js").
|
||||
let candidates = skip_flag_values(
|
||||
@@ -1880,7 +1319,7 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
(None, non_flags.first().map(|s| short_display_path(s)))
|
||||
} else {
|
||||
(
|
||||
non_flags.first().cloned().map(|s| s.to_string()),
|
||||
non_flags.first().cloned().map(String::from),
|
||||
non_flags.get(1).map(|s| short_display_path(s)),
|
||||
)
|
||||
};
|
||||
@@ -1915,7 +1354,7 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
.collect();
|
||||
// Do not shorten the query: grep patterns may legitimately contain slashes
|
||||
// and should be preserved verbatim. Only paths should be shortened.
|
||||
let query = non_flags.first().cloned().map(|s| s.to_string());
|
||||
let query = non_flags.first().cloned().map(String::from);
|
||||
let path = non_flags.get(1).map(|s| short_display_path(s));
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(main_cmd),
|
||||
@@ -1925,7 +1364,7 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
}
|
||||
Some((head, tail)) if head == "cat" => {
|
||||
// Support both `cat <file>` and `cat -- <file>` forms.
|
||||
let effective_tail: &[String] = if tail.first().map(|s| s.as_str()) == Some("--") {
|
||||
let effective_tail: &[String] = if tail.first().map(String::as_str) == Some("--") {
|
||||
&tail[1..]
|
||||
} else {
|
||||
tail
|
||||
@@ -2041,7 +1480,7 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
if head == "sed"
|
||||
&& tail.len() >= 3
|
||||
&& tail[0] == "-n"
|
||||
&& is_valid_sed_n_arg(tail.get(1).map(|s| s.as_str())) =>
|
||||
&& is_valid_sed_n_arg(tail.get(1).map(String::as_str)) =>
|
||||
{
|
||||
if let Some(path) = tail.get(2) {
|
||||
let name = short_display_path(path);
|
||||
|
||||
@@ -115,7 +115,7 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
// Build chain from cwd upwards and detect git root.
|
||||
let mut chain: Vec<PathBuf> = vec![dir.clone()];
|
||||
let mut git_root: Option<PathBuf> = None;
|
||||
let mut cursor = dir.clone();
|
||||
let mut cursor = dir;
|
||||
while let Some(parent) = cursor.parent() {
|
||||
let git_marker = cursor.join(".git");
|
||||
let git_exists = match std::fs::metadata(&git_marker) {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
You are a summarization assistant. A conversation follows between a user and a coding-focused AI (Codex). Your task is to generate a clear summary capturing:
|
||||
|
||||
• High-level objective or problem being solved
|
||||
• Key instructions or design decisions given by the user
|
||||
• Main code actions or behaviors from the AI
|
||||
• Important variables, functions, modules, or outputs discussed
|
||||
• Any unresolved questions or next steps
|
||||
|
||||
Produce the summary in a structured format like:
|
||||
|
||||
**Objective:** …
|
||||
|
||||
**User instructions:** … (bulleted)
|
||||
|
||||
**AI actions / code behavior:** … (bulleted)
|
||||
|
||||
**Important entities:** … (e.g. function names, variables, files)
|
||||
|
||||
**Open issues / next steps:** … (if any)
|
||||
|
||||
**Summary (concise):** (one or two sentences)
|
||||
55
codex-rs/core/src/review_format.rs
Normal file
55
codex-rs/core/src/review_format.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use crate::protocol::ReviewFinding;
|
||||
|
||||
// Note: We keep this module UI-agnostic. It returns plain strings that
|
||||
// higher layers (e.g., TUI) may style as needed.
|
||||
|
||||
fn format_location(item: &ReviewFinding) -> String {
|
||||
let path = item.code_location.absolute_file_path.display();
|
||||
let start = item.code_location.line_range.start;
|
||||
let end = item.code_location.line_range.end;
|
||||
format!("{path}:{start}-{end}")
|
||||
}
|
||||
|
||||
/// Format a full review findings block as plain text lines.
|
||||
///
|
||||
/// - When `selection` is `Some`, each item line includes a checkbox marker:
|
||||
/// "[x]" for selected items and "[ ]" for unselected. Missing indices
|
||||
/// default to selected.
|
||||
/// - When `selection` is `None`, the marker is omitted and a simple bullet is
|
||||
/// rendered ("- Title — path:start-end").
|
||||
pub fn format_review_findings_block(
|
||||
findings: &[ReviewFinding],
|
||||
selection: Option<&[bool]>,
|
||||
) -> String {
|
||||
let mut lines: Vec<String> = Vec::new();
|
||||
lines.push(String::new());
|
||||
|
||||
// Header
|
||||
if findings.len() > 1 {
|
||||
lines.push("Full review comments:".to_string());
|
||||
} else {
|
||||
lines.push("Review comment:".to_string());
|
||||
}
|
||||
|
||||
for (idx, item) in findings.iter().enumerate() {
|
||||
lines.push(String::new());
|
||||
|
||||
let title = &item.title;
|
||||
let location = format_location(item);
|
||||
|
||||
if let Some(flags) = selection {
|
||||
// Default to selected if index is out of bounds.
|
||||
let checked = flags.get(idx).copied().unwrap_or(true);
|
||||
let marker = if checked { "[x]" } else { "[ ]" };
|
||||
lines.push(format!("- {marker} {title} — {location}"));
|
||||
} else {
|
||||
lines.push(format!("- {title} — {location}"));
|
||||
}
|
||||
|
||||
for body_line in item.body.lines() {
|
||||
lines.push(format!(" {body_line}"));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
@@ -1,368 +0,0 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(
|
||||
config: &Config,
|
||||
uuid: Uuid,
|
||||
instructions: Option<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
|
||||
Ok(Self { tx })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
path: &Path,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => items.push(item),
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
None,
|
||||
cwd,
|
||||
));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
session_id: Uuid,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let date_str = timestamp
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{session_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)?;
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file
|
||||
writer.write_line(&session_meta_with_git).await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => {
|
||||
writer.write_line(&item).await?;
|
||||
}
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
writer
|
||||
.write_line(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
let _ = self.file.write_all(json.as_bytes()).await;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
385
codex-rs/core/src/rollout/list.rs
Normal file
385
codex-rs/core/src/rollout/list.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::io::{self};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_file_search as file_search;
|
||||
use std::num::NonZero;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use time::OffsetDateTime;
|
||||
use time::PrimitiveDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
|
||||
/// Returned page of conversation summaries.
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct ConversationsPage {
|
||||
/// Conversation summaries ordered newest first.
|
||||
pub items: Vec<ConversationItem>,
|
||||
/// Opaque pagination token to resume after the last item, or `None` if end.
|
||||
pub next_cursor: Option<Cursor>,
|
||||
/// Total number of files touched while scanning this request.
|
||||
pub num_scanned_files: usize,
|
||||
/// True if a hard scan cap was hit; consider resuming with `next_cursor`.
|
||||
pub reached_scan_cap: bool,
|
||||
}
|
||||
|
||||
/// Summary information for a conversation rollout file.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ConversationItem {
|
||||
/// Absolute path to the rollout file.
|
||||
pub path: PathBuf,
|
||||
/// First up to 5 JSONL records parsed as JSON (includes meta line).
|
||||
pub head: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Hard cap to bound worst‑case work per request.
|
||||
const MAX_SCAN_FILES: usize = 100;
|
||||
const HEAD_RECORD_LIMIT: usize = 10;
|
||||
|
||||
/// Pagination cursor identifying a file by timestamp and UUID.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Cursor {
|
||||
ts: OffsetDateTime,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl Cursor {
|
||||
fn new(ts: OffsetDateTime, id: Uuid) -> Self {
|
||||
Self { ts, id }
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for Cursor {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let ts_str = self
|
||||
.ts
|
||||
.format(&format_description!(
|
||||
"[year]-[month]-[day]T[hour]-[minute]-[second]"
|
||||
))
|
||||
.map_err(|e| serde::ser::Error::custom(format!("format error: {e}")))?;
|
||||
serializer.serialize_str(&format!("{ts_str}|{}", self.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for Cursor {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
parse_cursor(&s).ok_or_else(|| serde::de::Error::custom("invalid cursor"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve recorded conversation file paths with token pagination. The returned `next_cursor`
|
||||
/// can be supplied on the next call to resume after the last returned item, resilient to
|
||||
/// concurrent new sessions being appended. Ordering is stable by timestamp desc, then UUID desc.
|
||||
pub(crate) async fn get_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut root = codex_home.to_path_buf();
|
||||
root.push(SESSIONS_SUBDIR);
|
||||
|
||||
if !root.exists() {
|
||||
return Ok(ConversationsPage {
|
||||
items: Vec::new(),
|
||||
next_cursor: None,
|
||||
num_scanned_files: 0,
|
||||
reached_scan_cap: false,
|
||||
});
|
||||
}
|
||||
|
||||
let anchor = cursor.cloned();
|
||||
|
||||
let result = traverse_directories_for_paths(root.clone(), page_size, anchor).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Load the full contents of a single conversation session file at `path`.
|
||||
/// Returns the entire file contents as a String.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn get_conversation(path: &Path) -> io::Result<String> {
|
||||
tokio::fs::read_to_string(path).await
|
||||
}
|
||||
|
||||
/// Load conversation file paths from disk using directory traversal.
|
||||
///
|
||||
/// Directory layout: `~/.codex/sessions/YYYY/MM/DD/rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl`
|
||||
/// Returned newest (latest) first.
|
||||
async fn traverse_directories_for_paths(
|
||||
root: PathBuf,
|
||||
page_size: usize,
|
||||
anchor: Option<Cursor>,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut items: Vec<ConversationItem> = Vec::with_capacity(page_size);
|
||||
let mut scanned_files = 0usize;
|
||||
let mut anchor_passed = anchor.is_none();
|
||||
let (anchor_ts, anchor_id) = match anchor {
|
||||
Some(c) => (c.ts, c.id),
|
||||
None => (OffsetDateTime::UNIX_EPOCH, Uuid::nil()),
|
||||
};
|
||||
|
||||
let year_dirs = collect_dirs_desc(&root, |s| s.parse::<u16>().ok()).await?;
|
||||
|
||||
'outer: for (_year, year_path) in year_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break;
|
||||
}
|
||||
let month_dirs = collect_dirs_desc(year_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_month, month_path) in month_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let day_dirs = collect_dirs_desc(month_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_day, day_path) in day_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let mut day_files = collect_files(day_path, |name_str, path| {
|
||||
if !name_str.starts_with("rollout-") || !name_str.ends_with(".jsonl") {
|
||||
return None;
|
||||
}
|
||||
|
||||
parse_timestamp_uuid_from_filename(name_str)
|
||||
.map(|(ts, id)| (ts, id, name_str.to_string(), path.to_path_buf()))
|
||||
})
|
||||
.await?;
|
||||
// Stable ordering within the same second: (timestamp desc, uuid desc)
|
||||
day_files.sort_by_key(|(ts, sid, _name_str, _path)| (Reverse(*ts), Reverse(*sid)));
|
||||
for (ts, sid, _name_str, path) in day_files.into_iter() {
|
||||
scanned_files += 1;
|
||||
if scanned_files >= MAX_SCAN_FILES && items.len() >= page_size {
|
||||
break 'outer;
|
||||
}
|
||||
if !anchor_passed {
|
||||
if ts < anchor_ts || (ts == anchor_ts && sid < anchor_id) {
|
||||
anchor_passed = true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if items.len() == page_size {
|
||||
break 'outer;
|
||||
}
|
||||
// Read head and simultaneously detect message events within the same
|
||||
// first N JSONL records to avoid a second file read.
|
||||
let (head, saw_session_meta, saw_user_event) =
|
||||
read_head_and_flags(&path, HEAD_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or((Vec::new(), false, false));
|
||||
// Apply filters: must have session meta and at least one user message event
|
||||
if saw_session_meta && saw_user_event {
|
||||
items.push(ConversationItem { path, head });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let next = build_next_cursor(&items);
|
||||
Ok(ConversationsPage {
|
||||
items,
|
||||
next_cursor: next,
|
||||
num_scanned_files: scanned_files,
|
||||
reached_scan_cap: scanned_files >= MAX_SCAN_FILES,
|
||||
})
|
||||
}
|
||||
|
||||
/// Pagination cursor token format: "<file_ts>|<uuid>" where `file_ts` matches the
|
||||
/// filename timestamp portion (YYYY-MM-DDThh-mm-ss) used in rollout filenames.
|
||||
/// The cursor orders files by timestamp desc, then UUID desc.
|
||||
fn parse_cursor(token: &str) -> Option<Cursor> {
|
||||
let (file_ts, uuid_str) = token.split_once('|')?;
|
||||
|
||||
let Ok(uuid) = Uuid::parse_str(uuid_str) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
|
||||
|
||||
Some(Cursor::new(ts, uuid))
|
||||
}
|
||||
|
||||
fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
|
||||
let last = items.last()?;
|
||||
let file_name = last.path.file_name()?.to_string_lossy();
|
||||
let (ts, id) = parse_timestamp_uuid_from_filename(&file_name)?;
|
||||
Some(Cursor::new(ts, id))
|
||||
}
|
||||
|
||||
/// Collects immediate subdirectories of `parent`, parses their (string) names with `parse`,
|
||||
/// and returns them sorted descending by the parsed key.
|
||||
async fn collect_dirs_desc<T, F>(parent: &Path, parse: F) -> io::Result<Vec<(T, PathBuf)>>
|
||||
where
|
||||
T: Ord + Copy,
|
||||
F: Fn(&str) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut vec: Vec<(T, PathBuf)> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_dir())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s)
|
||||
{
|
||||
vec.push((v, entry.path()));
|
||||
}
|
||||
}
|
||||
vec.sort_by_key(|(v, _)| Reverse(*v));
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
/// Collects files in a directory and parses them with `parse`.
|
||||
async fn collect_files<T, F>(parent: &Path, parse: F) -> io::Result<Vec<T>>
|
||||
where
|
||||
F: Fn(&str, &Path) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut collected: Vec<T> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_file())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s, &entry.path())
|
||||
{
|
||||
collected.push(v);
|
||||
}
|
||||
}
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> {
|
||||
// Expected: rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl
|
||||
let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?;
|
||||
|
||||
// Scan from the right for a '-' such that the suffix parses as a UUID.
|
||||
let (sep_idx, uuid) = core
|
||||
.match_indices('-')
|
||||
.rev()
|
||||
.find_map(|(i, _)| Uuid::parse_str(&core[i + 1..]).ok().map(|u| (i, u)))?;
|
||||
|
||||
let ts_str = &core[..sep_idx];
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(ts_str, format).ok()?.assume_utc();
|
||||
Some((ts, uuid))
|
||||
}
|
||||
|
||||
async fn read_head_and_flags(
|
||||
path: &Path,
|
||||
max_records: usize,
|
||||
) -> io::Result<(Vec<serde_json::Value>, bool, bool)> {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let file = tokio::fs::File::open(path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut head: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_session_meta = false;
|
||||
let mut saw_user_event = false;
|
||||
|
||||
while head.len() < max_records {
|
||||
let line_opt = lines.next_line().await?;
|
||||
let Some(line) = line_opt else { break };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parsed: Result<RolloutLine, _> = serde_json::from_str(trimmed);
|
||||
let Ok(rollout_line) = parsed else { continue };
|
||||
|
||||
match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
if let Ok(val) = serde_json::to_value(session_meta_line) {
|
||||
head.push(val);
|
||||
saw_session_meta = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
if let Ok(val) = serde_json::to_value(item) {
|
||||
head.push(val);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::EventMsg(ev) => {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
saw_user_event = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((head, saw_session_meta, saw_user_event))
|
||||
}
|
||||
|
||||
/// Locate a recorded conversation rollout file by its UUID string using the existing
|
||||
/// paginated listing implementation. Returns `Ok(Some(path))` if found, `Ok(None)` if not present
|
||||
/// or the id is invalid.
|
||||
pub async fn find_conversation_path_by_id_str(
|
||||
codex_home: &Path,
|
||||
id_str: &str,
|
||||
) -> io::Result<Option<PathBuf>> {
|
||||
// Validate UUID format early.
|
||||
if Uuid::parse_str(id_str).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut root = codex_home.to_path_buf();
|
||||
root.push(SESSIONS_SUBDIR);
|
||||
if !root.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
// This is safe because we know the values are valid.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let limit = NonZero::new(1).unwrap();
|
||||
// This is safe because we know the values are valid.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let threads = NonZero::new(2).unwrap();
|
||||
let cancel = Arc::new(AtomicBool::new(false));
|
||||
let exclude: Vec<String> = Vec::new();
|
||||
let compute_indices = false;
|
||||
|
||||
let results = file_search::run(
|
||||
id_str,
|
||||
limit,
|
||||
&root,
|
||||
exclude,
|
||||
threads,
|
||||
cancel,
|
||||
compute_indices,
|
||||
)
|
||||
.map_err(|e| io::Error::other(format!("file search failed: {e}")))?;
|
||||
|
||||
Ok(results
|
||||
.matches
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|m| root.join(m.path)))
|
||||
}
|
||||
16
codex-rs/core/src/rollout/mod.rs
Normal file
16
codex-rs/core/src/rollout/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Rollout module: persistence and discovery of session rollout files.
|
||||
|
||||
pub const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
|
||||
|
||||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub use list::find_conversation_path_by_id_str;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
76
codex-rs/core/src/rollout/policy.rs
Normal file
76
codex-rs/core/src/rollout/policy.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
/// Whether a rollout `item` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether an `EventMsg` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
||||
match ev {
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::EnteredReviewMode(_)
|
||||
| EventMsg::ExitedReviewMode(_)
|
||||
| EventMsg::TurnAborted(_) => true,
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::CompactApprovalRequest(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationPath(_) => false,
|
||||
}
|
||||
}
|
||||
423
codex-rs/core/src/rollout/recorder.rs
Normal file
423
codex-rs/core/src/rollout/recorder.rs
Normal file
@@ -0,0 +1,423 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::list::ConversationsPage;
|
||||
use super::list::Cursor;
|
||||
use super::list::get_conversations;
|
||||
use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::ORIGINATOR;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: ConversationId,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
pub(crate) rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum RolloutRecorderParams {
|
||||
Create {
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
Resume {
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<RolloutItem>),
|
||||
/// Ensure all prior writes are processed; respond when flushed.
|
||||
Flush {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
Shutdown {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RolloutRecorderParams {
|
||||
pub fn new(conversation_id: ConversationId, instructions: Option<String>) -> Self {
|
||||
Self::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resume(path: PathBuf) -> Self {
|
||||
Self::Resume { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
/// List conversations (rollout files) under the provided Codex home directory.
|
||||
pub async fn list_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
) -> std::io::Result<ConversationsPage> {
|
||||
get_conversations(codex_home, page_size, cursor).await
|
||||
}
|
||||
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
|
||||
let (file, rollout_path, meta) = match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
} => {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
(
|
||||
tokio::fs::File::from_std(file),
|
||||
path,
|
||||
Some(SessionMeta {
|
||||
id: session_id,
|
||||
timestamp,
|
||||
cwd: config.cwd.clone(),
|
||||
originator: ORIGINATOR.value.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
instructions,
|
||||
}),
|
||||
)
|
||||
}
|
||||
RolloutRecorderParams::Resume { path } => (
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
path,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
|
||||
|
||||
Ok(Self { tx, rollout_path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if is_persisted_response_item(item) {
|
||||
filtered.push(item.clone());
|
||||
}
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
/// Flush all queued writes and wait until they are committed by the writer task.
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::Flush { ack: tx })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
if text.trim().is_empty() {
|
||||
return Err(IoError::other("empty session file"));
|
||||
}
|
||||
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
let mut conversation_id: Option<ConversationId> = None;
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("failed to parse line as JSON: {line:?}, error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the rollout line structure
|
||||
match serde_json::from_value::<RolloutLine>(v.clone()) {
|
||||
Ok(rollout_line) => match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
// Use the FIRST SessionMeta encountered in the file as the canonical
|
||||
// conversation id and main session information. Keep all items intact.
|
||||
if conversation_id.is_none() {
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
}
|
||||
items.push(RolloutItem::SessionMeta(session_meta_line));
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
RolloutItem::Compacted(item) => {
|
||||
items.push(RolloutItem::Compacted(item));
|
||||
}
|
||||
RolloutItem::TurnContext(item) => {
|
||||
items.push(RolloutItem::TurnContext(item));
|
||||
}
|
||||
RolloutItem::EventMsg(_ev) => {
|
||||
items.push(RolloutItem::EventMsg(_ev));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse rollout line: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Resumed rollout with {} items, conversation ID: {:?}",
|
||||
items.len(),
|
||||
conversation_id
|
||||
);
|
||||
let conversation_id = conversation_id
|
||||
.ok_or_else(|| IoError::other("failed to parse conversation ID from rollout file"))?;
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(InitialHistory::New);
|
||||
}
|
||||
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id,
|
||||
history: items,
|
||||
rollout_path: path.to_path_buf(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn get_rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Full path to the rollout file.
|
||||
path: PathBuf,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
conversation_id: ConversationId,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
}
|
||||
|
||||
fn create_log_file(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let date_str = timestamp
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)?;
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
|
||||
writer
|
||||
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_rollout_item(item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::Flush { ack } => {
|
||||
// Ensure underlying file is flushed and then ack.
|
||||
if let Err(e) = writer.file.flush().await {
|
||||
let _ = ack.send(());
|
||||
return Err(e);
|
||||
}
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> {
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = OffsetDateTime::now_utc()
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let line = RolloutLine {
|
||||
timestamp,
|
||||
item: rollout_item,
|
||||
};
|
||||
self.write_line(&line).await
|
||||
}
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
self.file.write_all(json.as_bytes()).await?;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
445
codex-rs/core/src/rollout/tests.rs
Normal file
445
codex-rs/core/src/rollout/tests.rs
Normal file
@@ -0,0 +1,445 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use tempfile::TempDir;
|
||||
use time::OffsetDateTime;
|
||||
use time::PrimitiveDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rollout::list::ConversationItem;
|
||||
use crate::rollout::list::ConversationsPage;
|
||||
use crate::rollout::list::Cursor;
|
||||
use crate::rollout::list::get_conversation;
|
||||
use crate::rollout::list::get_conversations;
|
||||
|
||||
fn write_session_file(
|
||||
root: &Path,
|
||||
ts_str: &str,
|
||||
uuid: Uuid,
|
||||
num_records: usize,
|
||||
) -> std::io::Result<(OffsetDateTime, Uuid)> {
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let dt = PrimitiveDateTime::parse(ts_str, format)
|
||||
.unwrap()
|
||||
.assume_utc();
|
||||
let dir = root
|
||||
.join("sessions")
|
||||
.join(format!("{:04}", dt.year()))
|
||||
.join(format!("{:02}", u8::from(dt.month())))
|
||||
.join(format!("{:02}", dt.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
let filename = format!("rollout-{ts_str}-{uuid}.jsonl");
|
||||
let file_path = dir.join(filename);
|
||||
let mut file = File::create(file_path)?;
|
||||
|
||||
let meta = serde_json::json!({
|
||||
"timestamp": ts_str,
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": uuid,
|
||||
"timestamp": ts_str,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{meta}")?;
|
||||
|
||||
// Include at least one user message event to satisfy listing filters
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts_str,
|
||||
"type": "event_msg",
|
||||
"payload": {
|
||||
"type": "user_message",
|
||||
"message": "Hello from user",
|
||||
"kind": "plain"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{user_event}")?;
|
||||
|
||||
for i in 0..num_records {
|
||||
let rec = serde_json::json!({
|
||||
"record_type": "response",
|
||||
"index": i
|
||||
});
|
||||
writeln!(file, "{rec}")?;
|
||||
}
|
||||
Ok((dt, uuid))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_conversations_latest_first() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
// Fixed UUIDs for deterministic expectations
|
||||
let u1 = Uuid::from_u128(1);
|
||||
let u2 = Uuid::from_u128(2);
|
||||
let u3 = Uuid::from_u128(3);
|
||||
|
||||
// Create three sessions across three days
|
||||
write_session_file(home, "2025-01-01T12-00-00", u1, 3).unwrap();
|
||||
write_session_file(home, "2025-01-02T12-00-00", u2, 3).unwrap();
|
||||
write_session_file(home, "2025-01-03T12-00-00", u3, 3).unwrap();
|
||||
|
||||
let page = get_conversations(home, 10, None).await.unwrap();
|
||||
|
||||
// Build expected objects
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("03")
|
||||
.join(format!("rollout-2025-01-03T12-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-01-02T12-00-00-{u2}.jsonl"));
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-01-01T12-00-00-{u1}.jsonl"));
|
||||
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-01-03T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-01-02T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-01-01T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
|
||||
let expected_cursor: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-01-01T12-00-00|{u1}\"")).unwrap();
|
||||
|
||||
let expected = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p1,
|
||||
head: head_3,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head_2,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head_1,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor),
|
||||
num_scanned_files: 3,
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
|
||||
assert_eq!(page, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pagination_cursor() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
// Fixed UUIDs for deterministic expectations
|
||||
let u1 = Uuid::from_u128(11);
|
||||
let u2 = Uuid::from_u128(22);
|
||||
let u3 = Uuid::from_u128(33);
|
||||
let u4 = Uuid::from_u128(44);
|
||||
let u5 = Uuid::from_u128(55);
|
||||
|
||||
// Oldest to newest
|
||||
write_session_file(home, "2025-03-01T09-00-00", u1, 1).unwrap();
|
||||
write_session_file(home, "2025-03-02T09-00-00", u2, 1).unwrap();
|
||||
write_session_file(home, "2025-03-03T09-00-00", u3, 1).unwrap();
|
||||
write_session_file(home, "2025-03-04T09-00-00", u4, 1).unwrap();
|
||||
write_session_file(home, "2025-03-05T09-00-00", u5, 1).unwrap();
|
||||
|
||||
let page1 = get_conversations(home, 2, None).await.unwrap();
|
||||
let p5 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("05")
|
||||
.join(format!("rollout-2025-03-05T09-00-00-{u5}.jsonl"));
|
||||
let p4 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("04")
|
||||
.join(format!("rollout-2025-03-04T09-00-00-{u4}.jsonl"));
|
||||
let head_5 = vec![serde_json::json!({
|
||||
"id": u5,
|
||||
"timestamp": "2025-03-05T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_4 = vec![serde_json::json!({
|
||||
"id": u4,
|
||||
"timestamp": "2025-03-04T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor1: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-04T09-00-00|{u4}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p5,
|
||||
head: head_5,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p4,
|
||||
head: head_4,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor1.clone()),
|
||||
num_scanned_files: 3, // scanned 05, 04, and peeked at 03 before breaking
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page1, expected_page1);
|
||||
|
||||
let page2 = get_conversations(home, 2, page1.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("03")
|
||||
.join(format!("rollout-2025-03-03T09-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-03-02T09-00-00-{u2}.jsonl"));
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-03-03T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-03-02T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor2: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-02T09-00-00|{u2}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head_3,
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head_2,
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02, and peeked at 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page2, expected_page2);
|
||||
|
||||
let page3 = get_conversations(home, 2, page2.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("03")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-03-01T09-00-00-{u1}.jsonl"));
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-03-01T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor3: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-01T09-00-00|{u1}\"")).unwrap();
|
||||
let expected_page3 = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: p1,
|
||||
head: head_1,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor3),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02 (anchor), 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page3, expected_page3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_conversation_contents() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
let uuid = Uuid::new_v4();
|
||||
let ts = "2025-04-01T10-30-00";
|
||||
write_session_file(home, ts, uuid, 2).unwrap();
|
||||
|
||||
let page = get_conversations(home, 1, None).await.unwrap();
|
||||
let path = &page.items[0].path;
|
||||
|
||||
let content = get_conversation(path).await.unwrap();
|
||||
|
||||
// Page equality (single item)
|
||||
let expected_path = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("04")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-04-01T10-30-00-{uuid}.jsonl"));
|
||||
let expected_head = vec![serde_json::json!({
|
||||
"id": uuid,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor: Cursor = serde_json::from_str(&format!("\"{ts}|{uuid}\"")).unwrap();
|
||||
let expected_page = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: expected_path,
|
||||
head: expected_head,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor),
|
||||
num_scanned_files: 1,
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page, expected_page);
|
||||
|
||||
// Entire file contents equality
|
||||
let meta = serde_json::json!({"timestamp": ts, "type": "session_meta", "payload": {"id": uuid, "timestamp": ts, "instructions": null, "cwd": ".", "originator": "test_originator", "cli_version": "test_version"}});
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts,
|
||||
"type": "event_msg",
|
||||
"payload": {"type": "user_message", "message": "Hello from user", "kind": "plain"}
|
||||
});
|
||||
let rec0 = serde_json::json!({"record_type": "response", "index": 0});
|
||||
let rec1 = serde_json::json!({"record_type": "response", "index": 1});
|
||||
let expected_content = format!("{meta}\n{user_event}\n{rec0}\n{rec1}\n");
|
||||
assert_eq!(content, expected_content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stable_ordering_same_second_pagination() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let home = temp.path();
|
||||
|
||||
let ts = "2025-07-01T00-00-00";
|
||||
let u1 = Uuid::from_u128(1);
|
||||
let u2 = Uuid::from_u128(2);
|
||||
let u3 = Uuid::from_u128(3);
|
||||
|
||||
write_session_file(home, ts, u1, 0).unwrap();
|
||||
write_session_file(home, ts, u2, 0).unwrap();
|
||||
write_session_file(home, ts, u3, 0).unwrap();
|
||||
|
||||
let page1 = get_conversations(home, 2, None).await.unwrap();
|
||||
|
||||
let p3 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u3}.jsonl"));
|
||||
let p2 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u2}.jsonl"));
|
||||
let head = |u: Uuid| -> Vec<serde_json::Value> {
|
||||
vec![serde_json::json!({
|
||||
"id": u,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})]
|
||||
};
|
||||
let expected_cursor1: Cursor = serde_json::from_str(&format!("\"{ts}|{u2}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
items: vec![
|
||||
ConversationItem {
|
||||
path: p3,
|
||||
head: head(u3),
|
||||
},
|
||||
ConversationItem {
|
||||
path: p2,
|
||||
head: head(u2),
|
||||
},
|
||||
],
|
||||
next_cursor: Some(expected_cursor1.clone()),
|
||||
num_scanned_files: 3, // scanned u3, u2, peeked u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page1, expected_page1);
|
||||
|
||||
let page2 = get_conversations(home, 2, page1.next_cursor.as_ref())
|
||||
.await
|
||||
.unwrap();
|
||||
let p1 = home
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u1}.jsonl"));
|
||||
let expected_cursor2: Cursor = serde_json::from_str(&format!("\"{ts}|{u1}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: p1,
|
||||
head: head(u1),
|
||||
}],
|
||||
next_cursor: Some(expected_cursor2),
|
||||
num_scanned_files: 3, // scanned u3, u2 (anchor), u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
assert_eq!(page2, expected_page2);
|
||||
}
|
||||
@@ -53,6 +53,13 @@ pub fn assess_patch_safety(
|
||||
// paths outside the project.
|
||||
match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
None if sandbox_policy == &SandboxPolicy::DangerFullAccess => {
|
||||
// If the user has explicitly requested DangerFullAccess, then
|
||||
// we can auto-approve even without a sandbox.
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
}
|
||||
}
|
||||
None => SafetyCheck::AskUser,
|
||||
}
|
||||
} else if policy == AskForApproval::Never {
|
||||
@@ -222,7 +229,7 @@ fn is_write_patch_constrained_to_writable_paths(
|
||||
|
||||
for (path, change) in action.changes() {
|
||||
match change {
|
||||
ApplyPatchFileChange::Add { .. } | ApplyPatchFileChange::Delete => {
|
||||
ApplyPatchFileChange::Add { .. } | ApplyPatchFileChange::Delete { .. } => {
|
||||
if !is_path_writable(path) {
|
||||
return false;
|
||||
}
|
||||
@@ -286,7 +293,7 @@ mod tests {
|
||||
// With the parent dir explicitly added as a writable root, the
|
||||
// outside write should be permitted.
|
||||
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![parent.clone()],
|
||||
writable_roots: vec![parent],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -18,19 +18,20 @@ const MACOS_PATH_TO_SEATBELT_EXECUTABLE: &str = "/usr/bin/sandbox-exec";
|
||||
|
||||
pub async fn spawn_command_under_seatbelt(
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: PathBuf,
|
||||
sandbox_policy_cwd: &Path,
|
||||
stdio_policy: StdioPolicy,
|
||||
mut env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
let args = create_seatbelt_command_args(command, sandbox_policy, &cwd);
|
||||
let args = create_seatbelt_command_args(command, sandbox_policy, sandbox_policy_cwd);
|
||||
let arg0 = None;
|
||||
env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
spawn_child_async(
|
||||
PathBuf::from(MACOS_PATH_TO_SEATBELT_EXECUTABLE),
|
||||
args,
|
||||
arg0,
|
||||
cwd,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
stdio_policy,
|
||||
env,
|
||||
@@ -41,7 +42,7 @@ pub async fn spawn_command_under_seatbelt(
|
||||
fn create_seatbelt_command_args(
|
||||
command: Vec<String>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: &Path,
|
||||
sandbox_policy_cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
let (file_write_policy, extra_cli_args) = {
|
||||
if sandbox_policy.has_full_disk_write_access() {
|
||||
@@ -51,7 +52,7 @@ fn create_seatbelt_command_args(
|
||||
Vec::<String>::new(),
|
||||
)
|
||||
} else {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(cwd);
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(sandbox_policy_cwd);
|
||||
|
||||
let mut writable_folder_policies: Vec<String> = Vec::new();
|
||||
let mut cli_args: Vec<String> = Vec::new();
|
||||
@@ -153,7 +154,7 @@ mod tests {
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
writable_roots: vec![root_with_git, root_without_git],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
; inspired by Chrome's sandbox policy:
|
||||
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/common.sb;l=273-319;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
|
||||
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/renderer.sb;l=64;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
|
||||
|
||||
; start with closed-by-default
|
||||
(deny default)
|
||||
@@ -9,7 +10,13 @@
|
||||
; child processes inherit the policy of their parent
|
||||
(allow process-exec)
|
||||
(allow process-fork)
|
||||
(allow signal (target self))
|
||||
(allow signal (target same-sandbox))
|
||||
|
||||
; Allow cf prefs to work.
|
||||
(allow user-preference-read)
|
||||
|
||||
; process-info
|
||||
(allow process-info* (target same-sandbox))
|
||||
|
||||
(allow file-write-data
|
||||
(require-all
|
||||
@@ -32,28 +39,22 @@
|
||||
(sysctl-name "hw.l3cachesize_compat")
|
||||
(sysctl-name "hw.logicalcpu_max")
|
||||
(sysctl-name "hw.machine")
|
||||
(sysctl-name "hw.memsize")
|
||||
(sysctl-name "hw.ncpu")
|
||||
(sysctl-name "hw.nperflevels")
|
||||
(sysctl-name "hw.optional.arm.FEAT_BF16")
|
||||
(sysctl-name "hw.optional.arm.FEAT_DotProd")
|
||||
(sysctl-name "hw.optional.arm.FEAT_FCMA")
|
||||
(sysctl-name "hw.optional.arm.FEAT_FHM")
|
||||
(sysctl-name "hw.optional.arm.FEAT_FP16")
|
||||
(sysctl-name "hw.optional.arm.FEAT_I8MM")
|
||||
(sysctl-name "hw.optional.arm.FEAT_JSCVT")
|
||||
(sysctl-name "hw.optional.arm.FEAT_LSE")
|
||||
(sysctl-name "hw.optional.arm.FEAT_RDM")
|
||||
(sysctl-name "hw.optional.arm.FEAT_SHA512")
|
||||
(sysctl-name "hw.optional.armv8_2_sha512")
|
||||
(sysctl-name "hw.memsize")
|
||||
(sysctl-name "hw.pagesize")
|
||||
; Chrome locks these CPU feature detection down a bit more tightly,
|
||||
; but mostly for fingerprinting concerns which isn't an issue for codex.
|
||||
(sysctl-name-prefix "hw.optional.arm.")
|
||||
(sysctl-name-prefix "hw.optional.armv8_")
|
||||
(sysctl-name "hw.packages")
|
||||
(sysctl-name "hw.pagesize_compat")
|
||||
(sysctl-name "hw.pagesize")
|
||||
(sysctl-name "hw.physicalcpu_max")
|
||||
(sysctl-name "hw.tbfrequency_compat")
|
||||
(sysctl-name "hw.vectorunit")
|
||||
(sysctl-name "kern.hostname")
|
||||
(sysctl-name "kern.maxfilesperproc")
|
||||
(sysctl-name "kern.maxproc")
|
||||
(sysctl-name "kern.osproductversion")
|
||||
(sysctl-name "kern.osrelease")
|
||||
(sysctl-name "kern.ostype")
|
||||
@@ -63,9 +64,27 @@
|
||||
(sysctl-name "kern.usrstack64")
|
||||
(sysctl-name "kern.version")
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name "vm.loadavg")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
(sysctl-name-prefix "kern.proc.pgrp.")
|
||||
(sysctl-name-prefix "kern.proc.pid.")
|
||||
(sysctl-name-prefix "net.routetable.")
|
||||
)
|
||||
|
||||
; IOKit
|
||||
(allow iokit-open
|
||||
(iokit-registry-entry-class "RootDomainUserClient")
|
||||
)
|
||||
|
||||
; needed to look up user info, see https://crbug.com/792228
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.system.opendirectoryd.libinfo")
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.PowerManagement.control")
|
||||
)
|
||||
|
||||
@@ -5,19 +5,26 @@ use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct ZshShell {
|
||||
shell_path: String,
|
||||
zshrc_path: String,
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct BashShell {
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PowerShellConfig {
|
||||
exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
|
||||
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub enum Shell {
|
||||
Zsh(ZshShell),
|
||||
Bash(BashShell),
|
||||
PowerShell(PowerShellConfig),
|
||||
Unknown,
|
||||
}
|
||||
@@ -25,27 +32,19 @@ pub enum Shell {
|
||||
impl Shell {
|
||||
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Shell::Zsh(zsh) => {
|
||||
if !std::path::Path::new(&zsh.zshrc_path).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = vec![zsh.shell_path.clone()];
|
||||
result.push("-lc".to_string());
|
||||
|
||||
let joined = strip_bash_lc(&command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
|
||||
|
||||
if let Some(joined) = joined {
|
||||
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
Shell::Zsh(zsh) => format_shell_invocation_with_rc(
|
||||
command.as_slice(),
|
||||
&zsh.shell_path,
|
||||
&zsh.zshrc_path,
|
||||
),
|
||||
Shell::Bash(bash) => format_shell_invocation_with_rc(
|
||||
command.as_slice(),
|
||||
&bash.shell_path,
|
||||
&bash.bashrc_path,
|
||||
),
|
||||
Shell::PowerShell(ps) => {
|
||||
// If model generated a bash command, prefer a detected bash fallback
|
||||
if let Some(script) = strip_bash_lc(&command) {
|
||||
if let Some(script) = strip_bash_lc(command.as_slice()) {
|
||||
return match &ps.bash_exe_fallback {
|
||||
Some(bash) => Some(vec![
|
||||
bash.to_string_lossy().to_string(),
|
||||
@@ -74,7 +73,7 @@ impl Shell {
|
||||
return Some(command);
|
||||
}
|
||||
|
||||
let joined = shlex::try_join(command.iter().map(|s| s.as_str())).ok();
|
||||
let joined = shlex::try_join(command.iter().map(String::as_str)).ok();
|
||||
return joined.map(|arg| {
|
||||
vec![
|
||||
ps.exe.clone(),
|
||||
@@ -97,14 +96,34 @@ impl Shell {
|
||||
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::PowerShell(ps) => Some(ps.exe.clone()),
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
match command.as_slice() {
|
||||
fn format_shell_invocation_with_rc(
|
||||
command: &[String],
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
) -> Option<Vec<String>> {
|
||||
let joined = strip_bash_lc(command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(String::as_str)).ok())?;
|
||||
|
||||
let rc_command = if std::path::Path::new(rc_path).exists() {
|
||||
format!("source {rc_path} && ({joined})")
|
||||
} else {
|
||||
joined
|
||||
};
|
||||
|
||||
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
|
||||
}
|
||||
|
||||
fn strip_bash_lc(command: &[String]) -> Option<String> {
|
||||
match command {
|
||||
// exactly three items
|
||||
[first, second, third]
|
||||
// first two must be "bash", "-lc"
|
||||
@@ -116,44 +135,43 @@ fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
use whoami;
|
||||
#[cfg(unix)]
|
||||
fn detect_default_user_shell() -> Shell {
|
||||
use libc::getpwuid;
|
||||
use libc::getuid;
|
||||
use std::ffi::CStr;
|
||||
|
||||
let user = whoami::username();
|
||||
let home = format!("/Users/{user}");
|
||||
let output = Command::new("dscl")
|
||||
.args([".", "-read", &home, "UserShell"])
|
||||
.output()
|
||||
.await
|
||||
.ok();
|
||||
match output {
|
||||
Some(o) => {
|
||||
if !o.status.success() {
|
||||
return Shell::Unknown;
|
||||
}
|
||||
let stdout = String::from_utf8_lossy(&o.stdout);
|
||||
for line in stdout.lines() {
|
||||
if let Some(shell_path) = line.strip_prefix("UserShell: ")
|
||||
&& shell_path.ends_with("/zsh")
|
||||
{
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc"),
|
||||
});
|
||||
}
|
||||
unsafe {
|
||||
let uid = getuid();
|
||||
let pw = getpwuid(uid);
|
||||
|
||||
if !pw.is_null() {
|
||||
let shell_path = CStr::from_ptr((*pw).pw_shell)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
|
||||
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path,
|
||||
zshrc_path: format!("{home_path}/.zshrc"),
|
||||
});
|
||||
}
|
||||
|
||||
Shell::Unknown
|
||||
if shell_path.ends_with("/bash") {
|
||||
return Shell::Bash(BashShell {
|
||||
shell_path,
|
||||
bashrc_path: format!("{home_path}/.bashrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => Shell::Unknown,
|
||||
}
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "macos"), not(target_os = "windows")))]
|
||||
#[cfg(unix)]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
detect_default_user_shell()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
@@ -196,11 +214,17 @@ pub async fn default_user_shell() -> Shell {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "windows"), not(unix)))]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
#[cfg(unix)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command;
|
||||
use std::string::ToString;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_current_shell_detects_zsh() {
|
||||
@@ -230,9 +254,126 @@ mod tests {
|
||||
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(actual_cmd, None);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/zsh".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bashrc_not_exists() {
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".to_string(),
|
||||
bashrc_path: "/does/not/exist/.bashrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bash_escaping_and_execution() {
|
||||
let shell_path = "/bin/bash";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source BASHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let bashrc_path = temp_home.path().join(".bashrc");
|
||||
std::fs::write(
|
||||
&bashrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
bashrc_path: bashrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(ToString::to_string).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: actual_cmd.unwrap(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
temp_home.path(),
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout.text, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
mod macos_tests {
|
||||
use super::*;
|
||||
use std::string::ToString;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
let shell_path = "/bin/zsh";
|
||||
@@ -295,13 +436,10 @@ mod tests {
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
.format_default_shell_invocation(input.iter().map(ToString::to_string).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
@@ -320,6 +458,7 @@ mod tests {
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
temp_home.path(),
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
@@ -422,10 +561,10 @@ mod tests_windows {
|
||||
|
||||
for (shell, input, expected_cmd) in cases {
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
.format_default_shell_invocation(input.iter().map(|s| (*s).to_string()).collect());
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(expected_cmd.iter().map(|s| s.to_string()).collect())
|
||||
Some(expected_cmd.iter().map(|s| (*s).to_string()).collect())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::AuthMode;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
/// Flat info parsed from the JWT in auth.json.
|
||||
@@ -22,43 +20,13 @@ pub struct TokenData {
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl TokenData {
|
||||
/// Returns true if this is a plan that should use the traditional
|
||||
/// "metered" billing via an API key.
|
||||
pub(crate) fn should_use_api_key(
|
||||
&self,
|
||||
preferred_auth_method: AuthMode,
|
||||
is_openai_email: bool,
|
||||
) -> bool {
|
||||
if preferred_auth_method == AuthMode::ApiKey {
|
||||
return true;
|
||||
}
|
||||
// If the email is an OpenAI email, use AuthMode::ChatGPT unless preferred_auth_method is AuthMode::ApiKey.
|
||||
if is_openai_email {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.id_token
|
||||
.chatgpt_plan_type
|
||||
.as_ref()
|
||||
.is_none_or(|plan| plan.is_plan_that_should_use_api_key())
|
||||
}
|
||||
|
||||
pub fn is_openai_email(&self) -> bool {
|
||||
self.id_token
|
||||
.email
|
||||
.as_deref()
|
||||
.is_some_and(|email| email.trim().to_ascii_lowercase().ends_with("@openai.com"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flat subset of useful claims in id_token from auth.json.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
pub email: Option<String>,
|
||||
/// The ChatGPT subscription plan type
|
||||
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
|
||||
/// (Note: ae has not verified that those are the exact values.)
|
||||
/// (Note: values may vary by backend.)
|
||||
pub(crate) chatgpt_plan_type: Option<PlanType>,
|
||||
pub raw_jwt: String,
|
||||
}
|
||||
@@ -79,28 +47,6 @@ pub(crate) enum PlanType {
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl PlanType {
|
||||
fn is_plan_that_should_use_api_key(&self) -> bool {
|
||||
match self {
|
||||
Self::Known(known) => {
|
||||
use KnownPlan::*;
|
||||
!matches!(known, Free | Plus | Pro | Team)
|
||||
}
|
||||
Self::Unknown(_) => {
|
||||
// Unknown plans should use the API key.
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_string(&self) -> String {
|
||||
match self {
|
||||
Self::Known(known) => format!("{known:?}").to_lowercase(),
|
||||
Self::Unknown(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum KnownPlan {
|
||||
@@ -137,7 +83,7 @@ pub enum IdTokenInfoError {
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub(crate) fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
// JWT format: header.payload.signature
|
||||
let mut parts = id_token.split('.');
|
||||
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
|
||||
@@ -204,9 +150,33 @@ mod tests {
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
assert_eq!(
|
||||
info.chatgpt_plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
);
|
||||
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({ "sub": "123" });
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.get_chatgpt_plan_type().is_none());
|
||||
}
|
||||
}
|
||||
19
codex-rs/core/src/tool_apply_patch.lark
Normal file
19
codex-rs/core/src/tool_apply_patch.lark
Normal file
@@ -0,0 +1,19 @@
|
||||
start: begin_patch hunk+ end_patch
|
||||
begin_patch: "*** Begin Patch" LF
|
||||
end_patch: "*** End Patch" LF?
|
||||
|
||||
hunk: add_hunk | delete_hunk | update_hunk
|
||||
add_hunk: "*** Add File: " filename LF add_line+
|
||||
delete_hunk: "*** Delete File: " filename LF
|
||||
update_hunk: "*** Update File: " filename LF change_move? change?
|
||||
|
||||
filename: /(.+)/
|
||||
add_line: "+" /(.*)/ LF -> line
|
||||
|
||||
change_move: "*** Move to: " filename LF
|
||||
change: (change_context | change_line)+ eof_line?
|
||||
change_context: ("@@" | "@@ " /(.+)/) LF
|
||||
change_line: ("+" | "-" | " ") /(.*)/ LF
|
||||
eof_line: "*** End of File" LF
|
||||
|
||||
%import common.LF
|
||||
@@ -8,6 +8,8 @@ use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
|
||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
pub(crate) input: String,
|
||||
@@ -29,27 +31,7 @@ pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
definition: r#"start: begin_patch hunk+ end_patch
|
||||
begin_patch: "*** Begin Patch" LF
|
||||
end_patch: "*** End Patch" LF?
|
||||
|
||||
hunk: add_hunk | delete_hunk | update_hunk
|
||||
add_hunk: "*** Add File: " filename LF add_line+
|
||||
delete_hunk: "*** Delete File: " filename LF
|
||||
update_hunk: "*** Update File: " filename LF change_move? change?
|
||||
|
||||
filename: /(.+)/
|
||||
add_line: "+" /(.+)/ LF -> line
|
||||
|
||||
change_move: "*** Move to: " filename LF
|
||||
change: (change_context | change_line)+ eof_line?
|
||||
change_context: ("@@" | "@@ " /(.+)/) LF
|
||||
change_line: ("+" | "-" | " ") /(.+)/ LF
|
||||
eof_line: "*** End of File" LF
|
||||
|
||||
%import common.LF
|
||||
"#
|
||||
.to_string(),
|
||||
definition: APPLY_PATCH_LARK_GRAMMAR.to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
180
codex-rs/core/src/truncate.rs
Normal file
180
codex-rs/core/src/truncate.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||
//! and suffix on UTF-8 boundaries.
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1;
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1;
|
||||
}
|
||||
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
let mut guess_tokens = est_tokens;
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::truncate_middle;
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||
let max_bytes = 32;
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
assert!(out.starts_with("abc"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("XYZ*"));
|
||||
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
let max_bytes = 64;
|
||||
let (out, tokens) = truncate_middle(&s, max_bytes);
|
||||
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||
assert_eq!(tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_handles_utf8_content() {
|
||||
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||
let max_bytes = 32;
|
||||
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(!out.contains('\u{fffd}'));
|
||||
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,7 @@ impl TurnDiffTracker {
|
||||
let baseline_file_info = if path.exists() {
|
||||
let mode = file_mode_for_path(path);
|
||||
let mode_val = mode.unwrap_or(FileMode::Regular);
|
||||
let content = blob_bytes(path, &mode_val).unwrap_or_default();
|
||||
let content = blob_bytes(path, mode_val).unwrap_or_default();
|
||||
let oid = if mode == Some(FileMode::Symlink) {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(&content))
|
||||
} else {
|
||||
@@ -266,7 +266,7 @@ impl TurnDiffTracker {
|
||||
};
|
||||
|
||||
let current_mode = file_mode_for_path(¤t_external_path).unwrap_or(FileMode::Regular);
|
||||
let right_bytes = blob_bytes(¤t_external_path, ¤t_mode);
|
||||
let right_bytes = blob_bytes(¤t_external_path, current_mode);
|
||||
|
||||
// Compute displays with &mut self before borrowing any baseline content.
|
||||
let left_display = self.relative_to_git_root_str(&baseline_external_path);
|
||||
@@ -388,7 +388,7 @@ enum FileMode {
|
||||
}
|
||||
|
||||
impl FileMode {
|
||||
fn as_str(&self) -> &'static str {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
FileMode::Regular => "100644",
|
||||
#[cfg(unix)]
|
||||
@@ -427,9 +427,9 @@ fn file_mode_for_path(_path: &Path) -> Option<FileMode> {
|
||||
Some(FileMode::Regular)
|
||||
}
|
||||
|
||||
fn blob_bytes(path: &Path, mode: &FileMode) -> Option<Vec<u8>> {
|
||||
fn blob_bytes(path: &Path, mode: FileMode) -> Option<Vec<u8>> {
|
||||
if path.exists() {
|
||||
let contents = if *mode == FileMode::Symlink {
|
||||
let contents = if mode == FileMode::Symlink {
|
||||
symlink_blob_bytes(path)
|
||||
.ok_or_else(|| anyhow!("failed to read symlink target for {}", path.display()))
|
||||
} else {
|
||||
@@ -578,7 +578,12 @@ index {ZERO_OID}..{right_oid}
|
||||
fs::write(&file, "x\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let del_changes = HashMap::from([(file.clone(), FileChange::Delete)]);
|
||||
let del_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Delete {
|
||||
content: "x\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&del_changes);
|
||||
|
||||
// Simulate apply: delete the file from disk.
|
||||
@@ -673,7 +678,7 @@ index {left_oid}..{right_oid}
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
src,
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
@@ -741,7 +746,12 @@ index {left_oid}..{right_oid}
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
|
||||
let del_b = HashMap::from([(b.clone(), FileChange::Delete)]);
|
||||
let del_b = HashMap::from([(
|
||||
b.clone(),
|
||||
FileChange::Delete {
|
||||
content: "z\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&del_b);
|
||||
// Simulate apply: delete b.txt.
|
||||
let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular);
|
||||
|
||||
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum UnifiedExecError {
|
||||
#[error("Failed to create unified exec session: {pty_error}")]
|
||||
CreateSession {
|
||||
#[source]
|
||||
pty_error: anyhow::Error,
|
||||
},
|
||||
#[error("Unknown session id {session_id}")]
|
||||
UnknownSessionId { session_id: i32 },
|
||||
#[error("failed to write to stdin")]
|
||||
WriteToStdin,
|
||||
#[error("missing command line for unified exec request")]
|
||||
MissingCommandLine,
|
||||
}
|
||||
|
||||
impl UnifiedExecError {
|
||||
pub(crate) fn create_session(error: anyhow::Error) -> Self {
|
||||
Self::CreateSession { pty_error: error }
|
||||
}
|
||||
}
|
||||
645
codex-rs/core/src/unified_exec/mod.rs
Normal file
645
codex-rs/core/src/unified_exec/mod.rs
Normal file
@@ -0,0 +1,645 @@
|
||||
use portable_pty::CommandBuilder;
|
||||
use portable_pty::PtySize;
|
||||
use portable_pty::native_pty_system;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI32;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::exec_command::ExecCommandSession;
|
||||
use crate::truncate::truncate_middle;
|
||||
|
||||
mod errors;
|
||||
|
||||
pub(crate) use errors::UnifiedExecError;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
const MAX_TIMEOUT_MS: u64 = 60_000;
|
||||
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecRequest<'a> {
|
||||
pub session_id: Option<i32>,
|
||||
pub input_chunks: &'a [String],
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct UnifiedExecResult {
|
||||
pub session_id: Option<i32>,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct UnifiedExecSessionManager {
|
||||
next_session_id: AtomicI32,
|
||||
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ManagedUnifiedExecSession {
|
||||
session: ExecCommandSession,
|
||||
output_buffer: OutputBuffer,
|
||||
/// Notifies waiters whenever new output has been appended to
|
||||
/// `output_buffer`, allowing clients to poll for fresh data.
|
||||
output_notify: Arc<Notify>,
|
||||
output_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct OutputBufferState {
|
||||
chunks: VecDeque<Vec<u8>>,
|
||||
total_bytes: usize,
|
||||
}
|
||||
|
||||
impl OutputBufferState {
|
||||
fn push_chunk(&mut self, chunk: Vec<u8>) {
|
||||
self.total_bytes = self.total_bytes.saturating_add(chunk.len());
|
||||
self.chunks.push_back(chunk);
|
||||
|
||||
let mut excess = self
|
||||
.total_bytes
|
||||
.saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
|
||||
while excess > 0 {
|
||||
match self.chunks.front_mut() {
|
||||
Some(front) if excess >= front.len() => {
|
||||
excess -= front.len();
|
||||
self.total_bytes = self.total_bytes.saturating_sub(front.len());
|
||||
self.chunks.pop_front();
|
||||
}
|
||||
Some(front) => {
|
||||
front.drain(..excess);
|
||||
self.total_bytes = self.total_bytes.saturating_sub(excess);
|
||||
break;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn drain(&mut self) -> Vec<Vec<u8>> {
|
||||
let drained: Vec<Vec<u8>> = self.chunks.drain(..).collect();
|
||||
self.total_bytes = 0;
|
||||
drained
|
||||
}
|
||||
}
|
||||
|
||||
type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||
|
||||
impl ManagedUnifiedExecSession {
|
||||
fn new(
|
||||
session: ExecCommandSession,
|
||||
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let mut receiver = initial_output_rx;
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
output_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.session.writer_sender()
|
||||
}
|
||||
|
||||
fn output_handles(&self) -> OutputHandles {
|
||||
(
|
||||
Arc::clone(&self.output_buffer),
|
||||
Arc::clone(&self.output_notify),
|
||||
)
|
||||
}
|
||||
|
||||
fn has_exited(&self) -> bool {
|
||||
self.session.has_exited()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ManagedUnifiedExecSession {
|
||||
fn drop(&mut self) {
|
||||
self.output_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedExecSessionManager {
|
||||
pub async fn handle_request(
|
||||
&self,
|
||||
request: UnifiedExecRequest<'_>,
|
||||
) -> Result<UnifiedExecResult, UnifiedExecError> {
|
||||
let (timeout_ms, timeout_warning) = match request.timeout_ms {
|
||||
Some(requested) if requested > MAX_TIMEOUT_MS => (
|
||||
MAX_TIMEOUT_MS,
|
||||
Some(format!(
|
||||
"Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n"
|
||||
)),
|
||||
),
|
||||
Some(requested) => (requested, None),
|
||||
None => (DEFAULT_TIMEOUT_MS, None),
|
||||
};
|
||||
|
||||
let mut new_session: Option<ManagedUnifiedExecSession> = None;
|
||||
let session_id;
|
||||
let writer_tx;
|
||||
let output_buffer;
|
||||
let output_notify;
|
||||
|
||||
if let Some(existing_id) = request.session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
match sessions.get(&existing_id) {
|
||||
Some(session) => {
|
||||
if session.has_exited() {
|
||||
sessions.remove(&existing_id);
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
let (buffer, notify) = session.output_handles();
|
||||
session_id = existing_id;
|
||||
writer_tx = session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
}
|
||||
None => {
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
drop(sessions);
|
||||
} else {
|
||||
let command = request.input_chunks.to_vec();
|
||||
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||
let (session, initial_output_rx) = create_unified_exec_session(&command).await?;
|
||||
let managed_session = ManagedUnifiedExecSession::new(session, initial_output_rx);
|
||||
let (buffer, notify) = managed_session.output_handles();
|
||||
writer_tx = managed_session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
session_id = new_id;
|
||||
new_session = Some(managed_session);
|
||||
};
|
||||
|
||||
if request.session_id.is_some() {
|
||||
let joined_input = request.input_chunks.join(" ");
|
||||
if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err()
|
||||
{
|
||||
return Err(UnifiedExecError::WriteToStdin);
|
||||
}
|
||||
}
|
||||
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let start = Instant::now();
|
||||
let deadline = start + Duration::from_millis(timeout_ms);
|
||||
|
||||
loop {
|
||||
let drained_chunks;
|
||||
let mut wait_for_output = None;
|
||||
{
|
||||
let mut guard = output_buffer.lock().await;
|
||||
drained_chunks = guard.drain();
|
||||
if drained_chunks.is_empty() {
|
||||
wait_for_output = Some(output_notify.notified());
|
||||
}
|
||||
}
|
||||
|
||||
if drained_chunks.is_empty() {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining == Duration::ZERO {
|
||||
break;
|
||||
}
|
||||
|
||||
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||
tokio::pin!(notified);
|
||||
tokio::select! {
|
||||
_ = &mut notified => {}
|
||||
_ = tokio::time::sleep(remaining) => break,
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for chunk in drained_chunks {
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (output, _maybe_tokens) = truncate_middle(
|
||||
&String::from_utf8_lossy(&collected),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
let output = if let Some(warning) = timeout_warning {
|
||||
format!("{warning}{output}")
|
||||
} else {
|
||||
output
|
||||
};
|
||||
|
||||
let should_store_session = if let Some(session) = new_session.as_ref() {
|
||||
!session.has_exited()
|
||||
} else if request.session_id.is_some() {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
if existing.has_exited() {
|
||||
sessions.remove(&session_id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_store_session {
|
||||
if let Some(session) = new_session {
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
}
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: Some(session_id),
|
||||
output,
|
||||
})
|
||||
} else {
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: None,
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_unified_exec_session(
|
||||
command: &[String],
|
||||
) -> Result<
|
||||
(
|
||||
ExecCommandSession,
|
||||
tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
),
|
||||
UnifiedExecError,
|
||||
> {
|
||||
if command.is_empty() {
|
||||
return Err(UnifiedExecError::MissingCommandLine);
|
||||
}
|
||||
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
let pair = pty_system
|
||||
.openpty(PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
|
||||
// Safe thanks to the check at the top of the function.
|
||||
let mut command_builder = CommandBuilder::new(command[0].clone());
|
||||
for arg in &command[1..] {
|
||||
command_builder.arg(arg);
|
||||
}
|
||||
|
||||
let mut child = pair
|
||||
.slave
|
||||
.spawn_command(command_builder)
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let killer = child.clone_killer();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||
|
||||
let mut reader = pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||
}
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
continue;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer = pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let writer = Arc::new(StdMutex::new(writer));
|
||||
let writer_handle = tokio::spawn({
|
||||
let writer = writer.clone();
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let writer = writer.clone();
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Ok(mut guard) = writer.lock() {
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = Arc::clone(&exit_status);
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let _ = child.wait();
|
||||
wait_exit_status.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
let (session, initial_output_rx) = ExecCommandSession::new(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer,
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
);
|
||||
Ok((session, initial_output_rx))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn push_chunk_trims_only_excess_bytes() {
|
||||
let mut buffer = OutputBufferState::default();
|
||||
buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]);
|
||||
buffer.push_chunk(vec![b'b']);
|
||||
buffer.push_chunk(vec![b'c']);
|
||||
|
||||
assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
assert_eq!(buffer.chunks.len(), 3);
|
||||
assert_eq!(
|
||||
buffer.chunks.front().unwrap().len(),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2
|
||||
);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session_id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_2.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let shell_a = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
let session_a = shell_a.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &[
|
||||
"echo".to_string(),
|
||||
"$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(10),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(7)).await;
|
||||
|
||||
let empty = Vec::new();
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &empty,
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignored while we have a better way to test this.
|
||||
async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(120_000),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.output.starts_with(
|
||||
"Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n"
|
||||
));
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignored while we have a better way to test this.
|
||||
async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.session_id.is_none());
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
assert!(manager.sessions.lock().await.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["exit\n".to_string()],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let err = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[],
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await
|
||||
.expect_err("expected unknown session error");
|
||||
|
||||
match err {
|
||||
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
|
||||
assert_eq!(err_id, session_id);
|
||||
}
|
||||
other => panic!("expected UnknownSessionId, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!manager.sessions.lock().await.contains_key(&session_id));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
|
||||
pub fn get_codex_user_agent(originator: Option<&str>) -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator.unwrap_or(DEFAULT_ORIGINATOR),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(re.is_match(&user_agent));
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user