mirror of
https://github.com/openai/codex.git
synced 2026-05-12 15:22:39 +00:00
Compare commits
72 Commits
ae/slow
...
codex/task
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0431598cbc | ||
|
|
41900e9d0f | ||
|
|
c1bde2a4ef | ||
|
|
6b0c486861 | ||
|
|
44ceaf085b | ||
|
|
c03e31ecf5 | ||
|
|
6915ba2100 | ||
|
|
50f53e7071 | ||
|
|
40fba1bb4c | ||
|
|
bdda762deb | ||
|
|
da5492694b | ||
|
|
a5d48a775b | ||
|
|
78f2785595 | ||
|
|
fc1723f131 | ||
|
|
ed5b0bfeb3 | ||
|
|
4b01f0f50a | ||
|
|
0139f6780c | ||
|
|
86ba270926 | ||
|
|
c146585cdb | ||
|
|
5fa7844ad7 | ||
|
|
84c9b574f9 | ||
|
|
272e13dd90 | ||
|
|
18d00e36b9 | ||
|
|
17550fee9e | ||
|
|
995f5c3614 | ||
|
|
9b53a306e3 | ||
|
|
0016346dfb | ||
|
|
f38ad65254 | ||
|
|
774892c6d7 | ||
|
|
897d4d5f17 | ||
|
|
8a281cd1f4 | ||
|
|
e8863b233b | ||
|
|
8fed0b53c4 | ||
|
|
00debb6399 | ||
|
|
0a0a10d8b3 | ||
|
|
13035561cd | ||
|
|
9be704a934 | ||
|
|
f7b4e29609 | ||
|
|
d6c5df9a0a | ||
|
|
8662162f45 | ||
|
|
57584d6f34 | ||
|
|
268a10f917 | ||
|
|
5346cc422d | ||
|
|
26f7c46856 | ||
|
|
90af046c5c | ||
|
|
961ed31901 | ||
|
|
85e7357973 | ||
|
|
f98fa85b44 | ||
|
|
ddcaf3dccd | ||
|
|
56296cad82 | ||
|
|
95b41dd7f1 | ||
|
|
bf82353f45 | ||
|
|
0308febc23 | ||
|
|
7b4a4c2219 | ||
|
|
3ddd4d47d0 | ||
|
|
ca6a0358de | ||
|
|
0026b12615 | ||
|
|
4300236681 | ||
|
|
ec238a2c39 | ||
|
|
b6165aee0c | ||
|
|
f4bc03d7c0 | ||
|
|
3c5e12e2a4 | ||
|
|
c89229db97 | ||
|
|
d3820f4782 | ||
|
|
e896db1180 | ||
|
|
96acb8a74e | ||
|
|
687a13bbe5 | ||
|
|
fe8122e514 | ||
|
|
876d4f450a | ||
|
|
f52320be86 | ||
|
|
a43ae86b6c | ||
|
|
496cb801e1 |
22
.github/ISSUE_TEMPLATE/2-bug-report.yml
vendored
22
.github/ISSUE_TEMPLATE/2-bug-report.yml
vendored
@@ -20,6 +20,14 @@ body:
|
||||
attributes:
|
||||
label: What version of Codex is running?
|
||||
description: Copy the output of `codex --version`
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: plan
|
||||
attributes:
|
||||
label: What subscription do you have?
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: model
|
||||
attributes:
|
||||
@@ -32,11 +40,18 @@ body:
|
||||
description: |
|
||||
For MacOS and Linux: copy the output of `uname -mprs`
|
||||
For Windows: copy the output of `"$([Environment]::OSVersion | ForEach-Object VersionString) $(if ([Environment]::Is64BitOperatingSystem) { "x64" } else { "x86" })"` in the PowerShell console
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: What issue are you seeing?
|
||||
description: Please include the full error messages and prompts with PII redacted. If possible, please provide text instead of a screenshot.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: What steps can reproduce the bug?
|
||||
description: Explain the bug and provide a code snippet that can reproduce it.
|
||||
description: Explain the bug and provide a code snippet that can reproduce it. Please include session id, token limit usage, context window usage if applicable.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -44,11 +59,6 @@ body:
|
||||
attributes:
|
||||
label: What is the expected behavior?
|
||||
description: If possible, please provide text instead of a screenshot.
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: What do you see instead?
|
||||
description: If possible, please provide text instead of a screenshot.
|
||||
- type: textarea
|
||||
id: notes
|
||||
attributes:
|
||||
|
||||
24
.github/ISSUE_TEMPLATE/5-vs-code-extension.yml
vendored
24
.github/ISSUE_TEMPLATE/5-vs-code-extension.yml
vendored
@@ -14,11 +14,21 @@ body:
|
||||
id: version
|
||||
attributes:
|
||||
label: What version of the VS Code extension are you using?
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: plan
|
||||
attributes:
|
||||
label: What subscription do you have?
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: ide
|
||||
attributes:
|
||||
label: Which IDE are you using?
|
||||
description: Like `VS Code`, `Cursor`, `Windsurf`, etc.
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: platform
|
||||
attributes:
|
||||
@@ -26,11 +36,18 @@ body:
|
||||
description: |
|
||||
For MacOS and Linux: copy the output of `uname -mprs`
|
||||
For Windows: copy the output of `"$([Environment]::OSVersion | ForEach-Object VersionString) $(if ([Environment]::Is64BitOperatingSystem) { "x64" } else { "x86" })"` in the PowerShell console
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: What issue are you seeing?
|
||||
description: Please include the full error messages and prompts with PII redacted. If possible, please provide text instead of a screenshot.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: What steps can reproduce the bug?
|
||||
description: Explain the bug and provide a code snippet that can reproduce it.
|
||||
description: Explain the bug and provide a code snippet that can reproduce it. Please include session id, token limit usage, context window usage if applicable.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -38,11 +55,6 @@ body:
|
||||
attributes:
|
||||
label: What is the expected behavior?
|
||||
description: If possible, please provide text instead of a screenshot.
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: What do you see instead?
|
||||
description: If possible, please provide text instead of a screenshot.
|
||||
- type: textarea
|
||||
id: notes
|
||||
attributes:
|
||||
|
||||
42
.github/workflows/rust-ci.yml
vendored
42
.github/workflows/rust-ci.yml
vendored
@@ -148,15 +148,26 @@ jobs:
|
||||
targets: ${{ matrix.target }}
|
||||
components: clippy
|
||||
|
||||
- uses: actions/cache@v4
|
||||
# Explicit cache restore: split cargo home vs target, so we can
|
||||
# avoid caching the large target dir on the gnu-dev job.
|
||||
- name: Restore cargo home cache
|
||||
id: cache_cargo_home_restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Restore target cache (except gnu-dev)
|
||||
id: cache_target_restore
|
||||
if: ${{ !(matrix.target == 'x86_64-unknown-linux-gnu' && matrix.profile != 'release') }}
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-target-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
@@ -194,6 +205,31 @@ jobs:
|
||||
env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
# Save caches explicitly; make non-fatal so cache packaging
|
||||
# never fails the overall job. Only save when key wasn't hit.
|
||||
- name: Save cargo home cache
|
||||
if: always() && !cancelled() && steps.cache_cargo_home_restore.outputs.cache-hit != 'true'
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Save target cache (except gnu-dev)
|
||||
if: >-
|
||||
always() && !cancelled() &&
|
||||
(steps.cache_target_restore.outputs.cache-hit != 'true') &&
|
||||
!(matrix.target == 'x86_64-unknown-linux-gnu' && matrix.profile != 'release')
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-target-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
# Fail the job if any of the previous steps failed.
|
||||
- name: verify all steps passed
|
||||
if: |
|
||||
|
||||
197
.github/workflows/rust-release.yml
vendored
197
.github/workflows/rust-release.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
|
||||
build:
|
||||
needs: tag-check
|
||||
name: ${{ matrix.runner }} - ${{ matrix.target }}
|
||||
name: Build - ${{ matrix.runner }} - ${{ matrix.target }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
@@ -94,11 +94,181 @@ jobs:
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
run: |
|
||||
sudo apt install -y musl-tools pkg-config
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
|
||||
- name: Cargo build
|
||||
run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-responses-api-proxy
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Configure Apple code signing
|
||||
shell: bash
|
||||
env:
|
||||
KEYCHAIN_PASSWORD: actions
|
||||
APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE_P12 }}
|
||||
APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE_PASSWORD:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE_PASSWORD is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cert_path="${RUNNER_TEMP}/apple_signing_certificate.p12"
|
||||
echo "$APPLE_CERTIFICATE" | base64 -d > "$cert_path"
|
||||
|
||||
keychain_path="${RUNNER_TEMP}/codex-signing.keychain-db"
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
security set-keychain-settings -lut 21600 "$keychain_path"
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
|
||||
keychain_args=()
|
||||
cleanup_keychain() {
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}" || true
|
||||
security default-keychain -s "${keychain_args[0]}" || true
|
||||
else
|
||||
security list-keychains -s || true
|
||||
fi
|
||||
if [[ -f "$keychain_path" ]]; then
|
||||
security delete-keychain "$keychain_path" || true
|
||||
fi
|
||||
}
|
||||
|
||||
while IFS= read -r keychain; do
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "$keychain_path" "${keychain_args[@]}"
|
||||
else
|
||||
security list-keychains -s "$keychain_path"
|
||||
fi
|
||||
|
||||
security default-keychain -s "$keychain_path"
|
||||
security import "$cert_path" -k "$keychain_path" -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign -T /usr/bin/security
|
||||
security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PASSWORD" "$keychain_path" > /dev/null
|
||||
|
||||
codesign_hashes=()
|
||||
while IFS= read -r hash; do
|
||||
[[ -n "$hash" ]] && codesign_hashes+=("$hash")
|
||||
done < <(security find-identity -v -p codesigning "$keychain_path" \
|
||||
| sed -n 's/.*\([0-9A-F]\{40\}\).*/\1/p' \
|
||||
| sort -u)
|
||||
|
||||
if ((${#codesign_hashes[@]} == 0)); then
|
||||
echo "No signing identities found in $keychain_path"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ((${#codesign_hashes[@]} > 1)); then
|
||||
echo "Multiple signing identities found in $keychain_path:"
|
||||
printf ' %s\n' "${codesign_hashes[@]}"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
APPLE_CODESIGN_IDENTITY="${codesign_hashes[0]}"
|
||||
|
||||
rm -f "$cert_path"
|
||||
|
||||
echo "APPLE_CODESIGN_IDENTITY=$APPLE_CODESIGN_IDENTITY" >> "$GITHUB_ENV"
|
||||
echo "APPLE_CODESIGN_KEYCHAIN=$keychain_path" >> "$GITHUB_ENV"
|
||||
echo "::add-mask::$APPLE_CODESIGN_IDENTITY"
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Sign macOS binaries
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CODESIGN_IDENTITY:-}" ]]; then
|
||||
echo "APPLE_CODESIGN_IDENTITY is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
keychain_args=()
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" && -f "${APPLE_CODESIGN_KEYCHAIN}" ]]; then
|
||||
keychain_args+=(--keychain "${APPLE_CODESIGN_KEYCHAIN}")
|
||||
fi
|
||||
|
||||
for binary in codex codex-responses-api-proxy; do
|
||||
path="target/${{ matrix.target }}/release/${binary}"
|
||||
codesign --force --options runtime --timestamp --sign "$APPLE_CODESIGN_IDENTITY" "${keychain_args[@]}" "$path"
|
||||
done
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Notarize macOS binaries
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_NOTARIZATION_KEY_P8: ${{ secrets.APPLE_NOTARIZATION_KEY_P8 }}
|
||||
APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
|
||||
APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
for var in APPLE_NOTARIZATION_KEY_P8 APPLE_NOTARIZATION_KEY_ID APPLE_NOTARIZATION_ISSUER_ID; do
|
||||
if [[ -z "${!var:-}" ]]; then
|
||||
echo "$var is required for notarization"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
notary_key_path="${RUNNER_TEMP}/notarytool.key.p8"
|
||||
echo "$APPLE_NOTARIZATION_KEY_P8" | base64 -d > "$notary_key_path"
|
||||
cleanup_notary() {
|
||||
rm -f "$notary_key_path"
|
||||
}
|
||||
trap cleanup_notary EXIT
|
||||
|
||||
notarize_binary() {
|
||||
local binary="$1"
|
||||
local source_path="target/${{ matrix.target }}/release/${binary}"
|
||||
local archive_path="${RUNNER_TEMP}/${binary}.zip"
|
||||
|
||||
if [[ ! -f "$source_path" ]]; then
|
||||
echo "Binary $source_path not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -f "$archive_path"
|
||||
ditto -c -k --keepParent "$source_path" "$archive_path"
|
||||
|
||||
submission_json=$(xcrun notarytool submit "$archive_path" \
|
||||
--key "$notary_key_path" \
|
||||
--key-id "$APPLE_NOTARIZATION_KEY_ID" \
|
||||
--issuer "$APPLE_NOTARIZATION_ISSUER_ID" \
|
||||
--output-format json \
|
||||
--wait)
|
||||
|
||||
status=$(printf '%s\n' "$submission_json" | jq -r '.status // "Unknown"')
|
||||
submission_id=$(printf '%s\n' "$submission_json" | jq -r '.id // ""')
|
||||
|
||||
if [[ -z "$submission_id" ]]; then
|
||||
echo "Failed to retrieve submission ID for $binary"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "::notice title=Notarization::$binary submission ${submission_id} completed with status ${status}"
|
||||
|
||||
if [[ "$status" != "Accepted" ]]; then
|
||||
echo "Notarization failed for ${binary} (submission ${submission_id}, status ${status})"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
notarize_binary "codex"
|
||||
notarize_binary "codex-responses-api-proxy"
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -157,6 +327,29 @@ jobs:
|
||||
zstd -T0 -19 --rm "$dest/$base"
|
||||
done
|
||||
|
||||
- name: Remove signing keychain
|
||||
if: ${{ always() && matrix.runner == 'macos-14' }}
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_CODESIGN_KEYCHAIN: ${{ env.APPLE_CODESIGN_KEYCHAIN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" ]]; then
|
||||
keychain_args=()
|
||||
while IFS= read -r keychain; do
|
||||
[[ "$keychain" == "$APPLE_CODESIGN_KEYCHAIN" ]] && continue
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}"
|
||||
security default-keychain -s "${keychain_args[0]}"
|
||||
fi
|
||||
|
||||
if [[ -f "$APPLE_CODESIGN_KEYCHAIN" ]]; then
|
||||
security delete-keychain "$APPLE_CODESIGN_KEYCHAIN"
|
||||
fi
|
||||
fi
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.target }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,6 +30,7 @@ result
|
||||
# cli tools
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
AGENTS.override.md
|
||||
|
||||
# caches
|
||||
.cache/
|
||||
|
||||
@@ -12,6 +12,7 @@ In the codex-rs folder where the rust code lives:
|
||||
- Always inline format! args when possible per https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args
|
||||
- Use method references over closures when possible per https://rust-lang.github.io/rust-clippy/master/index.html#redundant_closure_for_method_calls
|
||||
- When writing tests, prefer comparing the equality of entire objects over fields one by one.
|
||||
- When making a change that adds or changes an API, ensure that the documentation in the `docs/` folder is up to date if applicable.
|
||||
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
|
||||
|
||||
35
codex-cli/bin/codex.js
Executable file → Normal file
35
codex-cli/bin/codex.js
Executable file → Normal file
@@ -80,6 +80,32 @@ function getUpdatedPath(newDirs) {
|
||||
return updatedPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use heuristics to detect the package manager that was used to install Codex
|
||||
* in order to give the user a hint about how to update it.
|
||||
*/
|
||||
function detectPackageManager() {
|
||||
const userAgent = process.env.npm_config_user_agent || "";
|
||||
if (/\bbun\//.test(userAgent)) {
|
||||
return "bun";
|
||||
}
|
||||
|
||||
const execPath = process.env.npm_execpath || "";
|
||||
if (execPath.includes("bun")) {
|
||||
return "bun";
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.BUN_INSTALL ||
|
||||
process.env.BUN_INSTALL_GLOBAL_DIR ||
|
||||
process.env.BUN_INSTALL_BIN_DIR
|
||||
) {
|
||||
return "bun";
|
||||
}
|
||||
|
||||
return userAgent ? "npm" : null;
|
||||
}
|
||||
|
||||
const additionalDirs = [];
|
||||
const pathDir = path.join(archRoot, "path");
|
||||
if (existsSync(pathDir)) {
|
||||
@@ -87,9 +113,16 @@ if (existsSync(pathDir)) {
|
||||
}
|
||||
const updatedPath = getUpdatedPath(additionalDirs);
|
||||
|
||||
const env = { ...process.env, PATH: updatedPath };
|
||||
const packageManagerEnvVar =
|
||||
detectPackageManager() === "bun"
|
||||
? "CODEX_MANAGED_BY_BUN"
|
||||
: "CODEX_MANAGED_BY_NPM";
|
||||
env[packageManagerEnvVar] = "1";
|
||||
|
||||
const child = spawn(binaryPath, process.argv.slice(2), {
|
||||
stdio: "inherit",
|
||||
env: { ...process.env, PATH: updatedPath, CODEX_MANAGED_BY_NPM: "1" },
|
||||
env,
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
|
||||
311
codex-rs/Cargo.lock
generated
311
codex-rs/Cargo.lock
generated
@@ -899,6 +899,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-async-utils"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"pretty_assertions",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-backend-client"
|
||||
version = "0.0.0"
|
||||
@@ -1037,6 +1047,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"codex-app-server-protocol",
|
||||
"codex-apply-patch",
|
||||
"codex-async-utils",
|
||||
"codex-file-search",
|
||||
"codex-mcp-client",
|
||||
"codex-otel",
|
||||
@@ -1073,6 +1084,7 @@ dependencies = [
|
||||
"similar",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"test-log",
|
||||
"thiserror 2.0.16",
|
||||
"time",
|
||||
"tokio",
|
||||
@@ -1144,6 +1156,17 @@ dependencies = [
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-feedback"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"codex-protocol",
|
||||
"pretty_assertions",
|
||||
"sentry",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-file-search"
|
||||
version = "0.0.0"
|
||||
@@ -1352,7 +1375,9 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"codex-protocol",
|
||||
"dirs",
|
||||
"escargot",
|
||||
"futures",
|
||||
"keyring",
|
||||
"mcp-types",
|
||||
@@ -1362,6 +1387,7 @@ dependencies = [
|
||||
"rmcp",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"tiny_http",
|
||||
@@ -1387,6 +1413,7 @@ dependencies = [
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-feedback",
|
||||
"codex-file-search",
|
||||
"codex-git-tooling",
|
||||
"codex-login",
|
||||
@@ -1421,6 +1448,7 @@ dependencies = [
|
||||
"textwrap 0.16.2",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
@@ -1576,10 +1604,12 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"notify",
|
||||
"regex-lite",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"walkdir",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
@@ -1820,6 +1850,16 @@ version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugserver-types"
|
||||
version = "0.5.0"
|
||||
@@ -2298,6 +2338,18 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "findshlibs"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40b9e59cd0f7e0806cca4be089683ecb6434e602038df21fe6bf6711b2f07f64"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixed_decimal"
|
||||
version = "0.7.0"
|
||||
@@ -2370,6 +2422,15 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fsevent-sys"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
@@ -2657,6 +2718,17 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a56f203cd1c76362b69e3863fd987520ac36cf70a8c92627449b2f64a8cf7d65"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.3.1"
|
||||
@@ -3056,6 +3128,26 @@ version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "inotify"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"inotify-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inotify-sys"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.4"
|
||||
@@ -3256,6 +3348,26 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kqueue"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a"
|
||||
dependencies = [
|
||||
"kqueue-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kqueue-sys"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lalrpop"
|
||||
version = "0.19.12"
|
||||
@@ -3655,6 +3767,30 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be"
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "8.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"fsevent-sys",
|
||||
"inotify",
|
||||
"kqueue",
|
||||
"libc",
|
||||
"log",
|
||||
"mio",
|
||||
"notify-types",
|
||||
"walkdir",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify-types"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e0826a989adedc2a244799e823aece04662b66609d96af8dff7ac6df9a8925d"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.1"
|
||||
@@ -4772,9 +4908,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583d060e99feb3a3683fb48a1e4bf5f8d4a50951f429726f330ee5ff548837f8"
|
||||
checksum = "6f35acda8f89fca5fd8c96cae3c6d5b4c38ea0072df4c8030915f3b5ff469c1c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
@@ -4806,9 +4942,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "421d8b0ba302f479214889486f9550e63feca3af310f1190efcf6e2016802693"
|
||||
checksum = "c9f1d5220aaa23b79c3d02e18f7a554403b3ccea544bbb6c69d6bcb3e854a274"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
@@ -4829,6 +4965,15 @@ version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@@ -5143,6 +5288,120 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
|
||||
|
||||
[[package]]
|
||||
name = "sentry"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5484316556650182f03b43d4c746ce0e3e48074a21e2f51244b648b6542e1066"
|
||||
dependencies = [
|
||||
"httpdate",
|
||||
"native-tls",
|
||||
"reqwest",
|
||||
"sentry-backtrace",
|
||||
"sentry-contexts",
|
||||
"sentry-core",
|
||||
"sentry-debug-images",
|
||||
"sentry-panic",
|
||||
"sentry-tracing",
|
||||
"tokio",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-backtrace"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40aa225bb41e2ec9d7c90886834367f560efc1af028f1c5478a6cce6a59c463a"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-contexts"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a8dd746da3d16cb8c39751619cefd4fcdbd6df9610f3310fd646b55f6e39910"
|
||||
dependencies = [
|
||||
"hostname",
|
||||
"libc",
|
||||
"os_info",
|
||||
"rustc_version",
|
||||
"sentry-core",
|
||||
"uname",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-core"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "161283cfe8e99c8f6f236a402b9ccf726b201f365988b5bb637ebca0abbd4a30"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"sentry-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-debug-images"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc6b25e945fcaa5e97c43faee0267eebda9f18d4b09a251775d8fef1086238a"
|
||||
dependencies = [
|
||||
"findshlibs",
|
||||
"once_cell",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-panic"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc74f229c7186dd971a9491ffcbe7883544aa064d1589bd30b83fb856cd22d63"
|
||||
dependencies = [
|
||||
"sentry-backtrace",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-tracing"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd3c5faf2103cd01eeda779ea439b68c4ee15adcdb16600836e97feafab362ec"
|
||||
dependencies = [
|
||||
"sentry-backtrace",
|
||||
"sentry-core",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-types"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d68cdf6bc41b8ff3ae2a9c4671e97426dcdd154cc1d4b6b72813f285d6b163f"
|
||||
dependencies = [
|
||||
"debugid",
|
||||
"hex",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"time",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.226"
|
||||
@@ -5775,6 +6034,28 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
||||
|
||||
[[package]]
|
||||
name = "test-log"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
|
||||
dependencies = [
|
||||
"env_logger",
|
||||
"test-log-macros",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-log-macros"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
@@ -6350,6 +6631,15 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uname"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b72f89f0ca32e4db1c04e2a72f5345d59796d4866a1ee0609084569f73683dc8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
@@ -6409,6 +6699,19 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.4"
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
members = [
|
||||
"backend-client",
|
||||
"ansi-escape",
|
||||
"async-utils",
|
||||
"app-server",
|
||||
"app-server-protocol",
|
||||
"apply-patch",
|
||||
"arg0",
|
||||
"feedback",
|
||||
"codex-backend-openapi-models",
|
||||
"cloud-tasks",
|
||||
"cloud-tasks-client",
|
||||
@@ -55,7 +57,9 @@ codex-arg0 = { path = "arg0" }
|
||||
codex-chatgpt = { path = "chatgpt" }
|
||||
codex-common = { path = "common" }
|
||||
codex-core = { path = "core" }
|
||||
codex-async-utils = { path = "async-utils" }
|
||||
codex-exec = { path = "exec" }
|
||||
codex-feedback = { path = "feedback" }
|
||||
codex-file-search = { path = "file-search" }
|
||||
codex-git-tooling = { path = "git-tooling" }
|
||||
codex-linux-sandbox = { path = "linux-sandbox" }
|
||||
@@ -83,8 +87,8 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_matches = "1.5.0"
|
||||
assert_cmd = "2"
|
||||
assert_matches = "1.5.0"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.89"
|
||||
@@ -122,6 +126,7 @@ log = "0.4"
|
||||
maplit = "1.0.2"
|
||||
mime_guess = "2.0.5"
|
||||
multimap = "0.10.0"
|
||||
notify = "8.2.0"
|
||||
nucleo-matcher = "0.3.1"
|
||||
openssl-sys = "*"
|
||||
opentelemetry = "0.30.0"
|
||||
@@ -146,6 +151,7 @@ reqwest = "0.12"
|
||||
rmcp = { version = "0.8.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
sentry = "0.34.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
@@ -160,6 +166,7 @@ strum_macros = "0.27.2"
|
||||
supports-color = "3.0.2"
|
||||
sys-locale = "0.3.2"
|
||||
tempfile = "3.23.0"
|
||||
test-log = "0.2.18"
|
||||
textwrap = "0.16.2"
|
||||
thiserror = "2.0.16"
|
||||
time = "0.3"
|
||||
|
||||
@@ -3,11 +3,30 @@ use ansi_to_tui::IntoText;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::text::Text;
|
||||
|
||||
// Expand tabs in a best-effort way for transcript rendering.
|
||||
// Tabs can interact poorly with left-gutter prefixes in our TUI and CLI
|
||||
// transcript views (e.g., `nl` separates line numbers from content with a tab).
|
||||
// Replacing tabs with spaces avoids odd visual artifacts without changing
|
||||
// semantics for our use cases.
|
||||
fn expand_tabs(s: &str) -> std::borrow::Cow<'_, str> {
|
||||
if s.contains('\t') {
|
||||
// Keep it simple: replace each tab with 4 spaces.
|
||||
// We do not try to align to tab stops since most usages (like `nl`)
|
||||
// look acceptable with a fixed substitution and this avoids stateful math
|
||||
// across spans.
|
||||
std::borrow::Cow::Owned(s.replace('\t', " "))
|
||||
} else {
|
||||
std::borrow::Cow::Borrowed(s)
|
||||
}
|
||||
}
|
||||
|
||||
/// This function should be used when the contents of `s` are expected to match
|
||||
/// a single line. If multiple lines are found, a warning is logged and only the
|
||||
/// first line is returned.
|
||||
pub fn ansi_escape_line(s: &str) -> Line<'static> {
|
||||
let text = ansi_escape(s);
|
||||
// Normalize tabs to spaces to avoid odd gutter collisions in transcript mode.
|
||||
let s = expand_tabs(s);
|
||||
let text = ansi_escape(&s);
|
||||
match text.lines.as_slice() {
|
||||
[] => "".into(),
|
||||
[only] => only.clone(),
|
||||
|
||||
@@ -9,6 +9,7 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::FileChange;
|
||||
@@ -697,6 +698,7 @@ pub struct ExecCommandApprovalParams {
|
||||
pub cwd: PathBuf,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
pub parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
@@ -904,6 +906,9 @@ mod tests {
|
||||
command: vec!["echo".to_string(), "hello".to_string()],
|
||||
cwd: PathBuf::from("/tmp"),
|
||||
reason: Some("because tests".to_string()),
|
||||
parsed_cmd: vec![ParsedCommand::Unknown {
|
||||
cmd: "echo hello".to_string(),
|
||||
}],
|
||||
};
|
||||
let request = ServerRequest::ExecCommandApproval {
|
||||
request_id: RequestId::Integer(7),
|
||||
@@ -920,6 +925,12 @@ mod tests {
|
||||
"command": ["echo", "hello"],
|
||||
"cwd": "/tmp",
|
||||
"reason": "because tests",
|
||||
"parsedCmd": [
|
||||
{
|
||||
"type": "unknown",
|
||||
"cmd": "echo hello"
|
||||
}
|
||||
]
|
||||
}
|
||||
}),
|
||||
serde_json::to_value(&request)?,
|
||||
|
||||
@@ -1284,6 +1284,7 @@ async fn apply_bespoke_event_handling(
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
}) => {
|
||||
let params = ExecCommandApprovalParams {
|
||||
conversation_id,
|
||||
@@ -1291,6 +1292,7 @@ async fn apply_bespoke_event_handling(
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
};
|
||||
let rx = outgoing
|
||||
.send_request(ServerRequestPayload::ExecCommandApproval(params))
|
||||
|
||||
@@ -27,6 +27,7 @@ use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_core::protocol_config_types::ReasoningSummary;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::InputMessageKind;
|
||||
@@ -311,6 +312,9 @@ async fn test_send_user_turn_changes_approval_policy_behavior() {
|
||||
],
|
||||
cwd: working_directory.clone(),
|
||||
reason: None,
|
||||
parsed_cmd: vec![ParsedCommand::Unknown {
|
||||
cmd: "python3 -c 'print(42)'".to_string()
|
||||
}],
|
||||
},
|
||||
params
|
||||
);
|
||||
|
||||
15
codex-rs/async-utils/Cargo.toml
Normal file
15
codex-rs/async-utils/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
name = "codex-async-utils"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "time"] }
|
||||
tokio-util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
86
codex-rs/async-utils/src/lib.rs
Normal file
86
codex-rs/async-utils/src/lib.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use async_trait::async_trait;
|
||||
use std::future::Future;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum CancelErr {
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait OrCancelExt: Sized {
|
||||
type Output;
|
||||
|
||||
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F> OrCancelExt for F
|
||||
where
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr> {
|
||||
tokio::select! {
|
||||
_ = token.cancelled() => Err(CancelErr::Cancelled),
|
||||
res = self => Ok(res),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
use tokio::task;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_ok_when_future_completes_first() {
|
||||
let token = CancellationToken::new();
|
||||
let value = async { 42 };
|
||||
|
||||
let result = value.or_cancel(&token).await;
|
||||
|
||||
assert_eq!(Ok(42), result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_err_when_token_cancelled_first() {
|
||||
let token = CancellationToken::new();
|
||||
let token_clone = token.clone();
|
||||
|
||||
let cancel_handle = task::spawn(async move {
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
token_clone.cancel();
|
||||
});
|
||||
|
||||
let result = async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
7
|
||||
}
|
||||
.or_cancel(&token)
|
||||
.await;
|
||||
|
||||
cancel_handle.await.expect("cancel task panicked");
|
||||
assert_eq!(Err(CancelErr::Cancelled), result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_err_when_token_already_cancelled() {
|
||||
let token = CancellationToken::new();
|
||||
token.cancel();
|
||||
|
||||
let result = async {
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
5
|
||||
}
|
||||
.or_cancel(&token)
|
||||
.await;
|
||||
|
||||
assert_eq!(Err(CancelErr::Cancelled), result);
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ use codex_exec::Cli as ExecCli;
|
||||
use codex_responses_api_proxy::Args as ResponsesApiProxyArgs;
|
||||
use codex_tui::AppExitInfo;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
use codex_tui::UpdateAction;
|
||||
use owo_colors::OwoColorize;
|
||||
use std::path::PathBuf;
|
||||
use supports_color::Stream;
|
||||
@@ -26,6 +27,8 @@ use supports_color::Stream;
|
||||
mod mcp_cmd;
|
||||
|
||||
use crate::mcp_cmd::McpCli;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
|
||||
/// Codex CLI
|
||||
///
|
||||
@@ -45,6 +48,9 @@ struct MultitoolCli {
|
||||
#[clap(flatten)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub feature_toggles: FeatureToggles,
|
||||
|
||||
#[clap(flatten)]
|
||||
interactive: TuiCli,
|
||||
|
||||
@@ -97,6 +103,9 @@ enum Subcommand {
|
||||
/// Internal: run the responses API proxy.
|
||||
#[clap(hide = true)]
|
||||
ResponsesApiProxy(ResponsesApiProxyArgs),
|
||||
|
||||
/// Inspect feature flags.
|
||||
Features(FeaturesCli),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -157,7 +166,7 @@ struct LoginCommand {
|
||||
)]
|
||||
api_key: Option<String>,
|
||||
|
||||
#[arg(long = "use-device-code")]
|
||||
#[arg(long = "device-auth")]
|
||||
use_device_code: bool,
|
||||
|
||||
/// EXPERIMENTAL: Use custom OAuth issuer base URL (advanced)
|
||||
@@ -200,6 +209,7 @@ fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec<Stri
|
||||
let AppExitInfo {
|
||||
token_usage,
|
||||
conversation_id,
|
||||
..
|
||||
} = exit_info;
|
||||
|
||||
if token_usage.is_zero() {
|
||||
@@ -224,11 +234,79 @@ fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec<Stri
|
||||
lines
|
||||
}
|
||||
|
||||
fn print_exit_messages(exit_info: AppExitInfo) {
|
||||
/// Handle the app exit and print the results. Optionally run the update action.
|
||||
fn handle_app_exit(exit_info: AppExitInfo) -> anyhow::Result<()> {
|
||||
let update_action = exit_info.update_action;
|
||||
let color_enabled = supports_color::on(Stream::Stdout).is_some();
|
||||
for line in format_exit_messages(exit_info, color_enabled) {
|
||||
println!("{line}");
|
||||
}
|
||||
if let Some(action) = update_action {
|
||||
run_update_action(action)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the update action and print the result.
|
||||
fn run_update_action(action: UpdateAction) -> anyhow::Result<()> {
|
||||
println!();
|
||||
let (cmd, args) = action.command_args();
|
||||
let cmd_str = action.command_str();
|
||||
println!("Updating Codex via `{cmd_str}`...");
|
||||
let status = std::process::Command::new(cmd).args(args).status()?;
|
||||
if !status.success() {
|
||||
anyhow::bail!("`{cmd_str}` failed with status {status}");
|
||||
}
|
||||
println!();
|
||||
println!("🎉 Update ran successfully! Please restart Codex.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Parser, Clone)]
|
||||
struct FeatureToggles {
|
||||
/// Enable a feature (repeatable). Equivalent to `-c features.<name>=true`.
|
||||
#[arg(long = "enable", value_name = "FEATURE", action = clap::ArgAction::Append, global = true)]
|
||||
enable: Vec<String>,
|
||||
|
||||
/// Disable a feature (repeatable). Equivalent to `-c features.<name>=false`.
|
||||
#[arg(long = "disable", value_name = "FEATURE", action = clap::ArgAction::Append, global = true)]
|
||||
disable: Vec<String>,
|
||||
}
|
||||
|
||||
impl FeatureToggles {
|
||||
fn to_overrides(&self) -> Vec<String> {
|
||||
let mut v = Vec::new();
|
||||
for k in &self.enable {
|
||||
v.push(format!("features.{k}=true"));
|
||||
}
|
||||
for k in &self.disable {
|
||||
v.push(format!("features.{k}=false"));
|
||||
}
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct FeaturesCli {
|
||||
#[command(subcommand)]
|
||||
sub: FeaturesSubcommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
enum FeaturesSubcommand {
|
||||
/// List known features with their stage and effective state.
|
||||
List,
|
||||
}
|
||||
|
||||
fn stage_str(stage: codex_core::features::Stage) -> &'static str {
|
||||
use codex_core::features::Stage;
|
||||
match stage {
|
||||
Stage::Experimental => "experimental",
|
||||
Stage::Beta => "beta",
|
||||
Stage::Stable => "stable",
|
||||
Stage::Deprecated => "deprecated",
|
||||
Stage::Removed => "removed",
|
||||
}
|
||||
}
|
||||
|
||||
/// As early as possible in the process lifecycle, apply hardening measures. We
|
||||
@@ -248,11 +326,17 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let MultitoolCli {
|
||||
config_overrides: root_config_overrides,
|
||||
config_overrides: mut root_config_overrides,
|
||||
feature_toggles,
|
||||
mut interactive,
|
||||
subcommand,
|
||||
} = MultitoolCli::parse();
|
||||
|
||||
// Fold --enable/--disable into config overrides so they flow to all subcommands.
|
||||
root_config_overrides
|
||||
.raw_overrides
|
||||
.extend(feature_toggles.to_overrides());
|
||||
|
||||
match subcommand {
|
||||
None => {
|
||||
prepend_config_flags(
|
||||
@@ -260,7 +344,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
handle_app_exit(exit_info)?;
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -293,7 +377,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
config_overrides,
|
||||
);
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
handle_app_exit(exit_info)?;
|
||||
}
|
||||
Some(Subcommand::Login(mut login_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -381,6 +465,30 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
Some(Subcommand::GenerateTs(gen_cli)) => {
|
||||
codex_protocol_ts::generate_ts(&gen_cli.out_dir, gen_cli.prettier.as_deref())?;
|
||||
}
|
||||
Some(Subcommand::Features(FeaturesCli { sub })) => match sub {
|
||||
FeaturesSubcommand::List => {
|
||||
// Respect root-level `-c` overrides plus top-level flags like `--profile`.
|
||||
let cli_kv_overrides = root_config_overrides
|
||||
.parse_overrides()
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
// Thread through relevant top-level flags (at minimum, `--profile`).
|
||||
// Also honor `--search` since it maps to a feature toggle.
|
||||
let overrides = ConfigOverrides {
|
||||
config_profile: interactive.config_profile.clone(),
|
||||
tools_web_search_request: interactive.web_search.then_some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides).await?;
|
||||
for def in codex_core::features::FEATURES.iter() {
|
||||
let name = def.key;
|
||||
let stage = stage_str(def.stage);
|
||||
let enabled = config.features.enabled(def.id);
|
||||
println!("{name}\t{stage}\t{enabled}");
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -484,6 +592,7 @@ mod tests {
|
||||
interactive,
|
||||
config_overrides: root_overrides,
|
||||
subcommand,
|
||||
feature_toggles: _,
|
||||
} = cli;
|
||||
|
||||
let Subcommand::Resume(ResumeCommand {
|
||||
@@ -509,6 +618,7 @@ mod tests {
|
||||
conversation_id: conversation
|
||||
.map(ConversationId::from_string)
|
||||
.map(Result::unwrap),
|
||||
update_action: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,6 +627,7 @@ mod tests {
|
||||
let exit_info = AppExitInfo {
|
||||
token_usage: TokenUsage::default(),
|
||||
conversation_id: None,
|
||||
update_action: None,
|
||||
};
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert!(lines.is_empty());
|
||||
|
||||
@@ -4,7 +4,9 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use clap::ArgGroup;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_common::format_env_display::format_env_display;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::find_codex_home;
|
||||
@@ -12,8 +14,12 @@ use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::mcp::auth::compute_auth_statuses;
|
||||
use codex_core::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::delete_oauth_tokens;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
use codex_rmcp_client::supports_oauth_login;
|
||||
|
||||
/// [experimental] Launch Codex as an MCP server or manage configured MCP servers.
|
||||
///
|
||||
@@ -77,13 +83,61 @@ pub struct AddArgs {
|
||||
/// Name for the MCP server configuration.
|
||||
pub name: String,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
#[arg(long, value_parser = parse_env_pair, value_name = "KEY=VALUE")]
|
||||
pub env: Vec<(String, String)>,
|
||||
#[command(flatten)]
|
||||
pub transport_args: AddMcpTransportArgs,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
#[command(
|
||||
group(
|
||||
ArgGroup::new("transport")
|
||||
.args(["command", "url"])
|
||||
.required(true)
|
||||
.multiple(false)
|
||||
)
|
||||
)]
|
||||
pub struct AddMcpTransportArgs {
|
||||
#[command(flatten)]
|
||||
pub stdio: Option<AddMcpStdioArgs>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub streamable_http: Option<AddMcpStreamableHttpArgs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStdioArgs {
|
||||
/// Command to launch the MCP server.
|
||||
#[arg(trailing_var_arg = true, num_args = 1..)]
|
||||
/// Use --url for a streamable HTTP server.
|
||||
#[arg(
|
||||
trailing_var_arg = true,
|
||||
num_args = 0..,
|
||||
)]
|
||||
pub command: Vec<String>,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
/// Only valid with stdio servers.
|
||||
#[arg(
|
||||
long,
|
||||
value_parser = parse_env_pair,
|
||||
value_name = "KEY=VALUE",
|
||||
)]
|
||||
pub env: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStreamableHttpArgs {
|
||||
/// URL for a streamable HTTP MCP server.
|
||||
#[arg(long)]
|
||||
pub url: String,
|
||||
|
||||
/// Optional environment variable to read for a bearer token.
|
||||
/// Only valid with streamable HTTP servers.
|
||||
#[arg(
|
||||
long = "bearer-token-env-var",
|
||||
value_name = "ENV_VAR",
|
||||
requires = "url"
|
||||
)]
|
||||
pub bearer_token_env_var: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -138,39 +192,65 @@ impl McpCli {
|
||||
|
||||
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let AddArgs { name, env, command } = add_args;
|
||||
let AddArgs {
|
||||
name,
|
||||
transport_args,
|
||||
} = add_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let mut command_parts = command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in env {
|
||||
map.insert(key, value);
|
||||
}
|
||||
Some(map)
|
||||
};
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
let transport = match transport_args {
|
||||
AddMcpTransportArgs {
|
||||
stdio: Some(stdio), ..
|
||||
} => {
|
||||
let mut command_parts = stdio.command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if stdio.env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stdio.env.into_iter().collect::<HashMap<_, _>>())
|
||||
};
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
}
|
||||
AddMcpTransportArgs {
|
||||
streamable_http:
|
||||
Some(AddMcpStreamableHttpArgs {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
}),
|
||||
..
|
||||
} => McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"),
|
||||
};
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: transport.clone(),
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
};
|
||||
@@ -182,6 +262,26 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
|
||||
println!("Added global MCP server '{name}'.");
|
||||
|
||||
if let McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var: None,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = transport
|
||||
&& matches!(supports_oauth_login(&url).await, Ok(true))
|
||||
{
|
||||
println!("Detected OAuth support. Starting OAuth flow…");
|
||||
perform_oauth_login(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -219,7 +319,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
if !config.use_experimental_use_rmcp_client {
|
||||
if !config.features.enabled(Feature::RmcpClient) {
|
||||
bail!(
|
||||
"OAuth login is only supported when experimental_use_rmcp_client is true in config.toml."
|
||||
);
|
||||
@@ -231,12 +331,24 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
bail!("No MCP server named '{name}' found.");
|
||||
};
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url).await?;
|
||||
perform_oauth_login(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
@@ -259,7 +371,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url) {
|
||||
match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
@@ -276,29 +388,54 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
|
||||
let mut entries: Vec<_> = config.mcp_servers.iter().collect();
|
||||
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
)
|
||||
.await;
|
||||
|
||||
if list_args.json {
|
||||
let json_entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
.map(|(name, cfg)| {
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported);
|
||||
let transport = match &cfg.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => serde_json::json!({
|
||||
"type": "stdio",
|
||||
"command": command,
|
||||
"args": args,
|
||||
"env": env,
|
||||
"env_vars": env_vars,
|
||||
"cwd": cwd,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
"http_headers": http_headers,
|
||||
"env_http_headers": env_http_headers,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"enabled": cfg.enabled,
|
||||
"transport": transport,
|
||||
"startup_timeout_sec": cfg
|
||||
.startup_timeout_sec
|
||||
@@ -306,6 +443,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"tool_timeout_sec": cfg
|
||||
.tool_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
"auth_status": auth_status,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -319,45 +457,85 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut stdio_rows: Vec<[String; 4]> = Vec::new();
|
||||
let mut http_rows: Vec<[String; 3]> = Vec::new();
|
||||
let mut stdio_rows: Vec<[String; 7]> = Vec::new();
|
||||
let mut http_rows: Vec<[String; 5]> = Vec::new();
|
||||
|
||||
for (name, cfg) in entries {
|
||||
match &cfg.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let args_display = if args.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
args.join(" ")
|
||||
};
|
||||
let env_display = match env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
};
|
||||
stdio_rows.push([name.clone(), command.clone(), args_display, env_display]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
let has_bearer = if bearer_token.is_some() {
|
||||
"True"
|
||||
let env_display = format_env_display(env.as_ref(), env_vars);
|
||||
let cwd_display = cwd
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
} else {
|
||||
"False"
|
||||
"disabled".to_string()
|
||||
};
|
||||
http_rows.push([name.clone(), url.clone(), has_bearer.into()]);
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported)
|
||||
.to_string();
|
||||
stdio_rows.push([
|
||||
name.clone(),
|
||||
command.clone(),
|
||||
args_display,
|
||||
env_display,
|
||||
cwd_display,
|
||||
status,
|
||||
auth_status,
|
||||
]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => {
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
} else {
|
||||
"disabled".to_string()
|
||||
};
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported)
|
||||
.to_string();
|
||||
http_rows.push([
|
||||
name.clone(),
|
||||
url.clone(),
|
||||
bearer_token_env_var.clone().unwrap_or("-".to_string()),
|
||||
status,
|
||||
auth_status,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !stdio_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()];
|
||||
let mut widths = [
|
||||
"Name".len(),
|
||||
"Command".len(),
|
||||
"Args".len(),
|
||||
"Env".len(),
|
||||
"Cwd".len(),
|
||||
"Status".len(),
|
||||
"Auth".len(),
|
||||
];
|
||||
for row in &stdio_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -365,28 +543,40 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
"Name",
|
||||
"Command",
|
||||
"Args",
|
||||
"Env",
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {cwd:<cwd_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = "Name",
|
||||
command = "Command",
|
||||
args = "Args",
|
||||
env = "Env",
|
||||
cwd = "Cwd",
|
||||
status = "Status",
|
||||
auth = "Auth",
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
cwd_w = widths[4],
|
||||
status_w = widths[5],
|
||||
auth_w = widths[6],
|
||||
);
|
||||
|
||||
for row in &stdio_rows {
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {cwd:<cwd_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = row[0].as_str(),
|
||||
command = row[1].as_str(),
|
||||
args = row[2].as_str(),
|
||||
env = row[3].as_str(),
|
||||
cwd = row[4].as_str(),
|
||||
status = row[5].as_str(),
|
||||
auth = row[6].as_str(),
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
cwd_w = widths[4],
|
||||
status_w = widths[5],
|
||||
auth_w = widths[6],
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -396,7 +586,13 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
if !http_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Url".len(), "Has Bearer Token".len()];
|
||||
let mut widths = [
|
||||
"Name".len(),
|
||||
"Url".len(),
|
||||
"Bearer Token Env Var".len(),
|
||||
"Status".len(),
|
||||
"Auth".len(),
|
||||
];
|
||||
for row in &http_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -404,24 +600,32 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
"Name",
|
||||
"Url",
|
||||
"Has Bearer Token",
|
||||
"{name:<name_w$} {url:<url_w$} {token:<token_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = "Name",
|
||||
url = "Url",
|
||||
token = "Bearer Token Env Var",
|
||||
status = "Status",
|
||||
auth = "Auth",
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
status_w = widths[3],
|
||||
auth_w = widths[4],
|
||||
);
|
||||
|
||||
for row in &http_rows {
|
||||
println!(
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
"{name:<name_w$} {url:<url_w$} {token:<token_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = row[0].as_str(),
|
||||
url = row[1].as_str(),
|
||||
token = row[2].as_str(),
|
||||
status = row[3].as_str(),
|
||||
auth = row[4].as_str(),
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
status_w = widths[3],
|
||||
auth_w = widths[4],
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -441,20 +645,36 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
|
||||
if get_args.json {
|
||||
let transport = match &server.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => serde_json::json!({
|
||||
"type": "stdio",
|
||||
"command": command,
|
||||
"args": args,
|
||||
"env": env,
|
||||
"env_vars": env_vars,
|
||||
"cwd": cwd,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
"http_headers": http_headers,
|
||||
"env_http_headers": env_http_headers,
|
||||
}),
|
||||
};
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
"name": get_args.name,
|
||||
"enabled": server.enabled,
|
||||
"transport": transport,
|
||||
"startup_timeout_sec": server
|
||||
.startup_timeout_sec
|
||||
@@ -468,8 +688,15 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
}
|
||||
|
||||
println!("{}", get_args.name);
|
||||
println!(" enabled: {}", server.enabled);
|
||||
match &server.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
println!(" transport: stdio");
|
||||
println!(" command: {command}");
|
||||
let args_display = if args.is_empty() {
|
||||
@@ -478,10 +705,27 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
args.join(" ")
|
||||
};
|
||||
println!(" args: {args_display}");
|
||||
let env_display = match env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let cwd_display = cwd
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
println!(" cwd: {cwd_display}");
|
||||
let env_display = format_env_display(env.as_ref(), env_vars);
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
let headers_display = match http_headers {
|
||||
Some(map) if !map.is_empty() => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
@@ -490,14 +734,22 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
_ => "-".to_string(),
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let bearer = bearer_token.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token: {bearer}");
|
||||
println!(" http_headers: {headers_display}");
|
||||
let env_headers_display = match env_http_headers {
|
||||
Some(map) if !map.is_empty() => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
_ => "-".to_string(),
|
||||
};
|
||||
println!(" env_http_headers: {env_headers_display}");
|
||||
}
|
||||
}
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
|
||||
@@ -28,13 +28,22 @@ async fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
assert_eq!(servers.len(), 1);
|
||||
let docs = servers.get("docs").expect("server should exist");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
assert_eq!(command, "echo");
|
||||
assert_eq!(args, &vec!["hello".to_string()]);
|
||||
assert!(env.is_none());
|
||||
assert!(env_vars.is_empty());
|
||||
assert!(cwd.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(docs.enabled);
|
||||
|
||||
let mut remove_cmd = codex_command(codex_home.path())?;
|
||||
remove_cmd
|
||||
@@ -90,6 +99,130 @@ async fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
assert_eq!(env.len(), 2);
|
||||
assert_eq!(env.get("FOO"), Some(&"bar".to_string()));
|
||||
assert_eq!(env.get("ALPHA"), Some(&"beta".to_string()));
|
||||
assert!(envy.enabled);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_without_manual_token() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args(["mcp", "add", "github", "--url", "https://example.com/mcp"])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let github = servers.get("github").expect("github server should exist");
|
||||
match &github.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(github.enabled);
|
||||
|
||||
assert!(!codex_home.path().join(".credentials.json").exists());
|
||||
assert!(!codex_home.path().join(".env").exists());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_with_custom_env_var() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"issues",
|
||||
"--url",
|
||||
"https://example.com/issues",
|
||||
"--bearer-token-env-var",
|
||||
"GITHUB_TOKEN",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let issues = servers.get("issues").expect("issues server should exist");
|
||||
match &issues.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/issues");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN"));
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(issues.enabled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_rejects_removed_flag() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--with-bearer-token",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("--with-bearer-token"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_cant_add_command_and_url() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--command",
|
||||
"--",
|
||||
"echo",
|
||||
"hello",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("unexpected argument '--command' found"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use predicates::prelude::PredicateBooleanExt;
|
||||
use predicates::str::contains;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value as JsonValue;
|
||||
@@ -26,8 +30,8 @@ fn list_shows_empty_state() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_and_get_render_expected_output() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn list_and_get_render_expected_output() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add = codex_command(codex_home.path())?;
|
||||
@@ -45,6 +49,18 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let mut servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs_entry = servers
|
||||
.get_mut("docs")
|
||||
.expect("docs server should exist after add");
|
||||
match &mut docs_entry.transport {
|
||||
McpServerTransportConfig::Stdio { env_vars, .. } => {
|
||||
*env_vars = vec!["APP_TOKEN".to_string(), "WORKSPACE_ID".to_string()];
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let mut list_cmd = codex_command(codex_home.path())?;
|
||||
let list_output = list_cmd.args(["mcp", "list"]).output()?;
|
||||
assert!(list_output.status.success());
|
||||
@@ -53,6 +69,12 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("docs"));
|
||||
assert!(stdout.contains("docs-server"));
|
||||
assert!(stdout.contains("TOKEN=secret"));
|
||||
assert!(stdout.contains("APP_TOKEN=$APP_TOKEN"));
|
||||
assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID"));
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Auth"));
|
||||
assert!(stdout.contains("enabled"));
|
||||
assert!(stdout.contains("Unsupported"));
|
||||
|
||||
let mut list_json_cmd = codex_command(codex_home.path())?;
|
||||
let json_output = list_json_cmd.args(["mcp", "list", "--json"]).output()?;
|
||||
@@ -64,6 +86,7 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
json!([
|
||||
{
|
||||
"name": "docs",
|
||||
"enabled": true,
|
||||
"transport": {
|
||||
"type": "stdio",
|
||||
"command": "docs-server",
|
||||
@@ -73,10 +96,16 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
],
|
||||
"env": {
|
||||
"TOKEN": "secret"
|
||||
}
|
||||
},
|
||||
"env_vars": [
|
||||
"APP_TOKEN",
|
||||
"WORKSPACE_ID"
|
||||
],
|
||||
"cwd": null
|
||||
},
|
||||
"startup_timeout_sec": null,
|
||||
"tool_timeout_sec": null
|
||||
"tool_timeout_sec": null,
|
||||
"auth_status": "unsupported"
|
||||
}
|
||||
]
|
||||
)
|
||||
@@ -91,6 +120,9 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("command: docs-server"));
|
||||
assert!(stdout.contains("args: --port 4000"));
|
||||
assert!(stdout.contains("env: TOKEN=secret"));
|
||||
assert!(stdout.contains("APP_TOKEN=$APP_TOKEN"));
|
||||
assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID"));
|
||||
assert!(stdout.contains("enabled: true"));
|
||||
assert!(stdout.contains("remove: codex mcp remove docs"));
|
||||
|
||||
let mut get_json_cmd = codex_command(codex_home.path())?;
|
||||
@@ -98,7 +130,7 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
.args(["mcp", "get", "docs", "--json"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("\"name\": \"docs\""));
|
||||
.stdout(contains("\"name\": \"docs\"").and(contains("\"enabled\": true")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use clap::Args;
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
|
||||
@@ -6,4 +7,43 @@ use codex_common::CliConfigOverrides;
|
||||
pub struct Cli {
|
||||
#[clap(skip)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
#[command(subcommand)]
|
||||
pub command: Option<Command>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum Command {
|
||||
/// Submit a new Codex Cloud task without launching the TUI.
|
||||
Exec(ExecCommand),
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct ExecCommand {
|
||||
/// Task prompt to run in Codex Cloud.
|
||||
#[arg(value_name = "QUERY")]
|
||||
pub query: Option<String>,
|
||||
|
||||
/// Target environment identifier (see `codex cloud` to browse).
|
||||
#[arg(long = "env", value_name = "ENV_ID")]
|
||||
pub environment: String,
|
||||
|
||||
/// Number of assistant attempts (best-of-N).
|
||||
#[arg(
|
||||
long = "attempts",
|
||||
default_value_t = 1usize,
|
||||
value_parser = parse_attempts
|
||||
)]
|
||||
pub attempts: usize,
|
||||
}
|
||||
|
||||
fn parse_attempts(input: &str) -> Result<usize, String> {
|
||||
let value: usize = input
|
||||
.parse()
|
||||
.map_err(|_| "attempts must be an integer between 1 and 4".to_string())?;
|
||||
if (1..=4).contains(&value) {
|
||||
Ok(value)
|
||||
} else {
|
||||
Err("attempts must be between 1 and 4".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ mod ui;
|
||||
pub mod util;
|
||||
pub use cli::Cli;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -23,6 +25,175 @@ struct ApplyJob {
|
||||
diff_override: Option<String>,
|
||||
}
|
||||
|
||||
struct BackendContext {
|
||||
backend: Arc<dyn codex_cloud_tasks_client::CloudBackend>,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
async fn init_backend(user_agent_suffix: &str) -> anyhow::Result<BackendContext> {
|
||||
let use_mock = matches!(
|
||||
std::env::var("CODEX_CLOUD_TASKS_MODE").ok().as_deref(),
|
||||
Some("mock") | Some("MOCK")
|
||||
);
|
||||
let base_url = std::env::var("CODEX_CLOUD_TASKS_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://chatgpt.com/backend-api".to_string());
|
||||
|
||||
set_user_agent_suffix(user_agent_suffix);
|
||||
|
||||
if use_mock {
|
||||
return Ok(BackendContext {
|
||||
backend: Arc::new(codex_cloud_tasks_client::MockClient),
|
||||
base_url,
|
||||
});
|
||||
}
|
||||
|
||||
let ua = codex_core::default_client::get_codex_user_agent();
|
||||
let mut http = codex_cloud_tasks_client::HttpClient::new(base_url.clone())?.with_user_agent(ua);
|
||||
let style = if base_url.contains("/backend-api") {
|
||||
"wham"
|
||||
} else {
|
||||
"codex-api"
|
||||
};
|
||||
append_error_log(format!("startup: base_url={base_url} path_style={style}"));
|
||||
|
||||
let auth = match codex_core::config::find_codex_home()
|
||||
.ok()
|
||||
.map(|home| codex_login::AuthManager::new(home, false))
|
||||
.and_then(|am| am.auth())
|
||||
{
|
||||
Some(auth) => auth,
|
||||
None => {
|
||||
eprintln!(
|
||||
"Not signed in. Please run 'codex login' to sign in with ChatGPT, then re-run 'codex cloud'."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(acc) = auth.get_account_id() {
|
||||
append_error_log(format!("auth: mode=ChatGPT account_id={acc}"));
|
||||
}
|
||||
|
||||
let token = match auth.get_token().await {
|
||||
Ok(t) if !t.is_empty() => t,
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Not signed in. Please run 'codex login' to sign in with ChatGPT, then re-run 'codex cloud'."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
http = http.with_bearer_token(token.clone());
|
||||
if let Some(acc) = auth
|
||||
.get_account_id()
|
||||
.or_else(|| util::extract_chatgpt_account_id(&token))
|
||||
{
|
||||
append_error_log(format!("auth: set ChatGPT-Account-Id header: {acc}"));
|
||||
http = http.with_chatgpt_account_id(acc);
|
||||
}
|
||||
|
||||
Ok(BackendContext {
|
||||
backend: Arc::new(http),
|
||||
base_url,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_exec_command(args: crate::cli::ExecCommand) -> anyhow::Result<()> {
|
||||
let crate::cli::ExecCommand {
|
||||
query,
|
||||
environment,
|
||||
attempts,
|
||||
} = args;
|
||||
let ctx = init_backend("codex_cloud_tasks_exec").await?;
|
||||
let prompt = resolve_query_input(query)?;
|
||||
let env_id = resolve_environment_id(&ctx, &environment).await?;
|
||||
let created = codex_cloud_tasks_client::CloudBackend::create_task(
|
||||
&*ctx.backend,
|
||||
&env_id,
|
||||
&prompt,
|
||||
"main",
|
||||
false,
|
||||
attempts,
|
||||
)
|
||||
.await?;
|
||||
let url = util::task_url(&ctx.base_url, &created.id.0);
|
||||
println!("{url}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn resolve_environment_id(ctx: &BackendContext, requested: &str) -> anyhow::Result<String> {
|
||||
let trimmed = requested.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(anyhow!("environment id must not be empty"));
|
||||
}
|
||||
let normalized = util::normalize_base_url(&ctx.base_url);
|
||||
let headers = util::build_chatgpt_headers().await;
|
||||
let environments = crate::env_detect::list_environments(&normalized, &headers).await?;
|
||||
if environments.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no cloud environments are available for this workspace"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(row) = environments.iter().find(|row| row.id == trimmed) {
|
||||
return Ok(row.id.clone());
|
||||
}
|
||||
|
||||
let label_matches = environments
|
||||
.iter()
|
||||
.filter(|row| {
|
||||
row.label
|
||||
.as_deref()
|
||||
.map(|label| label.eq_ignore_ascii_case(trimmed))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
match label_matches.as_slice() {
|
||||
[] => Err(anyhow!(
|
||||
"environment '{trimmed}' not found; run `codex cloud` to list available environments"
|
||||
)),
|
||||
[single] => Ok(single.id.clone()),
|
||||
[first, rest @ ..] => {
|
||||
let first_id = &first.id;
|
||||
if rest.iter().all(|row| row.id == *first_id) {
|
||||
Ok(first_id.clone())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"environment label '{trimmed}' is ambiguous; run `codex cloud` to pick the desired environment id"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_query_input(query_arg: Option<String>) -> anyhow::Result<String> {
|
||||
match query_arg {
|
||||
Some(q) if q != "-" => Ok(q),
|
||||
maybe_dash => {
|
||||
let force_stdin = matches!(maybe_dash.as_deref(), Some("-"));
|
||||
if std::io::stdin().is_terminal() && !force_stdin {
|
||||
return Err(anyhow!(
|
||||
"no query provided. Pass one as an argument or pipe it via stdin."
|
||||
));
|
||||
}
|
||||
if !force_stdin {
|
||||
eprintln!("Reading query from stdin...");
|
||||
}
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin()
|
||||
.read_to_string(&mut buffer)
|
||||
.map_err(|e| anyhow!("failed to read query from stdin: {e}"))?;
|
||||
if buffer.trim().is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no query provided via stdin (received empty input)."
|
||||
));
|
||||
}
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn level_from_status(status: codex_cloud_tasks_client::ApplyStatus) -> app::ApplyResultLevel {
|
||||
match status {
|
||||
codex_cloud_tasks_client::ApplyStatus::Success => app::ApplyResultLevel::Success,
|
||||
@@ -148,7 +319,14 @@ fn spawn_apply(
|
||||
// (no standalone patch summarizer needed – UI displays raw diffs)
|
||||
|
||||
/// Entry point for the `codex cloud` subcommand.
|
||||
pub async fn run_main(_cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
pub async fn run_main(cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
if let Some(command) = cli.command {
|
||||
return match command {
|
||||
crate::cli::Command::Exec(args) => run_exec_command(args).await,
|
||||
};
|
||||
}
|
||||
let Cli { .. } = cli;
|
||||
|
||||
// Very minimal logging setup; mirrors other crates' pattern.
|
||||
let default_level = "error";
|
||||
let _ = tracing_subscriber::fmt()
|
||||
@@ -162,72 +340,8 @@ pub async fn run_main(_cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> a
|
||||
.try_init();
|
||||
|
||||
info!("Launching Cloud Tasks list UI");
|
||||
set_user_agent_suffix("codex_cloud_tasks_tui");
|
||||
|
||||
// Default to online unless explicitly configured to use mock.
|
||||
let use_mock = matches!(
|
||||
std::env::var("CODEX_CLOUD_TASKS_MODE").ok().as_deref(),
|
||||
Some("mock") | Some("MOCK")
|
||||
);
|
||||
|
||||
let backend: Arc<dyn codex_cloud_tasks_client::CloudBackend> = if use_mock {
|
||||
Arc::new(codex_cloud_tasks_client::MockClient)
|
||||
} else {
|
||||
// Build an HTTP client against the configured (or default) base URL.
|
||||
let base_url = std::env::var("CODEX_CLOUD_TASKS_BASE_URL")
|
||||
.unwrap_or_else(|_| "https://chatgpt.com/backend-api".to_string());
|
||||
let ua = codex_core::default_client::get_codex_user_agent();
|
||||
let mut http =
|
||||
codex_cloud_tasks_client::HttpClient::new(base_url.clone())?.with_user_agent(ua);
|
||||
// Log which base URL and path style we're going to use.
|
||||
let style = if base_url.contains("/backend-api") {
|
||||
"wham"
|
||||
} else {
|
||||
"codex-api"
|
||||
};
|
||||
append_error_log(format!("startup: base_url={base_url} path_style={style}"));
|
||||
|
||||
// Require ChatGPT login (SWIC). Exit with a clear message if missing.
|
||||
let _token = match codex_core::config::find_codex_home()
|
||||
.ok()
|
||||
.map(|home| codex_login::AuthManager::new(home, false))
|
||||
.and_then(|am| am.auth())
|
||||
{
|
||||
Some(auth) => {
|
||||
// Log account context for debugging workspace selection.
|
||||
if let Some(acc) = auth.get_account_id() {
|
||||
append_error_log(format!("auth: mode=ChatGPT account_id={acc}"));
|
||||
}
|
||||
match auth.get_token().await {
|
||||
Ok(t) if !t.is_empty() => {
|
||||
// Attach token and ChatGPT-Account-Id header if available
|
||||
http = http.with_bearer_token(t.clone());
|
||||
if let Some(acc) = auth
|
||||
.get_account_id()
|
||||
.or_else(|| util::extract_chatgpt_account_id(&t))
|
||||
{
|
||||
append_error_log(format!("auth: set ChatGPT-Account-Id header: {acc}"));
|
||||
http = http.with_chatgpt_account_id(acc);
|
||||
}
|
||||
t
|
||||
}
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Not signed in. Please run 'codex login' to sign in with ChatGPT, then re-run 'codex cloud'."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
eprintln!(
|
||||
"Not signed in. Please run 'codex login' to sign in with ChatGPT, then re-run 'codex cloud'."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
Arc::new(http)
|
||||
};
|
||||
let BackendContext { backend, .. } = init_backend("codex_cloud_tasks_tui").await?;
|
||||
let backend = backend;
|
||||
|
||||
// Terminal setup
|
||||
use crossterm::ExecutableCommand;
|
||||
|
||||
@@ -91,3 +91,18 @@ pub async fn build_chatgpt_headers() -> HeaderMap {
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
/// Construct a browser-friendly task URL for the given backend base URL.
|
||||
pub fn task_url(base_url: &str, task_id: &str) -> String {
|
||||
let normalized = normalize_base_url(base_url);
|
||||
if let Some(root) = normalized.strip_suffix("/backend-api") {
|
||||
return format!("{root}/codex/tasks/{task_id}");
|
||||
}
|
||||
if let Some(root) = normalized.strip_suffix("/api/codex") {
|
||||
return format!("{root}/codex/tasks/{task_id}");
|
||||
}
|
||||
if normalized.ends_with("/codex") {
|
||||
return format!("{normalized}/tasks/{task_id}");
|
||||
}
|
||||
format!("{normalized}/codex/tasks/{task_id}")
|
||||
}
|
||||
|
||||
66
codex-rs/common/src/format_env_display.rs
Normal file
66
codex-rs/common/src/format_env_display.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn format_env_display(env: Option<&HashMap<String, String>>, env_vars: &[String]) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(map) = env {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
parts.extend(
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(key, value)| format!("{key}={value}")),
|
||||
);
|
||||
}
|
||||
|
||||
if !env_vars.is_empty() {
|
||||
parts.extend(env_vars.iter().map(|var| format!("{var}=${var}")));
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
parts.join(", ")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn returns_dash_when_empty() {
|
||||
assert_eq!(format_env_display(None, &[]), "-");
|
||||
|
||||
let empty_map = HashMap::new();
|
||||
assert_eq!(format_env_display(Some(&empty_map), &[]), "-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn formats_sorted_env_pairs() {
|
||||
let mut env = HashMap::new();
|
||||
env.insert("B".to_string(), "two".to_string());
|
||||
env.insert("A".to_string(), "one".to_string());
|
||||
|
||||
assert_eq!(format_env_display(Some(&env), &[]), "A=one, B=two");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn formats_env_vars_with_dollar_prefix() {
|
||||
let vars = vec!["TOKEN".to_string(), "PATH".to_string()];
|
||||
|
||||
assert_eq!(format_env_display(None, &vars), "TOKEN=$TOKEN, PATH=$PATH");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combines_env_pairs_and_vars() {
|
||||
let mut env = HashMap::new();
|
||||
env.insert("HOME".to_string(), "/tmp".to_string());
|
||||
let vars = vec!["TOKEN".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
format_env_display(Some(&env), &vars),
|
||||
"HOME=/tmp, TOKEN=$TOKEN"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,9 @@ mod sandbox_mode_cli_arg;
|
||||
#[cfg(feature = "cli")]
|
||||
pub use sandbox_mode_cli_arg::SandboxModeCliArg;
|
||||
|
||||
#[cfg(feature = "cli")]
|
||||
pub mod format_env_display;
|
||||
|
||||
#[cfg(any(feature = "cli", test))]
|
||||
mod config_override;
|
||||
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
This file has moved. Please see the latest configuration documentation here:
|
||||
|
||||
- Full config docs: [docs/config.md](../docs/config.md)
|
||||
- MCP servers section: [docs/config.md#mcp_servers](../docs/config.md#mcp_servers)
|
||||
- MCP servers section: [docs/config.md#connecting-to-mcp-servers](../docs/config.md#connecting-to-mcp-servers)
|
||||
|
||||
@@ -26,6 +26,7 @@ codex-mcp-client = { workspace = true }
|
||||
codex-otel = { workspace = true, features = ["otel"] }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-utils-string = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
dunce = { workspace = true }
|
||||
@@ -47,6 +48,7 @@ shlex = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
test-log = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
time = { workspace = true, features = [
|
||||
"formatting",
|
||||
|
||||
@@ -135,6 +135,10 @@ impl CodexAuth {
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub fn get_account_email(&self) -> Option<String> {
|
||||
self.get_current_token_data().and_then(|t| t.id_token.email)
|
||||
}
|
||||
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
@@ -309,7 +311,12 @@ pub(crate) async fn stream_chat_completions(
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
@@ -349,7 +356,9 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
@@ -389,10 +398,12 @@ async fn process_chat_sse<S>(
|
||||
let mut reasoning_text = String::new();
|
||||
|
||||
loop {
|
||||
let sse = match otel_event_manager
|
||||
.log_sse_event(|| timeout(idle_timeout, stream.next()))
|
||||
.await
|
||||
{
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
|
||||
@@ -5,6 +5,8 @@ use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use bytes::Bytes;
|
||||
@@ -47,6 +49,7 @@ use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::RateLimitWindow;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::state::TaskKind;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
@@ -123,8 +126,16 @@ impl ModelClient {
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
self.stream_with_task_kind(prompt, TaskKind::Regular).await
|
||||
}
|
||||
|
||||
pub(crate) async fn stream_with_task_kind(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
task_kind: TaskKind,
|
||||
) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses(prompt).await,
|
||||
WireApi::Responses => self.stream_responses(prompt, task_kind).await,
|
||||
WireApi::Chat => {
|
||||
// Create the raw streaming connection first.
|
||||
let response_stream = stream_chat_completions(
|
||||
@@ -165,7 +176,11 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Implementation for the OpenAI *Responses* experimental API.
|
||||
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
async fn stream_responses(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
task_kind: TaskKind,
|
||||
) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
@@ -244,7 +259,7 @@ impl ModelClient {
|
||||
let max_attempts = self.provider.request_max_retries();
|
||||
for attempt in 0..=max_attempts {
|
||||
match self
|
||||
.attempt_stream_responses(attempt, &payload_json, &auth_manager)
|
||||
.attempt_stream_responses(attempt, &payload_json, &auth_manager, task_kind)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => {
|
||||
@@ -272,6 +287,7 @@ impl ModelClient {
|
||||
attempt: u64,
|
||||
payload_json: &Value,
|
||||
auth_manager: &Option<Arc<AuthManager>>,
|
||||
task_kind: TaskKind,
|
||||
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
@@ -294,6 +310,7 @@ impl ModelClient {
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.header("Codex-Task-Type", task_kind.header_value())
|
||||
.json(payload_json);
|
||||
|
||||
if let Some(auth) = auth.as_ref()
|
||||
@@ -336,7 +353,12 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
let stream = resp.bytes_stream().map_err(move |e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: request_id.clone(),
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
@@ -416,7 +438,9 @@ impl ModelClient {
|
||||
request_id,
|
||||
})
|
||||
}
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(
|
||||
CodexErr::ConnectionFailed(ConnectionFailedError { source: e }),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -649,10 +673,12 @@ async fn process_sse<S>(
|
||||
let mut response_error: Option<CodexErr> = None;
|
||||
|
||||
loop {
|
||||
let sse = match otel_event_manager
|
||||
.log_sse_event(|| timeout(idle_timeout, stream.next()))
|
||||
.await
|
||||
{
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
@@ -1013,6 +1039,7 @@ mod tests {
|
||||
"test",
|
||||
"test",
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::ConversationPathResponseEvent;
|
||||
use codex_protocol::protocol::ExitedReviewModeEvent;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
@@ -27,10 +28,17 @@ use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -57,6 +65,7 @@ use crate::exec_command::WriteStdinParams;
|
||||
use crate::executor::Executor;
|
||||
use crate::executor::ExecutorConfig;
|
||||
use crate::executor::normalize_exec_result;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
@@ -98,6 +107,7 @@ use crate::rollout::RolloutRecorderParams;
|
||||
use crate::shell;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
@@ -110,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -363,14 +374,32 @@ impl Session {
|
||||
|
||||
let mcp_fut = McpConnectionManager::new(
|
||||
config.mcp_servers.clone(),
|
||||
config.use_experimental_use_rmcp_client,
|
||||
config
|
||||
.features
|
||||
.enabled(crate::features::Feature::RmcpClient),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
let auth_statuses_fut = compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
|
||||
// Join all independent futures.
|
||||
let (rollout_recorder, mcp_res, default_shell, (history_log_id, history_entry_count)) =
|
||||
tokio::join!(rollout_fut, mcp_fut, default_shell_fut, history_meta_fut);
|
||||
let (
|
||||
rollout_recorder,
|
||||
mcp_res,
|
||||
default_shell,
|
||||
(history_log_id, history_entry_count),
|
||||
auth_statuses,
|
||||
) = tokio::join!(
|
||||
rollout_fut,
|
||||
mcp_fut,
|
||||
default_shell_fut,
|
||||
history_meta_fut,
|
||||
auth_statuses_fut
|
||||
);
|
||||
|
||||
let rollout_recorder = rollout_recorder.map_err(|e| {
|
||||
error!("failed to initialize rollout recorder: {e:#}");
|
||||
@@ -397,11 +426,27 @@ impl Session {
|
||||
// Surface individual client start-up failures to the user.
|
||||
if !failed_clients.is_empty() {
|
||||
for (server_name, err) in failed_clients {
|
||||
let message = format!("MCP client for `{server_name}` failed to start: {err:#}");
|
||||
error!("{message}");
|
||||
let auth_status = auth_statuses.get(&server_name);
|
||||
let requires_login = match auth_status {
|
||||
Some(McpAuthStatus::NotLoggedIn) => true,
|
||||
Some(McpAuthStatus::OAuth) => is_mcp_client_auth_required_error(&err),
|
||||
_ => false,
|
||||
};
|
||||
let log_message =
|
||||
format!("MCP client for `{server_name}` failed to start: {err:#}");
|
||||
error!("{log_message}");
|
||||
let display_message = if requires_login {
|
||||
format!(
|
||||
"The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}` to log in."
|
||||
)
|
||||
} else {
|
||||
log_message
|
||||
};
|
||||
post_session_configured_error_events.push(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
msg: EventMsg::Error(ErrorEvent { message }),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: display_message,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -411,6 +456,7 @@ impl Session {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
auth_manager.auth().and_then(|a| a.get_account_id()),
|
||||
auth_manager.auth().and_then(|a| a.get_account_email()),
|
||||
auth_manager.auth().map(|a| a.mode),
|
||||
config.otel.log_user_prompt,
|
||||
terminal::user_agent(),
|
||||
@@ -444,12 +490,7 @@ impl Session {
|
||||
client,
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
features: &config.features,
|
||||
}),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -591,6 +632,7 @@ impl Session {
|
||||
warn!("Overwriting existing pending approval for sub_id: {event_id}");
|
||||
}
|
||||
|
||||
let parsed_cmd = parse_command(&command);
|
||||
let event = Event {
|
||||
id: event_id,
|
||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
@@ -598,6 +640,7 @@ impl Session {
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
@@ -853,10 +896,7 @@ impl Session {
|
||||
call_id,
|
||||
command: command_for_display.clone(),
|
||||
cwd,
|
||||
parsed_cmd: parse_command(&command_for_display)
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
parsed_cmd: parse_command(&command_for_display),
|
||||
}),
|
||||
};
|
||||
let event = Event {
|
||||
@@ -881,6 +921,7 @@ impl Session {
|
||||
duration,
|
||||
exit_code,
|
||||
timed_out: _,
|
||||
..
|
||||
} = output;
|
||||
// Send full stdout/stderr to clients; do not truncate.
|
||||
let stdout = stdout.text.clone();
|
||||
@@ -939,21 +980,42 @@ impl Session {
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prepared: PreparedExec,
|
||||
approval_policy: AskForApproval,
|
||||
config: ExecutorConfig,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
let PreparedExec { context, request } = prepared;
|
||||
let is_apply_patch = context.apply_patch.is_some();
|
||||
let sub_id = context.sub_id.clone();
|
||||
let call_id = context.call_id.clone();
|
||||
|
||||
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||
.await;
|
||||
|
||||
let begin_turn_diff = turn_diff_tracker.clone();
|
||||
let begin_context = context.clone();
|
||||
let session = self;
|
||||
let result = self
|
||||
.services
|
||||
.executor
|
||||
.run(request, self, approval_policy, &context)
|
||||
.run(
|
||||
request,
|
||||
self,
|
||||
approval_policy,
|
||||
&context,
|
||||
config,
|
||||
move || {
|
||||
let turn_diff = begin_turn_diff.clone();
|
||||
let ctx = begin_context.clone();
|
||||
async move {
|
||||
session.on_exec_command_begin(turn_diff, ctx).await;
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
if matches!(
|
||||
&result,
|
||||
Err(ExecError::Function(FunctionCallError::Denied(_)))
|
||||
) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let normalized = normalize_exec_result(&result);
|
||||
let borrowed = normalized.event_output();
|
||||
|
||||
@@ -1028,6 +1090,39 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> anyhow::Result<ListResourcesResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.list_resources(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> anyhow::Result<ListResourceTemplatesResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.list_resource_templates(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> anyhow::Result<ReadResourceResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read_resource(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
@@ -1088,19 +1183,6 @@ impl Session {
|
||||
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
}
|
||||
|
||||
fn interrupt_task_sync(&self) {
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
{
|
||||
at.try_clear_pending_sync();
|
||||
let tasks = at.drain_tasks();
|
||||
*active = None;
|
||||
for (_sub_id, task) in tasks {
|
||||
task.handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn notifier(&self) -> &UserNotifier {
|
||||
&self.services.notifier
|
||||
}
|
||||
@@ -1114,12 +1196,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.interrupt_task_sync();
|
||||
}
|
||||
}
|
||||
|
||||
async fn submission_loop(
|
||||
sess: Arc<Session>,
|
||||
turn_context: TurnContext,
|
||||
@@ -1193,12 +1269,7 @@ async fn submission_loop(
|
||||
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &effective_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
features: &config.features,
|
||||
});
|
||||
|
||||
let new_turn_context = TurnContext {
|
||||
@@ -1295,14 +1366,7 @@ async fn submission_loop(
|
||||
client,
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config
|
||||
.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config
|
||||
.use_experimental_unified_exec_tool,
|
||||
features: &config.features,
|
||||
}),
|
||||
user_instructions: turn_context.user_instructions.clone(),
|
||||
base_instructions: turn_context.base_instructions.clone(),
|
||||
@@ -1402,10 +1466,25 @@ async fn submission_loop(
|
||||
|
||||
// This is a cheap lookup from the connection manager's cache.
|
||||
let tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let (auth_statuses, resources, resource_templates) = tokio::join!(
|
||||
compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
),
|
||||
sess.services.mcp_connection_manager.list_all_resources(),
|
||||
sess.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resource_templates()
|
||||
);
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::McpListToolsResponse(
|
||||
crate::protocol::McpListToolsResponseEvent { tools },
|
||||
crate::protocol::McpListToolsResponseEvent {
|
||||
tools,
|
||||
resources,
|
||||
resource_templates,
|
||||
auth_statuses,
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
@@ -1526,14 +1605,15 @@ async fn spawn_review_thread(
|
||||
let model = config.review_model.clone();
|
||||
let review_model_family = find_family_for_model(&model)
|
||||
.unwrap_or_else(|| parent_turn_context.client.get_model_family());
|
||||
// For reviews, disable plan, web_search, view_image regardless of global settings.
|
||||
let mut review_features = config.features.clone();
|
||||
review_features.disable(crate::features::Feature::PlanTool);
|
||||
review_features.disable(crate::features::Feature::WebSearchRequest);
|
||||
review_features.disable(crate::features::Feature::ViewImageTool);
|
||||
review_features.disable(crate::features::Feature::StreamableShell);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &review_model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
features: &review_features,
|
||||
});
|
||||
|
||||
let base_instructions = REVIEW_PROMPT.to_string();
|
||||
@@ -1624,6 +1704,8 @@ pub(crate) async fn run_task(
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
if input.is_empty() {
|
||||
return None;
|
||||
@@ -1707,6 +1789,8 @@ pub(crate) async fn run_task(
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1859,6 +1943,7 @@ pub(crate) async fn run_task(
|
||||
);
|
||||
sess.notifier()
|
||||
.notify(&UserNotification::AgentTurnComplete {
|
||||
thread_id: sess.conversation_id.to_string(),
|
||||
turn_id: sub_id.clone(),
|
||||
input_messages: turn_input_messages,
|
||||
last_assistant_message: last_agent_message.clone(),
|
||||
@@ -1867,6 +1952,10 @@ pub(crate) async fn run_task(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(CodexErr::TurnAborted) => {
|
||||
// Aborted turn is reported via a different event.
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
let event = Event {
|
||||
@@ -1932,6 +2021,8 @@ async fn run_turn(
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
@@ -1961,10 +2052,13 @@ async fn run_turn(
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&sub_id,
|
||||
&prompt,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
@@ -1998,9 +2092,7 @@ async fn run_turn(
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2029,6 +2121,7 @@ struct TurnRunResult {
|
||||
total_token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn try_run_turn(
|
||||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
@@ -2036,6 +2129,8 @@ async fn try_run_turn(
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
// call_ids that are part of this response.
|
||||
let completed_call_ids = prompt
|
||||
@@ -2101,7 +2196,12 @@ async fn try_run_turn(
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
@@ -2117,7 +2217,8 @@ async fn try_run_turn(
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let event = stream.next().or_cancel(&cancellation_token).await?;
|
||||
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
None => {
|
||||
@@ -2182,7 +2283,8 @@ async fn try_run_turn(
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
Err(FunctionCallError::RespondToModel(message))
|
||||
| Err(FunctionCallError::Denied(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
@@ -2221,7 +2323,10 @@ async fn try_run_turn(
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
||||
let processed_items = output
|
||||
.try_collect()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
@@ -2430,6 +2535,11 @@ pub(crate) async fn exit_review_mode(
|
||||
.await;
|
||||
}
|
||||
|
||||
fn is_mcp_client_auth_required_error(error: &anyhow::Error) -> bool {
|
||||
// StreamableHttpError::AuthRequired from the MCP SDK.
|
||||
error.to_string().contains("Auth required")
|
||||
}
|
||||
|
||||
use crate::executor::errors::ExecError;
|
||||
use crate::executor::linkers::PreparedExec;
|
||||
use crate::tools::context::ApplyPatchCommandContext;
|
||||
@@ -2459,6 +2569,8 @@ mod tests {
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
@@ -2468,8 +2580,6 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_history_matches_live_compactions() {
|
||||
@@ -2711,6 +2821,7 @@ mod tests {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
@@ -2740,12 +2851,7 @@ mod tests {
|
||||
);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
features: &config.features,
|
||||
});
|
||||
let turn_context = TurnContext {
|
||||
client,
|
||||
@@ -2813,12 +2919,7 @@ mod tests {
|
||||
);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
features: &config.features,
|
||||
});
|
||||
let turn_context = Arc::new(TurnContext {
|
||||
client,
|
||||
@@ -2858,12 +2959,15 @@ mod tests {
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NeverEndingTask(TaskKind);
|
||||
struct NeverEndingTask {
|
||||
kind: TaskKind,
|
||||
listen_to_cancellation_token: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SessionTask for NeverEndingTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
self.0
|
||||
self.kind
|
||||
}
|
||||
|
||||
async fn run(
|
||||
@@ -2872,20 +2976,26 @@ mod tests {
|
||||
_ctx: Arc<TurnContext>,
|
||||
_sub_id: String,
|
||||
_input: Vec<InputItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
if self.listen_to_cancellation_token {
|
||||
cancellation_token.cancelled().await;
|
||||
return None;
|
||||
}
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
if let TaskKind::Review = self.0 {
|
||||
if let TaskKind::Review = self.kind {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[test_log::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
@@ -2896,7 +3006,41 @@ mod tests {
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Regular),
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_gracefuly_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2910,7 +3054,7 @@ mod tests {
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-review".to_string();
|
||||
@@ -2921,18 +3065,27 @@ mod tests {
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Review),
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Review,
|
||||
listen_to_cancellation_token: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let first = rx.recv().await.expect("first event");
|
||||
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for first event")
|
||||
.expect("first event");
|
||||
match first.msg {
|
||||
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
let second = rx.recv().await.expect("second event");
|
||||
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for second event")
|
||||
.expect("second event");
|
||||
match second.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::protocol::InputItem;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
use crate::state::TaskKind;
|
||||
use crate::truncate::truncate_middle;
|
||||
use crate::util::backoff;
|
||||
use askama::Template;
|
||||
@@ -70,14 +71,10 @@ async fn run_compact_task_inner(
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input = sess
|
||||
let mut turn_input = sess
|
||||
.turn_input_with_history(vec![initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
..Default::default()
|
||||
};
|
||||
let mut truncated_count = 0usize;
|
||||
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
@@ -93,17 +90,36 @@ async fn run_compact_task_inner(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result =
|
||||
drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
if truncated_count > 0 {
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
||||
),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
turn_input.remove(0);
|
||||
truncated_count += 1;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
||||
.await;
|
||||
let event = Event {
|
||||
@@ -121,9 +137,7 @@ async fn run_compact_task_inner(
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
@@ -245,7 +259,11 @@ async fn drain_to_completed(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream_with_task_kind(prompt, TaskKind::Compact)
|
||||
.await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,18 @@ pub struct ConfigProfile {
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub include_view_image_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
pub tools_web_search: Option<bool>,
|
||||
pub tools_view_image: Option<bool>,
|
||||
/// Optional feature toggles scoped to this profile.
|
||||
#[serde(default)]
|
||||
pub features: Option<crate::features::FeaturesToml>,
|
||||
}
|
||||
|
||||
impl From<ConfigProfile> for codex_app_server_protocol::Profile {
|
||||
|
||||
@@ -20,6 +20,10 @@ pub struct McpServerConfig {
|
||||
#[serde(flatten)]
|
||||
pub transport: McpServerTransportConfig,
|
||||
|
||||
/// When `false`, Codex skips initializing this MCP server.
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
|
||||
#[serde(
|
||||
default,
|
||||
@@ -40,21 +44,34 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct RawMcpServerConfig {
|
||||
// stdio
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
args: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_vars: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
cwd: Option<PathBuf>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
// streamable_http
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
// shared
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
#[serde(default)]
|
||||
startup_timeout_ms: Option<u64>,
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
tool_timeout_sec: Option<Duration>,
|
||||
#[serde(default)]
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
let raw = RawMcpServerConfig::deserialize(deserializer)?;
|
||||
@@ -85,30 +102,58 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
command: Some(command),
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
url,
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?;
|
||||
throw_if_set(
|
||||
"stdio",
|
||||
"bearer_token_env_var",
|
||||
bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
throw_if_set("stdio", "http_headers", http_headers.as_ref())?;
|
||||
throw_if_set("stdio", "env_http_headers", env_http_headers.as_ref())?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
env,
|
||||
env_vars: env_vars.unwrap_or_default(),
|
||||
cwd,
|
||||
}
|
||||
}
|
||||
RawMcpServerConfig {
|
||||
url: Some(url),
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
..
|
||||
env_vars,
|
||||
cwd,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
startup_timeout_sec: _,
|
||||
tool_timeout_sec: _,
|
||||
startup_timeout_ms: _,
|
||||
enabled: _,
|
||||
} => {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token }
|
||||
throw_if_set("streamable_http", "env_vars", env_vars.as_ref())?;
|
||||
throw_if_set("streamable_http", "cwd", cwd.as_ref())?;
|
||||
throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
}
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
};
|
||||
@@ -117,10 +162,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
transport,
|
||||
startup_timeout_sec,
|
||||
tool_timeout_sec: raw.tool_timeout_sec,
|
||||
enabled: raw.enabled.unwrap_or_else(default_enabled),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
|
||||
pub enum McpServerTransportConfig {
|
||||
@@ -131,15 +181,25 @@ pub enum McpServerTransportConfig {
|
||||
args: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
env_vars: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
cwd: Option<PathBuf>,
|
||||
},
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// A plain text bearer token to use for authentication.
|
||||
/// This bearer token will be included in the HTTP request header as an `Authorization: Bearer <token>` header.
|
||||
/// This should be used with caution because it lives on disk in clear text.
|
||||
/// Name of the environment variable to read for an HTTP bearer token.
|
||||
/// When set, requests will include the token via `Authorization: Bearer <token>`.
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
/// Additional HTTP headers to include in requests to this server.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
/// HTTP headers where the value is sourced from an environment variable.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -301,6 +361,20 @@ pub struct Tui {
|
||||
pub notifications: Notifications,
|
||||
}
|
||||
|
||||
/// Settings for notices we display to users via the tui and app-server clients
|
||||
/// (primarily the Codex IDE extension). NOTE: these are different from
|
||||
/// notifications - notices are warnings, NUX screens, acknowledgements, etc.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Notice {
|
||||
/// Tracks whether the user has acknowledged the full access warning prompt.
|
||||
pub hide_full_access_warning: Option<bool>,
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
/// used by set_hide_full_access_warning until we refactor config updates
|
||||
pub(crate) const TABLE_KEY: &'static str = "notice";
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct SandboxWorkspaceWrite {
|
||||
#[serde(default)]
|
||||
@@ -447,9 +521,12 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -467,9 +544,12 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: None
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -488,9 +568,69 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())]))
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_env_vars() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_vars = ["FOO", "BAR"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with env_vars");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: vec!["FOO".to_string(), "BAR".to_string()],
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_cwd() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
cwd = "/tmp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with cwd");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: Some(PathBuf::from("/tmp")),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_disabled_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
enabled = false
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize disabled server config");
|
||||
|
||||
assert!(!cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -506,17 +646,20 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_bearer_token() {
|
||||
fn deserialize_streamable_http_server_config_with_env_var() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
bearer_token_env_var = "GITHUB_TOKEN"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
@@ -525,7 +668,35 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret".to_string())
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string()),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_headers() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
env_http_headers = { "X-Token" = "TOKEN_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config with headers");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Token".to_string(),
|
||||
"TOKEN_ENV".to_string()
|
||||
)])),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -553,13 +724,37 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_bearer_token_for_stdio_transport() {
|
||||
fn deserialize_rejects_headers_for_stdio() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject http_headers for stdio transport");
|
||||
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_http_headers = { "X-Foo" = "BAR_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject env_http_headers for stdio transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
url = "https://example.com"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer token for stdio transport");
|
||||
.expect_err("should reject bearer_token field");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("bearer_token is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_async_utils::CancelErr;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use reqwest::StatusCode;
|
||||
@@ -12,6 +14,9 @@ use tokio::task::JoinError;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
/// Limit UI error messages to a reasonable size while keeping useful context.
|
||||
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 4 KiB
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SandboxErr {
|
||||
/// Error from sandbox execution
|
||||
@@ -46,6 +51,9 @@ pub enum SandboxErr {
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CodexErr {
|
||||
#[error("turn aborted")]
|
||||
TurnAborted,
|
||||
|
||||
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
||||
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
||||
///
|
||||
@@ -87,6 +95,12 @@ pub enum CodexErr {
|
||||
#[error("{0}")]
|
||||
UsageLimitReached(UsageLimitReachedError),
|
||||
|
||||
#[error("{0}")]
|
||||
ResponseStreamFailed(ResponseStreamFailed),
|
||||
|
||||
#[error("{0}")]
|
||||
ConnectionFailed(ConnectionFailedError),
|
||||
|
||||
#[error(
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
)]
|
||||
@@ -122,9 +136,6 @@ pub enum CodexErr {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
@@ -143,6 +154,43 @@ pub enum CodexErr {
|
||||
EnvVar(EnvVarError),
|
||||
}
|
||||
|
||||
impl From<CancelErr> for CodexErr {
|
||||
fn from(_: CancelErr) -> Self {
|
||||
CodexErr::TurnAborted
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionFailedError {
|
||||
pub source: reqwest::Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConnectionFailedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Connection failed: {}", self.source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseStreamFailed {
|
||||
pub source: reqwest::Error,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ResponseStreamFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Error while reading the server response: {}{}",
|
||||
self.source,
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.map(|id| format!(", request id: {id}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UnexpectedResponseError {
|
||||
pub status: StatusCode,
|
||||
@@ -304,21 +352,44 @@ impl CodexErr {
|
||||
}
|
||||
|
||||
pub fn get_error_message_ui(e: &CodexErr) -> String {
|
||||
match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => output.stderr.text.clone(),
|
||||
let message = match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => {
|
||||
let aggregated = output.aggregated_output.text.trim();
|
||||
if !aggregated.is_empty() {
|
||||
output.aggregated_output.text.clone()
|
||||
} else {
|
||||
let stderr = output.stderr.text.trim();
|
||||
let stdout = output.stdout.text.trim();
|
||||
match (stderr.is_empty(), stdout.is_empty()) {
|
||||
(false, false) => format!("{stderr}\n{stdout}"),
|
||||
(false, true) => output.stderr.text.clone(),
|
||||
(true, false) => output.stdout.text.clone(),
|
||||
(true, true) => format!(
|
||||
"command failed inside sandbox with exit code {}",
|
||||
output.exit_code
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Timeouts are not sandbox errors from a UX perspective; present them plainly
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
),
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => {
|
||||
format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
)
|
||||
}
|
||||
_ => e.to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
truncate_middle(&message, ERROR_MESSAGE_UI_MAX_BYTES).0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::exec::StreamOutput;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn rate_limit_snapshot() -> RateLimitSnapshot {
|
||||
RateLimitSnapshot {
|
||||
@@ -348,6 +419,73 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_uses_aggregated_output_when_stderr_empty() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 77,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new("aggregate detail".to_string()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "aggregate detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_both_streams_when_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 9,
|
||||
stdout: StreamOutput::new("stdout detail".to_string()),
|
||||
stderr: StreamOutput::new("stderr detail".to_string()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_stdout_when_no_stderr() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 11,
|
||||
stdout: StreamOutput::new("stdout only".to_string()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(8),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stdout only");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_exit_code_when_no_output_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 13,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(5),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(
|
||||
get_error_message_ui(&err),
|
||||
"command failed inside sandbox with exit code 13"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
|
||||
@@ -177,7 +177,7 @@ pub async fn process_exec_tool_call(
|
||||
}));
|
||||
}
|
||||
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
if is_likely_sandbox_denied(sandbox_type, &exec_output) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
@@ -195,21 +195,57 @@ pub async fn process_exec_tool_call(
|
||||
/// We don't have a fully deterministic way to tell if our command failed
|
||||
/// because of the sandbox - a command in the user's zshrc file might hit an
|
||||
/// error, but the command itself might fail or succeed for other reasons.
|
||||
/// For now, we conservatively check for 'command not found' (exit code 127),
|
||||
/// and can add additional cases as necessary.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exit_code: i32) -> bool {
|
||||
if sandbox_type == SandboxType::None {
|
||||
/// For now, we conservatively check for well known command failure exit codes and
|
||||
/// also look for common sandbox denial keywords in the command output.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exec_output: &ExecToolCallOutput) -> bool {
|
||||
if sandbox_type == SandboxType::None || exec_output.exit_code == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick rejects: well-known non-sandbox shell exit codes
|
||||
// 127: command not found, 2: misuse of shell builtins
|
||||
if exit_code == 127 {
|
||||
// 2: misuse of shell builtins
|
||||
// 126: permission denied
|
||||
// 127: command not found
|
||||
const QUICK_REJECT_EXIT_CODES: [i32; 3] = [2, 126, 127];
|
||||
if QUICK_REJECT_EXIT_CODES.contains(&exec_output.exit_code) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For all other cases, we assume the sandbox is the cause
|
||||
true
|
||||
const SANDBOX_DENIED_KEYWORDS: [&str; 6] = [
|
||||
"operation not permitted",
|
||||
"permission denied",
|
||||
"read-only file system",
|
||||
"seccomp",
|
||||
"sandbox",
|
||||
"landlock",
|
||||
];
|
||||
|
||||
if [
|
||||
&exec_output.stderr.text,
|
||||
&exec_output.stdout.text,
|
||||
&exec_output.aggregated_output.text,
|
||||
]
|
||||
.into_iter()
|
||||
.any(|section| {
|
||||
let lower = section.to_lowercase();
|
||||
SANDBOX_DENIED_KEYWORDS
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle))
|
||||
}) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
const SIGSYS_CODE: i32 = libc::SIGSYS;
|
||||
if sandbox_type == SandboxType::LinuxSeccomp
|
||||
&& exec_output.exit_code == EXIT_CODE_SIGNAL_BASE + SIGSYS_CODE
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -436,3 +472,77 @@ fn synthetic_exit_status(code: i32) -> ExitStatus {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
std::process::ExitStatus::from_raw(code.try_into().unwrap())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn make_exec_output(
|
||||
exit_code: i32,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
aggregated: &str,
|
||||
) -> ExecToolCallOutput {
|
||||
ExecToolCallOutput {
|
||||
exit_code,
|
||||
stdout: StreamOutput::new(stdout.to_string()),
|
||||
stderr: StreamOutput::new(stderr.to_string()),
|
||||
aggregated_output: StreamOutput::new(aggregated.to_string()),
|
||||
duration: Duration::from_millis(1),
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_requires_keywords() {
|
||||
let output = make_exec_output(1, "", "", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_identifies_keyword_in_stderr() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_respects_quick_reject_exit_codes() {
|
||||
let output = make_exec_output(127, "", "command not found", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_ignores_non_sandbox_mode() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(!is_likely_sandbox_denied(SandboxType::None, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_uses_aggregated_output() {
|
||||
let output = make_exec_output(
|
||||
101,
|
||||
"",
|
||||
"",
|
||||
"cargo failed: Read-only file system when writing target",
|
||||
);
|
||||
assert!(is_likely_sandbox_denied(
|
||||
SandboxType::MacosSeatbelt,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn sandbox_detection_flags_sigsys_exit_code() {
|
||||
let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS;
|
||||
let output = make_exec_output(exit_code, "", "", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use async_trait::async_trait;
|
||||
use crate::CODEX_APPLY_PATCH_ARG1;
|
||||
use crate::apply_patch::ApplyPatchExec;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::executor::ExecutorConfig;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
|
||||
pub(crate) enum ExecutionMode {
|
||||
@@ -22,6 +23,7 @@ pub(crate) trait ExecutionBackend: Send + Sync {
|
||||
params: ExecParams,
|
||||
// Required for downcasting the apply_patch.
|
||||
mode: &ExecutionMode,
|
||||
config: &ExecutorConfig,
|
||||
) -> Result<ExecParams, FunctionCallError>;
|
||||
|
||||
fn stream_stdout(&self, _mode: &ExecutionMode) -> bool {
|
||||
@@ -47,6 +49,7 @@ impl ExecutionBackend for ShellBackend {
|
||||
&self,
|
||||
params: ExecParams,
|
||||
mode: &ExecutionMode,
|
||||
_config: &ExecutorConfig,
|
||||
) -> Result<ExecParams, FunctionCallError> {
|
||||
match mode {
|
||||
ExecutionMode::Shell => Ok(params),
|
||||
@@ -65,17 +68,22 @@ impl ExecutionBackend for ApplyPatchBackend {
|
||||
&self,
|
||||
params: ExecParams,
|
||||
mode: &ExecutionMode,
|
||||
config: &ExecutorConfig,
|
||||
) -> Result<ExecParams, FunctionCallError> {
|
||||
match mode {
|
||||
ExecutionMode::ApplyPatch(exec) => {
|
||||
let path_to_codex = env::current_exe()
|
||||
.ok()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"failed to determine path to codex executable".to_string(),
|
||||
)
|
||||
})?;
|
||||
let path_to_codex = if let Some(exe_path) = &config.codex_exe {
|
||||
exe_path.to_string_lossy().to_string()
|
||||
} else {
|
||||
env::current_exe()
|
||||
.ok()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"failed to determine path to codex executable".to_string(),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
let patch = exec.action.patch.clone();
|
||||
Ok(ExecParams {
|
||||
|
||||
@@ -60,5 +60,9 @@ pub mod errors {
|
||||
pub(crate) fn rejection(msg: impl Into<String>) -> Self {
|
||||
FunctionCallError::RespondToModel(msg.into()).into()
|
||||
}
|
||||
|
||||
pub(crate) fn denied(msg: impl Into<String>) -> Self {
|
||||
FunctionCallError::Denied(msg.into()).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
@@ -30,19 +31,19 @@ use codex_otel::otel_event_manager::ToolDecisionSource;
|
||||
pub(crate) struct ExecutorConfig {
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) sandbox_cwd: PathBuf,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) codex_exe: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ExecutorConfig {
|
||||
pub(crate) fn new(
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sandbox_cwd: PathBuf,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
codex_exe: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sandbox_policy,
|
||||
sandbox_cwd,
|
||||
codex_linux_sandbox_exe,
|
||||
codex_exe,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,23 +65,42 @@ impl Executor {
|
||||
|
||||
/// Updates the sandbox policy and working directory used for future
|
||||
/// executions without recreating the executor.
|
||||
pub(crate) fn update_environment(&self, sandbox_policy: SandboxPolicy, sandbox_cwd: PathBuf) {
|
||||
if let Ok(mut cfg) = self.config.write() {
|
||||
cfg.sandbox_policy = sandbox_policy;
|
||||
cfg.sandbox_cwd = sandbox_cwd;
|
||||
pub(crate) fn update_environment(
|
||||
&self,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sandbox_cwd: PathBuf,
|
||||
) -> ExecutorConfig {
|
||||
match self.config.write() {
|
||||
Ok(mut cfg) => {
|
||||
cfg.sandbox_policy = sandbox_policy;
|
||||
cfg.sandbox_cwd = sandbox_cwd;
|
||||
cfg.clone()
|
||||
}
|
||||
Err(err) => {
|
||||
let mut cfg = err.into_inner();
|
||||
cfg.sandbox_policy = sandbox_policy;
|
||||
cfg.sandbox_cwd = sandbox_cwd;
|
||||
cfg.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a prepared execution request end-to-end: prepares parameters, decides on
|
||||
/// sandbox placement (prompting the user when necessary), launches the command,
|
||||
/// and lets the backend post-process the final output.
|
||||
pub(crate) async fn run(
|
||||
pub(crate) async fn run<F, Fut>(
|
||||
&self,
|
||||
mut request: ExecutionRequest,
|
||||
session: &Session,
|
||||
approval_policy: AskForApproval,
|
||||
context: &ExecCommandContext,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
config: ExecutorConfig,
|
||||
on_exec_begin: F,
|
||||
) -> Result<ExecToolCallOutput, ExecError>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = ()>,
|
||||
{
|
||||
if matches!(request.mode, ExecutionMode::Shell) {
|
||||
request.params =
|
||||
maybe_translate_shell_command(request.params, session, request.use_shell_profile);
|
||||
@@ -94,17 +114,10 @@ impl Executor {
|
||||
None
|
||||
};
|
||||
request.params = backend
|
||||
.prepare(request.params, &request.mode)
|
||||
.prepare(request.params, &request.mode, &config)
|
||||
.map_err(ExecError::from)?;
|
||||
|
||||
// Step 2: Snapshot sandbox configuration so it stays stable for this run.
|
||||
let config = self
|
||||
.config
|
||||
.read()
|
||||
.map_err(|_| ExecError::rejection("executor config poisoned"))?
|
||||
.clone();
|
||||
|
||||
// Step 3: Decide sandbox placement, prompting for approval when needed.
|
||||
// Step 2: Decide sandbox placement, prompting for approval when needed.
|
||||
let sandbox_decision = select_sandbox(
|
||||
&request,
|
||||
approval_policy,
|
||||
@@ -119,8 +132,8 @@ impl Executor {
|
||||
if sandbox_decision.record_session_approval {
|
||||
self.approval_cache.insert(request.approval_command.clone());
|
||||
}
|
||||
|
||||
// Step 4: Launch the command within the chosen sandbox.
|
||||
on_exec_begin().await;
|
||||
// Step 3: Launch the command within the chosen sandbox.
|
||||
let first_attempt = self
|
||||
.spawn(
|
||||
request.params.clone(),
|
||||
@@ -130,7 +143,7 @@ impl Executor {
|
||||
)
|
||||
.await;
|
||||
|
||||
// Step 5: Handle sandbox outcomes, optionally escalating to an unsandboxed retry.
|
||||
// Step 4: Handle sandbox outcomes, optionally escalating to an unsandboxed retry.
|
||||
match first_attempt {
|
||||
Ok(output) => Ok(output),
|
||||
Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => {
|
||||
@@ -210,7 +223,7 @@ impl Executor {
|
||||
Ok(retry_output)
|
||||
}
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
Err(ExecError::denied("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -227,7 +240,7 @@ impl Executor {
|
||||
sandbox,
|
||||
&config.sandbox_policy,
|
||||
&config.sandbox_cwd,
|
||||
&config.codex_linux_sandbox_exe,
|
||||
&config.codex_exe,
|
||||
stdout_stream,
|
||||
)
|
||||
.await
|
||||
@@ -301,7 +314,8 @@ pub(crate) fn normalize_exec_result(
|
||||
}
|
||||
Err(err) => {
|
||||
let message = match err {
|
||||
ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(),
|
||||
ExecError::Function(FunctionCallError::RespondToModel(msg))
|
||||
| ExecError::Function(FunctionCallError::Denied(msg)) => msg.clone(),
|
||||
ExecError::Codex(e) => get_error_message_ui(e),
|
||||
err => err.to_string(),
|
||||
};
|
||||
@@ -380,6 +394,23 @@ mod tests {
|
||||
assert_eq!(message, "failed in sandbox: sandbox stderr");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_failure_message_falls_back_to_aggregated_output() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 101,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new("aggregate text".to_string()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
};
|
||||
let message = sandbox_failure_message(err);
|
||||
assert_eq!(message, "failed in sandbox: aggregate text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_function_error_synthesizes_payload() {
|
||||
let err = FunctionCallError::RespondToModel("boom".to_string());
|
||||
|
||||
@@ -149,7 +149,7 @@ async fn select_shell_sandbox(
|
||||
ReviewDecision::Approved => Ok(SandboxDecision::user_override(false)),
|
||||
ReviewDecision::ApprovedForSession => Ok(SandboxDecision::user_override(true)),
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
Err(ExecError::denied("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
258
codex-rs/core/src/features.rs
Normal file
258
codex-rs/core/src/features.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
//! Centralized feature flags and metadata.
|
||||
//!
|
||||
//! This module defines a small set of toggles that gate experimental and
|
||||
//! optional behavior across the codebase. Instead of wiring individual
|
||||
//! booleans through multiple types, call sites consult a single `Features`
|
||||
//! container attached to `Config`.
|
||||
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
mod legacy;
|
||||
pub(crate) use legacy::LegacyFeatureToggles;
|
||||
|
||||
/// High-level lifecycle stage for a feature.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Stage {
|
||||
Experimental,
|
||||
Beta,
|
||||
Stable,
|
||||
Deprecated,
|
||||
Removed,
|
||||
}
|
||||
|
||||
/// Unique features toggled via configuration.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum Feature {
|
||||
/// Use the single unified PTY-backed exec tool.
|
||||
UnifiedExec,
|
||||
/// Use the streamable exec-command/write-stdin tool pair.
|
||||
StreamableShell,
|
||||
/// Use the official Rust MCP client (rmcp).
|
||||
RmcpClient,
|
||||
/// Include the plan tool.
|
||||
PlanTool,
|
||||
/// Include the freeform apply_patch tool.
|
||||
ApplyPatchFreeform,
|
||||
/// Include the view_image tool.
|
||||
ViewImageTool,
|
||||
/// Allow the model to request web searches.
|
||||
WebSearchRequest,
|
||||
/// Automatically approve all approval requests from the harness.
|
||||
ApproveAll,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
pub fn key(self) -> &'static str {
|
||||
self.info().key
|
||||
}
|
||||
|
||||
pub fn stage(self) -> Stage {
|
||||
self.info().stage
|
||||
}
|
||||
|
||||
pub fn default_enabled(self) -> bool {
|
||||
self.info().default_enabled
|
||||
}
|
||||
|
||||
fn info(self) -> &'static FeatureSpec {
|
||||
FEATURES
|
||||
.iter()
|
||||
.find(|spec| spec.id == self)
|
||||
.unwrap_or_else(|| unreachable!("missing FeatureSpec for {:?}", self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds the effective set of enabled features.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Features {
|
||||
enabled: BTreeSet<Feature>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FeatureOverrides {
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub include_view_image_tool: Option<bool>,
|
||||
pub web_search_request: Option<bool>,
|
||||
}
|
||||
|
||||
impl FeatureOverrides {
|
||||
fn apply(self, features: &mut Features) {
|
||||
LegacyFeatureToggles {
|
||||
include_plan_tool: self.include_plan_tool,
|
||||
include_apply_patch_tool: self.include_apply_patch_tool,
|
||||
include_view_image_tool: self.include_view_image_tool,
|
||||
tools_web_search: self.web_search_request,
|
||||
..Default::default()
|
||||
}
|
||||
.apply(features);
|
||||
}
|
||||
}
|
||||
|
||||
impl Features {
|
||||
/// Starts with built-in defaults.
|
||||
pub fn with_defaults() -> Self {
|
||||
let mut set = BTreeSet::new();
|
||||
for spec in FEATURES {
|
||||
if spec.default_enabled {
|
||||
set.insert(spec.id);
|
||||
}
|
||||
}
|
||||
Self { enabled: set }
|
||||
}
|
||||
|
||||
pub fn enabled(&self, f: Feature) -> bool {
|
||||
self.enabled.contains(&f)
|
||||
}
|
||||
|
||||
pub fn enable(&mut self, f: Feature) {
|
||||
self.enabled.insert(f);
|
||||
}
|
||||
|
||||
pub fn disable(&mut self, f: Feature) {
|
||||
self.enabled.remove(&f);
|
||||
}
|
||||
|
||||
/// Apply a table of key -> bool toggles (e.g. from TOML).
|
||||
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
|
||||
for (k, v) in m {
|
||||
match feature_for_key(k) {
|
||||
Some(feat) => {
|
||||
if *v {
|
||||
self.enable(feat);
|
||||
} else {
|
||||
self.disable(feat);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown feature key in config: {k}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_config(
|
||||
cfg: &ConfigToml,
|
||||
config_profile: &ConfigProfile,
|
||||
overrides: FeatureOverrides,
|
||||
) -> Self {
|
||||
let mut features = Features::with_defaults();
|
||||
|
||||
let base_legacy = LegacyFeatureToggles {
|
||||
experimental_use_freeform_apply_patch: cfg.experimental_use_freeform_apply_patch,
|
||||
experimental_use_exec_command_tool: cfg.experimental_use_exec_command_tool,
|
||||
experimental_use_unified_exec_tool: cfg.experimental_use_unified_exec_tool,
|
||||
experimental_use_rmcp_client: cfg.experimental_use_rmcp_client,
|
||||
tools_web_search: cfg.tools.as_ref().and_then(|t| t.web_search),
|
||||
tools_view_image: cfg.tools.as_ref().and_then(|t| t.view_image),
|
||||
..Default::default()
|
||||
};
|
||||
base_legacy.apply(&mut features);
|
||||
|
||||
if let Some(base_features) = cfg.features.as_ref() {
|
||||
features.apply_map(&base_features.entries);
|
||||
}
|
||||
|
||||
let profile_legacy = LegacyFeatureToggles {
|
||||
include_plan_tool: config_profile.include_plan_tool,
|
||||
include_apply_patch_tool: config_profile.include_apply_patch_tool,
|
||||
include_view_image_tool: config_profile.include_view_image_tool,
|
||||
experimental_use_freeform_apply_patch: config_profile
|
||||
.experimental_use_freeform_apply_patch,
|
||||
experimental_use_exec_command_tool: config_profile.experimental_use_exec_command_tool,
|
||||
experimental_use_unified_exec_tool: config_profile.experimental_use_unified_exec_tool,
|
||||
experimental_use_rmcp_client: config_profile.experimental_use_rmcp_client,
|
||||
tools_web_search: config_profile.tools_web_search,
|
||||
tools_view_image: config_profile.tools_view_image,
|
||||
};
|
||||
profile_legacy.apply(&mut features);
|
||||
if let Some(profile_features) = config_profile.features.as_ref() {
|
||||
features.apply_map(&profile_features.entries);
|
||||
}
|
||||
|
||||
overrides.apply(&mut features);
|
||||
|
||||
features
|
||||
}
|
||||
}
|
||||
|
||||
/// Keys accepted in `[features]` tables.
|
||||
fn feature_for_key(key: &str) -> Option<Feature> {
|
||||
for spec in FEATURES {
|
||||
if spec.key == key {
|
||||
return Some(spec.id);
|
||||
}
|
||||
}
|
||||
legacy::feature_for_key(key)
|
||||
}
|
||||
|
||||
/// Deserializable features table for TOML.
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct FeaturesToml {
|
||||
#[serde(flatten)]
|
||||
pub entries: BTreeMap<String, bool>,
|
||||
}
|
||||
|
||||
/// Single, easy-to-read registry of all feature definitions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FeatureSpec {
|
||||
pub id: Feature,
|
||||
pub key: &'static str,
|
||||
pub stage: Stage,
|
||||
pub default_enabled: bool,
|
||||
}
|
||||
|
||||
pub const FEATURES: &[FeatureSpec] = &[
|
||||
FeatureSpec {
|
||||
id: Feature::UnifiedExec,
|
||||
key: "unified_exec",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::StreamableShell,
|
||||
key: "streamable_shell",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RmcpClient,
|
||||
key: "rmcp_client",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::PlanTool,
|
||||
key: "plan_tool",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ApplyPatchFreeform,
|
||||
key: "apply_patch_freeform",
|
||||
stage: Stage::Beta,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ViewImageTool,
|
||||
key: "view_image_tool",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::WebSearchRequest,
|
||||
key: "web_search_request",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ApproveAll,
|
||||
key: "approve_all",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
158
codex-rs/core/src/features/legacy.rs
Normal file
158
codex-rs/core/src/features/legacy.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use super::Feature;
|
||||
use super::Features;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Alias {
|
||||
legacy_key: &'static str,
|
||||
feature: Feature,
|
||||
}
|
||||
|
||||
const ALIASES: &[Alias] = &[
|
||||
Alias {
|
||||
legacy_key: "experimental_use_unified_exec_tool",
|
||||
feature: Feature::UnifiedExec,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_exec_command_tool",
|
||||
feature: Feature::StreamableShell,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_rmcp_client",
|
||||
feature: Feature::RmcpClient,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_freeform_apply_patch",
|
||||
feature: Feature::ApplyPatchFreeform,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "include_apply_patch_tool",
|
||||
feature: Feature::ApplyPatchFreeform,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "include_plan_tool",
|
||||
feature: Feature::PlanTool,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "include_view_image_tool",
|
||||
feature: Feature::ViewImageTool,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "web_search",
|
||||
feature: Feature::WebSearchRequest,
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) fn feature_for_key(key: &str) -> Option<Feature> {
|
||||
ALIASES
|
||||
.iter()
|
||||
.find(|alias| alias.legacy_key == key)
|
||||
.map(|alias| {
|
||||
log_alias(alias.legacy_key, alias.feature);
|
||||
alias.feature
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LegacyFeatureToggles {
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub include_view_image_tool: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub tools_web_search: Option<bool>,
|
||||
pub tools_view_image: Option<bool>,
|
||||
}
|
||||
|
||||
impl LegacyFeatureToggles {
|
||||
pub fn apply(self, features: &mut Features) {
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::PlanTool,
|
||||
self.include_plan_tool,
|
||||
"include_plan_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ApplyPatchFreeform,
|
||||
self.include_apply_patch_tool,
|
||||
"include_apply_patch_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ApplyPatchFreeform,
|
||||
self.experimental_use_freeform_apply_patch,
|
||||
"experimental_use_freeform_apply_patch",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::StreamableShell,
|
||||
self.experimental_use_exec_command_tool,
|
||||
"experimental_use_exec_command_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::UnifiedExec,
|
||||
self.experimental_use_unified_exec_tool,
|
||||
"experimental_use_unified_exec_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::RmcpClient,
|
||||
self.experimental_use_rmcp_client,
|
||||
"experimental_use_rmcp_client",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::WebSearchRequest,
|
||||
self.tools_web_search,
|
||||
"tools.web_search",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ViewImageTool,
|
||||
self.include_view_image_tool,
|
||||
"include_view_image_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ViewImageTool,
|
||||
self.tools_view_image,
|
||||
"tools.view_image",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_if_some(
|
||||
features: &mut Features,
|
||||
feature: Feature,
|
||||
maybe_value: Option<bool>,
|
||||
alias_key: &'static str,
|
||||
) {
|
||||
if let Some(enabled) = maybe_value {
|
||||
set_feature(features, feature, enabled);
|
||||
log_alias(alias_key, feature);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_feature(features: &mut Features, feature: Feature, enabled: bool) {
|
||||
if enabled {
|
||||
features.enable(feature);
|
||||
} else {
|
||||
features.disable(feature);
|
||||
}
|
||||
}
|
||||
|
||||
fn log_alias(alias: &str, feature: Feature) {
|
||||
let canonical = feature.key();
|
||||
if alias == canonical {
|
||||
return;
|
||||
}
|
||||
info!(
|
||||
%alias,
|
||||
canonical,
|
||||
"legacy feature toggle detected; prefer `[features].{canonical}`"
|
||||
);
|
||||
}
|
||||
@@ -4,6 +4,8 @@ use thiserror::Error;
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
#[error("{0}")]
|
||||
Denied(String),
|
||||
#[error("LocalShellCall without call_id or id")]
|
||||
MissingLocalShellCallId,
|
||||
#[error("Fatal error: {0}")]
|
||||
|
||||
@@ -13,7 +13,6 @@ mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
mod command_safety;
|
||||
pub mod config;
|
||||
@@ -29,14 +28,17 @@ pub mod exec;
|
||||
mod exec_command;
|
||||
pub mod exec_env;
|
||||
pub mod executor;
|
||||
pub mod features;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod landlock;
|
||||
pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
pub mod token_data;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
@@ -105,5 +107,4 @@ pub use codex_protocol::models::LocalShellExecAction;
|
||||
pub use codex_protocol::models::LocalShellStatus;
|
||||
pub use codex_protocol::models::ReasoningItemContent;
|
||||
pub use codex_protocol::models::ResponseItem;
|
||||
|
||||
pub mod otel_init;
|
||||
|
||||
62
codex-rs/core/src/mcp/auth.rs
Normal file
62
codex-rs/core/src/mcp/auth.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::determine_streamable_http_auth_status;
|
||||
use futures::future::join_all;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::McpServerTransportConfig;
|
||||
|
||||
pub async fn compute_auth_statuses<'a, I>(
|
||||
servers: I,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> HashMap<String, McpAuthStatus>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
|
||||
{
|
||||
let futures = servers.into_iter().map(|(name, config)| {
|
||||
let name = name.clone();
|
||||
let config = config.clone();
|
||||
async move {
|
||||
let status = match compute_auth_status(&name, &config, store_mode).await {
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
|
||||
McpAuthStatus::Unsupported
|
||||
}
|
||||
};
|
||||
(name, status)
|
||||
}
|
||||
});
|
||||
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
async fn compute_auth_status(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
match &config.transport {
|
||||
McpServerTransportConfig::Stdio { .. } => Ok(McpAuthStatus::Unsupported),
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
1
codex-rs/core/src/mcp/mod.rs
Normal file
1
codex-rs/core/src/mcp/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth;
|
||||
@@ -8,7 +8,9 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -16,9 +18,18 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
@@ -100,34 +111,51 @@ enum McpClientAdapter {
|
||||
}
|
||||
|
||||
impl McpClientAdapter {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn new_stdio_client(
|
||||
use_rmcp_client: bool,
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: Vec<String>,
|
||||
cwd: Option<PathBuf>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
let client =
|
||||
Arc::new(RmcpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
} else {
|
||||
let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?);
|
||||
let client =
|
||||
Arc::new(McpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Legacy(client))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn new_streamable_http_client(
|
||||
server_name: String,
|
||||
url: String,
|
||||
bearer_token: Option<String>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
|
||||
RmcpClient::new_streamable_http_client(
|
||||
&server_name,
|
||||
&url,
|
||||
bearer_token,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
@@ -144,6 +172,47 @@ impl McpClientAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resources(
|
||||
&self,
|
||||
params: Option<mcp_types::ListResourcesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ListResourcesResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Ok(ListResourcesResult {
|
||||
next_cursor: None,
|
||||
resources: Vec::new(),
|
||||
}),
|
||||
McpClientAdapter::Rmcp(client) => client.list_resources(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
params: mcp_types::ReadResourceRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ReadResourceResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Err(anyhow!(
|
||||
"resources/read is not supported by legacy MCP clients"
|
||||
)),
|
||||
McpClientAdapter::Rmcp(client) => client.read_resource(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
params: Option<mcp_types::ListResourceTemplatesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ListResourceTemplatesResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Ok(ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: Vec::new(),
|
||||
}),
|
||||
McpClientAdapter::Rmcp(client) => client.list_resource_templates(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -182,6 +251,7 @@ impl McpConnectionManager {
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
@@ -202,9 +272,21 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !cfg.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let resolved_bearer_token = match &cfg.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
|
||||
_ => Ok(None),
|
||||
};
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
@@ -229,7 +311,13 @@ impl McpConnectionManager {
|
||||
};
|
||||
|
||||
let client = match transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
McpClientAdapter::new_stdio_client(
|
||||
@@ -237,18 +325,28 @@ impl McpConnectionManager {
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
bearer_token,
|
||||
resolved_bearer_token.unwrap_or_default(),
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -300,7 +398,7 @@ impl McpConnectionManager {
|
||||
Ok((Self { clients, tools }, errors))
|
||||
}
|
||||
|
||||
/// Returns a single map that contains **all** tools. Each key is the
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools
|
||||
@@ -309,6 +407,133 @@ impl McpConnectionManager {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resources. Each key is the
|
||||
/// server name and the value is a vector of resources.
|
||||
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<Resource> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor.as_ref().map(|next| ListResourcesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resources(params, timeout).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resources);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!("resources/list returned duplicate cursor")),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<Resource>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(resources))) => {
|
||||
aggregated.insert(server_name, resources);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!("Failed to list resources for MCP server '{server_name}': {err:#}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resources for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resource templates. Each key is the
|
||||
/// server name and the value is a vector of resource templates.
|
||||
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<ResourceTemplate> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor
|
||||
.as_ref()
|
||||
.map(|next| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resource_templates(params, timeout).await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resource_templates);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!(
|
||||
"resources/templates/list returned duplicate cursor"
|
||||
)),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<ResourceTemplate>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(templates))) => {
|
||||
aggregated.insert(server_name, templates);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!(
|
||||
"Failed to list resource templates for MCP server '{server_name}': {err:#}"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resource templates for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
@@ -320,7 +545,7 @@ impl McpConnectionManager {
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let client = &managed.client;
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
@@ -329,6 +554,64 @@ impl McpConnectionManager {
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resources(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// List resource templates from the specified server.
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resource_templates(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/templates/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// Read a resource from the specified server.
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
let uri = params.uri.clone();
|
||||
|
||||
client
|
||||
.read_resource(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
@@ -336,8 +619,35 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bearer_token(
|
||||
server_name: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let Some(env_var) = bearer_token_env_var else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match env::var(env_var) {
|
||||
Ok(value) => {
|
||||
if value.is_empty() {
|
||||
Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is empty"
|
||||
))
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
Err(env::VarError::NotPresent) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is not set"
|
||||
)),
|
||||
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
/// contains all tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
|
||||
@@ -58,7 +58,10 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
let result = sess
|
||||
.call_tool(&server, &tool_name, arguments_value.clone())
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e}"));
|
||||
.map_err(|e| format!("tool call error: {e:?}"));
|
||||
if let Err(e) = &result {
|
||||
tracing::warn!("MCP tool call error: {e:?}");
|
||||
}
|
||||
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
invocation,
|
||||
|
||||
@@ -119,9 +119,10 @@ pub fn find_family_for_model(mut slug: &str) -> Option<ModelFamily> {
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec![
|
||||
"read_file".to_string(),
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"test_sync_tool".to_string()
|
||||
"read_file".to_string(),
|
||||
"test_sync_tool".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
@@ -134,7 +135,11 @@ pub fn find_family_for_model(mut slug: &str) -> Option<ModelFamily> {
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
experimental_supported_tools: vec!["read_file".to_string(), "list_dir".to_string()],
|
||||
experimental_supported_tools: vec![
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"read_file".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
use crate::bash::try_parse_bash;
|
||||
use crate::bash::try_parse_word_only_commands_sequence;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use shlex::split as shlex_split;
|
||||
use shlex::try_join as shlex_try_join;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub enum ParsedCommand {
|
||||
Read {
|
||||
cmd: String,
|
||||
name: String,
|
||||
},
|
||||
ListFiles {
|
||||
cmd: String,
|
||||
path: Option<String>,
|
||||
},
|
||||
Search {
|
||||
cmd: String,
|
||||
query: Option<String>,
|
||||
path: Option<String>,
|
||||
},
|
||||
Unknown {
|
||||
cmd: String,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert core's parsed command enum into the protocol's simplified type so
|
||||
// events can carry the canonical representation across process boundaries.
|
||||
impl From<ParsedCommand> for codex_protocol::parse_command::ParsedCommand {
|
||||
fn from(v: ParsedCommand) -> Self {
|
||||
use codex_protocol::parse_command::ParsedCommand as P;
|
||||
match v {
|
||||
ParsedCommand::Read { cmd, name } => P::Read { cmd, name },
|
||||
ParsedCommand::ListFiles { cmd, path } => P::ListFiles { cmd, path },
|
||||
ParsedCommand::Search { cmd, query, path } => P::Search { cmd, query, path },
|
||||
ParsedCommand::Unknown { cmd } => P::Unknown { cmd },
|
||||
}
|
||||
}
|
||||
}
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn shlex_join(tokens: &[String]) -> String {
|
||||
shlex_try_join(tokens.iter().map(String::as_str))
|
||||
@@ -72,6 +38,7 @@ pub fn parse_command(command: &[String]) -> Vec<ParsedCommand> {
|
||||
/// Tests are at the top to encourage using TDD + Codex to fix the implementation.
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
use std::string::ToString;
|
||||
|
||||
fn shlex_split_safe(s: &str) -> Vec<String> {
|
||||
@@ -221,6 +188,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("webview/README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -232,6 +200,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
path: PathBuf::from("foo/foo.txt"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -254,6 +223,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
path: PathBuf::from("foo/foo.txt"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -278,6 +248,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -290,6 +261,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("tui/Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -302,6 +274,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -315,6 +288,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}]
|
||||
);
|
||||
}
|
||||
@@ -484,6 +458,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "parse_command.rs".to_string(),
|
||||
path: PathBuf::from("core/src/parse_command.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -496,6 +471,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "history_cell.rs".to_string(),
|
||||
path: PathBuf::from("tui/src/history_cell.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -509,6 +485,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat -- ansi-escape/Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("ansi-escape/Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -538,6 +515,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "sed -n '260,640p' exec/src/event_processor_with_human_output.rs".to_string(),
|
||||
name: "event_processor_with_human_output.rs".to_string(),
|
||||
path: PathBuf::from("exec/src/event_processor_with_human_output.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -697,6 +675,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: r#"cat "pkg\\src\\main.rs""#.to_string(),
|
||||
name: "main.rs".to_string(),
|
||||
path: PathBuf::from(r#"pkg\src\main.rs"#),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -708,6 +687,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "head -n50 Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -738,6 +718,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "tail -n+10 README.md".to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -774,6 +755,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat -- ./-strange-file-name".to_string(),
|
||||
name: "-strange-file-name".to_string(),
|
||||
path: PathBuf::from("./-strange-file-name"),
|
||||
}],
|
||||
);
|
||||
|
||||
@@ -783,6 +765,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "sed -n '12,20p' Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -875,11 +858,39 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
// so summaries reflect the order they will run.
|
||||
|
||||
// Map each pipeline segment to its parsed summary.
|
||||
let mut commands: Vec<ParsedCommand> = parts
|
||||
.iter()
|
||||
.map(|tokens| summarize_main_tokens(tokens))
|
||||
.collect();
|
||||
// Map each pipeline segment to its parsed summary, tracking `cd` to compute paths.
|
||||
let mut commands: Vec<ParsedCommand> = Vec::new();
|
||||
let mut cwd: Option<String> = None;
|
||||
for tokens in &parts {
|
||||
if let Some((head, tail)) = tokens.split_first()
|
||||
&& head == "cd"
|
||||
{
|
||||
if let Some(dir) = tail.first() {
|
||||
cwd = Some(match &cwd {
|
||||
Some(base) => join_paths(base, dir),
|
||||
None => dir.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let parsed = summarize_main_tokens(tokens);
|
||||
let parsed = match parsed {
|
||||
ParsedCommand::Read { cmd, name, path } => {
|
||||
if let Some(base) = &cwd {
|
||||
let full = join_paths(base, &path.to_string_lossy());
|
||||
ParsedCommand::Read {
|
||||
cmd,
|
||||
name,
|
||||
path: PathBuf::from(full),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
commands.push(parsed);
|
||||
}
|
||||
|
||||
while let Some(next) = simplify_once(&commands) {
|
||||
commands = next;
|
||||
@@ -1164,10 +1175,39 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
cmd: script.clone(),
|
||||
}]);
|
||||
}
|
||||
let mut commands: Vec<ParsedCommand> = filtered_commands
|
||||
.into_iter()
|
||||
.map(|tokens| summarize_main_tokens(&tokens))
|
||||
.collect();
|
||||
// Build parsed commands, tracking `cd` segments to compute effective file paths.
|
||||
let mut commands: Vec<ParsedCommand> = Vec::new();
|
||||
let mut cwd: Option<String> = None;
|
||||
for tokens in filtered_commands.into_iter() {
|
||||
if let Some((head, tail)) = tokens.split_first()
|
||||
&& head == "cd"
|
||||
{
|
||||
if let Some(dir) = tail.first() {
|
||||
cwd = Some(match &cwd {
|
||||
Some(base) => join_paths(base, dir),
|
||||
None => dir.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let parsed = summarize_main_tokens(&tokens);
|
||||
let parsed = match parsed {
|
||||
ParsedCommand::Read { cmd, name, path } => {
|
||||
if let Some(base) = &cwd {
|
||||
let full = join_paths(base, &path.to_string_lossy());
|
||||
ParsedCommand::Read {
|
||||
cmd,
|
||||
name,
|
||||
path: PathBuf::from(full),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
commands.push(parsed);
|
||||
}
|
||||
if commands.len() > 1 {
|
||||
commands.retain(|pc| !matches!(pc, ParsedCommand::Unknown { cmd } if cmd == "true"));
|
||||
// Apply the same simplifications used for non-bash parsing, e.g., drop leading `cd`.
|
||||
@@ -1187,7 +1227,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
commands = commands
|
||||
.into_iter()
|
||||
.map(|pc| match pc {
|
||||
ParsedCommand::Read { name, cmd, .. } => {
|
||||
ParsedCommand::Read { name, cmd, path } => {
|
||||
if had_connectors {
|
||||
let has_pipe = script_tokens.iter().any(|t| t == "|");
|
||||
let has_sed_n = script_tokens.windows(2).any(|w| {
|
||||
@@ -1198,14 +1238,16 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
ParsedCommand::Read {
|
||||
cmd: script.clone(),
|
||||
name,
|
||||
path,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name }
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
name,
|
||||
path,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1370,10 +1412,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
tail
|
||||
};
|
||||
if effective_tail.len() == 1 {
|
||||
let name = short_display_path(&effective_tail[0]);
|
||||
let path = effective_tail[0].clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1408,10 +1452,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
i += 1;
|
||||
}
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
return ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1450,10 +1496,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
i += 1;
|
||||
}
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
return ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1465,10 +1513,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
// Avoid treating option values as paths (e.g., nl -s " ").
|
||||
let candidates = skip_flag_values(tail, &["-s", "-w", "-v", "-i", "-b"]);
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1483,10 +1533,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
&& is_valid_sed_n_arg(tail.get(1).map(String::as_str)) =>
|
||||
{
|
||||
if let Some(path) = tail.get(2) {
|
||||
let name = short_display_path(path);
|
||||
let path = path.clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1500,3 +1552,30 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn is_abs_like(path: &str) -> bool {
|
||||
if std::path::Path::new(path).is_absolute() {
|
||||
return true;
|
||||
}
|
||||
let mut chars = path.chars();
|
||||
match (chars.next(), chars.next(), chars.next()) {
|
||||
// Windows drive path like C:\
|
||||
(Some(d), Some(':'), Some('\\')) if d.is_ascii_alphabetic() => return true,
|
||||
// UNC path like \\server\share
|
||||
(Some('\\'), Some('\\'), _) => return true,
|
||||
_ => {}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn join_paths(base: &str, rel: &str) -> String {
|
||||
if is_abs_like(rel) {
|
||||
return rel.to_string();
|
||||
}
|
||||
if base.is_empty() {
|
||||
return rel.to_string();
|
||||
}
|
||||
let mut buf = PathBuf::from(base);
|
||||
buf.push(rel);
|
||||
buf.to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ use tracing::error;
|
||||
|
||||
/// Default filename scanned for project-level docs.
|
||||
pub const DEFAULT_PROJECT_DOC_FILENAME: &str = "AGENTS.md";
|
||||
/// Preferred local override for project-level docs.
|
||||
pub const LOCAL_PROJECT_DOC_FILENAME: &str = "AGENTS.override.md";
|
||||
|
||||
/// When both `Config::instructions` and the project doc are present, they will
|
||||
/// be concatenated with the following separator.
|
||||
@@ -178,7 +180,8 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
|
||||
fn candidate_filenames<'a>(config: &'a Config) -> Vec<&'a str> {
|
||||
let mut names: Vec<&'a str> =
|
||||
Vec::with_capacity(1 + config.project_doc_fallback_filenames.len());
|
||||
Vec::with_capacity(2 + config.project_doc_fallback_filenames.len());
|
||||
names.push(LOCAL_PROJECT_DOC_FILENAME);
|
||||
names.push(DEFAULT_PROJECT_DOC_FILENAME);
|
||||
for candidate in &config.project_doc_fallback_filenames {
|
||||
let candidate = candidate.as_str();
|
||||
@@ -381,6 +384,29 @@ mod tests {
|
||||
assert_eq!(res, "root doc\n\ncrate doc");
|
||||
}
|
||||
|
||||
/// AGENTS.override.md is preferred over AGENTS.md when both are present.
|
||||
#[tokio::test]
|
||||
async fn agents_local_md_preferred() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap();
|
||||
fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap();
|
||||
|
||||
let cfg = make_config(&tmp, 4096, None);
|
||||
|
||||
let res = get_user_instructions(&cfg)
|
||||
.await
|
||||
.expect("local doc expected");
|
||||
|
||||
assert_eq!(res, "local");
|
||||
|
||||
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
|
||||
assert_eq!(discovery.len(), 1);
|
||||
assert_eq!(
|
||||
discovery[0].file_name().unwrap().to_string_lossy(),
|
||||
LOCAL_PROJECT_DOC_FILENAME
|
||||
);
|
||||
}
|
||||
|
||||
/// When AGENTS.md is absent but a configured fallback exists, the fallback is used.
|
||||
#[tokio::test]
|
||||
async fn uses_configured_fallback_when_agents_missing() {
|
||||
|
||||
@@ -4,7 +4,9 @@ use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::AbortHandle;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -34,11 +36,23 @@ pub(crate) enum TaskKind {
|
||||
Compact,
|
||||
}
|
||||
|
||||
impl TaskKind {
|
||||
pub(crate) fn header_value(self) -> &'static str {
|
||||
match self {
|
||||
TaskKind::Regular => "standard",
|
||||
TaskKind::Review => "review",
|
||||
TaskKind::Compact => "compact",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RunningTask {
|
||||
pub(crate) handle: AbortHandle,
|
||||
pub(crate) done: Arc<Notify>,
|
||||
pub(crate) kind: TaskKind,
|
||||
pub(crate) task: Arc<dyn SessionTask>,
|
||||
pub(crate) cancellation_token: CancellationToken,
|
||||
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
@@ -105,11 +119,16 @@ impl ActiveTurn {
|
||||
let mut ts = self.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
|
||||
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
|
||||
pub(crate) fn try_clear_pending_sync(&self) {
|
||||
if let Ok(mut ts) = self.turn_state.try_lock() {
|
||||
ts.clear_pending();
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::TaskKind;
|
||||
|
||||
#[test]
|
||||
fn header_value_matches_expected_labels() {
|
||||
assert_eq!(TaskKind::Regular.header_value(), "standard");
|
||||
assert_eq!(TaskKind::Review.header_value(), "review");
|
||||
assert_eq!(TaskKind::Compact.header_value(), "compact");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::compact;
|
||||
@@ -25,6 +26,7 @@ impl SessionTask for CompactTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
||||
}
|
||||
|
||||
@@ -3,9 +3,15 @@ mod regular;
|
||||
mod review;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::select;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
@@ -23,6 +29,8 @@ pub(crate) use compact::CompactTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
|
||||
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
|
||||
|
||||
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionTaskContext {
|
||||
@@ -49,6 +57,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String>;
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
@@ -69,26 +78,42 @@ impl Session {
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let done = Arc::new(Notify::new());
|
||||
|
||||
let done_clone = Arc::clone(&done);
|
||||
let handle = {
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let sub_clone = sub_id.clone();
|
||||
let task_cancellation_token = cancellation_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
let last_agent_message = task_for_run
|
||||
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
|
||||
.run(
|
||||
Arc::clone(&session_ctx),
|
||||
ctx,
|
||||
sub_clone.clone(),
|
||||
input,
|
||||
task_cancellation_token.child_token(),
|
||||
)
|
||||
.await;
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||
|
||||
if !task_cancellation_token.is_cancelled() {
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||
}
|
||||
done_clone.notify_waiters();
|
||||
})
|
||||
.abort_handle()
|
||||
};
|
||||
|
||||
let running_task = RunningTask {
|
||||
handle,
|
||||
done,
|
||||
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||
kind: task_kind,
|
||||
task,
|
||||
cancellation_token,
|
||||
};
|
||||
self.register_new_active_task(sub_id, running_task).await;
|
||||
}
|
||||
@@ -143,14 +168,24 @@ impl Session {
|
||||
task: RunningTask,
|
||||
reason: TurnAbortReason,
|
||||
) {
|
||||
if task.handle.is_finished() {
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
return;
|
||||
}
|
||||
|
||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||
task.cancellation_token.cancel();
|
||||
let session_task = task.task;
|
||||
let handle = task.handle;
|
||||
handle.abort();
|
||||
|
||||
select! {
|
||||
_ = task.done.notified() => {
|
||||
},
|
||||
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
|
||||
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
|
||||
}
|
||||
}
|
||||
|
||||
task.handle.abort();
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task.abort(session_ctx, &sub_id).await;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_task;
|
||||
@@ -25,8 +26,17 @@ impl SessionTask for RegularTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
run_task(
|
||||
sess,
|
||||
ctx,
|
||||
sub_id,
|
||||
input,
|
||||
TaskKind::Regular,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::exit_review_mode;
|
||||
@@ -26,9 +27,18 @@ impl SessionTask for ReviewTask {
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
run_task(
|
||||
sess,
|
||||
ctx,
|
||||
sub_id,
|
||||
input,
|
||||
TaskKind::Review,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
|
||||
272
codex-rs/core/src/tools/handlers/grep_files.rs
Normal file
272
codex-rs/core/src/tools/handlers/grep_files.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct GrepFilesHandler;
|
||||
|
||||
const DEFAULT_LIMIT: usize = 100;
|
||||
const MAX_LIMIT: usize = 2000;
|
||||
const COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
fn default_limit() -> usize {
|
||||
DEFAULT_LIMIT
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GrepFilesArgs {
|
||||
pattern: String,
|
||||
#[serde(default)]
|
||||
include: Option<String>,
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for GrepFilesHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, turn, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"grep_files handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let pattern = args.pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"pattern must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let limit = args.limit.min(MAX_LIMIT);
|
||||
let search_path = turn.resolve_path(args.path.clone());
|
||||
|
||||
verify_path_exists(&search_path).await?;
|
||||
|
||||
let include = args.include.as_deref().map(str::trim).and_then(|val| {
|
||||
if val.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(val.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
let search_results =
|
||||
run_rg_search(pattern, include.as_deref(), &search_path, limit, &turn.cwd).await?;
|
||||
|
||||
if search_results.is_empty() {
|
||||
Ok(ToolOutput::Function {
|
||||
content: "No matches found.".to_string(),
|
||||
success: Some(false),
|
||||
})
|
||||
} else {
|
||||
Ok(ToolOutput::Function {
|
||||
content: search_results.join("\n"),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_path_exists(path: &Path) -> Result<(), FunctionCallError> {
|
||||
tokio::fs::metadata(path).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("unable to access `{}`: {err}", path.display()))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_rg_search(
|
||||
pattern: &str,
|
||||
include: Option<&str>,
|
||||
search_path: &Path,
|
||||
limit: usize,
|
||||
cwd: &Path,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut command = Command::new("rg");
|
||||
command
|
||||
.current_dir(cwd)
|
||||
.arg("--files-with-matches")
|
||||
.arg("--sortr=modified")
|
||||
.arg("--regexp")
|
||||
.arg(pattern)
|
||||
.arg("--no-messages");
|
||||
|
||||
if let Some(glob) = include {
|
||||
command.arg("--glob").arg(glob);
|
||||
}
|
||||
|
||||
command.arg("--").arg(search_path);
|
||||
|
||||
let output = timeout(COMMAND_TIMEOUT, command.output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("rg timed out after 30 seconds".to_string())
|
||||
})?
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to launch rg: {err}. Ensure ripgrep is installed and on PATH."
|
||||
))
|
||||
})?;
|
||||
|
||||
match output.status.code() {
|
||||
Some(0) => Ok(parse_results(&output.stdout, limit)),
|
||||
Some(1) => Ok(Vec::new()),
|
||||
_ => {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(FunctionCallError::RespondToModel(format!(
|
||||
"rg failed: {stderr}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_results(stdout: &[u8], limit: usize) -> Vec<String> {
|
||||
let mut results = Vec::new();
|
||||
for line in stdout.split(|byte| *byte == b'\n') {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(text) = std::str::from_utf8(line) {
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
results.push(text.to_string());
|
||||
if results.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command as StdCommand;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn parses_basic_results() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n";
|
||||
let parsed = parse_results(stdout, 10);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_truncates_after_limit() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n";
|
||||
let parsed = parse_results(stdout, 2);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_returns_results() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
std::fs::write(dir.join("other.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results.iter().any(|path| path.ends_with("match_one.txt")));
|
||||
assert!(results.iter().any(|path| path.ends_with("match_two.txt")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_with_glob_filter() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results.iter().all(|path| path.ends_with("match_one.rs")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_respects_limit() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "alpha one").unwrap();
|
||||
std::fs::write(dir.join("two.txt"), "alpha two").unwrap();
|
||||
std::fs::write(dir.join("three.txt"), "alpha three").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 2, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_handles_no_matches() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 5, dir).await?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rg_available() -> bool {
|
||||
StdCommand::new("rg")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
773
codex-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
773
codex-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
@@ -0,0 +1,773 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::TextContent;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
use crate::protocol::McpToolCallEndEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct McpResourceHandler;
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourcesArgs {
|
||||
/// Lists all resources from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourceTemplatesArgs {
|
||||
/// Lists all resource templates from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ReadResourceArgs {
|
||||
server: String,
|
||||
uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
resource: Resource,
|
||||
}
|
||||
|
||||
impl ResourceWithServer {
|
||||
fn new(server: String, resource: Resource) -> Self {
|
||||
Self { server, resource }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceTemplateWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
template: ResourceTemplate,
|
||||
}
|
||||
|
||||
impl ResourceTemplateWithServer {
|
||||
fn new(server: String, template: ResourceTemplate) -> Self {
|
||||
Self { server, template }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourcesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resources: Vec<ResourceWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourcesPayload {
|
||||
fn from_single_server(server: String, result: ListResourcesResult) -> Self {
|
||||
let resources = result
|
||||
.resources
|
||||
.into_iter()
|
||||
.map(|resource| ResourceWithServer::new(server.clone(), resource))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resources,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(resources_by_server: HashMap<String, Vec<Resource>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<Resource>)> = resources_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resources = Vec::new();
|
||||
for (server, server_resources) in entries {
|
||||
for resource in server_resources {
|
||||
resources.push(ResourceWithServer::new(server.clone(), resource));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resources,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourceTemplatesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resource_templates: Vec<ResourceTemplateWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourceTemplatesPayload {
|
||||
fn from_single_server(server: String, result: ListResourceTemplatesResult) -> Self {
|
||||
let resource_templates = result
|
||||
.resource_templates
|
||||
.into_iter()
|
||||
.map(|template| ResourceTemplateWithServer::new(server.clone(), template))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resource_templates,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(templates_by_server: HashMap<String, Vec<ResourceTemplate>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<ResourceTemplate>)> =
|
||||
templates_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resource_templates = Vec::new();
|
||||
for (server, server_templates) in entries {
|
||||
for template in server_templates {
|
||||
resource_templates.push(ResourceTemplateWithServer::new(server.clone(), template));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resource_templates,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReadResourcePayload {
|
||||
server: String,
|
||||
uri: String,
|
||||
#[serde(flatten)]
|
||||
result: ReadResourceResult,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for McpResourceHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"mcp_resource handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let arguments_value = parse_arguments(arguments.as_str())?;
|
||||
|
||||
match tool_name.as_str() {
|
||||
"list_mcp_resources" => {
|
||||
handle_list_resources(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"list_mcp_resource_templates" => {
|
||||
handle_list_resource_templates(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"read_mcp_resource" => {
|
||||
handle_read_resource(Arc::clone(&session), sub_id, call_id, arguments_value).await
|
||||
}
|
||||
other => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported MCP resource tool: {other}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resources(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourcesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resources".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor.clone().map(|value| ListResourcesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resources(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}"))
|
||||
})?;
|
||||
Ok(ListResourcesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let resources = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resources()
|
||||
.await;
|
||||
Ok(ListResourcesPayload::from_all_servers(resources))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resource_templates(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourceTemplatesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resource_templates".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor
|
||||
.clone()
|
||||
.map(|value| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resource_templates(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"resources/templates/list failed: {err:#}"
|
||||
))
|
||||
})?;
|
||||
Ok(ListResourceTemplatesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let templates = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resource_templates()
|
||||
.await;
|
||||
Ok(ListResourceTemplatesPayload::from_all_servers(templates))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_read_resource(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ReadResourceArgs = parse_args(arguments.clone())?;
|
||||
let ReadResourceArgs { server, uri } = args;
|
||||
let server = normalize_required_string("server", server)?;
|
||||
let uri = normalize_required_string("uri", uri)?;
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone(),
|
||||
tool: "read_mcp_resource".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
||||
let result = session
|
||||
.read_resource(&server, ReadResourceRequestParams { uri: uri.clone() })
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}"))
|
||||
})?;
|
||||
|
||||
Ok(ReadResourcePayload {
|
||||
server,
|
||||
uri,
|
||||
result,
|
||||
})
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallToolResult {
|
||||
CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
annotations: None,
|
||||
text: content.to_string(),
|
||||
r#type: "text".to_string(),
|
||||
})],
|
||||
is_error: success.map(|value| !value),
|
||||
structured_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_tool_call_begin(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_tool_call_end(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
duration: Duration,
|
||||
result: Result<CallToolResult, String>,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
duration,
|
||||
result,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
fn normalize_optional_string(input: Option<String>) -> Option<String> {
|
||||
input.and_then(|value| {
|
||||
let trimmed = value.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_required_string(field: &str, value: String) -> Result<String, FunctionCallError> {
|
||||
match normalize_optional_string(Some(value)) {
|
||||
Some(normalized) => Ok(normalized),
|
||||
None => Err(FunctionCallError::RespondToModel(format!(
|
||||
"{field} must be provided"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_function_output<T>(payload: T) -> Result<ToolOutput, FunctionCallError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let content = serde_json::to_string(&payload).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize MCP resource response: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_arguments(raw_args: &str) -> Result<Option<Value>, FunctionCallError> {
|
||||
if raw_args.trim().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
serde_json::from_str(raw_args).map(Some).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => serde_json::from_value(value).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
}),
|
||||
None => Err(FunctionCallError::RespondToModel(
|
||||
"failed to parse function arguments: expected value".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args_with_default<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned + Default,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => parse_args(Some(value)),
|
||||
None => Ok(T::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
fn resource(uri: &str, name: &str) -> Resource {
|
||||
Resource {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
size: None,
|
||||
title: None,
|
||||
uri: uri.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn template(uri_template: &str, name: &str) -> ResourceTemplate {
|
||||
ResourceTemplate {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
title: None,
|
||||
uri_template: uri_template.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resource_with_server_serializes_server_field() {
|
||||
let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize resource");
|
||||
|
||||
assert_eq!(value["server"], json!("test"));
|
||||
assert_eq!(value["uri"], json!("memo://id"));
|
||||
assert_eq!(value["name"], json!("memo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_single_server_copies_next_cursor() {
|
||||
let result = ListResourcesResult {
|
||||
next_cursor: Some("cursor-1".to_string()),
|
||||
resources: vec![resource("memo://id", "memo")],
|
||||
};
|
||||
let payload = ListResourcesPayload::from_single_server("srv".to_string(), result);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
|
||||
assert_eq!(value["server"], json!("srv"));
|
||||
assert_eq!(value["nextCursor"], json!("cursor-1"));
|
||||
let resources = value["resources"].as_array().expect("resources array");
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0]["server"], json!("srv"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_all_servers_is_sorted() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]);
|
||||
map.insert(
|
||||
"alpha".to_string(),
|
||||
vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")],
|
||||
);
|
||||
|
||||
let payload = ListResourcesPayload::from_all_servers(map);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
let uris: Vec<String> = value["resources"]
|
||||
.as_array()
|
||||
.expect("resources array")
|
||||
.iter()
|
||||
.map(|entry| entry["uri"].as_str().unwrap().to_string())
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
uris,
|
||||
vec![
|
||||
"memo://a-1".to_string(),
|
||||
"memo://a-2".to_string(),
|
||||
"memo://b-1".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_result_from_content_marks_success() {
|
||||
let result = call_tool_result_from_content("{}", Some(true));
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert_eq!(result.content.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_arguments_handles_empty_and_json() {
|
||||
assert!(
|
||||
parse_arguments(" \n\t").unwrap().is_none(),
|
||||
"expected None for empty arguments"
|
||||
);
|
||||
|
||||
let value = parse_arguments(r#"{"server":"figma"}"#)
|
||||
.expect("parse json")
|
||||
.expect("value present");
|
||||
assert_eq!(value["server"], json!("figma"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_with_server_serializes_server_field() {
|
||||
let entry =
|
||||
ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize template");
|
||||
|
||||
assert_eq!(
|
||||
value,
|
||||
json!({
|
||||
"server": "srv",
|
||||
"uriTemplate": "memo://{id}",
|
||||
"name": "memo"
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
pub mod apply_patch;
|
||||
mod exec_stream;
|
||||
mod grep_files;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
mod mcp_resource;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
@@ -13,8 +15,10 @@ pub use plan::PLAN_TOOL;
|
||||
|
||||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use exec_stream::ExecStreamHandler;
|
||||
pub use grep_files::GrepFilesHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use mcp_resource::McpResourceHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -129,7 +129,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
None => ExecutionMode::Shell,
|
||||
};
|
||||
|
||||
sess.services.executor.update_environment(
|
||||
let executor_config = sess.services.executor.update_environment(
|
||||
turn_context.sandbox_policy.clone(),
|
||||
turn_context.cwd.clone(),
|
||||
);
|
||||
@@ -152,6 +152,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
turn_diff_tracker.clone(),
|
||||
prepared_exec,
|
||||
turn_context.approval_policy,
|
||||
executor_config,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -238,6 +239,7 @@ fn truncate_function_error(err: FunctionCallError) -> FunctionCallError {
|
||||
FunctionCallError::RespondToModel(msg) => {
|
||||
FunctionCallError::RespondToModel(format_exec_output(&msg))
|
||||
}
|
||||
FunctionCallError::Denied(msg) => FunctionCallError::Denied(format_exec_output(&msg)),
|
||||
FunctionCallError::Fatal(msg) => FunctionCallError::Fatal(format_exec_output(&msg)),
|
||||
other => other,
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::tools::handlers::PLAN_TOOL;
|
||||
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
|
||||
@@ -33,26 +35,23 @@ pub(crate) struct ToolsConfig {
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) model_family: &'a ModelFamily,
|
||||
pub(crate) include_plan_tool: bool,
|
||||
pub(crate) include_apply_patch_tool: bool,
|
||||
pub(crate) include_web_search_request: bool,
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
pub(crate) features: &'a Features,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
pub fn new(params: &ToolsConfigParams) -> Self {
|
||||
let ToolsConfigParams {
|
||||
model_family,
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_web_search_request,
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
features,
|
||||
} = params;
|
||||
let shell_type = if *use_streamable_shell_tool {
|
||||
let use_streamable_shell_tool = features.enabled(Feature::StreamableShell);
|
||||
let experimental_unified_exec_tool = features.enabled(Feature::UnifiedExec);
|
||||
let include_plan_tool = features.enabled(Feature::PlanTool);
|
||||
let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform);
|
||||
let include_web_search_request = features.enabled(Feature::WebSearchRequest);
|
||||
let include_view_image_tool = features.enabled(Feature::ViewImageTool);
|
||||
|
||||
let shell_type = if use_streamable_shell_tool {
|
||||
ConfigShellToolType::Streamable
|
||||
} else if model_family.uses_local_shell_tool {
|
||||
ConfigShellToolType::Local
|
||||
@@ -64,7 +63,7 @@ impl ToolsConfig {
|
||||
Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform),
|
||||
Some(ApplyPatchToolType::Function) => Some(ApplyPatchToolType::Function),
|
||||
None => {
|
||||
if *include_apply_patch_tool {
|
||||
if include_apply_patch_tool {
|
||||
Some(ApplyPatchToolType::Freeform)
|
||||
} else {
|
||||
None
|
||||
@@ -74,11 +73,11 @@ impl ToolsConfig {
|
||||
|
||||
Self {
|
||||
shell_type,
|
||||
plan_tool: *include_plan_tool,
|
||||
plan_tool: include_plan_tool,
|
||||
apply_patch_tool_type,
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
web_search_request: include_web_search_request,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
experimental_supported_tools: model_family.experimental_supported_tools.clone(),
|
||||
}
|
||||
}
|
||||
@@ -320,6 +319,56 @@ fn create_test_sync_tool() -> ToolSpec {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_grep_files_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"pattern".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Regular expression pattern to search for.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"include".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional glob that limits which files are searched (e.g. \"*.rs\" or \
|
||||
\"*.{ts,tsx}\")."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"path".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Directory or file path to search. Defaults to the session's working directory."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"limit".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum number of file paths to return (defaults to 100).".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "grep_files".to_string(),
|
||||
description: "Finds files whose contents match the pattern and lists them by modification \
|
||||
time."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["pattern".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_file_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
@@ -342,11 +391,72 @@ fn create_read_file_tool() -> ToolSpec {
|
||||
description: Some("The maximum number of lines to return.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"mode".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional mode selector: \"slice\" for simple ranges (default) or \"indentation\" \
|
||||
to expand around an anchor line."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
let mut indentation_properties = BTreeMap::new();
|
||||
indentation_properties.insert(
|
||||
"anchor_line".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Anchor line to center the indentation lookup on (defaults to offset).".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"max_levels".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"How many parent indentation levels (smaller indents) to include.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"include_siblings".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some(
|
||||
"When true, include additional blocks that share the anchor indentation."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"include_header".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some(
|
||||
"Include doc comments or attributes directly above the selected block.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"max_lines".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Hard cap on the number of lines returned when using indentation mode.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"indentation".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: indentation_properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "read_file".to_string(),
|
||||
description:
|
||||
"Reads a local file with 1-indexed line numbers and returns up to the requested number of lines."
|
||||
"Reads a local file with 1-indexed line numbers, supporting slice and indentation-aware block modes."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -401,6 +511,107 @@ fn create_list_dir_tool() -> ToolSpec {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_mcp_resources_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional MCP server name. When omitted, lists resources from every configured server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"cursor".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Opaque cursor returned by a previous list_mcp_resources call for the same server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_mcp_resources".to_string(),
|
||||
description: "Lists resources provided by MCP servers. Resources allow servers to share data that provides context to language models, such as files, database schemas, or application-specific information. Prefer resources over web search when possible.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_mcp_resource_templates_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional MCP server name. When omitted, lists resource templates from all configured servers."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"cursor".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Opaque cursor returned by a previous list_mcp_resource_templates call for the same server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_mcp_resource_templates".to_string(),
|
||||
description: "Lists resource templates provided by MCP servers. Parameterized resource templates allow servers to share data that takes parameters and provides context to language models, such as files, database schemas, or application-specific information. Prefer resource templates over web search when possible.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_mcp_resource_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"MCP server name exactly as configured. Must match the 'server' field returned by list_mcp_resources."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"uri".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Resource URI to read. Must be one of the URIs returned by list_mcp_resources."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "read_mcp_resource".to_string(),
|
||||
description:
|
||||
"Read a specific resource from an MCP server given the server name and resource URI."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["server".to_string(), "uri".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
/// TODO(dylan): deprecate once we get rid of json tool
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
@@ -610,8 +821,10 @@ pub(crate) fn build_specs(
|
||||
use crate::exec_command::create_write_stdin_tool_for_responses_api;
|
||||
use crate::tools::handlers::ApplyPatchHandler;
|
||||
use crate::tools::handlers::ExecStreamHandler;
|
||||
use crate::tools::handlers::GrepFilesHandler;
|
||||
use crate::tools::handlers::ListDirHandler;
|
||||
use crate::tools::handlers::McpHandler;
|
||||
use crate::tools::handlers::McpResourceHandler;
|
||||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
@@ -629,6 +842,7 @@ pub(crate) fn build_specs(
|
||||
let apply_patch_handler = Arc::new(ApplyPatchHandler);
|
||||
let view_image_handler = Arc::new(ViewImageHandler);
|
||||
let mcp_handler = Arc::new(McpHandler);
|
||||
let mcp_resource_handler = Arc::new(McpResourceHandler);
|
||||
|
||||
if config.experimental_unified_exec_tool {
|
||||
builder.push_spec(create_unified_exec_tool());
|
||||
@@ -659,6 +873,13 @@ pub(crate) fn build_specs(
|
||||
builder.register_handler("container.exec", shell_handler.clone());
|
||||
builder.register_handler("local_shell", shell_handler);
|
||||
|
||||
builder.push_spec_with_parallel_support(create_list_mcp_resources_tool(), true);
|
||||
builder.push_spec_with_parallel_support(create_list_mcp_resource_templates_tool(), true);
|
||||
builder.push_spec_with_parallel_support(create_read_mcp_resource_tool(), true);
|
||||
builder.register_handler("list_mcp_resources", mcp_resource_handler.clone());
|
||||
builder.register_handler("list_mcp_resource_templates", mcp_resource_handler.clone());
|
||||
builder.register_handler("read_mcp_resource", mcp_resource_handler);
|
||||
|
||||
if config.plan_tool {
|
||||
builder.push_spec(PLAN_TOOL.clone());
|
||||
builder.register_handler("update_plan", plan_handler);
|
||||
@@ -678,8 +899,16 @@ pub(crate) fn build_specs(
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "read_file")
|
||||
.contains(&"grep_files".to_string())
|
||||
{
|
||||
let grep_files_handler = Arc::new(GrepFilesHandler);
|
||||
builder.push_spec_with_parallel_support(create_grep_files_tool(), true);
|
||||
builder.register_handler("grep_files", grep_files_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.contains(&"read_file".to_string())
|
||||
{
|
||||
let read_file_handler = Arc::new(ReadFileHandler);
|
||||
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
|
||||
@@ -698,8 +927,7 @@ pub(crate) fn build_specs(
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "test_sync_tool")
|
||||
.contains(&"test_sync_tool".to_string())
|
||||
{
|
||||
let test_sync_handler = Arc::new(TestSyncHandler);
|
||||
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
|
||||
@@ -787,40 +1015,54 @@ mod tests {
|
||||
fn test_build_specs() {
|
||||
let model_family = find_family_for_model("codex-mini-latest")
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::PlanTool);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
features.enable(Feature::UnifiedExec);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs_default_shell() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::PlanTool);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
features.enable(Feature::UnifiedExec);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -829,34 +1071,30 @@ mod tests {
|
||||
fn test_parallel_support_flags() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.disable(Feature::ViewImageTool);
|
||||
features.enable(Feature::UnifiedExec);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_test_model_family_includes_sync_tool() {
|
||||
let model_family = find_family_for_model("test-gpt-5-codex")
|
||||
.expect("test-gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.disable(Feature::ViewImageTool);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: false,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
@@ -870,20 +1108,23 @@ mod tests {
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "read_file")
|
||||
);
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "grep_files")
|
||||
);
|
||||
assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs_mcp_tools() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(
|
||||
&config,
|
||||
@@ -928,15 +1169,19 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
],
|
||||
);
|
||||
|
||||
let tool = find_tool(&tools, "test_server/do_something_cool");
|
||||
assert_eq!(
|
||||
tools[3].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
&tool.spec,
|
||||
&ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
properties: BTreeMap::from([
|
||||
@@ -981,14 +1226,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_build_specs_mcp_tools_sorted_by_name() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -1046,6 +1288,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -1058,14 +1303,12 @@ mod tests {
|
||||
fn test_mcp_tool_property_missing_type_defaults_to_string() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1096,6 +1339,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1104,7 +1350,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/search".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1127,14 +1373,12 @@ mod tests {
|
||||
fn test_mcp_tool_integer_normalized_to_number() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1163,6 +1407,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1170,7 +1417,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/paginate".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1191,14 +1438,13 @@ mod tests {
|
||||
fn test_mcp_tool_array_without_items_gets_default_string_items() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
features.enable(Feature::ApplyPatchFreeform);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: true,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1227,6 +1473,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1234,7 +1483,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/tags".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1258,14 +1507,12 @@ mod tests {
|
||||
fn test_mcp_tool_anyof_defaults_to_string() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
|
||||
let (tools, _) = build_specs(
|
||||
@@ -1294,6 +1541,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1301,7 +1551,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/value".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1337,14 +1587,12 @@ mod tests {
|
||||
fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let mut features = Features::with_defaults();
|
||||
features.enable(Feature::UnifiedExec);
|
||||
features.enable(Feature::WebSearchRequest);
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
features: &features,
|
||||
});
|
||||
let (tools, _) = build_specs(
|
||||
&config,
|
||||
@@ -1398,6 +1646,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1406,7 +1657,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
|
||||
@@ -110,11 +110,22 @@ impl ManagedUnifiedExecSession {
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(chunk) => {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
// If we lag behind the broadcast buffer, skip missed
|
||||
// messages but keep the task alive to continue streaming.
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
||||
continue;
|
||||
}
|
||||
// When the sender closes, exit the task.
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ impl UserNotifier {
|
||||
pub(crate) enum UserNotification {
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
AgentTurnComplete {
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
|
||||
/// Messages that the user sent to the agent to initiate the turn.
|
||||
@@ -67,6 +68,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_user_notification() -> Result<()> {
|
||||
let notification = UserNotification::AgentTurnComplete {
|
||||
thread_id: "b5f6c1c2-1111-2222-3333-444455556666".to_string(),
|
||||
turn_id: "12345".to_string(),
|
||||
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
|
||||
last_assistant_message: Some(
|
||||
@@ -76,7 +78,7 @@ mod tests {
|
||||
let serialized = serde_json::to_string(¬ification)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"{"type":"agent-turn-complete","turn-id":"12345","input-messages":["Rename `foo` to `bar` and update the callsites."],"last-assistant-message":"Rename complete and verified `cargo build` succeeds."}"#
|
||||
r#"{"type":"agent-turn-complete","thread-id":"b5f6c1c2-1111-2222-3333-444455556666","turn-id":"12345","input-messages":["Rename `foo` to `bar` and update the callsites."],"last-assistant-message":"Rename complete and verified `cargo build` succeeds."}"#
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -78,6 +78,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -10,8 +10,10 @@ path = "lib.rs"
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
|
||||
@@ -164,6 +164,149 @@ pub fn sandbox_network_env_var() -> &'static str {
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
}
|
||||
|
||||
pub mod fs_wait {
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use notify::RecursiveMode;
|
||||
use notify::Watcher;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::mpsc::RecvTimeoutError;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::task;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
pub async fn wait_for_path_exists(
|
||||
path: impl Into<PathBuf>,
|
||||
timeout: Duration,
|
||||
) -> Result<PathBuf> {
|
||||
let path = path.into();
|
||||
task::spawn_blocking(move || wait_for_path_exists_blocking(path, timeout)).await?
|
||||
}
|
||||
|
||||
pub async fn wait_for_matching_file(
|
||||
root: impl Into<PathBuf>,
|
||||
timeout: Duration,
|
||||
predicate: impl FnMut(&Path) -> bool + Send + 'static,
|
||||
) -> Result<PathBuf> {
|
||||
let root = root.into();
|
||||
task::spawn_blocking(move || {
|
||||
let mut predicate = predicate;
|
||||
blocking_find_matching_file(root, timeout, &mut predicate)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
fn wait_for_path_exists_blocking(path: PathBuf, timeout: Duration) -> Result<PathBuf> {
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
let watch_root = nearest_existing_ancestor(&path);
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut watcher = notify::recommended_watcher(move |res| {
|
||||
let _ = tx.send(res);
|
||||
})?;
|
||||
watcher.watch(&watch_root, RecursiveMode::Recursive)?;
|
||||
|
||||
let deadline = Instant::now() + timeout;
|
||||
loop {
|
||||
if path.exists() {
|
||||
return Ok(path.clone());
|
||||
}
|
||||
let now = Instant::now();
|
||||
if now >= deadline {
|
||||
break;
|
||||
}
|
||||
let remaining = deadline.saturating_duration_since(now);
|
||||
match rx.recv_timeout(remaining) {
|
||||
Ok(Ok(_event)) => {
|
||||
if path.exists() {
|
||||
return Ok(path.clone());
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => return Err(err.into()),
|
||||
Err(RecvTimeoutError::Timeout) => break,
|
||||
Err(RecvTimeoutError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
if path.exists() {
|
||||
Ok(path)
|
||||
} else {
|
||||
Err(anyhow!("timed out waiting for {:?}", path))
|
||||
}
|
||||
}
|
||||
|
||||
fn blocking_find_matching_file(
|
||||
root: PathBuf,
|
||||
timeout: Duration,
|
||||
predicate: &mut impl FnMut(&Path) -> bool,
|
||||
) -> Result<PathBuf> {
|
||||
let root = wait_for_path_exists_blocking(root, timeout)?;
|
||||
|
||||
if let Some(found) = scan_for_match(&root, predicate) {
|
||||
return Ok(found);
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut watcher = notify::recommended_watcher(move |res| {
|
||||
let _ = tx.send(res);
|
||||
})?;
|
||||
watcher.watch(&root, RecursiveMode::Recursive)?;
|
||||
|
||||
let deadline = Instant::now() + timeout;
|
||||
|
||||
while Instant::now() < deadline {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
match rx.recv_timeout(remaining) {
|
||||
Ok(Ok(_event)) => {
|
||||
if let Some(found) = scan_for_match(&root, predicate) {
|
||||
return Ok(found);
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => return Err(err.into()),
|
||||
Err(RecvTimeoutError::Timeout) => break,
|
||||
Err(RecvTimeoutError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(found) = scan_for_match(&root, predicate) {
|
||||
Ok(found)
|
||||
} else {
|
||||
Err(anyhow!("timed out waiting for matching file in {:?}", root))
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_for_match(root: &Path, predicate: &mut impl FnMut(&Path) -> bool) -> Option<PathBuf> {
|
||||
for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
|
||||
let path = entry.path();
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if predicate(path) {
|
||||
return Some(path.to_path_buf());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn nearest_existing_ancestor(path: &Path) -> PathBuf {
|
||||
let mut current = path;
|
||||
loop {
|
||||
if current.exists() {
|
||||
return current.to_path_buf();
|
||||
}
|
||||
match current.parent() {
|
||||
Some(parent) => current = parent,
|
||||
None => return PathBuf::from("."),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! skip_if_sandbox {
|
||||
() => {{
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::mem::swap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
@@ -39,6 +40,12 @@ impl TestCodexBuilder {
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.codex_linux_sandbox_exe = Some(PathBuf::from(
|
||||
assert_cmd::Command::cargo_bin("codex")?
|
||||
.get_program()
|
||||
.to_os_string(),
|
||||
));
|
||||
|
||||
let mut mutators = vec![];
|
||||
swap(&mut self.config_mutators, &mut mutators);
|
||||
|
||||
|
||||
103
codex-rs/core/tests/responses_headers.rs
Normal file
103
codex-rs/core/tests/responses_headers.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use futures::StreamExt;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::matchers::header;
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_stream_includes_task_type_header() {
|
||||
core_test_support::skip_if_no_network!();
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_body = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
|
||||
let request_recorder = responses::mount_sse_once_match(
|
||||
&server,
|
||||
header("Codex-Task-Type", "standard"),
|
||||
response_body,
|
||||
)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().expect("failed to create TempDir");
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let config = Arc::new(config);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
|
||||
let otel_event_manager = OtelEventManager::new(
|
||||
conversation_id,
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
otel_event_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let request = request_recorder.single_request();
|
||||
assert_eq!(
|
||||
request.header("Codex-Task-Type").as_deref(),
|
||||
Some("standard")
|
||||
);
|
||||
}
|
||||
@@ -1,12 +1,11 @@
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::protocol::GitInfo;
|
||||
use core_test_support::fs_wait;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use walkdir::WalkDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -211,12 +210,12 @@ async fn responses_api_stream_cli() {
|
||||
|
||||
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_creates_and_checks_session_file() {
|
||||
async fn integration_creates_and_checks_session_file() -> anyhow::Result<()> {
|
||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||
skip_if_no_network!();
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
// 1. Temp home so we read/write isolated session files.
|
||||
let home = TempDir::new().unwrap();
|
||||
let home = TempDir::new()?;
|
||||
|
||||
// 2. Unique marker we'll look for in the session log.
|
||||
let marker = format!("integration-test-{}", Uuid::new_v4());
|
||||
@@ -254,63 +253,20 @@ async fn integration_creates_and_checks_session_file() {
|
||||
|
||||
// Wait for sessions dir to appear.
|
||||
let sessions_dir = home.path().join("sessions");
|
||||
let dir_deadline = Instant::now() + Duration::from_secs(5);
|
||||
while !sessions_dir.exists() && Instant::now() < dir_deadline {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
assert!(sessions_dir.exists(), "sessions directory never appeared");
|
||||
fs_wait::wait_for_path_exists(&sessions_dir, Duration::from_secs(5)).await?;
|
||||
|
||||
// Find the session file that contains `marker`.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut matching_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && matching_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let path = entry.path();
|
||||
let Ok(content) = std::fs::read_to_string(path) else {
|
||||
continue;
|
||||
};
|
||||
let mut lines = content.lines();
|
||||
if lines.next().is_none() {
|
||||
continue;
|
||||
}
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let item: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = payload.get("content")
|
||||
&& c.to_string().contains(&marker)
|
||||
{
|
||||
matching_path = Some(path.to_path_buf());
|
||||
break;
|
||||
}
|
||||
}
|
||||
let marker_clone = marker.clone();
|
||||
let path = fs_wait::wait_for_matching_file(&sessions_dir, Duration::from_secs(10), move |p| {
|
||||
if p.extension().and_then(|ext| ext.to_str()) != Some("jsonl") {
|
||||
return false;
|
||||
}
|
||||
if matching_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
||||
let path = match matching_path {
|
||||
Some(p) => p,
|
||||
None => panic!("No session file containing the marker was found"),
|
||||
};
|
||||
let Ok(content) = std::fs::read_to_string(p) else {
|
||||
return false;
|
||||
};
|
||||
content.contains(&marker_clone)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Basic sanity checks on location and metadata.
|
||||
let rel = match path.strip_prefix(&sessions_dir) {
|
||||
@@ -418,42 +374,25 @@ async fn integration_creates_and_checks_session_file() {
|
||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||
|
||||
// Find the new session file containing the resumed marker.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut resumed_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && resumed_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
let marker2_clone = marker2.clone();
|
||||
let resumed_path =
|
||||
fs_wait::wait_for_matching_file(&sessions_dir, Duration::from_secs(10), move |p| {
|
||||
if p.extension().and_then(|ext| ext.to_str()) != Some("jsonl") {
|
||||
return false;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let p = entry.path();
|
||||
let Ok(c) = std::fs::read_to_string(p) else {
|
||||
continue;
|
||||
};
|
||||
if c.contains(&marker2) {
|
||||
resumed_path = Some(p.to_path_buf());
|
||||
break;
|
||||
}
|
||||
}
|
||||
if resumed_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
std::fs::read_to_string(p)
|
||||
.map(|content| content.contains(&marker2_clone))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let resumed_path = resumed_path.expect("No resumed session file found containing the marker2");
|
||||
// Resume should write to the existing log file.
|
||||
assert_eq!(
|
||||
resumed_path, path,
|
||||
"resume should create a new session file"
|
||||
);
|
||||
|
||||
let resumed_content = std::fs::read_to_string(&resumed_path).unwrap();
|
||||
let resumed_content = std::fs::read_to_string(&resumed_path)?;
|
||||
assert!(
|
||||
resumed_content.contains(&marker),
|
||||
"resumed file missing original marker"
|
||||
@@ -462,6 +401,7 @@ async fn integration_creates_and_checks_session_file() {
|
||||
resumed_content.contains(&marker2),
|
||||
"resumed file missing resumed marker"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Integration test to verify git info is collected and recorded in session files.
|
||||
|
||||
@@ -657,6 +657,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -22,6 +22,7 @@ use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_failed;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
@@ -38,6 +39,8 @@ const SECOND_LARGE_REPLY: &str = "SECOND_LARGE_REPLY";
|
||||
const FIRST_AUTO_SUMMARY: &str = "FIRST_AUTO_SUMMARY";
|
||||
const SECOND_AUTO_SUMMARY: &str = "SECOND_AUTO_SUMMARY";
|
||||
const FINAL_REPLY: &str = "FINAL_REPLY";
|
||||
const CONTEXT_LIMIT_MESSAGE: &str =
|
||||
"Your input exceeds the context window of this model. Please adjust your input and try again.";
|
||||
const DUMMY_FUNCTION_NAME: &str = "unsupported_tool";
|
||||
const DUMMY_CALL_ID: &str = "call-multi-auto";
|
||||
|
||||
@@ -622,6 +625,130 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn manual_compact_retries_after_context_window_error() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let user_turn = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed("r1"),
|
||||
]);
|
||||
let compact_failed = sse_failed(
|
||||
"resp-fail",
|
||||
"context_length_exceeded",
|
||||
CONTEXT_LIMIT_MESSAGE,
|
||||
);
|
||||
let compact_succeeds = sse(vec![
|
||||
ev_assistant_message("m2", SUMMARY_TEXT),
|
||||
ev_completed("r2"),
|
||||
]);
|
||||
|
||||
let request_log = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
user_turn.clone(),
|
||||
compact_failed.clone(),
|
||||
compact_succeeds.clone(),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200_000);
|
||||
let codex = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"))
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "first turn".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex.submit(Op::Compact).await.unwrap();
|
||||
|
||||
let EventMsg::BackgroundEvent(event) =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::BackgroundEvent(_))).await
|
||||
else {
|
||||
panic!("expected background event after compact retry");
|
||||
};
|
||||
assert!(
|
||||
event.message.contains("Trimmed 1 older conversation item"),
|
||||
"background event should mention trimmed item count: {}",
|
||||
event.message
|
||||
);
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = request_log.requests();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"expected user turn and two compact attempts"
|
||||
);
|
||||
|
||||
let compact_attempt = requests[1].body_json();
|
||||
let retry_attempt = requests[2].body_json();
|
||||
|
||||
let compact_input = compact_attempt["input"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("compact attempt missing input array: {compact_attempt}"));
|
||||
let retry_input = retry_attempt["input"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("retry attempt missing input array: {retry_attempt}"));
|
||||
assert_eq!(
|
||||
compact_input
|
||||
.last()
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(|text| text.as_str()),
|
||||
Some(SUMMARIZATION_PROMPT),
|
||||
"compact attempt should include summarization prompt"
|
||||
);
|
||||
assert_eq!(
|
||||
retry_input
|
||||
.last()
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(|text| text.as_str()),
|
||||
Some(SUMMARIZATION_PROMPT),
|
||||
"retry attempt should include summarization prompt"
|
||||
);
|
||||
assert_eq!(
|
||||
retry_input.len(),
|
||||
compact_input.len().saturating_sub(1),
|
||||
"retry should drop exactly one history item (before {} vs after {})",
|
||||
compact_input.len(),
|
||||
retry_input.len()
|
||||
);
|
||||
if let (Some(first_before), Some(first_after)) = (compact_input.first(), retry_input.first()) {
|
||||
assert_ne!(
|
||||
first_before, first_after,
|
||||
"retry should drop the oldest conversation item"
|
||||
);
|
||||
} else {
|
||||
panic!("expected non-empty compact inputs");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_events() {
|
||||
skip_if_no_network!();
|
||||
|
||||
237
codex-rs/core/tests/suite/grep_files.rs
Normal file
237
codex-rs/core/tests/suite/grep_files.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::process::Command as StdCommand;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
const MODEL_WITH_TOOL: &str = "test-gpt-5-codex";
|
||||
|
||||
fn ripgrep_available() -> bool {
|
||||
StdCommand::new("rg")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
macro_rules! skip_if_ripgrep_missing {
|
||||
($ret:expr $(,)?) => {{
|
||||
if !ripgrep_available() {
|
||||
eprintln!("rg not available in PATH; skipping test");
|
||||
return $ret;
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn grep_files_tool_collects_matches() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_ripgrep_missing!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_test_codex(&server).await?;
|
||||
|
||||
let search_dir = test.cwd.path().join("src");
|
||||
std::fs::create_dir_all(&search_dir)?;
|
||||
let alpha = search_dir.join("alpha.rs");
|
||||
let beta = search_dir.join("beta.rs");
|
||||
let gamma = search_dir.join("gamma.txt");
|
||||
std::fs::write(&alpha, "alpha needle\n")?;
|
||||
std::fs::write(&beta, "beta needle\n")?;
|
||||
std::fs::write(&gamma, "needle in text but excluded\n")?;
|
||||
|
||||
let call_id = "grep-files-collect";
|
||||
let arguments = serde_json::json!({
|
||||
"pattern": "needle",
|
||||
"path": search_dir.to_string_lossy(),
|
||||
"include": "*.rs",
|
||||
})
|
||||
.to_string();
|
||||
|
||||
mount_tool_sequence(&server, call_id, &arguments, "grep_files").await;
|
||||
submit_turn(&test, "please find uses of needle").await?;
|
||||
|
||||
let bodies = recorded_bodies(&server).await?;
|
||||
let tool_output = find_tool_output(&bodies, call_id).expect("tool output present");
|
||||
let payload = tool_output.get("output").expect("output field present");
|
||||
let (content_opt, success_opt) = extract_content_and_success(payload);
|
||||
let content = content_opt.expect("content present");
|
||||
let success = success_opt.unwrap_or(true);
|
||||
assert!(success, "expected success for matches, got {payload:?}");
|
||||
|
||||
let entries = collect_file_names(content);
|
||||
assert_eq!(entries.len(), 2, "content: {content}");
|
||||
assert!(
|
||||
entries.contains("alpha.rs"),
|
||||
"missing alpha.rs in {entries:?}"
|
||||
);
|
||||
assert!(
|
||||
entries.contains("beta.rs"),
|
||||
"missing beta.rs in {entries:?}"
|
||||
);
|
||||
assert!(
|
||||
!entries.contains("gamma.txt"),
|
||||
"txt file should be filtered out: {entries:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn grep_files_tool_reports_empty_results() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_ripgrep_missing!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_test_codex(&server).await?;
|
||||
|
||||
let search_dir = test.cwd.path().join("logs");
|
||||
std::fs::create_dir_all(&search_dir)?;
|
||||
std::fs::write(search_dir.join("output.txt"), "no hits here")?;
|
||||
|
||||
let call_id = "grep-files-empty";
|
||||
let arguments = serde_json::json!({
|
||||
"pattern": "needle",
|
||||
"path": search_dir.to_string_lossy(),
|
||||
"limit": 5,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
mount_tool_sequence(&server, call_id, &arguments, "grep_files").await;
|
||||
submit_turn(&test, "search again").await?;
|
||||
|
||||
let bodies = recorded_bodies(&server).await?;
|
||||
let tool_output = find_tool_output(&bodies, call_id).expect("tool output present");
|
||||
let payload = tool_output.get("output").expect("output field present");
|
||||
let (content_opt, success_opt) = extract_content_and_success(payload);
|
||||
let content = content_opt.expect("content present");
|
||||
if let Some(success) = success_opt {
|
||||
assert!(!success, "expected success=false payload: {payload:?}");
|
||||
}
|
||||
assert_eq!(content, "No matches found.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn build_test_codex(server: &wiremock::MockServer) -> Result<TestCodex> {
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = MODEL_WITH_TOOL.to_string();
|
||||
config.model_family =
|
||||
find_family_for_model(MODEL_WITH_TOOL).expect("model family for test model");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mount_tool_sequence(
|
||||
server: &wiremock::MockServer,
|
||||
call_id: &str,
|
||||
arguments: &str,
|
||||
tool_name: &str,
|
||||
) {
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, tool_name, arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(server, any(), second_response).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn recorded_bodies(server: &wiremock::MockServer) -> Result<Vec<Value>> {
|
||||
let requests = server.received_requests().await.expect("requests recorded");
|
||||
Ok(requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn find_tool_output<'a>(requests: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
requests.iter().find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_file_names(content: &str) -> HashSet<String> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
if line.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
Path::new(line)
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().into_owned())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_content_and_success(value: &Value) -> (Option<&str>, Option<bool>) {
|
||||
match value {
|
||||
Value::String(text) => (Some(text.as_str()), None),
|
||||
Value::Object(obj) => (
|
||||
obj.get("content").and_then(Value::as_str),
|
||||
obj.get("success").and_then(Value::as_bool),
|
||||
),
|
||||
_ => (None, None),
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ mod compact_resume_fork;
|
||||
mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod grep_files;
|
||||
mod json_result;
|
||||
mod list_dir;
|
||||
mod live_cli;
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -56,12 +57,12 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
config.model = model.to_string();
|
||||
config.model_family =
|
||||
find_family_for_model(model).unwrap_or_else(|| panic!("unknown model family for {model}"));
|
||||
config.include_plan_tool = false;
|
||||
config.include_apply_patch_tool = false;
|
||||
config.include_view_image_tool = false;
|
||||
config.tools_web_search_request = false;
|
||||
config.use_experimental_streamable_shell_tool = false;
|
||||
config.use_experimental_unified_exec_tool = false;
|
||||
config.features.disable(Feature::PlanTool);
|
||||
config.features.disable(Feature::ApplyPatchFreeform);
|
||||
config.features.disable(Feature::ViewImageTool);
|
||||
config.features.disable(Feature::WebSearchRequest);
|
||||
config.features.disable(Feature::StreamableShell);
|
||||
config.features.disable(Feature::UnifiedExec);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
@@ -93,21 +94,37 @@ async fn model_selects_expected_tools() {
|
||||
let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await;
|
||||
assert_eq!(
|
||||
codex_tools,
|
||||
vec!["local_shell".to_string()],
|
||||
vec![
|
||||
"local_shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string()
|
||||
],
|
||||
"codex-mini-latest should expose the local shell tool",
|
||||
);
|
||||
|
||||
let o3_tools = collect_tool_identifiers_for_model("o3").await;
|
||||
assert_eq!(
|
||||
o3_tools,
|
||||
vec!["shell".to_string()],
|
||||
vec![
|
||||
"shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string()
|
||||
],
|
||||
"o3 should expose the generic shell tool",
|
||||
);
|
||||
|
||||
let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await;
|
||||
assert_eq!(
|
||||
gpt5_codex_tools,
|
||||
vec!["shell".to_string(), "apply_patch".to_string(),],
|
||||
vec![
|
||||
"shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string(),
|
||||
"apply_patch".to_string()
|
||||
],
|
||||
"gpt-5-codex should expose the apply_patch tool",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -99,10 +100,10 @@ async fn codex_mini_latest_tools() {
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
config.features.disable(Feature::ApplyPatchFreeform);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
config.include_apply_patch_tool = false;
|
||||
config.model = "codex-mini-latest".to_string();
|
||||
config.model_family = find_family_for_model("codex-mini-latest").unwrap();
|
||||
|
||||
@@ -185,7 +186,7 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
config.include_plan_tool = true;
|
||||
config.features.enable(Feature::PlanTool);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
@@ -222,10 +223,28 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
// our internal implementation is responsible for keeping tools in sync
|
||||
// with the OpenAI schema, so we just verify the tool presence here
|
||||
let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([
|
||||
("gpt-5", vec!["shell", "update_plan", "view_image"]),
|
||||
(
|
||||
"gpt-5",
|
||||
vec![
|
||||
"shell",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"view_image",
|
||||
],
|
||||
),
|
||||
(
|
||||
"gpt-5-codex",
|
||||
vec!["shell", "update_plan", "apply_patch", "view_image"],
|
||||
vec![
|
||||
"shell",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"apply_patch",
|
||||
"view_image",
|
||||
],
|
||||
),
|
||||
]);
|
||||
let expected_tools_names = tools_by_model
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
@@ -9,6 +10,7 @@ use std::time::UNIX_EPOCH;
|
||||
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use codex_core::features::Feature;
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -34,6 +36,7 @@ use tokio::time::sleep;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
#[serial(mcp_test_value)]
|
||||
async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -74,7 +77,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
@@ -85,7 +88,10 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -104,7 +110,143 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
#[serial(mcp_test_value)]
|
||||
async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-1234";
|
||||
let server_name = "rmcp_whitelist";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env-from-whitelist";
|
||||
let _guard = EnvVarGuard::set("MCP_TEST_VALUE", OsStr::new(expected_env_value));
|
||||
let rmcp_test_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_stdio_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: vec!["MCP_TEST_VALUE".to_string()],
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
@@ -226,14 +368,17 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -252,7 +397,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
@@ -406,14 +551,17 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -432,7 +580,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -9,9 +10,12 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
@@ -20,8 +24,11 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str, sandbox_policy: SandboxPolicy) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
@@ -71,13 +78,28 @@ fn find_function_call_output<'a>(bodies: &'a [Value], call_id: &str) -> Option<&
|
||||
None
|
||||
}
|
||||
|
||||
fn find_custom_tool_call_output<'a>(bodies: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some("custom_tool_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
{
|
||||
return Some(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_stays_json_without_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = false;
|
||||
config.features.disable(Feature::ApplyPatchFreeform);
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family");
|
||||
});
|
||||
@@ -119,7 +141,12 @@ async fn shell_output_stays_json_without_freeform_apply_patch() -> Result<()> {
|
||||
.and_then(Value::as_str)
|
||||
.expect("shell output string");
|
||||
|
||||
let parsed: Value = serde_json::from_str(output)?;
|
||||
let mut parsed: Value = serde_json::from_str(output)?;
|
||||
if let Some(metadata) = parsed.get_mut("metadata").and_then(Value::as_object_mut) {
|
||||
// duration_seconds is non-deterministic; remove it for deep equality
|
||||
let _ = metadata.remove("duration_seconds");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
parsed
|
||||
.get("metadata")
|
||||
@@ -143,7 +170,7 @@ async fn shell_output_is_structured_with_freeform_apply_patch() -> Result<()> {
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
config.features.enable(Feature::ApplyPatchFreeform);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
@@ -198,6 +225,83 @@ freeform shell
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_for_freeform_tool_records_duration() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
let sleep_cmd = vec!["/bin/bash", "-c", "sleep 1"];
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
let sleep_cmd = vec!["/bin/bash", "-c", "sleep 1"];
|
||||
|
||||
#[cfg(windows)]
|
||||
let sleep_cmd = "timeout 1";
|
||||
|
||||
let call_id = "shell-structured";
|
||||
let args = json!({
|
||||
"command": sleep_cmd,
|
||||
"timeout_ms": 2_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the structured shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("structured output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("structured output string");
|
||||
|
||||
let expected_pattern = r#"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
$"#;
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
let wall_time_regex = Regex::new(r"(?m)^Wall (?:time|Clock): ([0-9]+(?:\.[0-9]+)?) seconds$")
|
||||
.expect("compile wall time regex");
|
||||
let wall_time_seconds = wall_time_regex
|
||||
.captures(output)
|
||||
.and_then(|caps| caps.get(1))
|
||||
.and_then(|value| value.as_str().parse::<f32>().ok())
|
||||
.expect("expected structured shell output to contain wall time seconds");
|
||||
assert!(
|
||||
wall_time_seconds > 0.5,
|
||||
"expected wall time to be greater than zero seconds, got {wall_time_seconds}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_reserializes_truncated_content() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -213,7 +317,7 @@ async fn shell_output_reserializes_truncated_content() -> Result<()> {
|
||||
let call_id = "shell-truncated";
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "seq 1 400"],
|
||||
"timeout_ms": 1_000,
|
||||
"timeout_ms": 5_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
@@ -275,3 +379,428 @@ $"#;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_custom_tool_output_is_structured() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-structured";
|
||||
let file_name = "structured.txt";
|
||||
let patch = format!(
|
||||
r#"*** Begin Patch
|
||||
*** Add File: {file_name}
|
||||
+from custom tool
|
||||
*** End Patch
|
||||
"#
|
||||
);
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_custom_tool_call(call_id, "apply_patch", &patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"apply the patch via custom tool",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_custom_tool_call_output(&bodies, call_id).expect("apply_patch output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("apply_patch output string");
|
||||
|
||||
let expected_pattern = format!(
|
||||
r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
Success. Updated the following files:
|
||||
A {file_name}
|
||||
?$"
|
||||
);
|
||||
assert_regex_match(&expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_custom_tool_call_creates_file() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-add-file";
|
||||
let file_name = "custom_tool_apply_patch.txt";
|
||||
let patch = format!(
|
||||
"*** Begin Patch\n*** Add File: {file_name}\n+custom tool content\n*** End Patch\n"
|
||||
);
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_custom_tool_call(call_id, "apply_patch", &patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "apply_patch done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"apply the patch via custom tool to create a file",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_custom_tool_call_output(&bodies, call_id).expect("apply_patch output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("apply_patch output string");
|
||||
|
||||
let expected_pattern = format!(
|
||||
r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
Success. Updated the following files:
|
||||
A {file_name}
|
||||
?$"
|
||||
);
|
||||
assert_regex_match(&expected_pattern, output);
|
||||
|
||||
let new_file_path = test.cwd.path().join(file_name);
|
||||
let created_contents = fs::read_to_string(&new_file_path)?;
|
||||
assert_eq!(
|
||||
created_contents, "custom tool content\n",
|
||||
"expected file contents for {file_name}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_custom_tool_call_updates_existing_file() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-update-file";
|
||||
let file_name = "custom_tool_apply_patch_existing.txt";
|
||||
let file_path = test.cwd.path().join(file_name);
|
||||
fs::write(&file_path, "before\n")?;
|
||||
let patch = format!(
|
||||
"*** Begin Patch\n*** Update File: {file_name}\n@@\n-before\n+after\n*** End Patch\n"
|
||||
);
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_custom_tool_call(call_id, "apply_patch", &patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "apply_patch update done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"apply the patch via custom tool to update a file",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_custom_tool_call_output(&bodies, call_id).expect("apply_patch output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("apply_patch output string");
|
||||
|
||||
let expected_pattern = format!(
|
||||
r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
Success. Updated the following files:
|
||||
M {file_name}
|
||||
?$"
|
||||
);
|
||||
assert_regex_match(&expected_pattern, output);
|
||||
|
||||
let updated_contents = fs::read_to_string(file_path)?;
|
||||
assert_eq!(updated_contents, "after\n", "expected updated file content");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_custom_tool_call_reports_failure_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-failure";
|
||||
let missing_file = "missing_custom_tool_apply_patch.txt";
|
||||
let patch = format!(
|
||||
"*** Begin Patch\n*** Update File: {missing_file}\n@@\n-before\n+after\n*** End Patch\n"
|
||||
);
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_custom_tool_call(call_id, "apply_patch", &patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "apply_patch failure done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"attempt a failing apply_patch via custom tool",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_custom_tool_call_output(&bodies, call_id).expect("apply_patch output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("apply_patch output string");
|
||||
|
||||
let expected_output = format!(
|
||||
"apply_patch verification failed: Failed to read file to update {}/{missing_file}: No such file or directory (os error 2)",
|
||||
test.cwd.path().to_string_lossy()
|
||||
);
|
||||
assert_eq!(output, expected_output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_function_call_output_is_structured() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-function";
|
||||
let file_name = "function_apply_patch.txt";
|
||||
let patch =
|
||||
format!("*** Begin Patch\n*** Add File: {file_name}\n+via function call\n*** End Patch\n");
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_apply_patch_function_call(call_id, &patch),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "apply_patch function done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"apply the patch via function-call apply_patch",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("apply_patch function output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("apply_patch output string");
|
||||
|
||||
let expected_pattern = format!(
|
||||
r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
Success. Updated the following files:
|
||||
A {file_name}
|
||||
?$"
|
||||
);
|
||||
assert_regex_match(&expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_is_structured_for_nonzero_exit() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("gpt-5-codex").expect("gpt-5-codex is a model family");
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-nonzero-exit";
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "exit 42"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "shell failure handled"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the failing shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item = find_function_call_output(&bodies, call_id).expect("shell output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("shell output string");
|
||||
|
||||
let expected_pattern = r"(?s)^Exit code: 42
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
?$";
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn local_shell_call_output_is_structured() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("gpt-5-codex").expect("gpt-5-codex is a model family");
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "local-shell-call";
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_local_shell_call(call_id, "completed", vec!["/bin/echo", "local shell"]),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "local shell done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the local shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("local shell output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("local shell output string");
|
||||
|
||||
let expected_pattern = r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
local shell
|
||||
?$";
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::fs;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -104,7 +107,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_plan_tool = true;
|
||||
config.features.enable(Feature::PlanTool);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -191,7 +194,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_plan_tool = true;
|
||||
config.features.enable(Feature::PlanTool);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -285,7 +288,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
config.features.enable(Feature::ApplyPatchFreeform);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -294,15 +297,19 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let file_name = "notes.txt";
|
||||
let file_path = cwd.path().join(file_name);
|
||||
let call_id = "apply-patch-call";
|
||||
let patch_content = r#"*** Begin Patch
|
||||
*** Add File: notes.txt
|
||||
let patch_content = format!(
|
||||
r#"*** Begin Patch
|
||||
*** Add File: {file_name}
|
||||
+Tool harness apply patch
|
||||
*** End Patch"#;
|
||||
*** End Patch"#
|
||||
);
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_apply_patch_function_call(call_id, &patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
@@ -351,6 +358,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
assert!(saw_patch_begin, "expected PatchApplyBegin event");
|
||||
let patch_end_success =
|
||||
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
|
||||
assert!(patch_end_success);
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
@@ -360,38 +368,21 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
);
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
|
||||
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
|
||||
let exit_code = exec_output["metadata"]["exit_code"]
|
||||
.as_i64()
|
||||
.expect("exit_code present");
|
||||
let summary = exec_output["output"].as_str().expect("output field");
|
||||
assert_eq!(
|
||||
exit_code, 0,
|
||||
"expected apply_patch exit_code=0, got {exit_code}, summary: {summary:?}"
|
||||
);
|
||||
assert!(
|
||||
patch_end_success,
|
||||
"expected PatchApplyEnd success flag, summary: {summary:?}"
|
||||
);
|
||||
assert!(
|
||||
summary.contains("Success."),
|
||||
"expected apply_patch summary to note success, got {summary:?}"
|
||||
);
|
||||
let expected_pattern = format!(
|
||||
r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
Success. Updated the following files:
|
||||
A {file_name}
|
||||
?$"
|
||||
);
|
||||
assert_regex_match(&expected_pattern, output_text);
|
||||
|
||||
let patched_path = cwd.path().join("notes.txt");
|
||||
let contents = std::fs::read_to_string(&patched_path)
|
||||
.unwrap_or_else(|e| panic!("failed reading {}: {e}", patched_path.display()));
|
||||
assert_eq!(contents, "Tool harness apply patch\n");
|
||||
} else {
|
||||
assert!(
|
||||
output_text.contains("codex-run-as-apply-patch"),
|
||||
"expected apply_patch failure message to mention codex-run-as-apply-patch, got {output_text:?}"
|
||||
);
|
||||
assert!(
|
||||
!patch_end_success,
|
||||
"expected PatchApplyEnd to report success=false when apply_patch invocation fails"
|
||||
);
|
||||
}
|
||||
let updated_contents = fs::read_to_string(file_path)?;
|
||||
assert_eq!(
|
||||
updated_contents, "Tool harness apply patch\n",
|
||||
"expected updated file content"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -403,7 +394,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
config.features.enable(Feature::ApplyPatchFreeform);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
|
||||
@@ -63,8 +63,9 @@ async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Re
|
||||
}
|
||||
|
||||
fn assert_parallel_duration(actual: Duration) {
|
||||
// Allow headroom for runtime overhead while still differentiating from serial execution.
|
||||
assert!(
|
||||
actual < Duration::from_millis(500),
|
||||
actual < Duration::from_millis(750),
|
||||
"expected parallel execution to finish quickly, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -293,7 +294,11 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
let mock = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.use_experimental_unified_exec_tool = use_unified_exec;
|
||||
if use_unified_exec {
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
} else {
|
||||
config.features.disable(Feature::UnifiedExec);
|
||||
}
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
@@ -403,79 +408,6 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_sandbox_denied_truncates_error_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-denied";
|
||||
let long_line = "this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
let script = format!(
|
||||
"for i in $(seq 1 500); do >&2 echo '{long_line}'; done; cat <<'EOF' > denied.txt\ncontent\nEOF",
|
||||
);
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", script],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"attempt to write in read-only sandbox",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let denied_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output = denied_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("denied output string");
|
||||
|
||||
let sandbox_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
failed in sandbox: .*?(?:Operation not permitted|Permission denied|Read-only file system).*?
|
||||
\[\.{3} omitted \d+ of \d+ lines \.{3}\]
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
\n?$"#;
|
||||
let sandbox_regex = Regex::new(sandbox_pattern)?;
|
||||
if !sandbox_regex.is_match(output) {
|
||||
let fallback_pattern = r#"(?s)^Total output lines: \d+
|
||||
|
||||
failed in sandbox: this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
.*(?:Operation not permitted|Permission denied|Read-only file system).*$"#;
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -42,7 +43,13 @@ fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, Value>> {
|
||||
if let Some(call_id) = item.get("call_id").and_then(Value::as_str) {
|
||||
let content = extract_output_text(item)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing tool output content"))?;
|
||||
let parsed: Value = serde_json::from_str(content)?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let parsed: Value = serde_json::from_str(trimmed).map_err(|err| {
|
||||
anyhow::anyhow!("failed to parse tool output content {trimmed:?}: {err}")
|
||||
})?;
|
||||
outputs.insert(call_id.to_string(), parsed);
|
||||
}
|
||||
}
|
||||
@@ -59,7 +66,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -168,7 +175,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
async fn unified_exec_streams_after_lagged_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
@@ -176,6 +183,132 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let script = r#"python3 - <<'PY'
|
||||
import sys
|
||||
import time
|
||||
|
||||
chunk = b'x' * (1 << 20)
|
||||
for _ in range(4):
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.flush()
|
||||
|
||||
time.sleep(0.2)
|
||||
for _ in range(5):
|
||||
sys.stdout.write("TAIL-MARKER\n")
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.05)
|
||||
|
||||
time.sleep(0.2)
|
||||
PY
|
||||
"#;
|
||||
|
||||
let first_call_id = "uexec-lag-start";
|
||||
let first_args = serde_json::json!({
|
||||
"input": ["/bin/sh", "-c", script],
|
||||
"timeout_ms": 25,
|
||||
});
|
||||
|
||||
let second_call_id = "uexec-lag-poll";
|
||||
let second_args = serde_json::json!({
|
||||
"input": Vec::<String>::new(),
|
||||
"session_id": "0",
|
||||
"timeout_ms": 2_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "lag handled"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "exercise lag handling".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let start_output = outputs
|
||||
.get(first_call_id)
|
||||
.expect("missing initial unified_exec output");
|
||||
let session_id = start_output["session_id"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
!session_id.is_empty(),
|
||||
"expected session id from initial unified_exec response"
|
||||
);
|
||||
|
||||
let poll_output = outputs
|
||||
.get(second_call_id)
|
||||
.expect("missing poll unified_exec output");
|
||||
let poll_text = poll_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
poll_text.contains("TAIL-MARKER"),
|
||||
"expected poll output to contain tail marker, got {poll_text:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::os::unix::fs::PermissionsExt;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::fs_wait;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
@@ -17,8 +18,7 @@ use responses::ev_assistant_message;
|
||||
use responses::ev_completed;
|
||||
use responses::sse;
|
||||
use responses::start_mock_server;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() -> anyhow::Result<()> {
|
||||
@@ -60,14 +60,7 @@ echo -n "${@: -1}" > $(dirname "${0}")/notify.txt"#,
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// We fork the notify script, so we need to wait for it to write to the file.
|
||||
for _ in 0..100u32 {
|
||||
if notify_file.exists() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
assert!(notify_file.exists());
|
||||
fs_wait::wait_for_path_exists(¬ify_file, Duration::from_secs(5)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,42 +1,31 @@
|
||||
{ pkgs, monorep-deps ? [], ... }:
|
||||
let
|
||||
{
|
||||
openssl,
|
||||
rustPlatform,
|
||||
pkg-config,
|
||||
lib,
|
||||
...
|
||||
}:
|
||||
rustPlatform.buildRustPackage (finalAttrs: {
|
||||
env = {
|
||||
PKG_CONFIG_PATH = "${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH";
|
||||
PKG_CONFIG_PATH = "${openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH";
|
||||
};
|
||||
in
|
||||
rec {
|
||||
package = pkgs.rustPlatform.buildRustPackage {
|
||||
inherit env;
|
||||
pname = "codex-rs";
|
||||
version = "0.1.0";
|
||||
cargoLock.lockFile = ./Cargo.lock;
|
||||
doCheck = false;
|
||||
src = ./.;
|
||||
nativeBuildInputs = with pkgs; [
|
||||
pkg-config
|
||||
openssl
|
||||
];
|
||||
meta = with pkgs.lib; {
|
||||
description = "OpenAI Codex command‑line interface rust implementation";
|
||||
license = licenses.asl20;
|
||||
homepage = "https://github.com/openai/codex";
|
||||
};
|
||||
pname = "codex-rs";
|
||||
version = "0.1.0";
|
||||
cargoLock.lockFile = ./Cargo.lock;
|
||||
doCheck = false;
|
||||
src = ./.;
|
||||
nativeBuildInputs = [
|
||||
pkg-config
|
||||
openssl
|
||||
];
|
||||
|
||||
cargoLock.outputHashes = {
|
||||
"ratatui-0.29.0" = "sha256-HBvT5c8GsiCxMffNjJGLmHnvG77A6cqEL+1ARurBXho=";
|
||||
};
|
||||
devShell = pkgs.mkShell {
|
||||
inherit env;
|
||||
name = "codex-rs-dev";
|
||||
packages = monorep-deps ++ [
|
||||
pkgs.cargo
|
||||
package
|
||||
];
|
||||
shellHook = ''
|
||||
echo "Entering development shell for codex-rs"
|
||||
alias codex="cd ${package.src}/tui; cargo run; cd -"
|
||||
${pkgs.rustPlatform.cargoSetupHook}
|
||||
'';
|
||||
|
||||
meta = with lib; {
|
||||
description = "OpenAI Codex command‑line interface rust implementation";
|
||||
license = licenses.asl20;
|
||||
homepage = "https://github.com/openai/codex";
|
||||
};
|
||||
app = {
|
||||
type = "app";
|
||||
program = "${package}/bin/codex";
|
||||
};
|
||||
}
|
||||
})
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::git_info::get_git_repo_root;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::Event;
|
||||
@@ -168,8 +169,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
model,
|
||||
review_model: None,
|
||||
config_profile,
|
||||
// This CLI is intended to be headless and has no affordances for asking
|
||||
// the user for approval.
|
||||
// Default to never ask for approvals in headless mode. Feature flags can override.
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox_mode,
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
@@ -192,6 +192,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides).await?;
|
||||
let approve_all_enabled = config.features.enabled(Feature::ApproveAll);
|
||||
|
||||
let otel = codex_core::otel_init::build_provider(&config, env!("CARGO_PKG_VERSION"));
|
||||
|
||||
@@ -360,6 +361,34 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
if matches!(event.msg, EventMsg::Error(_)) {
|
||||
error_seen = true;
|
||||
}
|
||||
// Auto-approve requests when the approve_all feature is enabled.
|
||||
if approve_all_enabled {
|
||||
match &event.msg {
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::ExecApproval {
|
||||
id: event.id.clone(),
|
||||
decision: codex_core::protocol::ReviewDecision::Approved,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to auto-approve exec: {e}");
|
||||
}
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::PatchApproval {
|
||||
id: event.id.clone(),
|
||||
decision: codex_core::protocol::ReviewDecision::Approved,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to auto-approve patch: {e}");
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let shutdown: CodexStatus = event_processor.process_event(event);
|
||||
match shutdown {
|
||||
CodexStatus::Running => continue,
|
||||
|
||||
81
codex-rs/exec/tests/suite/approve_all.rs
Normal file
81
codex-rs/exec/tests/suite/approve_all.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
async fn run_exec_with_args(args: &[&str]) -> Result<String> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
let call_id = "exec-approve";
|
||||
let exec_args = json!({
|
||||
"command": [
|
||||
if cfg!(windows) { "cmd.exe" } else { "/bin/sh" },
|
||||
if cfg!(windows) { "/C" } else { "-lc" },
|
||||
"echo approve-all-ok",
|
||||
],
|
||||
"timeout_ms": 1500,
|
||||
"with_escalated_permissions": true
|
||||
});
|
||||
|
||||
let response_streams = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&exec_args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let mock = mount_sse_sequence(&server, response_streams).await;
|
||||
|
||||
test.cmd_with_server(&server).args(args).assert().success();
|
||||
|
||||
let requests = mock.requests();
|
||||
assert!(requests.len() >= 2, "expected at least two responses POSTs");
|
||||
let item = requests[1].function_call_output(call_id);
|
||||
let output_str = item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("function_call_output.output should be a string");
|
||||
|
||||
Ok(output_str.to_string())
|
||||
}
|
||||
|
||||
/// Setting `features.approve_all=true` should switch to auto-approvals.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn approve_all_auto_accepts_exec() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let output = run_exec_with_args(&[
|
||||
"--skip-git-repo-check",
|
||||
"-c",
|
||||
"features.approve_all=true",
|
||||
"train",
|
||||
])
|
||||
.await?;
|
||||
assert!(
|
||||
output.contains("Exit code: 0"),
|
||||
"expected Exit code: 0 in output: {output}"
|
||||
);
|
||||
assert!(
|
||||
output.contains("approve-all-ok"),
|
||||
"expected command output in response: {output}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
mod apply_patch;
|
||||
mod approve_all;
|
||||
mod auth_env;
|
||||
mod originator;
|
||||
mod output_schema;
|
||||
|
||||
13
codex-rs/feedback/Cargo.toml
Normal file
13
codex-rs/feedback/Cargo.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
name = "codex-feedback"
|
||||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
sentry = { version = "0.34" }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
231
codex-rs/feedback/src/lib.rs
Normal file
231
codex-rs/feedback/src/lib.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_protocol::ConversationId;
|
||||
use tracing_subscriber::fmt::writer::MakeWriter;
|
||||
|
||||
const DEFAULT_MAX_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
|
||||
const SENTRY_DSN: &str =
|
||||
"https://ae32ed50620d7a7792c1ce5df38b3e3e@o33249.ingest.us.sentry.io/4510195390611458";
|
||||
const UPLOAD_TIMEOUT_SECS: u64 = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CodexFeedback {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl Default for CodexFeedback {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexFeedback {
|
||||
pub fn new() -> Self {
|
||||
Self::with_capacity(DEFAULT_MAX_BYTES)
|
||||
}
|
||||
|
||||
pub(crate) fn with_capacity(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(FeedbackInner::new(max_bytes)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_writer(&self) -> FeedbackMakeWriter {
|
||||
FeedbackMakeWriter {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn snapshot(&self, session_id: Option<ConversationId>) -> CodexLogSnapshot {
|
||||
let bytes = {
|
||||
let guard = self.inner.ring.lock().expect("mutex poisoned");
|
||||
guard.snapshot_bytes()
|
||||
};
|
||||
CodexLogSnapshot {
|
||||
bytes,
|
||||
thread_id: session_id
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or("no-active-thread-".to_string() + &ConversationId::new().to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FeedbackInner {
|
||||
ring: Mutex<RingBuffer>,
|
||||
}
|
||||
|
||||
impl FeedbackInner {
|
||||
fn new(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
ring: Mutex::new(RingBuffer::new(max_bytes)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FeedbackMakeWriter {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl<'a> MakeWriter<'a> for FeedbackMakeWriter {
|
||||
type Writer = FeedbackWriter;
|
||||
|
||||
fn make_writer(&'a self) -> Self::Writer {
|
||||
FeedbackWriter {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeedbackWriter {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl Write for FeedbackWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut guard = self.inner.ring.lock().map_err(|_| io::ErrorKind::Other)?;
|
||||
guard.push_bytes(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct RingBuffer {
|
||||
max: usize,
|
||||
buf: VecDeque<u8>,
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
max: capacity,
|
||||
buf: VecDeque::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.buf.len()
|
||||
}
|
||||
|
||||
fn push_bytes(&mut self, data: &[u8]) {
|
||||
if data.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the incoming chunk is larger than capacity, keep only the trailing bytes.
|
||||
if data.len() >= self.max {
|
||||
self.buf.clear();
|
||||
let start = data.len() - self.max;
|
||||
self.buf.extend(data[start..].iter().copied());
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict from the front if we would exceed capacity.
|
||||
let needed = self.len() + data.len();
|
||||
if needed > self.max {
|
||||
let to_drop = needed - self.max;
|
||||
for _ in 0..to_drop {
|
||||
let _ = self.buf.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
self.buf.extend(data.iter().copied());
|
||||
}
|
||||
|
||||
fn snapshot_bytes(&self) -> Vec<u8> {
|
||||
self.buf.iter().copied().collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodexLogSnapshot {
|
||||
bytes: Vec<u8>,
|
||||
pub thread_id: String,
|
||||
}
|
||||
|
||||
impl CodexLogSnapshot {
|
||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
|
||||
pub fn save_to_temp_file(&self) -> io::Result<PathBuf> {
|
||||
let dir = std::env::temp_dir();
|
||||
let filename = format!("codex-feedback-{}.log", self.thread_id);
|
||||
let path = dir.join(filename);
|
||||
fs::write(&path, self.as_bytes())?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn upload_to_sentry(&self) -> Result<()> {
|
||||
use std::collections::BTreeMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use sentry::Client;
|
||||
use sentry::ClientOptions;
|
||||
use sentry::protocol::Attachment;
|
||||
use sentry::protocol::Envelope;
|
||||
use sentry::protocol::EnvelopeItem;
|
||||
use sentry::protocol::Event;
|
||||
use sentry::protocol::Level;
|
||||
use sentry::transports::DefaultTransportFactory;
|
||||
use sentry::types::Dsn;
|
||||
|
||||
let client = Client::from_config(ClientOptions {
|
||||
dsn: Some(Dsn::from_str(SENTRY_DSN).map_err(|e| anyhow!("invalid DSN: {}", e))?),
|
||||
transport: Some(Arc::new(DefaultTransportFactory {})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let tags = BTreeMap::from([(String::from("thread_id"), self.thread_id.to_string())]);
|
||||
|
||||
let event = Event {
|
||||
level: Level::Error,
|
||||
message: Some("Codex Log Upload ".to_string() + &self.thread_id),
|
||||
tags,
|
||||
..Default::default()
|
||||
};
|
||||
let mut envelope = Envelope::new();
|
||||
envelope.add_item(EnvelopeItem::Event(event));
|
||||
envelope.add_item(EnvelopeItem::Attachment(Attachment {
|
||||
buffer: self.bytes.clone(),
|
||||
filename: String::from("codex-logs.log"),
|
||||
content_type: Some("text/plain".to_string()),
|
||||
ty: None,
|
||||
}));
|
||||
|
||||
client.send_envelope(envelope);
|
||||
client.flush(Some(Duration::from_secs(UPLOAD_TIMEOUT_SECS)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_drops_front_when_full() {
|
||||
let fb = CodexFeedback::with_capacity(8);
|
||||
{
|
||||
let mut w = fb.make_writer().make_writer();
|
||||
w.write_all(b"abcdefgh").unwrap();
|
||||
w.write_all(b"ij").unwrap();
|
||||
}
|
||||
let snap = fb.snapshot(None);
|
||||
// Capacity 8: after writing 10 bytes, we should keep the last 8.
|
||||
pretty_assertions::assert_eq!(std::str::from_utf8(snap.as_bytes()).unwrap(), "cdefghij");
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,7 @@ async fn main() -> Result<()> {
|
||||
// Spawn the subprocess and connect the client.
|
||||
let program = args.remove(0);
|
||||
let env = None;
|
||||
let client = McpClient::new_stdio_client(program, args, env)
|
||||
let client = McpClient::new_stdio_client(program, args, env, &[], None)
|
||||
.await
|
||||
.with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?;
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -86,19 +87,26 @@ impl McpClient {
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
cwd: Option<PathBuf>,
|
||||
) -> std::io::Result<Self> {
|
||||
let mut child = Command::new(program)
|
||||
let mut command = Command::new(program);
|
||||
command
|
||||
.args(args)
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env))
|
||||
.envs(create_env_for_mcp_server(env, env_vars))
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::null())
|
||||
// As noted in the `kill_on_drop` documentation, the Tokio runtime makes
|
||||
// a "best effort" to reap-after-exit to avoid zombie processes, but it
|
||||
// is not a guarantee.
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
.kill_on_drop(true);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let mut child = command.spawn()?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
@@ -447,12 +455,16 @@ const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
/// `config.toml`.
|
||||
fn create_env_for_mcp_server(
|
||||
extra_env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
) -> HashMap<String, String> {
|
||||
DEFAULT_ENV_VARS
|
||||
.iter()
|
||||
.filter_map(|var| match std::env::var(var) {
|
||||
Ok(value) => Some((var.to_string(), value)),
|
||||
Err(_) => None,
|
||||
.copied()
|
||||
.chain(env_vars.iter().map(String::as_str))
|
||||
.filter_map(|var| {
|
||||
std::env::var(var)
|
||||
.ok()
|
||||
.map(|value| (var.to_string(), value))
|
||||
})
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect::<HashMap<_, _>>()
|
||||
@@ -462,14 +474,36 @@ fn create_env_for_mcp_server(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn set_env_var(key: &str, value: &str) {
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_env_var(key: &str) {
|
||||
unsafe {
|
||||
std::env::remove_var(key);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_env_for_mcp_server() {
|
||||
let env_var = "USER";
|
||||
let env_var_existing_value = std::env::var(env_var).unwrap_or_default();
|
||||
let env_var_new_value = format!("{env_var_existing_value}-extra");
|
||||
let extra_env = HashMap::from([(env_var.to_owned(), env_var_new_value.clone())]);
|
||||
let mcp_server_env = create_env_for_mcp_server(Some(extra_env));
|
||||
let mcp_server_env = create_env_for_mcp_server(Some(extra_env), &[]);
|
||||
assert!(mcp_server_env.contains_key("PATH"));
|
||||
assert_eq!(Some(&env_var_new_value), mcp_server_env.get(env_var));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_env_for_mcp_server_includes_extra_whitelisted_vars() {
|
||||
let custom_var = "CUSTOM_TEST_VAR";
|
||||
let value = "value".to_string();
|
||||
set_env_var(custom_var, &value);
|
||||
let mcp_server_env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
||||
assert_eq!(Some(&value), mcp_server_env.get(custom_var));
|
||||
remove_env_var(custom_var);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,6 +178,7 @@ async fn run_codex_tool_session_inner(
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
parsed_cmd,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
@@ -188,6 +189,7 @@ async fn run_codex_tool_session_inner(
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
parsed_cmd,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
@@ -35,6 +36,7 @@ pub struct ExecApprovalElicitRequestParams {
|
||||
pub codex_call_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
pub codex_parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
@@ -56,6 +58,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
call_id: String,
|
||||
codex_parsed_cmd: Vec<ParsedCommand>,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(String::as_str)).unwrap_or_else(|_| command.join(" "));
|
||||
@@ -77,6 +80,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
codex_call_id: call_id,
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
codex_parsed_cmd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::env;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::parse_command;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
@@ -176,6 +177,7 @@ fn create_expected_elicitation_request(
|
||||
shlex::try_join(command.iter().map(std::convert::AsRef::as_ref))?,
|
||||
workdir.to_string_lossy()
|
||||
);
|
||||
let codex_parsed_cmd = parse_command::parse_command(&command);
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
@@ -193,6 +195,7 @@ fn create_expected_elicitation_request(
|
||||
codex_command: command,
|
||||
codex_cwd: workdir.to_path_buf(),
|
||||
codex_call_id: "call1234".to_string(),
|
||||
codex_parsed_cmd,
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ pub struct OtelEventMetadata {
|
||||
conversation_id: ConversationId,
|
||||
auth_mode: Option<String>,
|
||||
account_id: Option<String>,
|
||||
account_email: Option<String>,
|
||||
model: String,
|
||||
slug: String,
|
||||
log_user_prompts: bool,
|
||||
@@ -46,11 +47,13 @@ pub struct OtelEventManager {
|
||||
}
|
||||
|
||||
impl OtelEventManager {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
conversation_id: ConversationId,
|
||||
model: &str,
|
||||
slug: &str,
|
||||
account_id: Option<String>,
|
||||
account_email: Option<String>,
|
||||
auth_mode: Option<AuthMode>,
|
||||
log_user_prompts: bool,
|
||||
terminal_type: String,
|
||||
@@ -60,6 +63,7 @@ impl OtelEventManager {
|
||||
conversation_id,
|
||||
auth_mode: auth_mode.map(|m| m.to_string()),
|
||||
account_id,
|
||||
account_email,
|
||||
model: model.to_owned(),
|
||||
slug: slug.to_owned(),
|
||||
log_user_prompts,
|
||||
@@ -98,6 +102,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -136,6 +141,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -148,21 +154,15 @@ impl OtelEventManager {
|
||||
response
|
||||
}
|
||||
|
||||
pub async fn log_sse_event<Next, Fut, E>(
|
||||
pub fn log_sse_event<E>(
|
||||
&self,
|
||||
next: Next,
|
||||
) -> Result<Option<Result<StreamEvent, StreamError<E>>>, Elapsed>
|
||||
where
|
||||
Next: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<Option<Result<StreamEvent, StreamError<E>>>, Elapsed>>,
|
||||
response: &Result<Option<Result<StreamEvent, StreamError<E>>>, Elapsed>,
|
||||
duration: Duration,
|
||||
) where
|
||||
E: Display,
|
||||
{
|
||||
let start = std::time::Instant::now();
|
||||
let response = next().await;
|
||||
let duration = start.elapsed();
|
||||
|
||||
match response {
|
||||
Ok(Some(Ok(ref sse))) => {
|
||||
Ok(Some(Ok(sse))) => {
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
self.sse_event(&sse.event, duration);
|
||||
} else {
|
||||
@@ -191,7 +191,7 @@ impl OtelEventManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(ref error))) => {
|
||||
Ok(Some(Err(error))) => {
|
||||
self.sse_event_failed(None, duration, error);
|
||||
}
|
||||
Ok(None) => {}
|
||||
@@ -199,8 +199,6 @@ impl OtelEventManager {
|
||||
self.sse_event_failed(None, duration, &"idle timeout waiting for SSE");
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn sse_event(&self, kind: &str, duration: Duration) {
|
||||
@@ -213,6 +211,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -234,6 +233,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -248,6 +248,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -270,6 +271,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -294,6 +296,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -328,6 +331,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -351,6 +355,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -391,7 +396,8 @@ impl OtelEventManager {
|
||||
conversation.id = %self.metadata.conversation_id,
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.account_id= self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -416,6 +422,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -445,6 +452,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::PathBuf;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, TS)]
|
||||
@@ -8,6 +9,11 @@ pub enum ParsedCommand {
|
||||
Read {
|
||||
cmd: String,
|
||||
name: String,
|
||||
/// (Best effort) Path to the file being read by the command. When
|
||||
/// possible, this is an absolute path, though when relative, it should
|
||||
/// be resolved against the `cwd`` that will be used to run the command
|
||||
/// to derive the absolute path.
|
||||
path: PathBuf,
|
||||
},
|
||||
ListFiles {
|
||||
cmd: String,
|
||||
|
||||
@@ -21,6 +21,8 @@ use crate::num_format::format_with_separators;
|
||||
use crate::parse_command::ParsedCommand;
|
||||
use crate::plan_tool::UpdatePlanArgs;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::Resource as McpResource;
|
||||
use mcp_types::ResourceTemplate as McpResourceTemplate;
|
||||
use mcp_types::Tool as McpTool;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -1178,6 +1180,7 @@ pub struct ExecApprovalRequestEvent {
|
||||
/// Optional human-readable reason for the approval (e.g. retry without sandbox).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
pub parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
@@ -1203,6 +1206,11 @@ pub struct StreamErrorEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct StreamInfoEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct PatchApplyBeginEvent {
|
||||
/// Identifier so this can be paired with the PatchApplyEnd event.
|
||||
@@ -1243,6 +1251,34 @@ pub struct GetHistoryEntryResponseEvent {
|
||||
pub struct McpListToolsResponseEvent {
|
||||
/// Fully qualified tool name -> tool definition.
|
||||
pub tools: std::collections::HashMap<String, McpTool>,
|
||||
/// Known resources grouped by server name.
|
||||
pub resources: std::collections::HashMap<String, Vec<McpResource>>,
|
||||
/// Known resource templates grouped by server name.
|
||||
pub resource_templates: std::collections::HashMap<String, Vec<McpResourceTemplate>>,
|
||||
/// Authentication status for each configured MCP server.
|
||||
pub auth_statuses: std::collections::HashMap<String, McpAuthStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[ts(rename_all = "snake_case")]
|
||||
pub enum McpAuthStatus {
|
||||
Unsupported,
|
||||
NotLoggedIn,
|
||||
BearerToken,
|
||||
OAuth,
|
||||
}
|
||||
|
||||
impl fmt::Display for McpAuthStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let text = match self {
|
||||
McpAuthStatus::Unsupported => "Unsupported",
|
||||
McpAuthStatus::NotLoggedIn => "Not logged in",
|
||||
McpAuthStatus::BearerToken => "Bearer token",
|
||||
McpAuthStatus::OAuth => "OAuth",
|
||||
};
|
||||
f.write_str(text)
|
||||
}
|
||||
}
|
||||
|
||||
/// Response payload for `Op::ListCustomPrompts`.
|
||||
|
||||
@@ -12,6 +12,7 @@ axum = { workspace = true, default-features = false, features = [
|
||||
"http1",
|
||||
"tokio",
|
||||
] }
|
||||
codex-protocol = { workspace = true }
|
||||
keyring = { workspace = true, features = [
|
||||
"apple-native",
|
||||
"crypto-rust",
|
||||
@@ -56,5 +57,7 @@ urlencoding = { workspace = true }
|
||||
webbrowser = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
escargot = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
138
codex-rs/rmcp-client/src/auth_status.rs
Normal file
138
codex-rs/rmcp-client/src/auth_status.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Error;
|
||||
use anyhow::Result;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
|
||||
|
||||
/// Determine the authentication status for a streamable HTTP MCP server.
|
||||
pub async fn determine_streamable_http_auth_status(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if bearer_token_env_var.is_some() {
|
||||
return Ok(McpAuthStatus::BearerToken);
|
||||
}
|
||||
|
||||
if has_oauth_tokens(server_name, url, store_mode)? {
|
||||
return Ok(McpAuthStatus::OAuth);
|
||||
}
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
match supports_oauth_login_with_headers(url, &default_headers).await {
|
||||
Ok(true) => Ok(McpAuthStatus::NotLoggedIn),
|
||||
Ok(false) => Ok(McpAuthStatus::Unsupported),
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"failed to detect OAuth support for MCP server `{server_name}` at {url}: {error:?}"
|
||||
);
|
||||
Ok(McpAuthStatus::Unsupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login.
|
||||
pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
||||
supports_oauth_login_with_headers(url, &HeaderMap::new()).await
|
||||
}
|
||||
|
||||
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT);
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
let response = match client
|
||||
.get(discovery_url.clone())
|
||||
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
||||
Ok(metadata) => metadata,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthDiscoveryMetadata {
|
||||
#[serde(default)]
|
||||
authorization_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
token_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
|
||||
/// This is a requirement for MCP servers to support OAuth.
|
||||
/// https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk/blob/main/crates/rmcp/src/transport/auth.rs#L182
|
||||
fn discovery_paths(base_path: &str) -> Vec<String> {
|
||||
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
|
||||
let canonical = "/.well-known/oauth-authorization-server".to_string();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
return vec![canonical];
|
||||
}
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
let mut push_unique = |candidate: String| {
|
||||
if !candidates.contains(&candidate) {
|
||||
candidates.push(candidate);
|
||||
}
|
||||
};
|
||||
|
||||
push_unique(format!("{canonical}/{trimmed}"));
|
||||
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
|
||||
push_unique(canonical);
|
||||
|
||||
candidates
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user