mirror of
https://github.com/openai/codex.git
synced 2026-05-12 23:32:44 +00:00
Compare commits
100 Commits
jif/parall
...
file-notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ace71f45 | ||
|
|
6b05624040 | ||
|
|
26f7c46856 | ||
|
|
90af046c5c | ||
|
|
961ed31901 | ||
|
|
85e7357973 | ||
|
|
f98fa85b44 | ||
|
|
ddcaf3dccd | ||
|
|
56296cad82 | ||
|
|
95b41dd7f1 | ||
|
|
bf82353f45 | ||
|
|
0308febc23 | ||
|
|
7b4a4c2219 | ||
|
|
3ddd4d47d0 | ||
|
|
ca6a0358de | ||
|
|
0026b12615 | ||
|
|
4300236681 | ||
|
|
ec238a2c39 | ||
|
|
b6165aee0c | ||
|
|
f4bc03d7c0 | ||
|
|
3c5e12e2a4 | ||
|
|
c89229db97 | ||
|
|
d3820f4782 | ||
|
|
e896db1180 | ||
|
|
96acb8a74e | ||
|
|
687a13bbe5 | ||
|
|
fe8122e514 | ||
|
|
876d4f450a | ||
|
|
f52320be86 | ||
|
|
a43ae86b6c | ||
|
|
496cb801e1 | ||
|
|
abd517091f | ||
|
|
b8b04514bc | ||
|
|
0e5d72cc57 | ||
|
|
60f9e85c16 | ||
|
|
b016a3e7d8 | ||
|
|
a0d56541cf | ||
|
|
226215f36d | ||
|
|
338c2c873c | ||
|
|
4b0f5eb6a8 | ||
|
|
75176dae70 | ||
|
|
12fd2b4160 | ||
|
|
f2555422b9 | ||
|
|
27f169bb91 | ||
|
|
b16c985ed2 | ||
|
|
35a770e871 | ||
|
|
b09f62a1c3 | ||
|
|
5833508a17 | ||
|
|
d73055c5b1 | ||
|
|
7e3a272b29 | ||
|
|
661663c98a | ||
|
|
721003c552 | ||
|
|
36f1cca1b1 | ||
|
|
d3e1beb26c | ||
|
|
c264ae6021 | ||
|
|
8cd882c4bd | ||
|
|
90fe5e4a7e | ||
|
|
a90a58f7a1 | ||
|
|
b2d81a7cac | ||
|
|
77a8b7fdeb | ||
|
|
7fa5e95c1f | ||
|
|
191d620707 | ||
|
|
53504a38d2 | ||
|
|
5c42419b02 | ||
|
|
aecbe0f333 | ||
|
|
a30a902db5 | ||
|
|
f3b4a26f32 | ||
|
|
dc3c6bf62a | ||
|
|
3203862167 | ||
|
|
06853d94f0 | ||
|
|
cc2f4aafd7 | ||
|
|
356ea6ea34 | ||
|
|
4764fc1ee7 | ||
|
|
90ef94d3b3 | ||
|
|
6c2969d22d | ||
|
|
0ad1b0782b | ||
|
|
d7acd146fb | ||
|
|
c5465aed60 | ||
|
|
a95605a867 | ||
|
|
848058f05b | ||
|
|
a4f1c9d67e | ||
|
|
665341c9b1 | ||
|
|
fae0e6c52c | ||
|
|
1b4a79f03c | ||
|
|
640192ac3d | ||
|
|
205c36e393 | ||
|
|
d13ee79c41 | ||
|
|
bde468ff8d | ||
|
|
e292d1ed21 | ||
|
|
de8d77274a | ||
|
|
a5b7675e42 | ||
|
|
9823de3cc6 | ||
|
|
c32e9cfe86 | ||
|
|
1d17ca1fa3 | ||
|
|
bfe3328129 | ||
|
|
e0b38bd7a2 | ||
|
|
153338c20f | ||
|
|
3495a7dc37 | ||
|
|
042d4d55d9 | ||
|
|
5af08e0719 |
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
@@ -2,7 +2,6 @@ name: 🎁 Feature Request
|
||||
description: Propose a new feature for Codex
|
||||
labels:
|
||||
- enhancement
|
||||
- needs triage
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
@@ -19,11 +18,6 @@ body:
|
||||
label: What feature would you like to see?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: author
|
||||
attributes:
|
||||
label: Are you interested in implementing this feature?
|
||||
description: Please wait for acknowledgement before implementing or opening a PR.
|
||||
- type: textarea
|
||||
id: notes
|
||||
attributes:
|
||||
|
||||
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -60,3 +60,6 @@ jobs:
|
||||
run: ./scripts/asciicheck.py codex-cli/README.md
|
||||
- name: Check codex-cli/README ToC
|
||||
run: python3 scripts/readme_toc.py codex-cli/README.md
|
||||
|
||||
- name: Prettier (run `pnpm run format:fix` to fix)
|
||||
run: pnpm run format
|
||||
|
||||
60
.github/workflows/issue-deduplicator.yml
vendored
60
.github/workflows/issue-deduplicator.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -44,11 +44,38 @@ jobs:
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
prompt_file: .github/prompts/issue-deduplicator.txt
|
||||
require_repo_write: false
|
||||
codex_version: 0.43.0-alpha.16
|
||||
codex_args: -m gpt-5
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
model: gpt-5
|
||||
prompt: |
|
||||
You are an assistant that triages new GitHub issues by identifying potential duplicates.
|
||||
|
||||
You will receive the following JSON files located in the current working directory:
|
||||
- `codex-current-issue.json`: JSON object describing the newly created issue (fields: number, title, body).
|
||||
- `codex-existing-issues.json`: JSON array of recent issues (each element includes number, title, body, createdAt).
|
||||
|
||||
Instructions:
|
||||
- Compare the current issue against the existing issues to find up to five that appear to describe the same underlying problem or request.
|
||||
- Focus on the underlying intent and context of each issue—such as reported symptoms, feature requests, reproduction steps, or error messages—rather than relying solely on string similarity or synthetic metrics.
|
||||
- After your analysis, validate your results in 1-2 lines explaining your decision to return the selected matches.
|
||||
- When unsure, prefer returning fewer matches.
|
||||
- Include at most five numbers.
|
||||
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"issues": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"reason": { "type": "string" }
|
||||
},
|
||||
"required": ["issues", "reason"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
||||
comment-on-issue:
|
||||
name: Comment with potential duplicates
|
||||
@@ -66,22 +93,33 @@ jobs:
|
||||
with:
|
||||
github-token: ${{ github.token }}
|
||||
script: |
|
||||
let numbers;
|
||||
const raw = process.env.CODEX_OUTPUT ?? '';
|
||||
let parsed;
|
||||
try {
|
||||
numbers = JSON.parse(process.env.CODEX_OUTPUT);
|
||||
parsed = JSON.parse(raw);
|
||||
} catch (error) {
|
||||
core.info(`Codex output was not valid JSON. Raw output: ${raw}`);
|
||||
core.info(`Parse error: ${error.message}`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (numbers.length === 0) {
|
||||
const issues = Array.isArray(parsed?.issues) ? parsed.issues : [];
|
||||
const currentIssueNumber = String(context.payload.issue.number);
|
||||
|
||||
console.log(`Current issue number: ${currentIssueNumber}`);
|
||||
console.log(issues);
|
||||
|
||||
const filteredIssues = issues.filter((value) => String(value) !== currentIssueNumber);
|
||||
|
||||
if (filteredIssues.length === 0) {
|
||||
core.info('Codex reported no potential duplicates.');
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = [
|
||||
'Potential duplicates detected:'
|
||||
...numbers.map((value) => `- #${value}`),
|
||||
'Potential duplicates detected. Please review them and close your issue if it is a duplicate.',
|
||||
'',
|
||||
...filteredIssues.map((value) => `- #${String(value)}`),
|
||||
'',
|
||||
'*Powered by [Codex Action](https://github.com/openai/codex-action)*'];
|
||||
|
||||
|
||||
63
.github/workflows/issue-labeler.yml
vendored
63
.github/workflows/issue-labeler.yml
vendored
@@ -13,23 +13,60 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
ISSUE_NUMBER: ${{ github.event.issue.number }}
|
||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||
ISSUE_BODY: ${{ github.event.issue.body }}
|
||||
REPO_FULL_NAME: ${{ github.repository }}
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
prompt_file: .github/prompts/issue-labeler.txt
|
||||
require_repo_write: false
|
||||
codex_version: 0.43.0-alpha.16
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
prompt: |
|
||||
You are an assistant that reviews GitHub issues for the repository.
|
||||
|
||||
Your job is to choose the most appropriate existing labels for the issue described later in this prompt.
|
||||
Follow these rules:
|
||||
- Only pick labels out of the list below.
|
||||
- Prefer a small set of precise labels over many broad ones.
|
||||
|
||||
Labels to apply:
|
||||
1. bug — Reproducible defects in Codex products (CLI, VS Code extension, web, auth).
|
||||
2. enhancement — Feature requests or usability improvements that ask for new capabilities, better ergonomics, or quality-of-life tweaks.
|
||||
3. extension — VS Code (or other IDE) extension-specific issues.
|
||||
4. windows-os — Bugs or friction specific to Windows environments (always when PowerShell is mentioned, path handling, copy/paste, OS-specific auth or tooling failures).
|
||||
5. mcp — Topics involving Model Context Protocol servers/clients.
|
||||
6. codex-web — Issues targeting the Codex web UI/Cloud experience.
|
||||
8. azure — Problems or requests tied to Azure OpenAI deployments.
|
||||
9. documentation — Updates or corrections needed in docs/README/config references (broken links, missing examples, outdated keys, clarification requests).
|
||||
10. model-behavior — Undesirable LLM behavior: forgetting goals, refusing work, hallucinating environment details, quota misreports, or other reasoning/performance anomalies.
|
||||
|
||||
Issue number: ${{ github.event.issue.number }}
|
||||
|
||||
Issue title:
|
||||
${{ github.event.issue.title }}
|
||||
|
||||
Issue body:
|
||||
${{ github.event.issue.body }}
|
||||
|
||||
Repository full name:
|
||||
${{ github.repository }}
|
||||
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"labels": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["labels"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
||||
apply-labels:
|
||||
name: Apply labels from Codex output
|
||||
@@ -53,12 +90,12 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if ! printf '%s' "$json" | jq -e 'type == "array"' >/dev/null 2>&1; then
|
||||
echo "Codex output was not a JSON array. Raw output: $json"
|
||||
if ! printf '%s' "$json" | jq -e 'type == "object" and (.labels | type == "array")' >/dev/null 2>&1; then
|
||||
echo "Codex output did not include a labels array. Raw output: $json"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
labels=$(printf '%s' "$json" | jq -r '.[] | tostring')
|
||||
labels=$(printf '%s' "$json" | jq -r '.labels[] | tostring')
|
||||
if [ -z "$labels" ]; then
|
||||
echo "Codex returned an empty array. Nothing to do."
|
||||
exit 0
|
||||
|
||||
42
.github/workflows/rust-ci.yml
vendored
42
.github/workflows/rust-ci.yml
vendored
@@ -148,15 +148,26 @@ jobs:
|
||||
targets: ${{ matrix.target }}
|
||||
components: clippy
|
||||
|
||||
- uses: actions/cache@v4
|
||||
# Explicit cache restore: split cargo home vs target, so we can
|
||||
# avoid caching the large target dir on the gnu-dev job.
|
||||
- name: Restore cargo home cache
|
||||
id: cache_cargo_home_restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Restore target cache (except gnu-dev)
|
||||
id: cache_target_restore
|
||||
if: ${{ !(matrix.target == 'x86_64-unknown-linux-gnu' && matrix.profile != 'release') }}
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-target-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
@@ -194,6 +205,31 @@ jobs:
|
||||
env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
# Save caches explicitly; make non-fatal so cache packaging
|
||||
# never fails the overall job. Only save when key wasn't hit.
|
||||
- name: Save cargo home cache
|
||||
if: always() && !cancelled() && steps.cache_cargo_home_restore.outputs.cache-hit != 'true'
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Save target cache (except gnu-dev)
|
||||
if: >-
|
||||
always() && !cancelled() &&
|
||||
(steps.cache_target_restore.outputs.cache-hit != 'true') &&
|
||||
!(matrix.target == 'x86_64-unknown-linux-gnu' && matrix.profile != 'release')
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/codex-rs/target/
|
||||
key: cargo-target-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
# Fail the job if any of the previous steps failed.
|
||||
- name: verify all steps passed
|
||||
if: |
|
||||
|
||||
134
.github/workflows/rust-release.yml
vendored
134
.github/workflows/rust-release.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
|
||||
build:
|
||||
needs: tag-check
|
||||
name: ${{ matrix.runner }} - ${{ matrix.target }}
|
||||
name: Build - ${{ matrix.runner }} - ${{ matrix.target }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
@@ -94,11 +94,118 @@ jobs:
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
run: |
|
||||
sudo apt install -y musl-tools pkg-config
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
|
||||
- name: Cargo build
|
||||
run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-responses-api-proxy
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Configure Apple code signing
|
||||
shell: bash
|
||||
env:
|
||||
KEYCHAIN_PASSWORD: actions
|
||||
APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE_P12 }}
|
||||
APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE_PASSWORD:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE_PASSWORD is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cert_path="${RUNNER_TEMP}/apple_signing_certificate.p12"
|
||||
echo "$APPLE_CERTIFICATE" | base64 -d > "$cert_path"
|
||||
|
||||
keychain_path="${RUNNER_TEMP}/codex-signing.keychain-db"
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
security set-keychain-settings -lut 21600 "$keychain_path"
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
|
||||
keychain_args=()
|
||||
cleanup_keychain() {
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}" || true
|
||||
security default-keychain -s "${keychain_args[0]}" || true
|
||||
else
|
||||
security list-keychains -s || true
|
||||
fi
|
||||
if [[ -f "$keychain_path" ]]; then
|
||||
security delete-keychain "$keychain_path" || true
|
||||
fi
|
||||
}
|
||||
|
||||
while IFS= read -r keychain; do
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "$keychain_path" "${keychain_args[@]}"
|
||||
else
|
||||
security list-keychains -s "$keychain_path"
|
||||
fi
|
||||
|
||||
security default-keychain -s "$keychain_path"
|
||||
security import "$cert_path" -k "$keychain_path" -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign -T /usr/bin/security
|
||||
security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PASSWORD" "$keychain_path" > /dev/null
|
||||
|
||||
codesign_hashes=()
|
||||
while IFS= read -r hash; do
|
||||
[[ -n "$hash" ]] && codesign_hashes+=("$hash")
|
||||
done < <(security find-identity -v -p codesigning "$keychain_path" \
|
||||
| sed -n 's/.*\([0-9A-F]\{40\}\).*/\1/p' \
|
||||
| sort -u)
|
||||
|
||||
if ((${#codesign_hashes[@]} == 0)); then
|
||||
echo "No signing identities found in $keychain_path"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ((${#codesign_hashes[@]} > 1)); then
|
||||
echo "Multiple signing identities found in $keychain_path:"
|
||||
printf ' %s\n' "${codesign_hashes[@]}"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
APPLE_CODESIGN_IDENTITY="${codesign_hashes[0]}"
|
||||
|
||||
rm -f "$cert_path"
|
||||
|
||||
echo "APPLE_CODESIGN_IDENTITY=$APPLE_CODESIGN_IDENTITY" >> "$GITHUB_ENV"
|
||||
echo "APPLE_CODESIGN_KEYCHAIN=$keychain_path" >> "$GITHUB_ENV"
|
||||
echo "::add-mask::$APPLE_CODESIGN_IDENTITY"
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Sign macOS binaries
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CODESIGN_IDENTITY:-}" ]]; then
|
||||
echo "APPLE_CODESIGN_IDENTITY is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
keychain_args=()
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" && -f "${APPLE_CODESIGN_KEYCHAIN}" ]]; then
|
||||
keychain_args+=(--keychain "${APPLE_CODESIGN_KEYCHAIN}")
|
||||
fi
|
||||
|
||||
for binary in codex codex-responses-api-proxy; do
|
||||
path="target/${{ matrix.target }}/release/${binary}"
|
||||
codesign --force --options runtime --timestamp --sign "$APPLE_CODESIGN_IDENTITY" "${keychain_args[@]}" "$path"
|
||||
done
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -157,6 +264,29 @@ jobs:
|
||||
zstd -T0 -19 --rm "$dest/$base"
|
||||
done
|
||||
|
||||
- name: Remove signing keychain
|
||||
if: ${{ always() && matrix.runner == 'macos-14' }}
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_CODESIGN_KEYCHAIN: ${{ env.APPLE_CODESIGN_KEYCHAIN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" ]]; then
|
||||
keychain_args=()
|
||||
while IFS= read -r keychain; do
|
||||
[[ "$keychain" == "$APPLE_CODESIGN_KEYCHAIN" ]] && continue
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}"
|
||||
security default-keychain -s "${keychain_args[0]}"
|
||||
fi
|
||||
|
||||
if [[ -f "$APPLE_CODESIGN_KEYCHAIN" ]]; then
|
||||
security delete-keychain "$APPLE_CODESIGN_KEYCHAIN"
|
||||
fi
|
||||
fi
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.target }}
|
||||
|
||||
35
AGENTS.md
35
AGENTS.md
@@ -8,11 +8,16 @@ In the codex-rs folder where the rust code lives:
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` or `CODEX_SANDBOX_ENV_VAR`.
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
- Always collapse if statements per https://rust-lang.github.io/rust-clippy/master/index.html#collapsible_if
|
||||
- Always inline format! args when possible per https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args
|
||||
- Use method references over closures when possible per https://rust-lang.github.io/rust-clippy/master/index.html#redundant_closure_for_method_calls
|
||||
- When writing tests, prefer comparing the equality of entire objects over fields one by one.
|
||||
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
|
||||
1. Run the test for the specific project that was changed. For example, if changes were made in `codex-rs/tui`, run `cargo test -p codex-tui`.
|
||||
2. Once those pass, if any changes were made in common, core, or protocol, run the complete test suite with `cargo test --all-features`.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
|
||||
## TUI style conventions
|
||||
|
||||
@@ -28,6 +33,7 @@ See `codex-rs/tui/styles.md`.
|
||||
- Desired: vec![" └ ".into(), "M".red(), " ".dim(), "tui/src/app.rs".dim()]
|
||||
|
||||
### TUI Styling (ratatui)
|
||||
|
||||
- Prefer Stylize helpers: use "text".dim(), .bold(), .cyan(), .italic(), .underlined() instead of manual Style where possible.
|
||||
- Prefer simple conversions: use "text".into() for spans and vec![…].into() for lines; when inference is ambiguous (e.g., Paragraph::new/Cell::from), use Line::from(spans) or Span::from(text).
|
||||
- Computed styles: if the Style is computed at runtime, using `Span::styled` is OK (`Span::from(text).set_style(style)` is also acceptable).
|
||||
@@ -39,6 +45,7 @@ See `codex-rs/tui/styles.md`.
|
||||
- Compactness: prefer the form that stays on one line after rustfmt; if only one of Line::from(vec![…]) or vec![…].into() avoids wrapping, choose that. If both wrap, pick the one with fewer wrapped lines.
|
||||
|
||||
### Text wrapping
|
||||
|
||||
- Always use textwrap::wrap to wrap plain strings.
|
||||
- If you have a ratatui Line and you want to wrap it, use the helpers in tui/src/wrapping.rs, e.g. word_wrap_lines / word_wrap_line.
|
||||
- If you need to indent wrapped lines, use the initial_indent / subsequent_indent options from RtOptions if you can, rather than writing custom logic.
|
||||
@@ -60,8 +67,34 @@ This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to va
|
||||
- `cargo insta accept -p codex-tui`
|
||||
|
||||
If you don’t have the tool:
|
||||
|
||||
- `cargo install cargo-insta`
|
||||
|
||||
### Test assertions
|
||||
|
||||
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
|
||||
|
||||
### Integration tests (core)
|
||||
|
||||
- Prefer the utilities in `core_test_support::responses` when writing end-to-end Codex tests.
|
||||
|
||||
- All `mount_sse*` helpers return a `ResponseMock`; hold onto it so you can assert against outbound `/responses` POST bodies.
|
||||
- Use `ResponseMock::single_request()` when a test should only issue one POST, or `ResponseMock::requests()` to inspect every captured `ResponsesRequest`.
|
||||
- `ResponsesRequest` exposes helpers (`body_json`, `input`, `function_call_output`, `custom_tool_call_output`, `call_output`, `header`, `path`, `query_param`) so assertions can target structured payloads instead of manual JSON digging.
|
||||
- Build SSE payloads with the provided `ev_*` constructors and the `sse(...)`.
|
||||
|
||||
- Typical pattern:
|
||||
|
||||
```rust
|
||||
let mock = responses::mount_sse_once(&server, responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
responses::ev_completed("resp-1"),
|
||||
])).await;
|
||||
|
||||
codex.submit(Op::UserTurn { ... }).await?;
|
||||
|
||||
// Assert request body if needed.
|
||||
let request = mock.single_request();
|
||||
// assert using request.function_call_output(call_id) or request.json_body() or other helpers.
|
||||
```
|
||||
|
||||
10
README.md
10
README.md
@@ -1,4 +1,3 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
@@ -62,8 +61,7 @@ You can also use Codex with an API key, but this requires [additional setup](./d
|
||||
|
||||
### Model Context Protocol (MCP)
|
||||
|
||||
Codex CLI supports [MCP servers](./docs/advanced.md#model-context-protocol-mcp). Enable by adding an `mcp_servers` section to your `~/.codex/config.toml`.
|
||||
|
||||
Codex can access MCP servers. To configure them, refer to the [config docs](./docs/config.md#mcp_servers).
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -83,9 +81,11 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
- [Auth methods](./docs/authentication.md#forcing-a-specific-auth-method-advanced)
|
||||
- [Login on a "Headless" machine](./docs/authentication.md#connecting-on-a-headless-machine)
|
||||
- [**Non-interactive mode**](./docs/exec.md)
|
||||
- **Automating Codex**
|
||||
- [GitHub Action](https://github.com/openai/codex-action)
|
||||
- [TypeScript SDK](./sdk/typescript/README.md)
|
||||
- [Non-interactive mode (`codex exec`)](./docs/exec.md)
|
||||
- [**Advanced**](./docs/advanced.md)
|
||||
- [Non-interactive / CI mode](./docs/advanced.md#non-interactive--ci-mode)
|
||||
- [Tracing / verbose logging](./docs/advanced.md#tracing--verbose-logging)
|
||||
- [Model Context Protocol (MCP)](./docs/advanced.md#model-context-protocol-mcp)
|
||||
- [**Zero data retention (ZDR)**](./docs/zdr.md)
|
||||
|
||||
1556
codex-rs/Cargo.lock
generated
1556
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -76,6 +76,7 @@ codex-utils-string = { path = "utils/string" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
mcp-types = { path = "mcp-types" }
|
||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||
notify = "6"
|
||||
|
||||
# External
|
||||
allocative = "0.3.3"
|
||||
@@ -83,10 +84,12 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_matches = "1.5.0"
|
||||
assert_cmd = "2"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.89"
|
||||
axum = { version = "0.8", default-features = false }
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
chrono = "0.4.42"
|
||||
@@ -104,7 +107,7 @@ env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
escargot = "0.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
ignore = "0.4.23"
|
||||
@@ -112,6 +115,7 @@ image = { version = "^0.25.8", default-features = false }
|
||||
indexmap = "2.6.0"
|
||||
insta = "1.43.2"
|
||||
itertools = "0.14.0"
|
||||
keyring = "3.6"
|
||||
landlock = "0.4.1"
|
||||
lazy_static = "1"
|
||||
libc = "0.2.175"
|
||||
@@ -140,11 +144,13 @@ rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = "0.12"
|
||||
rmcp = { version = "0.8.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
serial_test = "3.2.0"
|
||||
sha1 = "0.10.6"
|
||||
sha2 = "0.10"
|
||||
shlex = "1.3.0"
|
||||
@@ -170,8 +176,9 @@ tracing = "0.1.41"
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = "0.3.20"
|
||||
tracing-test = "0.2.5"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
tree-sitter = "0.25.10"
|
||||
tree-sitter-bash = "0.25"
|
||||
tree-sitter-highlight = "0.25.10"
|
||||
ts-rs = "11"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.2"
|
||||
@@ -239,5 +246,9 @@ strip = "symbols"
|
||||
codegen-units = 1
|
||||
|
||||
[patch.crates-io]
|
||||
# Uncomment to debug local changes.
|
||||
# ratatui = { path = "../../ratatui" }
|
||||
ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" }
|
||||
|
||||
# Uncomment to debug local changes.
|
||||
# rmcp = { path = "../../rust-sdk/crates/rmcp" }
|
||||
|
||||
@@ -23,9 +23,15 @@ Codex supports a rich set of configuration options. Note that the Rust CLI uses
|
||||
|
||||
### Model Context Protocol Support
|
||||
|
||||
Codex CLI functions as an MCP client that can connect to MCP servers on startup. See the [`mcp_servers`](../docs/config.md#mcp_servers) section in the configuration documentation for details.
|
||||
#### MCP client
|
||||
|
||||
It is still experimental, but you can also launch Codex as an MCP _server_ by running `codex mcp-server`. Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
Codex CLI functions as an MCP client that allows the Codex CLI and IDE extension to connect to MCP servers on startup. See the [`configuration documentation`](../docs/config.md#mcp_servers) for details.
|
||||
|
||||
#### MCP server (experimental)
|
||||
|
||||
Codex can be launched as an MCP _server_ by running `codex mcp-server`. This allows _other_ MCP clients to use Codex as a tool for another agent.
|
||||
|
||||
Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
|
||||
```shell
|
||||
npx @modelcontextprotocol/inspector codex mcp-server
|
||||
@@ -71,9 +77,13 @@ To test to see what happens when a command is run under the sandbox provided by
|
||||
|
||||
```
|
||||
# macOS
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex sandbox macos [--full-auto] [COMMAND]...
|
||||
|
||||
# Linux
|
||||
codex sandbox linux [--full-auto] [COMMAND]...
|
||||
|
||||
# Legacy aliases
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex debug landlock [--full-auto] [COMMAND]...
|
||||
```
|
||||
|
||||
|
||||
@@ -3,11 +3,30 @@ use ansi_to_tui::IntoText;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::text::Text;
|
||||
|
||||
// Expand tabs in a best-effort way for transcript rendering.
|
||||
// Tabs can interact poorly with left-gutter prefixes in our TUI and CLI
|
||||
// transcript views (e.g., `nl` separates line numbers from content with a tab).
|
||||
// Replacing tabs with spaces avoids odd visual artifacts without changing
|
||||
// semantics for our use cases.
|
||||
fn expand_tabs(s: &str) -> std::borrow::Cow<'_, str> {
|
||||
if s.contains('\t') {
|
||||
// Keep it simple: replace each tab with 4 spaces.
|
||||
// We do not try to align to tab stops since most usages (like `nl`)
|
||||
// look acceptable with a fixed substitution and this avoids stateful math
|
||||
// across spans.
|
||||
std::borrow::Cow::Owned(s.replace('\t', " "))
|
||||
} else {
|
||||
std::borrow::Cow::Borrowed(s)
|
||||
}
|
||||
}
|
||||
|
||||
/// This function should be used when the contents of `s` are expected to match
|
||||
/// a single line. If multiple lines are found, a warning is logged and only the
|
||||
/// first line is returned.
|
||||
pub fn ansi_escape_line(s: &str) -> Line<'static> {
|
||||
let text = ansi_escape(s);
|
||||
// Normalize tabs to spaces to avoid odd gutter collisions in transcript mode.
|
||||
let s = expand_tabs(s);
|
||||
let text = ansi_escape(&s);
|
||||
match text.lines.as_slice() {
|
||||
[] => "".into(),
|
||||
[only] => only.clone(),
|
||||
|
||||
15
codex-rs/app-server/README.md
Normal file
15
codex-rs/app-server/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# codex-app-server
|
||||
|
||||
`codex app-server` is the harness Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt). The message schema is currently unstable, but those who wish to build experimental UIs on top of Codex may find it valuable.
|
||||
|
||||
## Protocol
|
||||
|
||||
Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication, streaming JSONL over stdio. The protocol is JSON-RPC 2.0, though the `"jsonrpc":"2.0"` header is omitted.
|
||||
|
||||
## Message Schema
|
||||
|
||||
Currently, you can dump a TypeScript version of the schema using `codex generate-ts`. It is specific to the version of Codex you used to run `generate-ts`, so the two are guaranteed to be compatible.
|
||||
|
||||
```
|
||||
codex generate-ts --out DIR
|
||||
```
|
||||
@@ -500,7 +500,7 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn get_user_saved_config(&self, request_id: RequestId) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -653,18 +653,19 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn process_new_conversation(&self, request_id: RequestId, params: NewConversationParams) {
|
||||
let config = match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()) {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let config =
|
||||
match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()).await {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match self.conversation_manager.new_conversation(config).await {
|
||||
Ok(conversation_id) => {
|
||||
@@ -752,7 +753,7 @@ impl CodexMessageProcessor {
|
||||
// Derive a Config using the same logic as new conversation, honoring overrides if provided.
|
||||
let config = match params.overrides {
|
||||
Some(overrides) => {
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone())
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone()).await
|
||||
}
|
||||
None => Ok(self.config.as_ref().clone()),
|
||||
};
|
||||
@@ -1320,7 +1321,7 @@ async fn apply_bespoke_event_handling(
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_config_from_params(
|
||||
async fn derive_config_from_params(
|
||||
params: NewConversationParams,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> std::io::Result<Config> {
|
||||
@@ -1358,7 +1359,7 @@ fn derive_config_from_params(
|
||||
.map(|(k, v)| (k, json_to_toml(v)))
|
||||
.collect();
|
||||
|
||||
Config::load_with_cli_overrides(cli_overrides, overrides)
|
||||
Config::load_with_cli_overrides(cli_overrides, overrides).await
|
||||
}
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
|
||||
@@ -81,6 +81,7 @@ pub async fn run_main(
|
||||
)
|
||||
})?;
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}"))
|
||||
})?;
|
||||
|
||||
@@ -23,5 +23,6 @@ tree-sitter-bash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -843,6 +843,7 @@ pub fn print_summary(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::string::ToString;
|
||||
@@ -894,10 +895,10 @@ mod tests {
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch(&args),
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -905,10 +906,10 @@ mod tests {
|
||||
let patch = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch".to_string();
|
||||
let args = vec![patch];
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -916,10 +917,10 @@ mod tests {
|
||||
let script = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch";
|
||||
let args = args_bash(script);
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -29,7 +29,8 @@ pub async fn run_apply_command(
|
||||
.parse_overrides()
|
||||
.map_err(anyhow::Error::msg)?,
|
||||
ConfigOverrides::default(),
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ codex-app-server-protocol = { workspace = true }
|
||||
codex-protocol-ts = { workspace = true }
|
||||
codex-responses-api-proxy = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-cloud-tasks = { path = "../cloud-tasks" }
|
||||
ctor = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
@@ -46,6 +47,7 @@ tokio = { workspace = true, features = [
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -73,7 +73,8 @@ async fn run_command_under_sandbox(
|
||||
codex_linux_sandbox_exe,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
// In practice, this should be `std::env::current_dir()` because this CLI
|
||||
// does not support `--cwd`, but let's use the config value for consistency.
|
||||
|
||||
@@ -26,7 +26,7 @@ pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
}
|
||||
|
||||
pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match login_with_chatgpt(config.codex_home).await {
|
||||
Ok(_) => {
|
||||
@@ -44,7 +44,7 @@ pub async fn run_login_with_api_key(
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
api_key: String,
|
||||
) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match login_with_api_key(&config.codex_home, &api_key) {
|
||||
Ok(_) => {
|
||||
@@ -91,7 +91,7 @@ pub async fn run_login_with_device_code(
|
||||
issuer_base_url: Option<String>,
|
||||
client_id: Option<String>,
|
||||
) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
let mut opts = ServerOptions::new(
|
||||
config.codex_home,
|
||||
client_id.unwrap_or(CLIENT_ID.to_string()),
|
||||
@@ -112,7 +112,7 @@ pub async fn run_login_with_device_code(
|
||||
}
|
||||
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
@@ -143,7 +143,7 @@ pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
}
|
||||
|
||||
pub async fn run_logout(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match logout(&config.codex_home) {
|
||||
Ok(true) => {
|
||||
@@ -161,7 +161,7 @@ pub async fn run_logout(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
async fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
let cli_overrides = match cli_config_overrides.parse_overrides() {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
@@ -171,7 +171,7 @@ fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
};
|
||||
|
||||
let config_overrides = ConfigOverrides::default();
|
||||
match Config::load_with_cli_overrides(cli_overrides, config_overrides) {
|
||||
match Config::load_with_cli_overrides(cli_overrides, config_overrides).await {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
eprintln!("Error loading configuration: {e}");
|
||||
|
||||
@@ -76,8 +76,9 @@ enum Subcommand {
|
||||
/// Generate shell completion scripts.
|
||||
Completion(CompletionCommand),
|
||||
|
||||
/// Internal debugging commands.
|
||||
Debug(DebugArgs),
|
||||
/// Run commands within a Codex-provided sandbox.
|
||||
#[clap(visible_alias = "debug")]
|
||||
Sandbox(SandboxArgs),
|
||||
|
||||
/// Apply the latest diff produced by Codex agent as a `git apply` to your local working tree.
|
||||
#[clap(visible_alias = "a")]
|
||||
@@ -121,18 +122,20 @@ struct ResumeCommand {
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct DebugArgs {
|
||||
struct SandboxArgs {
|
||||
#[command(subcommand)]
|
||||
cmd: DebugCommand,
|
||||
cmd: SandboxCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
enum DebugCommand {
|
||||
enum SandboxCommand {
|
||||
/// Run a command under Seatbelt (macOS only).
|
||||
Seatbelt(SeatbeltCommand),
|
||||
#[clap(visible_alias = "seatbelt")]
|
||||
Macos(SeatbeltCommand),
|
||||
|
||||
/// Run a command under Landlock+seccomp (Linux only).
|
||||
Landlock(LandlockCommand),
|
||||
#[clap(visible_alias = "landlock")]
|
||||
Linux(LandlockCommand),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -154,9 +157,7 @@ struct LoginCommand {
|
||||
)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// EXPERIMENTAL: Use device code flow (not yet supported)
|
||||
/// This feature is experimental and may changed in future releases.
|
||||
#[arg(long = "experimental_use-device-code", hide = true)]
|
||||
#[arg(long = "device-auth")]
|
||||
use_device_code: bool,
|
||||
|
||||
/// EXPERIMENTAL: Use custom OAuth issuer base URL (advanced)
|
||||
@@ -291,7 +292,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
last,
|
||||
config_overrides,
|
||||
);
|
||||
codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
}
|
||||
Some(Subcommand::Login(mut login_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -341,8 +343,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
);
|
||||
codex_cloud_tasks::run_main(cloud_cli, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Debug(debug_args)) => match debug_args.cmd {
|
||||
DebugCommand::Seatbelt(mut seatbelt_cli) => {
|
||||
Some(Subcommand::Sandbox(sandbox_args)) => match sandbox_args.cmd {
|
||||
SandboxCommand::Macos(mut seatbelt_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut seatbelt_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -353,7 +355,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
DebugCommand::Landlock(mut landlock_cli) => {
|
||||
SandboxCommand::Linux(mut landlock_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut landlock_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -472,6 +474,7 @@ fn print_completion(cmd: CompletionCommand) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
@@ -604,14 +607,14 @@ mod tests {
|
||||
assert_eq!(interactive.model.as_deref(), Some("gpt-5-test"));
|
||||
assert!(interactive.oss);
|
||||
assert_eq!(interactive.config_profile.as_deref(), Some("my-profile"));
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
interactive.sandbox_mode,
|
||||
Some(codex_common::SandboxModeCliArg::WorkspaceWrite)
|
||||
));
|
||||
assert!(matches!(
|
||||
);
|
||||
assert_matches!(
|
||||
interactive.approval_policy,
|
||||
Some(codex_common::ApprovalModeCliArg::OnRequest)
|
||||
));
|
||||
);
|
||||
assert!(interactive.full_auto);
|
||||
assert_eq!(
|
||||
interactive.cwd.as_deref(),
|
||||
|
||||
@@ -4,6 +4,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use clap::ArgGroup;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
@@ -12,6 +13,10 @@ use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use codex_core::mcp::auth::compute_auth_statuses;
|
||||
use codex_core::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::delete_oauth_tokens;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
|
||||
/// [experimental] Launch Codex as an MCP server or manage configured MCP servers.
|
||||
///
|
||||
@@ -43,6 +48,14 @@ pub enum McpSubcommand {
|
||||
|
||||
/// [experimental] Remove a global MCP server entry.
|
||||
Remove(RemoveArgs),
|
||||
|
||||
/// [experimental] Authenticate with a configured MCP server via OAuth.
|
||||
/// Requires experimental_use_rmcp_client = true in config.toml.
|
||||
Login(LoginArgs),
|
||||
|
||||
/// [experimental] Remove stored OAuth credentials for a server.
|
||||
/// Requires experimental_use_rmcp_client = true in config.toml.
|
||||
Logout(LogoutArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -67,13 +80,61 @@ pub struct AddArgs {
|
||||
/// Name for the MCP server configuration.
|
||||
pub name: String,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
#[arg(long, value_parser = parse_env_pair, value_name = "KEY=VALUE")]
|
||||
pub env: Vec<(String, String)>,
|
||||
#[command(flatten)]
|
||||
pub transport_args: AddMcpTransportArgs,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
#[command(
|
||||
group(
|
||||
ArgGroup::new("transport")
|
||||
.args(["command", "url"])
|
||||
.required(true)
|
||||
.multiple(false)
|
||||
)
|
||||
)]
|
||||
pub struct AddMcpTransportArgs {
|
||||
#[command(flatten)]
|
||||
pub stdio: Option<AddMcpStdioArgs>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub streamable_http: Option<AddMcpStreamableHttpArgs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStdioArgs {
|
||||
/// Command to launch the MCP server.
|
||||
#[arg(trailing_var_arg = true, num_args = 1..)]
|
||||
/// Use --url for a streamable HTTP server.
|
||||
#[arg(
|
||||
trailing_var_arg = true,
|
||||
num_args = 0..,
|
||||
)]
|
||||
pub command: Vec<String>,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
/// Only valid with stdio servers.
|
||||
#[arg(
|
||||
long,
|
||||
value_parser = parse_env_pair,
|
||||
value_name = "KEY=VALUE",
|
||||
)]
|
||||
pub env: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStreamableHttpArgs {
|
||||
/// URL for a streamable HTTP MCP server.
|
||||
#[arg(long)]
|
||||
pub url: String,
|
||||
|
||||
/// Optional environment variable to read for a bearer token.
|
||||
/// Only valid with streamable HTTP servers.
|
||||
#[arg(
|
||||
long = "bearer-token-env-var",
|
||||
value_name = "ENV_VAR",
|
||||
requires = "url"
|
||||
)]
|
||||
pub bearer_token_env_var: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -82,6 +143,18 @@ pub struct RemoveArgs {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct LoginArgs {
|
||||
/// Name of the MCP server to authenticate with oauth.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct LogoutArgs {
|
||||
/// Name of the MCP server to deauthenticate.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl McpCli {
|
||||
pub async fn run(self) -> Result<()> {
|
||||
let McpCli {
|
||||
@@ -91,16 +164,22 @@ impl McpCli {
|
||||
|
||||
match subcommand {
|
||||
McpSubcommand::List(args) => {
|
||||
run_list(&config_overrides, args)?;
|
||||
run_list(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Get(args) => {
|
||||
run_get(&config_overrides, args)?;
|
||||
run_get(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Add(args) => {
|
||||
run_add(&config_overrides, args)?;
|
||||
run_add(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Remove(args) => {
|
||||
run_remove(&config_overrides, args)?;
|
||||
run_remove(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Login(args) => {
|
||||
run_login(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Logout(args) => {
|
||||
run_logout(&config_overrides, args).await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,40 +187,56 @@ impl McpCli {
|
||||
}
|
||||
}
|
||||
|
||||
fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let AddArgs { name, env, command } = add_args;
|
||||
let AddArgs {
|
||||
name,
|
||||
transport_args,
|
||||
} = add_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let mut command_parts = command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in env {
|
||||
map.insert(key, value);
|
||||
}
|
||||
Some(map)
|
||||
};
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
let transport = match transport_args {
|
||||
AddMcpTransportArgs {
|
||||
stdio: Some(stdio), ..
|
||||
} => {
|
||||
let mut command_parts = stdio.command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if stdio.env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stdio.env.into_iter().collect::<HashMap<_, _>>())
|
||||
};
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
}
|
||||
}
|
||||
AddMcpTransportArgs {
|
||||
streamable_http: Some(streamable_http),
|
||||
..
|
||||
} => McpServerTransportConfig::StreamableHttp {
|
||||
url: streamable_http.url,
|
||||
bearer_token_env_var: streamable_http.bearer_token_env_var,
|
||||
},
|
||||
AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"),
|
||||
};
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport,
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
};
|
||||
@@ -156,7 +251,7 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) -> Result<()> {
|
||||
async fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) -> Result<()> {
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let RemoveArgs { name } = remove_args;
|
||||
@@ -165,6 +260,7 @@ fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) ->
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let removed = servers.remove(&name).is_some();
|
||||
@@ -183,18 +279,83 @@ fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
if !config.use_experimental_use_rmcp_client {
|
||||
bail!(
|
||||
"OAuth login is only supported when experimental_use_rmcp_client is true in config.toml."
|
||||
);
|
||||
}
|
||||
|
||||
let LoginArgs { name } = login_args;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&name) else {
|
||||
bail!("No MCP server named '{name}' found.");
|
||||
};
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let LogoutArgs { name } = logout_args;
|
||||
|
||||
let server = config
|
||||
.mcp_servers
|
||||
.get(&name)
|
||||
.ok_or_else(|| anyhow!("No MCP server named '{name}' found in configuration."))?;
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let mut entries: Vec<_> = config.mcp_servers.iter().collect();
|
||||
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
)
|
||||
.await;
|
||||
|
||||
if list_args.json {
|
||||
let json_entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
.map(|(name, cfg)| {
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported);
|
||||
let transport = match &cfg.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({
|
||||
"type": "stdio",
|
||||
@@ -202,17 +363,21 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"enabled": cfg.enabled,
|
||||
"transport": transport,
|
||||
"startup_timeout_sec": cfg
|
||||
.startup_timeout_sec
|
||||
@@ -220,6 +385,7 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
"tool_timeout_sec": cfg
|
||||
.tool_timeout_sec
|
||||
.map(|timeout| timeout.as_secs_f64()),
|
||||
"auth_status": auth_status,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -233,8 +399,8 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut stdio_rows: Vec<[String; 4]> = Vec::new();
|
||||
let mut http_rows: Vec<[String; 3]> = Vec::new();
|
||||
let mut stdio_rows: Vec<[String; 6]> = Vec::new();
|
||||
let mut http_rows: Vec<[String; 5]> = Vec::new();
|
||||
|
||||
for (name, cfg) in entries {
|
||||
match &cfg.transport {
|
||||
@@ -257,21 +423,59 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
.join(", ")
|
||||
}
|
||||
};
|
||||
stdio_rows.push([name.clone(), command.clone(), args_display, env_display]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
let has_bearer = if bearer_token.is_some() {
|
||||
"True"
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
} else {
|
||||
"False"
|
||||
"disabled".to_string()
|
||||
};
|
||||
http_rows.push([name.clone(), url.clone(), has_bearer.into()]);
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported)
|
||||
.to_string();
|
||||
stdio_rows.push([
|
||||
name.clone(),
|
||||
command.clone(),
|
||||
args_display,
|
||||
env_display,
|
||||
status,
|
||||
auth_status,
|
||||
]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
} else {
|
||||
"disabled".to_string()
|
||||
};
|
||||
let auth_status = auth_statuses
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported)
|
||||
.to_string();
|
||||
http_rows.push([
|
||||
name.clone(),
|
||||
url.clone(),
|
||||
bearer_token_env_var.clone().unwrap_or("-".to_string()),
|
||||
status,
|
||||
auth_status,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !stdio_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()];
|
||||
let mut widths = [
|
||||
"Name".len(),
|
||||
"Command".len(),
|
||||
"Args".len(),
|
||||
"Env".len(),
|
||||
"Status".len(),
|
||||
"Auth".len(),
|
||||
];
|
||||
for row in &stdio_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -279,28 +483,36 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
"Name",
|
||||
"Command",
|
||||
"Args",
|
||||
"Env",
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = "Name",
|
||||
command = "Command",
|
||||
args = "Args",
|
||||
env = "Env",
|
||||
status = "Status",
|
||||
auth = "Auth",
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
status_w = widths[4],
|
||||
auth_w = widths[5],
|
||||
);
|
||||
|
||||
for row in &stdio_rows {
|
||||
println!(
|
||||
"{:<name_w$} {:<cmd_w$} {:<args_w$} {:<env_w$}",
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = row[0].as_str(),
|
||||
command = row[1].as_str(),
|
||||
args = row[2].as_str(),
|
||||
env = row[3].as_str(),
|
||||
status = row[4].as_str(),
|
||||
auth = row[5].as_str(),
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
status_w = widths[4],
|
||||
auth_w = widths[5],
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -310,7 +522,13 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
}
|
||||
|
||||
if !http_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Url".len(), "Has Bearer Token".len()];
|
||||
let mut widths = [
|
||||
"Name".len(),
|
||||
"Url".len(),
|
||||
"Bearer Token Env Var".len(),
|
||||
"Status".len(),
|
||||
"Auth".len(),
|
||||
];
|
||||
for row in &http_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -318,24 +536,32 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
"Name",
|
||||
"Url",
|
||||
"Has Bearer Token",
|
||||
"{name:<name_w$} {url:<url_w$} {token:<token_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = "Name",
|
||||
url = "Url",
|
||||
token = "Bearer Token Env Var",
|
||||
status = "Status",
|
||||
auth = "Auth",
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
status_w = widths[3],
|
||||
auth_w = widths[4],
|
||||
);
|
||||
|
||||
for row in &http_rows {
|
||||
println!(
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
"{name:<name_w$} {url:<url_w$} {token:<token_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = row[0].as_str(),
|
||||
url = row[1].as_str(),
|
||||
token = row[2].as_str(),
|
||||
status = row[3].as_str(),
|
||||
auth = row[4].as_str(),
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
status_w = widths[3],
|
||||
auth_w = widths[4],
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -343,9 +569,10 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<()> {
|
||||
async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&get_args.name) else {
|
||||
@@ -360,14 +587,18 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<(
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
}),
|
||||
};
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
"name": get_args.name,
|
||||
"enabled": server.enabled,
|
||||
"transport": transport,
|
||||
"startup_timeout_sec": server
|
||||
.startup_timeout_sec
|
||||
@@ -381,6 +612,7 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<(
|
||||
}
|
||||
|
||||
println!("{}", get_args.name);
|
||||
println!(" enabled: {}", server.enabled);
|
||||
match &server.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
println!(" transport: stdio");
|
||||
@@ -406,11 +638,14 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<(
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let bearer = bearer_token.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token: {bearer}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
}
|
||||
}
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
|
||||
@@ -13,8 +13,8 @@ fn codex_command(codex_home: &Path) -> Result<assert_cmd::Command> {
|
||||
Ok(cmd)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
@@ -24,7 +24,7 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("Added global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert_eq!(servers.len(), 1);
|
||||
let docs = servers.get("docs").expect("server should exist");
|
||||
match &docs.transport {
|
||||
@@ -35,6 +35,7 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(docs.enabled);
|
||||
|
||||
let mut remove_cmd = codex_command(codex_home.path())?;
|
||||
remove_cmd
|
||||
@@ -43,7 +44,7 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("Removed global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
let mut remove_again_cmd = codex_command(codex_home.path())?;
|
||||
@@ -53,14 +54,14 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("No MCP server named 'docs' found."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
@@ -80,7 +81,7 @@ fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let envy = servers.get("envy").expect("server should exist");
|
||||
let env = match &envy.transport {
|
||||
McpServerTransportConfig::Stdio { env: Some(env), .. } => env,
|
||||
@@ -90,6 +91,122 @@ fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
assert_eq!(env.len(), 2);
|
||||
assert_eq!(env.get("FOO"), Some(&"bar".to_string()));
|
||||
assert_eq!(env.get("ALPHA"), Some(&"beta".to_string()));
|
||||
assert!(envy.enabled);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_without_manual_token() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args(["mcp", "add", "github", "--url", "https://example.com/mcp"])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let github = servers.get("github").expect("github server should exist");
|
||||
match &github.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(github.enabled);
|
||||
|
||||
assert!(!codex_home.path().join(".credentials.json").exists());
|
||||
assert!(!codex_home.path().join(".env").exists());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_with_custom_env_var() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"issues",
|
||||
"--url",
|
||||
"https://example.com/issues",
|
||||
"--bearer-token-env-var",
|
||||
"GITHUB_TOKEN",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let issues = servers.get("issues").expect("issues server should exist");
|
||||
match &issues.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/issues");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
assert!(issues.enabled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_rejects_removed_flag() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--with-bearer-token",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("--with-bearer-token"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_cant_add_command_and_url() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--command",
|
||||
"--",
|
||||
"echo",
|
||||
"hello",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("unexpected argument '--command' found"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use predicates::prelude::PredicateBooleanExt;
|
||||
use predicates::str::contains;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value as JsonValue;
|
||||
@@ -53,6 +54,10 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("docs"));
|
||||
assert!(stdout.contains("docs-server"));
|
||||
assert!(stdout.contains("TOKEN=secret"));
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Auth"));
|
||||
assert!(stdout.contains("enabled"));
|
||||
assert!(stdout.contains("Unsupported"));
|
||||
|
||||
let mut list_json_cmd = codex_command(codex_home.path())?;
|
||||
let json_output = list_json_cmd.args(["mcp", "list", "--json"]).output()?;
|
||||
@@ -64,6 +69,7 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
json!([
|
||||
{
|
||||
"name": "docs",
|
||||
"enabled": true,
|
||||
"transport": {
|
||||
"type": "stdio",
|
||||
"command": "docs-server",
|
||||
@@ -76,7 +82,8 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
}
|
||||
},
|
||||
"startup_timeout_sec": null,
|
||||
"tool_timeout_sec": null
|
||||
"tool_timeout_sec": null,
|
||||
"auth_status": "unsupported"
|
||||
}
|
||||
]
|
||||
)
|
||||
@@ -91,6 +98,7 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("command: docs-server"));
|
||||
assert!(stdout.contains("args: --port 4000"));
|
||||
assert!(stdout.contains("env: TOKEN=secret"));
|
||||
assert!(stdout.contains("enabled: true"));
|
||||
assert!(stdout.contains("remove: codex mcp remove docs"));
|
||||
|
||||
let mut get_json_cmd = codex_command(codex_home.path())?;
|
||||
@@ -98,7 +106,7 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
.args(["mcp", "get", "docs", "--json"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("\"name\": \"docs\""));
|
||||
.stdout(contains("\"name\": \"docs\"").and(contains("\"enabled\": true")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-cloud-tasks"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "codex_cloud_tasks"
|
||||
@@ -11,26 +11,28 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
anyhow = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = [
|
||||
"mock",
|
||||
"online",
|
||||
] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = ["mock", "online"] }
|
||||
ratatui = { version = "0.29.0" }
|
||||
crossterm = { version = "0.28.1", features = ["event-stream"] }
|
||||
tokio-stream = "0.1.17"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-login = { path = "../login" }
|
||||
codex-core = { path = "../core" }
|
||||
throbber-widgets-tui = "0.8.0"
|
||||
base64 = "0.22"
|
||||
serde_json = "1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
unicode-width = "0.1"
|
||||
codex-login = { path = "../login" }
|
||||
codex-tui = { path = "../tui" }
|
||||
crossterm = { workspace = true, features = ["event-stream"] }
|
||||
ratatui = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
tokio-stream = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
unicode-width = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
async-trait = "0.1"
|
||||
async-trait = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
// Environment filter data models for the TUI
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -42,15 +43,13 @@ use crate::scrollable_diff::ScrollableDiff;
|
||||
use codex_cloud_tasks_client::CloudBackend;
|
||||
use codex_cloud_tasks_client::TaskId;
|
||||
use codex_cloud_tasks_client::TaskSummary;
|
||||
use throbber_widgets_tui::ThrobberState;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct App {
|
||||
pub tasks: Vec<TaskSummary>,
|
||||
pub selected: usize,
|
||||
pub status: String,
|
||||
pub diff_overlay: Option<DiffOverlay>,
|
||||
pub throbber: ThrobberState,
|
||||
pub spinner_start: Option<Instant>,
|
||||
pub refresh_inflight: bool,
|
||||
pub details_inflight: bool,
|
||||
// Environment filter state
|
||||
@@ -82,7 +81,7 @@ impl App {
|
||||
selected: 0,
|
||||
status: "Press r to refresh".to_string(),
|
||||
diff_overlay: None,
|
||||
throbber: ThrobberState::default(),
|
||||
spinner_start: None,
|
||||
refresh_inflight: false,
|
||||
details_inflight: false,
|
||||
env_filter: None,
|
||||
|
||||
@@ -400,16 +400,20 @@ pub async fn run_main(_cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> a
|
||||
let _ = frame_tx.send(Instant::now() + codex_tui::ComposerInput::recommended_flush_delay());
|
||||
}
|
||||
}
|
||||
// Advance throbber only while loading.
|
||||
// Keep spinner pulsing only while loading.
|
||||
if app.refresh_inflight
|
||||
|| app.details_inflight
|
||||
|| app.env_loading
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
app.throbber.calc_next();
|
||||
if app.spinner_start.is_none() {
|
||||
app.spinner_start = Some(Instant::now());
|
||||
}
|
||||
needs_redraw = true;
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(100));
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(600));
|
||||
} else {
|
||||
app.spinner_start = None;
|
||||
}
|
||||
render_if_needed(&mut terminal, &mut app, &mut needs_redraw)?;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use ratatui::widgets::ListState;
|
||||
use ratatui::widgets::Padding;
|
||||
use ratatui::widgets::Paragraph;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::app::AttemptView;
|
||||
@@ -229,7 +230,7 @@ fn draw_list(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|
||||
// In-box spinner during initial/refresh loads
|
||||
if app.refresh_inflight {
|
||||
draw_centered_spinner(frame, inner, &mut app.throbber, "Loading tasks…");
|
||||
draw_centered_spinner(frame, inner, &mut app.spinner_start, "Loading tasks…");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +292,7 @@ fn draw_footer(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
draw_inline_spinner(frame, top[1], &mut app.throbber, "Loading…");
|
||||
draw_inline_spinner(frame, top[1], &mut app.spinner_start, "Loading…");
|
||||
} else {
|
||||
frame.render_widget(Clear, top[1]);
|
||||
}
|
||||
@@ -449,7 +450,12 @@ fn draw_diff_overlay(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
.map(|o| o.sd.wrapped_lines().is_empty())
|
||||
.unwrap_or(true);
|
||||
if app.details_inflight && raw_empty {
|
||||
draw_centered_spinner(frame, content_area, &mut app.throbber, "Loading details…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content_area,
|
||||
&mut app.spinner_start,
|
||||
"Loading details…",
|
||||
);
|
||||
} else {
|
||||
let scroll = app
|
||||
.diff_overlay
|
||||
@@ -494,11 +500,11 @@ pub fn draw_apply_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
frame.render_widget(header, rows[0]);
|
||||
// Body: spinner while preflight/apply runs; otherwise show result message and path lists
|
||||
if app.apply_preflight_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Checking…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Checking…");
|
||||
} else if app.apply_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Applying…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Applying…");
|
||||
} else if m.result_message.is_none() {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Loading…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Loading…");
|
||||
} else if let Some(msg) = &m.result_message {
|
||||
let mut body_lines: Vec<Line> = Vec::new();
|
||||
let first = match m.result_level {
|
||||
@@ -859,29 +865,29 @@ fn format_relative_time(ts: chrono::DateTime<Utc>) -> String {
|
||||
fn draw_inline_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
use ratatui::style::Style;
|
||||
use throbber_widgets_tui::BRAILLE_EIGHT;
|
||||
use throbber_widgets_tui::Throbber;
|
||||
use throbber_widgets_tui::WhichUse;
|
||||
let w = Throbber::default()
|
||||
.label(label)
|
||||
.style(Style::default().cyan())
|
||||
.throbber_style(Style::default().magenta().bold())
|
||||
.throbber_set(BRAILLE_EIGHT)
|
||||
.use_type(WhichUse::Spin);
|
||||
frame.render_stateful_widget(w, area, state);
|
||||
use ratatui::widgets::Paragraph;
|
||||
let start = spinner_start.get_or_insert_with(Instant::now);
|
||||
let blink_on = (start.elapsed().as_millis() / 600).is_multiple_of(2);
|
||||
let dot = if blink_on {
|
||||
"• ".into()
|
||||
} else {
|
||||
"◦ ".dim()
|
||||
};
|
||||
let label = label.cyan();
|
||||
let line = Line::from(vec![dot, label]);
|
||||
frame.render_widget(Paragraph::new(line), area);
|
||||
}
|
||||
|
||||
fn draw_centered_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
// Center a 1xN throbber within the given rect
|
||||
// Center a 1xN spinner within the given rect
|
||||
let rows = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
@@ -898,7 +904,7 @@ fn draw_centered_spinner(
|
||||
Constraint::Percentage(50),
|
||||
])
|
||||
.split(rows[1]);
|
||||
draw_inline_spinner(frame, cols[1], state, label);
|
||||
draw_inline_spinner(frame, cols[1], spinner_start, label);
|
||||
}
|
||||
|
||||
// Styling helpers for diff rendering live inline where used.
|
||||
@@ -918,7 +924,12 @@ pub fn draw_env_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
let content = overlay_content(inner);
|
||||
|
||||
if app.env_loading {
|
||||
draw_centered_spinner(frame, content, &mut app.throbber, "Loading environments…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content,
|
||||
&mut app.spinner_start,
|
||||
"Loading environments…",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@ async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-mcp-client = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-otel = { workspace = true, features = ["otel"] }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-utils-string = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
dunce = { workspace = true }
|
||||
@@ -35,6 +35,7 @@ futures = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
portable-pty = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
@@ -61,7 +62,7 @@ tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
toml = { workspace = true }
|
||||
toml_edit = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
@@ -76,6 +77,9 @@ wildmatch = { workspace = true }
|
||||
landlock = { workspace = true }
|
||||
seccompiler = { workspace = true }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
core-foundation = "0.9"
|
||||
|
||||
# Build OpenSSL from source for musl builds.
|
||||
[target.x86_64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
@@ -86,16 +90,18 @@ openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
maplit = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
tracing-test = { workspace = true, features = ["no-env-filter"] }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
tracing-test = { workspace = true, features = ["no-env-filter"] }
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
@@ -12,7 +12,7 @@ Expects `/usr/bin/sandbox-exec` to be present.
|
||||
|
||||
### Linux
|
||||
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex debug landlock` when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex sandbox linux` (legacy alias: `codex debug landlock`) when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
|
||||
### All Platforms
|
||||
|
||||
|
||||
@@ -10,12 +10,18 @@ You are Codex, based on GPT-5. You are running as a coding agent in the Codex CL
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## File change notifications
|
||||
|
||||
- Environment contexts may include a `<changed_files>` section listing paths that changed since your last turn. Treat this list as mandatory review input: inspect every listed file before planning or applying additional edits.
|
||||
|
||||
## Plan tool
|
||||
|
||||
|
||||
@@ -23,9 +23,13 @@ Your default personality and tone is concise, direct, and friendly. You communic
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## File change notifications
|
||||
|
||||
- Environment contexts may include a `<changed_files>` section listing paths that changed since your last turn. Treat this list as mandatory review input: inspect every listed file before planning or applying additional edits.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
|
||||
@@ -389,10 +389,12 @@ async fn process_chat_sse<S>(
|
||||
let mut reasoning_text = String::new();
|
||||
|
||||
loop {
|
||||
let sse = match otel_event_manager
|
||||
.log_sse_event(|| timeout(idle_timeout, stream.next()))
|
||||
.await
|
||||
{
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
|
||||
@@ -63,7 +63,6 @@ struct ErrorResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
|
||||
@@ -228,7 +227,7 @@ impl ModelClient {
|
||||
input: &input_with_instructions,
|
||||
tools: &tools_json,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: prompt.parallel_tool_calls,
|
||||
reasoning,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
@@ -650,10 +649,12 @@ async fn process_sse<S>(
|
||||
let mut response_error: Option<CodexErr> = None;
|
||||
|
||||
loop {
|
||||
let sse = match otel_event_manager
|
||||
.log_sse_event(|| timeout(idle_timeout, stream.next()))
|
||||
.await
|
||||
{
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
@@ -794,9 +795,13 @@ async fn process_sse<S>(
|
||||
if let Some(error) = error {
|
||||
match serde_json::from_value::<Error>(error.clone()) {
|
||||
Ok(error) => {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(CodexErr::ContextWindowExceeded);
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ErrorResponse: {e}");
|
||||
@@ -922,9 +927,14 @@ fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
@@ -1179,6 +1189,74 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_with_newline_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ────────────────────────────
|
||||
// Table-driven test from `main`
|
||||
// ────────────────────────────
|
||||
@@ -1316,10 +1394,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro)));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
@@ -1334,7 +1409,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip");
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
|
||||
@@ -9,9 +9,11 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
@@ -31,6 +33,9 @@ pub struct Prompt {
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
|
||||
@@ -64,10 +69,125 @@ impl Prompt {
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
@@ -182,6 +302,17 @@ pub(crate) mod tools {
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
@@ -327,7 +458,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -368,7 +499,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -404,7 +535,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
|
||||
@@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
@@ -55,6 +57,8 @@ use crate::exec_command::WriteStdinParams;
|
||||
use crate::executor::Executor;
|
||||
use crate::executor::ExecutorConfig;
|
||||
use crate::executor::normalize_exec_result;
|
||||
use crate::file_change_notifier::FileChangeWatcher;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
@@ -100,7 +104,9 @@ use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::format_exec_output_str;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -360,6 +366,7 @@ impl Session {
|
||||
let mcp_fut = McpConnectionManager::new(
|
||||
config.mcp_servers.clone(),
|
||||
config.use_experimental_use_rmcp_client,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
@@ -456,6 +463,14 @@ impl Session {
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
};
|
||||
let (file_change_watcher, file_change_collector) =
|
||||
match FileChangeWatcher::start(turn_context.cwd.clone()) {
|
||||
Ok((watcher, collector)) => (Some(watcher), Some(collector)),
|
||||
Err(err) => {
|
||||
warn!("file change notifications disabled: {err:#}");
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager,
|
||||
session_manager: ExecSessionManager::default(),
|
||||
@@ -469,6 +484,8 @@ impl Session {
|
||||
turn_context.cwd.clone(),
|
||||
config.codex_linux_sandbox_exe.clone(),
|
||||
)),
|
||||
file_change_collector,
|
||||
_file_change_watcher: file_change_watcher,
|
||||
};
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
@@ -723,6 +740,7 @@ impl Session {
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
Some(self.user_shell().clone()),
|
||||
None,
|
||||
)));
|
||||
items
|
||||
}
|
||||
@@ -782,6 +800,17 @@ impl Session {
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) {
|
||||
let context_window = turn_context.client.get_model_context_window();
|
||||
if let Some(context_window) = context_window {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_token_usage_full(context_window);
|
||||
}
|
||||
self.send_token_count_event(sub_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a user input item to conversation history and also persist a
|
||||
/// corresponding UserMessage EventMsg to rollout.
|
||||
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
|
||||
@@ -807,7 +836,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
let ExecCommandContext {
|
||||
@@ -823,7 +852,10 @@ impl Session {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => {
|
||||
turn_diff_tracker.on_patch_begin(&changes);
|
||||
{
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.on_patch_begin(&changes);
|
||||
}
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
@@ -850,7 +882,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
output: &ExecToolCallOutput,
|
||||
@@ -898,7 +930,10 @@ impl Session {
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -915,7 +950,7 @@ impl Session {
|
||||
/// Returns the output of the exec tool call.
|
||||
pub(crate) async fn run_exec_with_events(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prepared: PreparedExec,
|
||||
approval_policy: AskForApproval,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
@@ -924,7 +959,7 @@ impl Session {
|
||||
let sub_id = context.sub_id.clone();
|
||||
let call_id = context.call_id.clone();
|
||||
|
||||
self.on_exec_command_begin(turn_diff_tracker, context.clone())
|
||||
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||
.await;
|
||||
|
||||
let result = self
|
||||
@@ -963,6 +998,13 @@ impl Session {
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn take_changed_files(&self) -> Vec<PathBuf> {
|
||||
match &self.services.file_change_collector {
|
||||
Some(collector) => collector.drain().await,
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_stream_error(&self, sub_id: &str, message: impl Into<String>) {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
@@ -1204,6 +1246,7 @@ async fn submission_loop(
|
||||
sandbox_policy,
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
None,
|
||||
))])
|
||||
.await;
|
||||
}
|
||||
@@ -1381,10 +1424,18 @@ async fn submission_loop(
|
||||
|
||||
// This is a cheap lookup from the connection manager's cache.
|
||||
let tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
)
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::McpListToolsResponse(
|
||||
crate::protocol::McpListToolsResponseEvent { tools },
|
||||
crate::protocol::McpListToolsResponseEvent {
|
||||
tools,
|
||||
auth_statuses,
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
@@ -1565,7 +1616,7 @@ async fn spawn_review_thread(
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
let input: Vec<InputItem> = vec![InputItem::Text {
|
||||
text: format!("{base_instructions}\n\n---\n\nNow, here's your task: {review_prompt}"),
|
||||
text: review_prompt,
|
||||
}];
|
||||
let tc = Arc::new(review_turn_context);
|
||||
|
||||
@@ -1628,12 +1679,28 @@ pub(crate) async fn run_task(
|
||||
} else {
|
||||
sess.record_input_and_rollout_usermsg(&initial_input_for_turn)
|
||||
.await;
|
||||
let changed_files = sess.take_changed_files().await;
|
||||
if !changed_files.is_empty() {
|
||||
let response_item: ResponseItem =
|
||||
EnvironmentContext::new(None, None, None, None, Some(changed_files)).into();
|
||||
sess.record_conversation_items(std::slice::from_ref(&response_item))
|
||||
.await;
|
||||
for msg in
|
||||
map_response_item_to_event_messages(&response_item, sess.show_raw_agent_reasoning())
|
||||
{
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let mut auto_compact_recently_attempted = false;
|
||||
|
||||
loop {
|
||||
@@ -1681,9 +1748,9 @@ pub(crate) async fn run_task(
|
||||
})
|
||||
.collect();
|
||||
match run_turn(
|
||||
&sess,
|
||||
turn_context.as_ref(),
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
)
|
||||
@@ -1906,18 +1973,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(mcp_tools),
|
||||
));
|
||||
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
tools: router.specs().to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
@@ -1925,10 +2001,10 @@ async fn run_turn(
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(
|
||||
&router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&sub_id,
|
||||
&prompt,
|
||||
)
|
||||
@@ -1938,6 +2014,10 @@ async fn run_turn(
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
@@ -1964,9 +2044,7 @@ async fn run_turn(
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1984,9 +2062,9 @@ async fn run_turn(
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
pub(crate) struct ProcessedResponseItem {
|
||||
pub(crate) item: ResponseItem,
|
||||
pub(crate) response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -1996,10 +2074,10 @@ struct TurnRunResult {
|
||||
}
|
||||
|
||||
async fn try_run_turn(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
@@ -2069,44 +2147,102 @@ async fn try_run_turn(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.to_string(),
|
||||
);
|
||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||
FuturesOrdered::new();
|
||||
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
// Treat as a disconnected stream so the caller can retry.
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
None => {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
}
|
||||
let add_completed = &mut |response_item: ProcessedResponseItem| {
|
||||
output.push_back(future::ready(Ok(response_item)).boxed());
|
||||
};
|
||||
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(
|
||||
router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
|
||||
let response = tool_runtime.handle_tool_call(call);
|
||||
|
||||
output.push_back(
|
||||
async move {
|
||||
Ok(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response.await?),
|
||||
})
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
Ok(None) => {
|
||||
let response = handle_non_tool_response_item(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
add_completed(ProcessedResponseItem { item, response });
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => {
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
let _ = sess
|
||||
@@ -2126,10 +2262,15 @@ async fn try_run_turn(
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -2140,7 +2281,7 @@ async fn try_run_turn(
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
processed_items: output,
|
||||
processed_items,
|
||||
total_token_usage: token_usage.clone(),
|
||||
};
|
||||
|
||||
@@ -2188,88 +2329,40 @@ async fn try_run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
async fn handle_non_tool_response_item(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
|
||||
match ToolRouter::build_tool_call(sess, item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
match router
|
||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(Some(response)),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
&item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg,
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
@@ -2393,6 +2486,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::config::ConfigOverrides;
|
||||
use crate::config::ConfigToml;
|
||||
use crate::file_change_notifier::collector_for_tests;
|
||||
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::InitialHistory;
|
||||
@@ -2505,13 +2599,19 @@ mod tests {
|
||||
|
||||
let out = format_exec_output_str(&exec);
|
||||
|
||||
// Strip truncation header if present for subsequent assertions
|
||||
let body = out
|
||||
.strip_prefix("Total output lines: ")
|
||||
.and_then(|rest| rest.split_once("\n\n").map(|x| x.1))
|
||||
.unwrap_or(out.as_str());
|
||||
|
||||
// Expect elision marker with correct counts
|
||||
let omitted = 400 - MODEL_FORMAT_MAX_LINES; // 144
|
||||
let marker = format!("\n[... omitted {omitted} of 400 lines ...]\n\n");
|
||||
assert!(out.contains(&marker), "missing marker: {out}");
|
||||
|
||||
// Validate head and tail
|
||||
let parts: Vec<&str> = out.split(&marker).collect();
|
||||
let parts: Vec<&str> = body.split(&marker).collect();
|
||||
assert_eq!(parts.len(), 2, "expected one marker split");
|
||||
let head = parts[0];
|
||||
let tail = parts[1];
|
||||
@@ -2547,14 +2647,19 @@ mod tests {
|
||||
};
|
||||
|
||||
let out = format_exec_output_str(&exec);
|
||||
assert!(out.len() <= MODEL_FORMAT_MAX_BYTES, "exceeds byte budget");
|
||||
// Keep strict budget on the truncated body (excluding header)
|
||||
let body = out
|
||||
.strip_prefix("Total output lines: ")
|
||||
.and_then(|rest| rest.split_once("\n\n").map(|x| x.1))
|
||||
.unwrap_or(out.as_str());
|
||||
assert!(body.len() <= MODEL_FORMAT_MAX_BYTES, "exceeds byte budget");
|
||||
assert!(out.contains("omitted"), "should contain elision marker");
|
||||
|
||||
// Ensure head and tail are drawn from the original
|
||||
assert!(full.starts_with(out.chars().take(8).collect::<String>().as_str()));
|
||||
assert!(full.starts_with(body.chars().take(8).collect::<String>().as_str()));
|
||||
assert!(
|
||||
full.ends_with(
|
||||
out.chars()
|
||||
body.chars()
|
||||
.rev()
|
||||
.take(8)
|
||||
.collect::<String>()
|
||||
@@ -2712,6 +2817,8 @@ mod tests {
|
||||
turn_context.cwd.clone(),
|
||||
None,
|
||||
)),
|
||||
file_change_collector: None,
|
||||
_file_change_watcher: None,
|
||||
};
|
||||
let session = Session {
|
||||
conversation_id,
|
||||
@@ -2785,6 +2892,8 @@ mod tests {
|
||||
config.cwd.clone(),
|
||||
None,
|
||||
)),
|
||||
file_change_collector: None,
|
||||
_file_change_watcher: None,
|
||||
};
|
||||
let session = Arc::new(Session {
|
||||
conversation_id,
|
||||
@@ -2901,13 +3010,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn fatal_tool_error_stops_turn_and_reports_error() {
|
||||
let (session, turn_context, _rx) = make_session_and_context_with_rx();
|
||||
let session_ref = session.as_ref();
|
||||
let turn_context_ref = turn_context.as_ref();
|
||||
let router = ToolRouter::from_config(
|
||||
&turn_context_ref.tools_config,
|
||||
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
|
||||
&turn_context.tools_config,
|
||||
Some(session.services.mcp_connection_manager.list_all_tools()),
|
||||
);
|
||||
let mut tracker = TurnDiffTracker::new();
|
||||
let item = ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
@@ -2916,25 +3022,69 @@ mod tests {
|
||||
input: "{}".to_string(),
|
||||
};
|
||||
|
||||
let err = handle_response_item(
|
||||
&router,
|
||||
session_ref,
|
||||
turn_context_ref,
|
||||
&mut tracker,
|
||||
"sub-id",
|
||||
item,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
|
||||
.expect("build tool call")
|
||||
.expect("tool call present");
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let err = router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
tracker,
|
||||
"sub-id".to_string(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
|
||||
match err {
|
||||
CodexErr::Fatal(message) => {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_eq!(message, "tool shell invoked with incompatible payload");
|
||||
}
|
||||
other => panic!("expected CodexErr::Fatal, got {other:?}"),
|
||||
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_task_emits_environment_context_with_changed_files() {
|
||||
let (mut session, turn_context) = make_session_and_context();
|
||||
let collector = collector_for_tests(turn_context.cwd.clone());
|
||||
collector
|
||||
.record_for_tests(PathBuf::from("tracked.rs"))
|
||||
.await;
|
||||
session.services.file_change_collector = Some(collector);
|
||||
session.services._file_change_watcher = None;
|
||||
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let input = vec![InputItem::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
let _ = run_task(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
"sub".to_string(),
|
||||
input,
|
||||
)
|
||||
.await;
|
||||
|
||||
let history = session.history_snapshot().await;
|
||||
let mut found = false;
|
||||
for item in history {
|
||||
if let ResponseItem::Message { content, .. } = item {
|
||||
for part in content {
|
||||
if let ContentItem::InputText { text } = part
|
||||
&& text.contains("<changed_files>")
|
||||
&& text.contains("tracked.rs")
|
||||
{
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(found, "missing changed_files environment context");
|
||||
}
|
||||
|
||||
fn sample_rollout(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
@@ -3045,9 +3195,11 @@ mod tests {
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let (session, mut turn_context) = make_session_and_context();
|
||||
let (session, mut turn_context_raw) = make_session_and_context();
|
||||
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
||||
turn_context.approval_policy = AskForApproval::OnFailure;
|
||||
turn_context_raw.approval_policy = AskForApproval::OnFailure;
|
||||
let session = Arc::new(session);
|
||||
let mut turn_context = Arc::new(turn_context_raw);
|
||||
|
||||
let params = ExecParams {
|
||||
command: if cfg!(windows) {
|
||||
@@ -3075,7 +3227,7 @@ mod tests {
|
||||
..params.clone()
|
||||
};
|
||||
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let tool_name = "shell";
|
||||
let sub_id = "test-sub".to_string();
|
||||
@@ -3084,9 +3236,9 @@ mod tests {
|
||||
let resp = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id,
|
||||
call_id,
|
||||
)
|
||||
@@ -3105,14 +3257,16 @@ mod tests {
|
||||
|
||||
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
||||
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
||||
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
Arc::get_mut(&mut turn_context)
|
||||
.expect("unique turn context Arc")
|
||||
.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
let resp2 = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params2,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
"test-sub".to_string(),
|
||||
"test-call-2".to_string(),
|
||||
)
|
||||
|
||||
@@ -70,14 +70,10 @@ async fn run_compact_task_inner(
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input = sess
|
||||
let mut turn_input = sess
|
||||
.turn_input_with_history(vec![initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
..Default::default()
|
||||
};
|
||||
let mut truncated_count = 0usize;
|
||||
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
@@ -93,25 +89,54 @@ async fn run_compact_task_inner(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result =
|
||||
drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
if truncated_count > 0 {
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
||||
),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
turn_input.remove(0);
|
||||
truncated_count += 1;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
use crate::config_loader::LoadedConfigLayers;
|
||||
pub use crate::config_loader::load_config_as_toml;
|
||||
use crate::config_loader::load_config_layers_with_overrides;
|
||||
use crate::config_loader::merge_toml_values;
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use crate::config_types::DEFAULT_OTEL_ENVIRONMENT;
|
||||
use crate::config_types::History;
|
||||
@@ -29,12 +33,15 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tempfile::NamedTempFile;
|
||||
use toml::Value as TomlValue;
|
||||
use toml_edit::Array as TomlArray;
|
||||
@@ -42,7 +49,10 @@ use toml_edit::DocumentMut;
|
||||
use toml_edit::Item as TomlItem;
|
||||
use toml_edit::Table as TomlTable;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
#[cfg(target_os = "windows")]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
const OPENAI_DEFAULT_REVIEW_MODEL: &str = "gpt-5-codex";
|
||||
pub const GPT_5_CODEX_MEDIUM_MODEL: &str = "gpt-5-codex";
|
||||
|
||||
@@ -135,6 +145,15 @@ pub struct Config {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred store for MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
/// auto (default): keyring if available, otherwise file.
|
||||
pub mcp_oauth_credentials_store_mode: OAuthCredentialsStoreMode,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
|
||||
@@ -202,6 +221,9 @@ pub struct Config {
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: bool,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -212,50 +234,38 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration with *generic* CLI overrides (`-c key=value`) applied
|
||||
/// **in between** the values parsed from `config.toml` and the
|
||||
/// strongly-typed overrides specified via [`ConfigOverrides`].
|
||||
///
|
||||
/// The precedence order is therefore: `config.toml` < `-c` overrides <
|
||||
/// `ConfigOverrides`.
|
||||
pub fn load_with_cli_overrides(
|
||||
pub async fn load_with_cli_overrides(
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
overrides: ConfigOverrides,
|
||||
) -> std::io::Result<Self> {
|
||||
// Resolve the directory that stores Codex state (e.g. ~/.codex or the
|
||||
// value of $CODEX_HOME) so we can embed it into the resulting
|
||||
// `Config` instance.
|
||||
let codex_home = find_codex_home()?;
|
||||
|
||||
// Step 1: parse `config.toml` into a generic JSON value.
|
||||
let mut root_value = load_config_as_toml(&codex_home)?;
|
||||
let root_value = load_resolved_config(
|
||||
&codex_home,
|
||||
cli_overrides,
|
||||
crate::config_loader::LoaderOverrides::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Step 2: apply the `-c` overrides.
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut root_value, &path, value);
|
||||
}
|
||||
|
||||
// Step 3: deserialize into `ConfigToml` so that Serde can enforce the
|
||||
// correct types.
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
|
||||
// Step 4: merge with the strongly-typed overrides.
|
||||
Self::load_from_base_config_with_overrides(cfg, overrides, codex_home)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_config_as_toml_with_cli_overrides(
|
||||
pub async fn load_config_as_toml_with_cli_overrides(
|
||||
codex_home: &Path,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
) -> std::io::Result<ConfigToml> {
|
||||
let mut root_value = load_config_as_toml(codex_home)?;
|
||||
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut root_value, &path, value);
|
||||
}
|
||||
let root_value = load_resolved_config(
|
||||
codex_home,
|
||||
cli_overrides,
|
||||
crate::config_loader::LoaderOverrides::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
@@ -265,43 +275,73 @@ pub fn load_config_as_toml_with_cli_overrides(
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
/// Read `CODEX_HOME/config.toml` and return it as a generic TOML value. Returns
|
||||
/// an empty TOML table when the file does not exist.
|
||||
pub fn load_config_as_toml(codex_home: &Path) -> std::io::Result<TomlValue> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
match std::fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(val) => Ok(val),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse config.toml: {e}");
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
},
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
tracing::info!("config.toml not found, using defaults");
|
||||
Ok(TomlValue::Table(Default::default()))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read config.toml: {e}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
async fn load_resolved_config(
|
||||
codex_home: &Path,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
overrides: crate::config_loader::LoaderOverrides,
|
||||
) -> std::io::Result<TomlValue> {
|
||||
let layers = load_config_layers_with_overrides(codex_home, overrides).await?;
|
||||
Ok(apply_overlays(layers, cli_overrides))
|
||||
}
|
||||
|
||||
pub fn load_global_mcp_servers(
|
||||
fn apply_overlays(
|
||||
layers: LoadedConfigLayers,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
) -> TomlValue {
|
||||
let LoadedConfigLayers {
|
||||
mut base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = layers;
|
||||
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut base, &path, value);
|
||||
}
|
||||
|
||||
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
|
||||
merge_toml_values(&mut base, &overlay);
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
pub async fn load_global_mcp_servers(
|
||||
codex_home: &Path,
|
||||
) -> std::io::Result<BTreeMap<String, McpServerConfig>> {
|
||||
let root_value = load_config_as_toml(codex_home)?;
|
||||
let root_value = load_config_as_toml(codex_home).await?;
|
||||
let Some(servers_value) = root_value.get("mcp_servers") else {
|
||||
return Ok(BTreeMap::new());
|
||||
};
|
||||
|
||||
ensure_no_inline_bearer_tokens(servers_value)?;
|
||||
|
||||
servers_value
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
|
||||
/// We briefly allowed plain text bearer_token fields in MCP server configs.
|
||||
/// We want to warn people who recently added these fields but can remove this after a few months.
|
||||
fn ensure_no_inline_bearer_tokens(value: &TomlValue) -> std::io::Result<()> {
|
||||
let Some(servers_table) = value.as_table() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for (server_name, server_value) in servers_table {
|
||||
if let Some(server_table) = server_value.as_table()
|
||||
&& server_table.contains_key("bearer_token")
|
||||
{
|
||||
let message = format!(
|
||||
"mcp_servers.{server_name} uses unsupported `bearer_token`; set `bearer_token_env_var`."
|
||||
);
|
||||
return Err(std::io::Error::new(ErrorKind::InvalidData, message));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_global_mcp_servers(
|
||||
codex_home: &Path,
|
||||
servers: &BTreeMap<String, McpServerConfig>,
|
||||
@@ -350,14 +390,21 @@ pub fn write_global_mcp_servers(
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
entry["url"] = toml_edit::value(url.clone());
|
||||
if let Some(token) = bearer_token {
|
||||
entry["bearer_token"] = toml_edit::value(token.clone());
|
||||
if let Some(env_var) = bearer_token_env_var {
|
||||
entry["bearer_token_env_var"] = toml_edit::value(env_var.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !config.enabled {
|
||||
entry["enabled"] = toml_edit::value(false);
|
||||
}
|
||||
|
||||
if let Some(timeout) = config.startup_timeout_sec {
|
||||
entry["startup_timeout_sec"] = toml_edit::value(timeout.as_secs_f64());
|
||||
}
|
||||
@@ -469,6 +516,29 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist the acknowledgement flag for the Windows onboarding screen.
|
||||
pub fn set_windows_wsl_setup_acknowledged(
|
||||
codex_home: &Path,
|
||||
acknowledged: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
doc["windows_wsl_setup_acknowledged"] = toml_edit::value(acknowledged);
|
||||
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
std::fs::write(tmp_file.path(), doc.to_string())?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
@@ -666,6 +736,14 @@ pub struct ConfigToml {
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred backend for storing MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: Use a file in the Codex home directory.
|
||||
/// auto (default): Use the OS-specific keyring service if available, otherwise use a file.
|
||||
#[serde(default)]
|
||||
pub mcp_oauth_credentials_store: Option<OAuthCredentialsStoreMode>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
@@ -722,6 +800,7 @@ pub struct ConfigToml {
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
@@ -735,6 +814,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// OTEL configuration.
|
||||
pub otel: Option<crate::config_types::OtelConfigToml>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ConfigToml> for UserSavedConfig {
|
||||
@@ -1042,6 +1124,9 @@ impl Config {
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
// The config.toml omits "_mode" because it's a config file. However, "_mode"
|
||||
// is important in code to differentiate the mode from the store implementation.
|
||||
mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(),
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
project_doc_fallback_filenames: cfg
|
||||
@@ -1080,7 +1165,9 @@ impl Config {
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool
|
||||
.or(cfg.experimental_use_freeform_apply_patch)
|
||||
.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
@@ -1091,6 +1178,7 @@ impl Config {
|
||||
use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
windows_wsl_setup_acknowledged: cfg.windows_wsl_setup_acknowledged.unwrap_or(false),
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
.tui
|
||||
@@ -1330,17 +1418,96 @@ exclude_slash_tmp = true
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
fn config_defaults_to_auto_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml::default();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
assert!(servers.is_empty());
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_round_trips_entries() -> anyhow::Result<()> {
|
||||
fn config_honors_explicit_file_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml {
|
||||
mcp_oauth_credentials_store: Some(OAuthCredentialsStoreMode::File),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn managed_config_overrides_oauth_store_mode() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(&config_path, "mcp_oauth_credentials_store = \"file\"\n")?;
|
||||
std::fs::write(&managed_path, "mcp_oauth_credentials_store = \"keyring\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path.clone()),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(codex_home.path(), Vec::new(), overrides).await?;
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
assert_eq!(
|
||||
cfg.mcp_oauth_credentials_store,
|
||||
Some(OAuthCredentialsStoreMode::Keyring),
|
||||
);
|
||||
|
||||
let final_config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
assert_eq!(
|
||||
final_config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Keyring,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_round_trips_entries() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::new();
|
||||
@@ -1352,6 +1519,7 @@ exclude_slash_tmp = true
|
||||
args: vec!["hello".to_string()],
|
||||
env: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(3)),
|
||||
tool_timeout_sec: Some(Duration::from_secs(5)),
|
||||
},
|
||||
@@ -1359,7 +1527,7 @@ exclude_slash_tmp = true
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert_eq!(loaded.len(), 1);
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
@@ -1372,17 +1540,51 @@ exclude_slash_tmp = true
|
||||
}
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3)));
|
||||
assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5)));
|
||||
assert!(docs.enabled);
|
||||
|
||||
let empty = BTreeMap::new();
|
||||
write_global_mcp_servers(codex_home.path(), &empty)?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(loaded.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn managed_config_wins_over_cli_overrides() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
codex_home.path().join(CONFIG_TOML_FILE),
|
||||
"model = \"base\"\n",
|
||||
)?;
|
||||
std::fs::write(&managed_path, "model = \"managed_config\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(
|
||||
codex_home.path(),
|
||||
vec![("model".to_string(), TomlValue::String("cli".to_string()))],
|
||||
overrides,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
|
||||
assert_eq!(cfg.model.as_deref(), Some("managed_config"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
@@ -1396,15 +1598,40 @@ startup_timeout_ms = 2500
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = servers.get("docs").expect("docs entry");
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_millis(2500)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_rejects_inline_bearer_token() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let err = load_global_mcp_servers(codex_home.path())
|
||||
.await
|
||||
.expect_err("bearer_token entries should be rejected");
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(err.to_string().contains("bearer_token"));
|
||||
assert!(err.to_string().contains("bearer_token_env_var"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
@@ -1418,6 +1645,7 @@ startup_timeout_ms = 2500
|
||||
("ALPHA_VAR".to_string(), "1".to_string()),
|
||||
])),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -1439,7 +1667,7 @@ ZIG_VAR = "3"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
@@ -1457,8 +1685,8 @@ ZIG_VAR = "3"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::from([(
|
||||
@@ -1466,8 +1694,9 @@ ZIG_VAR = "3"
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret-token".to_string()),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -1481,17 +1710,20 @@ ZIG_VAR = "3"
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret-token"
|
||||
bearer_token_env_var = "MCP_TOKEN"
|
||||
startup_timeout_sec = 2.0
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert_eq!(bearer_token.as_deref(), Some("secret-token"));
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1502,8 +1734,9 @@ startup_timeout_sec = 2.0
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -1518,12 +1751,15 @@ url = "https://example.com/mcp"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token.is_none());
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1531,6 +1767,40 @@ url = "https://example.com/mcp"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_disabled_flag() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "docs-server".to_string(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
},
|
||||
enabled: false,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert!(
|
||||
serialized.contains("enabled = false"),
|
||||
"serialized config missing disabled flag:\n{serialized}"
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
assert!(!docs.enabled);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1828,6 +2098,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1850,6 +2121,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1889,6 +2161,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1911,6 +2184,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1965,6 +2239,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1987,6 +2262,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2027,6 +2303,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2049,6 +2326,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2159,6 +2437,7 @@ trust_level = "trusted"
|
||||
#[cfg(test)]
|
||||
mod notifications_tests {
|
||||
use crate::config_types::Notifications;
|
||||
use assert_matches::assert_matches;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
@@ -2178,10 +2457,7 @@ mod notifications_tests {
|
||||
notifications = true
|
||||
"#;
|
||||
let parsed: RootTomlTest = toml::from_str(toml).expect("deserialize notifications=true");
|
||||
assert!(matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Enabled(true)
|
||||
));
|
||||
assert_matches!(parsed.tui.notifications, Notifications::Enabled(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2192,9 +2468,9 @@ mod notifications_tests {
|
||||
"#;
|
||||
let parsed: RootTomlTest =
|
||||
toml::from_str(toml).expect("deserialize notifications=[\"foo\"]");
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Custom(ref v) if v == &vec!["foo".to_string()]
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
118
codex-rs/core/src/config_loader/macos.rs
Normal file
118
codex-rs/core/src/config_loader/macos.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use std::io;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod native {
|
||||
use super::*;
|
||||
use base64::Engine;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use core_foundation::base::TCFType;
|
||||
use core_foundation::string::CFString;
|
||||
use core_foundation::string::CFStringRef;
|
||||
use std::ffi::c_void;
|
||||
use tokio::task;
|
||||
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
if let Some(encoded) = override_base64 {
|
||||
let trimmed = encoded.trim();
|
||||
return if trimmed.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
};
|
||||
}
|
||||
|
||||
const LOAD_ERROR: &str = "Failed to load managed preferences configuration";
|
||||
|
||||
match task::spawn_blocking(load_managed_admin_config).await {
|
||||
Ok(result) => result,
|
||||
Err(join_err) => {
|
||||
if join_err.is_cancelled() {
|
||||
tracing::error!("Managed preferences load task was cancelled");
|
||||
} else {
|
||||
tracing::error!("Managed preferences load task failed: {join_err}");
|
||||
}
|
||||
Err(io::Error::other(LOAD_ERROR))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn load_managed_admin_config() -> io::Result<Option<TomlValue>> {
|
||||
#[link(name = "CoreFoundation", kind = "framework")]
|
||||
unsafe extern "C" {
|
||||
fn CFPreferencesCopyAppValue(
|
||||
key: CFStringRef,
|
||||
application_id: CFStringRef,
|
||||
) -> *mut c_void;
|
||||
}
|
||||
|
||||
const MANAGED_PREFERENCES_APPLICATION_ID: &str = "com.openai.codex";
|
||||
const MANAGED_PREFERENCES_CONFIG_KEY: &str = "config_toml_base64";
|
||||
|
||||
let application_id = CFString::new(MANAGED_PREFERENCES_APPLICATION_ID);
|
||||
let key = CFString::new(MANAGED_PREFERENCES_CONFIG_KEY);
|
||||
|
||||
let value_ref = unsafe {
|
||||
CFPreferencesCopyAppValue(
|
||||
key.as_concrete_TypeRef(),
|
||||
application_id.as_concrete_TypeRef(),
|
||||
)
|
||||
};
|
||||
|
||||
if value_ref.is_null() {
|
||||
tracing::debug!(
|
||||
"Managed preferences for {} key {} not found",
|
||||
MANAGED_PREFERENCES_APPLICATION_ID,
|
||||
MANAGED_PREFERENCES_CONFIG_KEY
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = unsafe { CFString::wrap_under_create_rule(value_ref as _) };
|
||||
let contents = value.to_string();
|
||||
let trimmed = contents.trim();
|
||||
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
}
|
||||
|
||||
pub(super) fn parse_managed_preferences_base64(encoded: &str) -> io::Result<TomlValue> {
|
||||
let decoded = BASE64_STANDARD.decode(encoded.as_bytes()).map_err(|err| {
|
||||
tracing::error!("Failed to decode managed preferences as base64: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
let decoded_str = String::from_utf8(decoded).map_err(|err| {
|
||||
tracing::error!("Managed preferences base64 contents were not valid UTF-8: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
match toml::from_str::<TomlValue>(&decoded_str) {
|
||||
Ok(TomlValue::Table(parsed)) => Ok(TomlValue::Table(parsed)),
|
||||
Ok(other) => {
|
||||
tracing::error!(
|
||||
"Managed preferences TOML must have a table at the root, found {other:?}",
|
||||
);
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"managed preferences root must be a table",
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse managed preferences TOML: {err}");
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) use native::load_managed_admin_config_layer;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
_override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
Ok(None)
|
||||
}
|
||||
311
codex-rs/core/src/config_loader/mod.rs
Normal file
311
codex-rs/core/src/config_loader/mod.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
mod macos;
|
||||
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use macos::load_managed_admin_config_layer;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(unix)]
|
||||
const CODEX_MANAGED_CONFIG_SYSTEM_PATH: &str = "/etc/codex/managed_config.toml";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LoadedConfigLayers {
|
||||
pub base: TomlValue,
|
||||
pub managed_config: Option<TomlValue>,
|
||||
pub managed_preferences: Option<TomlValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct LoaderOverrides {
|
||||
pub managed_config_path: Option<PathBuf>,
|
||||
#[cfg(target_os = "macos")]
|
||||
pub managed_preferences_base64: Option<String>,
|
||||
}
|
||||
|
||||
// Configuration layering pipeline (top overrides bottom):
|
||||
//
|
||||
// +-------------------------+
|
||||
// | Managed preferences (*) |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | managed_config.toml |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | config.toml (base) |
|
||||
// +-------------------------+
|
||||
//
|
||||
// (*) Only available on macOS via managed device profiles.
|
||||
|
||||
pub async fn load_config_as_toml(codex_home: &Path) -> io::Result<TomlValue> {
|
||||
load_config_as_toml_with_overrides(codex_home, LoaderOverrides::default()).await
|
||||
}
|
||||
|
||||
fn default_empty_table() -> TomlValue {
|
||||
TomlValue::Table(Default::default())
|
||||
}
|
||||
|
||||
pub(crate) async fn load_config_layers_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
load_config_layers_internal(codex_home, overrides).await
|
||||
}
|
||||
|
||||
async fn load_config_as_toml_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<TomlValue> {
|
||||
let layers = load_config_layers_internal(codex_home, overrides).await?;
|
||||
Ok(apply_managed_layers(layers))
|
||||
}
|
||||
|
||||
async fn load_config_layers_internal(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
#[cfg(target_os = "macos")]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
managed_preferences_base64,
|
||||
} = overrides;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
} = overrides;
|
||||
|
||||
let managed_config_path =
|
||||
managed_config_path.unwrap_or_else(|| managed_config_default_path(codex_home));
|
||||
|
||||
let user_config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let user_config = read_config_from_path(&user_config_path, true).await?;
|
||||
let managed_config = read_config_from_path(&managed_config_path, false).await?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
let managed_preferences =
|
||||
load_managed_admin_config_layer(managed_preferences_base64.as_deref()).await?;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let managed_preferences = load_managed_admin_config_layer(None).await?;
|
||||
|
||||
Ok(LoadedConfigLayers {
|
||||
base: user_config.unwrap_or_else(default_empty_table),
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_config_from_path(
|
||||
path: &Path,
|
||||
log_missing_as_info: bool,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
match fs::read_to_string(path).await {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse {}: {err}", path.display());
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
},
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => {
|
||||
if log_missing_as_info {
|
||||
tracing::info!("{} not found, using defaults", path.display());
|
||||
} else {
|
||||
tracing::debug!("{} not found", path.display());
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to read {}: {err}", path.display());
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge config `overlay` into `base`, giving `overlay` precedence.
|
||||
pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
if let TomlValue::Table(overlay_table) = overlay
|
||||
&& let TomlValue::Table(base_table) = base
|
||||
{
|
||||
for (key, value) in overlay_table {
|
||||
if let Some(existing) = base_table.get_mut(key) {
|
||||
merge_toml_values(existing, value);
|
||||
} else {
|
||||
base_table.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*base = overlay.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn managed_config_default_path(codex_home: &Path) -> PathBuf {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let _ = codex_home;
|
||||
PathBuf::from(CODEX_MANAGED_CONFIG_SYSTEM_PATH)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
codex_home.join("managed_config.toml")
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_managed_layers(layers: LoadedConfigLayers) -> TomlValue {
|
||||
let LoadedConfigLayers {
|
||||
mut base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = layers;
|
||||
|
||||
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
|
||||
merge_toml_values(&mut base, &overlay);
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn merges_managed_config_layer_on_top() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"foo = 1
|
||||
|
||||
[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"foo = 2
|
||||
|
||||
[nested]
|
||||
value = "managed_config"
|
||||
extra = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
|
||||
assert_eq!(table.get("foo"), Some(&TomlValue::Integer(2)));
|
||||
let nested = table
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed_config".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("extra"), Some(&TomlValue::Boolean(true)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_empty_when_all_layers_missing() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let layers = load_config_layers_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load layers");
|
||||
let base_table = layers.base.as_table().expect("base table expected");
|
||||
assert!(
|
||||
base_table.is_empty(),
|
||||
"expected empty base layer when configs missing"
|
||||
);
|
||||
assert!(
|
||||
layers.managed_config.is_none(),
|
||||
"managed config layer should be absent when file missing"
|
||||
);
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let loaded = load_config_as_toml(tmp.path()).await.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
assert!(
|
||||
table.is_empty(),
|
||||
"expected empty table when configs missing"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn managed_preferences_take_highest_precedence() {
|
||||
use base64::Engine;
|
||||
|
||||
let managed_payload = r#"
|
||||
[nested]
|
||||
value = "managed"
|
||||
flag = false
|
||||
"#;
|
||||
let encoded = base64::prelude::BASE64_STANDARD.encode(managed_payload.as_bytes());
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"[nested]
|
||||
value = "managed_config"
|
||||
flag = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
managed_preferences_base64: Some(encoded),
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let nested = loaded
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested table");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("flag"), Some(&TomlValue::Boolean(false)));
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,10 @@ pub struct McpServerConfig {
|
||||
#[serde(flatten)]
|
||||
pub transport: McpServerTransportConfig,
|
||||
|
||||
/// When `false`, Codex skips initializing this MCP server.
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
|
||||
#[serde(
|
||||
default,
|
||||
@@ -48,6 +52,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
@@ -55,6 +60,8 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
startup_timeout_ms: Option<u64>,
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
tool_timeout_sec: Option<Duration>,
|
||||
#[serde(default)]
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
let raw = RawMcpServerConfig::deserialize(deserializer)?;
|
||||
@@ -86,11 +93,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
args,
|
||||
env,
|
||||
url,
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?;
|
||||
throw_if_set(
|
||||
"stdio",
|
||||
"bearer_token_env_var",
|
||||
bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
@@ -100,6 +111,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
RawMcpServerConfig {
|
||||
url: Some(url),
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
@@ -108,7 +120,11 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token }
|
||||
throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
}
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
};
|
||||
@@ -117,10 +133,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
transport,
|
||||
startup_timeout_sec,
|
||||
tool_timeout_sec: raw.tool_timeout_sec,
|
||||
enabled: raw.enabled.unwrap_or_else(default_enabled),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
|
||||
pub enum McpServerTransportConfig {
|
||||
@@ -135,11 +156,11 @@ pub enum McpServerTransportConfig {
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// A plain text bearer token to use for authentication.
|
||||
/// This bearer token will be included in the HTTP request header as an `Authorization: Bearer <token>` header.
|
||||
/// This should be used with caution because it lives on disk in clear text.
|
||||
/// Name of the environment variable to read for an HTTP bearer token.
|
||||
/// When set, requests will include the token via `Authorization: Bearer <token>`.
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -450,6 +471,7 @@ mod tests {
|
||||
env: None
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -470,6 +492,7 @@ mod tests {
|
||||
env: None
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -491,6 +514,20 @@ mod tests {
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())]))
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_disabled_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
enabled = false
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize disabled server config");
|
||||
|
||||
assert!(!cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -506,17 +543,18 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None
|
||||
bearer_token_env_var: None
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_bearer_token() {
|
||||
fn deserialize_streamable_http_server_config_with_env_var() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
bearer_token_env_var = "GITHUB_TOKEN"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
@@ -525,9 +563,10 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret".to_string())
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string())
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -553,13 +592,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_bearer_token_for_stdio_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
url = "https://example.com"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer token for stdio transport");
|
||||
.expect_err("should reject bearer_token field");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("bearer_token is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +210,7 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -236,7 +237,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn drops_from_last_user_only() {
|
||||
let items = vec![
|
||||
let items = [
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
@@ -283,7 +284,7 @@ mod tests {
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
assert_matches!(truncated2, InitialHistory::New);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::sync::OnceLock;
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
@@ -35,10 +35,11 @@ pub enum SetOriginatorError {
|
||||
AlreadyInitialized,
|
||||
}
|
||||
|
||||
fn init_originator_from_env() -> Originator {
|
||||
let default = "codex_cli_rs";
|
||||
fn get_originator_value(provided: Option<String>) -> Originator {
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.unwrap_or_else(|_| default.to_string());
|
||||
.ok()
|
||||
.or(provided)
|
||||
.unwrap_or(DEFAULT_ORIGINATOR.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
@@ -48,31 +49,22 @@ fn init_originator_from_env() -> Originator {
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: default.to_string(),
|
||||
header_value: HeaderValue::from_static(default),
|
||||
value: DEFAULT_ORIGINATOR.to_string(),
|
||||
header_value: HeaderValue::from_static(DEFAULT_ORIGINATOR),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_originator(value: String) -> Result<Originator, SetOriginatorError> {
|
||||
let header_value =
|
||||
HeaderValue::from_str(&value).map_err(|_| SetOriginatorError::InvalidHeaderValue)?;
|
||||
Ok(Originator {
|
||||
value,
|
||||
header_value,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_default_originator(value: &str) -> Result<(), SetOriginatorError> {
|
||||
let originator = build_originator(value.to_string())?;
|
||||
pub fn set_default_originator(value: String) -> Result<(), SetOriginatorError> {
|
||||
let originator = get_originator_value(Some(value));
|
||||
ORIGINATOR
|
||||
.set(originator)
|
||||
.map_err(|_| SetOriginatorError::AlreadyInitialized)
|
||||
}
|
||||
|
||||
pub fn originator() -> &'static Originator {
|
||||
ORIGINATOR.get_or_init(init_originator_from_env)
|
||||
ORIGINATOR.get_or_init(|| get_originator_value(None))
|
||||
}
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
|
||||
@@ -29,6 +29,7 @@ pub(crate) struct EnvironmentContext {
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
pub changed_files: Option<Vec<PathBuf>>,
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
@@ -37,6 +38,7 @@ impl EnvironmentContext {
|
||||
approval_policy: Option<AskForApproval>,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
shell: Option<Shell>,
|
||||
changed_files: Option<Vec<PathBuf>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
cwd,
|
||||
@@ -70,6 +72,8 @@ impl EnvironmentContext {
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
changed_files: changed_files
|
||||
.and_then(|paths| if paths.is_empty() { None } else { Some(paths) }),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +89,7 @@ impl EnvironmentContext {
|
||||
writable_roots,
|
||||
// should compare all fields except shell
|
||||
shell: _,
|
||||
changed_files,
|
||||
} = other;
|
||||
|
||||
self.cwd == *cwd
|
||||
@@ -92,6 +97,7 @@ impl EnvironmentContext {
|
||||
&& self.sandbox_mode == *sandbox_mode
|
||||
&& self.network_access == *network_access
|
||||
&& self.writable_roots == *writable_roots
|
||||
&& self.changed_files == *changed_files
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +109,7 @@ impl From<&TurnContext> for EnvironmentContext {
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -155,6 +162,14 @@ impl EnvironmentContext {
|
||||
{
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
}
|
||||
if let Some(changed_files) = &self.changed_files
|
||||
&& !changed_files.is_empty() {
|
||||
lines.push(" <changed_files>".to_string());
|
||||
for path in changed_files {
|
||||
lines.push(format!(" <path>{}</path>", path.to_string_lossy()));
|
||||
}
|
||||
lines.push(" </changed_files>".to_string());
|
||||
}
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
@@ -196,6 +211,7 @@ mod tests {
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
@@ -219,6 +235,7 @@ mod tests {
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
@@ -237,6 +254,7 @@ mod tests {
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
@@ -256,12 +274,14 @@ mod tests {
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::Never),
|
||||
Some(workspace_write_policy(vec!["/repo"], true)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
@@ -273,12 +293,14 @@ mod tests {
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_read_only_policy()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_workspace_write_policy()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
@@ -291,12 +313,14 @@ mod tests {
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
@@ -312,6 +336,7 @@ mod tests {
|
||||
shell_path: "/bin/bash".into(),
|
||||
bashrc_path: "/home/user/.bashrc".into(),
|
||||
})),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
@@ -321,6 +346,7 @@ mod tests {
|
||||
shell_path: "/bin/zsh".into(),
|
||||
zshrc_path: "/home/user/.zshrc".into(),
|
||||
})),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(context1.equals_except_shell(&context2));
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use reqwest::StatusCode;
|
||||
@@ -12,6 +13,9 @@ use tokio::task::JoinError;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
/// Limit UI error messages to a reasonable size while keeping useful context.
|
||||
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 4 KiB
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SandboxErr {
|
||||
/// Error from sandbox execution
|
||||
@@ -55,6 +59,11 @@ pub enum CodexErr {
|
||||
#[error("stream disconnected before completion: {0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error(
|
||||
"Codex ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying."
|
||||
)]
|
||||
ContextWindowExceeded,
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
@@ -299,21 +308,44 @@ impl CodexErr {
|
||||
}
|
||||
|
||||
pub fn get_error_message_ui(e: &CodexErr) -> String {
|
||||
match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => output.stderr.text.clone(),
|
||||
let message = match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => {
|
||||
let aggregated = output.aggregated_output.text.trim();
|
||||
if !aggregated.is_empty() {
|
||||
output.aggregated_output.text.clone()
|
||||
} else {
|
||||
let stderr = output.stderr.text.trim();
|
||||
let stdout = output.stdout.text.trim();
|
||||
match (stderr.is_empty(), stdout.is_empty()) {
|
||||
(false, false) => format!("{stderr}\n{stdout}"),
|
||||
(false, true) => output.stderr.text.clone(),
|
||||
(true, false) => output.stdout.text.clone(),
|
||||
(true, true) => format!(
|
||||
"command failed inside sandbox with exit code {}",
|
||||
output.exit_code
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Timeouts are not sandbox errors from a UX perspective; present them plainly
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
),
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => {
|
||||
format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
)
|
||||
}
|
||||
_ => e.to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
truncate_middle(&message, ERROR_MESSAGE_UI_MAX_BYTES).0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::exec::StreamOutput;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn rate_limit_snapshot() -> RateLimitSnapshot {
|
||||
RateLimitSnapshot {
|
||||
@@ -343,6 +375,73 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_uses_aggregated_output_when_stderr_empty() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 77,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new("aggregate detail".to_string()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "aggregate detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_both_streams_when_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 9,
|
||||
stdout: StreamOutput::new("stdout detail".to_string()),
|
||||
stderr: StreamOutput::new("stderr detail".to_string()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_stdout_when_no_stderr() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 11,
|
||||
stdout: StreamOutput::new("stdout only".to_string()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(8),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stdout only");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_exit_code_when_no_output_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 13,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(5),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(
|
||||
get_error_message_ui(&err),
|
||||
"command failed inside sandbox with exit code 13"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
|
||||
@@ -127,6 +127,7 @@ mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -158,7 +159,7 @@ mod tests {
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_matches!(user.kind, Some(InputMessageKind::Plain));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
|
||||
@@ -177,7 +177,7 @@ pub async fn process_exec_tool_call(
|
||||
}));
|
||||
}
|
||||
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
if is_likely_sandbox_denied(sandbox_type, &exec_output) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
@@ -195,21 +195,57 @@ pub async fn process_exec_tool_call(
|
||||
/// We don't have a fully deterministic way to tell if our command failed
|
||||
/// because of the sandbox - a command in the user's zshrc file might hit an
|
||||
/// error, but the command itself might fail or succeed for other reasons.
|
||||
/// For now, we conservatively check for 'command not found' (exit code 127),
|
||||
/// and can add additional cases as necessary.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exit_code: i32) -> bool {
|
||||
if sandbox_type == SandboxType::None {
|
||||
/// For now, we conservatively check for well known command failure exit codes and
|
||||
/// also look for common sandbox denial keywords in the command output.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exec_output: &ExecToolCallOutput) -> bool {
|
||||
if sandbox_type == SandboxType::None || exec_output.exit_code == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick rejects: well-known non-sandbox shell exit codes
|
||||
// 127: command not found, 2: misuse of shell builtins
|
||||
if exit_code == 127 {
|
||||
// 2: misuse of shell builtins
|
||||
// 126: permission denied
|
||||
// 127: command not found
|
||||
const QUICK_REJECT_EXIT_CODES: [i32; 3] = [2, 126, 127];
|
||||
if QUICK_REJECT_EXIT_CODES.contains(&exec_output.exit_code) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For all other cases, we assume the sandbox is the cause
|
||||
true
|
||||
const SANDBOX_DENIED_KEYWORDS: [&str; 6] = [
|
||||
"operation not permitted",
|
||||
"permission denied",
|
||||
"read-only file system",
|
||||
"seccomp",
|
||||
"sandbox",
|
||||
"landlock",
|
||||
];
|
||||
|
||||
if [
|
||||
&exec_output.stderr.text,
|
||||
&exec_output.stdout.text,
|
||||
&exec_output.aggregated_output.text,
|
||||
]
|
||||
.into_iter()
|
||||
.any(|section| {
|
||||
let lower = section.to_lowercase();
|
||||
SANDBOX_DENIED_KEYWORDS
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle))
|
||||
}) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
const SIGSYS_CODE: i32 = libc::SIGSYS;
|
||||
if sandbox_type == SandboxType::LinuxSeccomp
|
||||
&& exec_output.exit_code == EXIT_CODE_SIGNAL_BASE + SIGSYS_CODE
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -436,3 +472,77 @@ fn synthetic_exit_status(code: i32) -> ExitStatus {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
std::process::ExitStatus::from_raw(code.try_into().unwrap())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn make_exec_output(
|
||||
exit_code: i32,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
aggregated: &str,
|
||||
) -> ExecToolCallOutput {
|
||||
ExecToolCallOutput {
|
||||
exit_code,
|
||||
stdout: StreamOutput::new(stdout.to_string()),
|
||||
stderr: StreamOutput::new(stderr.to_string()),
|
||||
aggregated_output: StreamOutput::new(aggregated.to_string()),
|
||||
duration: Duration::from_millis(1),
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_requires_keywords() {
|
||||
let output = make_exec_output(1, "", "", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_identifies_keyword_in_stderr() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_respects_quick_reject_exit_codes() {
|
||||
let output = make_exec_output(127, "", "command not found", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_ignores_non_sandbox_mode() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(!is_likely_sandbox_denied(SandboxType::None, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_uses_aggregated_output() {
|
||||
let output = make_exec_output(
|
||||
101,
|
||||
"",
|
||||
"",
|
||||
"cargo failed: Read-only file system when writing target",
|
||||
);
|
||||
assert!(is_likely_sandbox_denied(
|
||||
SandboxType::MacosSeatbelt,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn sandbox_detection_flags_sigsys_exit_code() {
|
||||
let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS;
|
||||
let output = make_exec_output(exit_code, "", "", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,6 +380,23 @@ mod tests {
|
||||
assert_eq!(message, "failed in sandbox: sandbox stderr");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_failure_message_falls_back_to_aggregated_output() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 101,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new("aggregate text".to_string()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
};
|
||||
let message = sandbox_failure_message(err);
|
||||
assert_eq!(message, "failed in sandbox: aggregate text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_function_error_synthesizes_payload() {
|
||||
let err = FunctionCallError::RespondToModel("boom".to_string());
|
||||
|
||||
177
codex-rs/core/src/file_change_notifier.rs
Normal file
177
codex-rs/core/src/file_change_notifier.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use dunce::canonicalize;
|
||||
use notify::Config;
|
||||
use notify::Event;
|
||||
use notify::EventKind;
|
||||
use notify::RecommendedWatcher;
|
||||
use notify::RecursiveMode;
|
||||
use notify::Watcher;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct FileChangeCollector {
|
||||
root: Arc<PathBuf>,
|
||||
changes: Arc<Mutex<HashSet<PathBuf>>>,
|
||||
}
|
||||
|
||||
impl FileChangeCollector {
|
||||
fn new(root: PathBuf) -> Self {
|
||||
Self {
|
||||
root: Arc::new(root),
|
||||
changes: Arc::new(Mutex::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn root_path(&self) -> &Path {
|
||||
self.root.as_ref()
|
||||
}
|
||||
|
||||
pub(crate) async fn record_event(&self, event: Event) {
|
||||
let Event { kind, paths, .. } = event;
|
||||
if should_skip(kind) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut normalized_paths = Vec::new();
|
||||
for path in paths {
|
||||
if let Some(rel) = self.normalize_path(&path) {
|
||||
normalized_paths.push(rel);
|
||||
}
|
||||
}
|
||||
|
||||
if normalized_paths.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut guard = self.changes.lock().await;
|
||||
guard.extend(normalized_paths);
|
||||
}
|
||||
|
||||
fn normalize_path(&self, path: &Path) -> Option<PathBuf> {
|
||||
let rel = path.strip_prefix(self.root_path()).ok()?;
|
||||
if should_ignore(rel) {
|
||||
return None;
|
||||
}
|
||||
Some(rel.to_path_buf())
|
||||
}
|
||||
|
||||
pub(crate) async fn drain(&self) -> Vec<PathBuf> {
|
||||
let mut guard = self.changes.lock().await;
|
||||
let mut entries: Vec<PathBuf> = guard.drain().collect();
|
||||
entries.sort();
|
||||
entries
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip(kind: EventKind) -> bool {
|
||||
matches!(
|
||||
kind,
|
||||
EventKind::Access(_) | EventKind::Other | EventKind::Any
|
||||
)
|
||||
}
|
||||
|
||||
fn should_ignore(path: &Path) -> bool {
|
||||
path.components().any(|component| {
|
||||
if let std::path::Component::Normal(part) = component {
|
||||
matches!(
|
||||
part.to_str(),
|
||||
Some(".git") | Some(".codex") | Some("target") | Some(".DS_Store") | Some(".idea")
|
||||
)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) struct FileChangeWatcher {
|
||||
_watcher: RecommendedWatcher,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl FileChangeWatcher {
|
||||
pub(crate) fn start(root: PathBuf) -> Result<(Self, FileChangeCollector)> {
|
||||
let canonical_root = canonicalize(&root)
|
||||
.with_context(|| format!("failed to canonicalize {}", root.display()))?;
|
||||
let collector = FileChangeCollector::new(canonical_root.clone());
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
let mut watcher = RecommendedWatcher::new(
|
||||
move |res| {
|
||||
let _ = tx.send(res);
|
||||
},
|
||||
Config::default(),
|
||||
)
|
||||
.context("failed to initialize file watcher")?;
|
||||
watcher
|
||||
.watch(&canonical_root, RecursiveMode::Recursive)
|
||||
.with_context(|| format!("failed to watch path {}", canonical_root.display()))?;
|
||||
|
||||
let collector_clone = collector.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
while let Some(event_result) = rx.recv().await {
|
||||
match event_result {
|
||||
Ok(event) => collector_clone.record_event(event).await,
|
||||
Err(err) => warn!("file watcher error: {err}"),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
_watcher: watcher,
|
||||
task,
|
||||
},
|
||||
collector,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FileChangeWatcher {
|
||||
fn drop(&mut self) {
|
||||
if self.task.is_finished() {
|
||||
return;
|
||||
}
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn collector_for_tests(root: PathBuf) -> FileChangeCollector {
|
||||
FileChangeCollector::new(root)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl FileChangeCollector {
|
||||
pub(crate) async fn record_for_tests(&self, path: PathBuf) {
|
||||
let mut guard = self.changes.lock().await;
|
||||
guard.insert(path);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_returns_empty_when_no_changes() {
|
||||
let collector = collector_for_tests(PathBuf::from("."));
|
||||
let drained = collector.drain().await;
|
||||
assert!(drained.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_and_drain_changes() {
|
||||
let collector = collector_for_tests(PathBuf::from("/tmp/root"));
|
||||
collector.record_for_tests(PathBuf::from("file.txt")).await;
|
||||
let drained = collector.drain().await;
|
||||
assert_eq!(drained, vec![PathBuf::from("file.txt")]);
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ pub use codex_conversation::CodexConversation;
|
||||
mod command_safety;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_loader;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -28,9 +29,11 @@ pub mod exec;
|
||||
mod exec_command;
|
||||
pub mod exec_env;
|
||||
pub mod executor;
|
||||
mod file_change_notifier;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod landlock;
|
||||
pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
|
||||
58
codex-rs/core/src/mcp/auth.rs
Normal file
58
codex-rs/core/src/mcp/auth.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::determine_streamable_http_auth_status;
|
||||
use futures::future::join_all;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::McpServerTransportConfig;
|
||||
|
||||
pub async fn compute_auth_statuses<'a, I>(
|
||||
servers: I,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> HashMap<String, McpAuthStatus>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
|
||||
{
|
||||
let futures = servers.into_iter().map(|(name, config)| {
|
||||
let name = name.clone();
|
||||
let config = config.clone();
|
||||
async move {
|
||||
let status = match compute_auth_status(&name, &config, store_mode).await {
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
|
||||
McpAuthStatus::Unsupported
|
||||
}
|
||||
};
|
||||
(name, status)
|
||||
}
|
||||
});
|
||||
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
async fn compute_auth_status(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
match &config.transport {
|
||||
McpServerTransportConfig::Stdio { .. } => Ok(McpAuthStatus::Unsupported),
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
1
codex-rs/core/src/mcp/mod.rs
Normal file
1
codex-rs/core/src/mcp/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth;
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -16,6 +17,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -108,9 +110,6 @@ impl McpClientAdapter {
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
info!(
|
||||
"new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}"
|
||||
);
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
@@ -123,12 +122,17 @@ impl McpClientAdapter {
|
||||
}
|
||||
|
||||
async fn new_streamable_http_client(
|
||||
server_name: String,
|
||||
url: String,
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?);
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
}
|
||||
@@ -182,6 +186,7 @@ impl McpConnectionManager {
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
@@ -202,22 +207,21 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp { .. }
|
||||
) && !use_rmcp_client
|
||||
{
|
||||
info!(
|
||||
"skipping MCP server `{}` configured with url because rmcp client is disabled",
|
||||
server_name
|
||||
);
|
||||
if !cfg.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let use_rmcp_client_flag = use_rmcp_client;
|
||||
let resolved_bearer_token = match &cfg.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
|
||||
_ => Ok(None),
|
||||
};
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
@@ -246,21 +250,23 @@ impl McpConnectionManager {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
McpClientAdapter::new_stdio_client(
|
||||
use_rmcp_client_flag,
|
||||
use_rmcp_client,
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
params.clone(),
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
bearer_token,
|
||||
resolved_bearer_token.unwrap_or_default(),
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -348,6 +354,33 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bearer_token(
|
||||
server_name: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let Some(env_var) = bearer_token_env_var else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match env::var(env_var) {
|
||||
Ok(value) => {
|
||||
if value.is_empty() {
|
||||
Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is empty"
|
||||
))
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
Err(env::VarError::NotPresent) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is not set"
|
||||
)),
|
||||
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
|
||||
@@ -35,12 +35,19 @@ pub struct ModelFamily {
|
||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||
pub uses_local_shell_tool: bool,
|
||||
|
||||
/// Whether this model supports parallel tool calls when using the
|
||||
/// Responses API.
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
|
||||
// Instructions to use for querying the model
|
||||
pub base_instructions: String,
|
||||
|
||||
/// Names of beta tools that should be exposed to this model family.
|
||||
pub experimental_supported_tools: Vec<String>,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
@@ -55,8 +62,10 @@ macro_rules! model_family {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
@@ -68,7 +77,11 @@ macro_rules! model_family {
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
pub fn find_family_for_model(mut slug: &str) -> Option<ModelFamily> {
|
||||
// TODO(jif) clean once we have proper feature flags
|
||||
if matches!(std::env::var("CODEX_EXPERIMENTAL").as_deref(), Ok("1")) {
|
||||
slug = "codex-experimental";
|
||||
}
|
||||
if slug.starts_with("o3") {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
@@ -99,12 +112,45 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||
} else if slug.starts_with("test-gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec![
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"read_file".to_string(),
|
||||
"test_sync_tool".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Internal models.
|
||||
} else if slug.starts_with("codex-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
experimental_supported_tools: vec![
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"read_file".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Production models.
|
||||
} else if slug.starts_with("gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
@@ -125,7 +171,9 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::RolloutRecorder;
|
||||
use crate::exec_command::ExecSessionManager;
|
||||
use crate::executor::Executor;
|
||||
use crate::file_change_notifier::FileChangeCollector;
|
||||
use crate::file_change_notifier::FileChangeWatcher;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
@@ -15,4 +17,6 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) user_shell: crate::shell::Shell,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
pub(crate) executor: Executor,
|
||||
pub(crate) file_change_collector: Option<FileChangeCollector>,
|
||||
pub(crate) _file_change_watcher: Option<FileChangeWatcher>,
|
||||
}
|
||||
|
||||
@@ -64,5 +64,14 @@ impl SessionState {
|
||||
(self.token_info.clone(), self.latest_rate_limits.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: u64) {
|
||||
match &mut self.token_info {
|
||||
Some(info) => info.fill_to_context_window(context_window),
|
||||
None => {
|
||||
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pending input/approval moved to TurnState.
|
||||
}
|
||||
|
||||
@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct ToolInvocation<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub tracker: &'a mut TurnDiffTracker,
|
||||
pub sub_id: &'a str,
|
||||
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -106,7 +104,7 @@ pub enum ApplyPatchToolType {
|
||||
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
|
||||
ToolSpec::Freeform(FreeformTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files. This is a FREEFORM tool, so do not wrap the patch in JSON.".to_string(),
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
|
||||
@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
tool_name,
|
||||
|
||||
272
codex-rs/core/src/tools/handlers/grep_files.rs
Normal file
272
codex-rs/core/src/tools/handlers/grep_files.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct GrepFilesHandler;
|
||||
|
||||
const DEFAULT_LIMIT: usize = 100;
|
||||
const MAX_LIMIT: usize = 2000;
|
||||
const COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
fn default_limit() -> usize {
|
||||
DEFAULT_LIMIT
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GrepFilesArgs {
|
||||
pattern: String,
|
||||
#[serde(default)]
|
||||
include: Option<String>,
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for GrepFilesHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, turn, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"grep_files handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let pattern = args.pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"pattern must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let limit = args.limit.min(MAX_LIMIT);
|
||||
let search_path = turn.resolve_path(args.path.clone());
|
||||
|
||||
verify_path_exists(&search_path).await?;
|
||||
|
||||
let include = args.include.as_deref().map(str::trim).and_then(|val| {
|
||||
if val.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(val.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
let search_results =
|
||||
run_rg_search(pattern, include.as_deref(), &search_path, limit, &turn.cwd).await?;
|
||||
|
||||
if search_results.is_empty() {
|
||||
Ok(ToolOutput::Function {
|
||||
content: "No matches found.".to_string(),
|
||||
success: Some(false),
|
||||
})
|
||||
} else {
|
||||
Ok(ToolOutput::Function {
|
||||
content: search_results.join("\n"),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_path_exists(path: &Path) -> Result<(), FunctionCallError> {
|
||||
tokio::fs::metadata(path).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("unable to access `{}`: {err}", path.display()))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_rg_search(
|
||||
pattern: &str,
|
||||
include: Option<&str>,
|
||||
search_path: &Path,
|
||||
limit: usize,
|
||||
cwd: &Path,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut command = Command::new("rg");
|
||||
command
|
||||
.current_dir(cwd)
|
||||
.arg("--files-with-matches")
|
||||
.arg("--sortr=modified")
|
||||
.arg("--regexp")
|
||||
.arg(pattern)
|
||||
.arg("--no-messages");
|
||||
|
||||
if let Some(glob) = include {
|
||||
command.arg("--glob").arg(glob);
|
||||
}
|
||||
|
||||
command.arg("--").arg(search_path);
|
||||
|
||||
let output = timeout(COMMAND_TIMEOUT, command.output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("rg timed out after 30 seconds".to_string())
|
||||
})?
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to launch rg: {err}. Ensure ripgrep is installed and on PATH."
|
||||
))
|
||||
})?;
|
||||
|
||||
match output.status.code() {
|
||||
Some(0) => Ok(parse_results(&output.stdout, limit)),
|
||||
Some(1) => Ok(Vec::new()),
|
||||
_ => {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(FunctionCallError::RespondToModel(format!(
|
||||
"rg failed: {stderr}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_results(stdout: &[u8], limit: usize) -> Vec<String> {
|
||||
let mut results = Vec::new();
|
||||
for line in stdout.split(|byte| *byte == b'\n') {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(text) = std::str::from_utf8(line) {
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
results.push(text.to_string());
|
||||
if results.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command as StdCommand;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn parses_basic_results() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n";
|
||||
let parsed = parse_results(stdout, 10);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_truncates_after_limit() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n";
|
||||
let parsed = parse_results(stdout, 2);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_returns_results() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
std::fs::write(dir.join("other.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results.iter().any(|path| path.ends_with("match_one.txt")));
|
||||
assert!(results.iter().any(|path| path.ends_with("match_two.txt")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_with_glob_filter() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results.iter().all(|path| path.ends_with("match_one.rs")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_respects_limit() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "alpha one").unwrap();
|
||||
std::fs::write(dir.join("two.txt"), "alpha two").unwrap();
|
||||
std::fs::write(dir.join("three.txt"), "alpha three").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 2, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_handles_no_matches() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 5, dir).await?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rg_available() -> bool {
|
||||
StdCommand::new("rg")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
476
codex-rs/core/src/tools/handlers/list_dir.rs
Normal file
476
codex-rs/core/src/tools/handlers/list_dir.rs
Normal file
@@ -0,0 +1,476 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::FileType;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ListDirHandler;
|
||||
|
||||
const MAX_ENTRY_LENGTH: usize = 500;
|
||||
const INDENTATION_SPACES: usize = 2;
|
||||
|
||||
fn default_offset() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
25
|
||||
}
|
||||
|
||||
fn default_depth() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListDirArgs {
|
||||
dir_path: String,
|
||||
#[serde(default = "default_offset")]
|
||||
offset: usize,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
#[serde(default = "default_depth")]
|
||||
depth: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ListDirHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"list_dir handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ListDirArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let ListDirArgs {
|
||||
dir_path,
|
||||
offset,
|
||||
limit,
|
||||
depth,
|
||||
} = args;
|
||||
|
||||
if offset == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset must be a 1-indexed entry number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if depth == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"depth must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(&dir_path);
|
||||
if !path.is_absolute() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"dir_path must be an absolute path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(&path, offset, limit, depth).await?;
|
||||
let mut output = Vec::with_capacity(entries.len() + 1);
|
||||
output.push(format!("Absolute path: {}", path.display()));
|
||||
output.extend(entries);
|
||||
Ok(ToolOutput::Function {
|
||||
content: output.join("\n"),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_dir_slice(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
depth: usize,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut entries = Vec::new();
|
||||
collect_entries(path, Path::new(""), depth, &mut entries).await?;
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let start_index = offset - 1;
|
||||
if start_index >= entries.len() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset exceeds directory entry count".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let remaining_entries = entries.len() - start_index;
|
||||
let capped_limit = limit.min(remaining_entries);
|
||||
let end_index = start_index + capped_limit;
|
||||
let mut selected_entries = entries[start_index..end_index].to_vec();
|
||||
selected_entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
|
||||
let mut formatted = Vec::with_capacity(selected_entries.len());
|
||||
|
||||
for entry in &selected_entries {
|
||||
formatted.push(format_entry_line(entry));
|
||||
}
|
||||
|
||||
if end_index < entries.len() {
|
||||
formatted.push(format!("More than {capped_limit} entries found"));
|
||||
}
|
||||
|
||||
Ok(formatted)
|
||||
}
|
||||
|
||||
async fn collect_entries(
|
||||
dir_path: &Path,
|
||||
relative_prefix: &Path,
|
||||
depth: usize,
|
||||
entries: &mut Vec<DirEntry>,
|
||||
) -> Result<(), FunctionCallError> {
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back((dir_path.to_path_buf(), relative_prefix.to_path_buf(), depth));
|
||||
|
||||
while let Some((current_dir, prefix, remaining_depth)) = queue.pop_front() {
|
||||
let mut read_dir = fs::read_dir(¤t_dir).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})?;
|
||||
|
||||
let mut dir_entries = Vec::new();
|
||||
|
||||
while let Some(entry) = read_dir.next_entry().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})? {
|
||||
let file_type = entry.file_type().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to inspect entry: {err}"))
|
||||
})?;
|
||||
|
||||
let file_name = entry.file_name();
|
||||
let relative_path = if prefix.as_os_str().is_empty() {
|
||||
PathBuf::from(&file_name)
|
||||
} else {
|
||||
prefix.join(&file_name)
|
||||
};
|
||||
|
||||
let display_name = format_entry_component(&file_name);
|
||||
let display_depth = prefix.components().count();
|
||||
let sort_key = format_entry_name(&relative_path);
|
||||
let kind = DirEntryKind::from(&file_type);
|
||||
dir_entries.push((
|
||||
entry.path(),
|
||||
relative_path,
|
||||
kind,
|
||||
DirEntry {
|
||||
name: sort_key,
|
||||
display_name,
|
||||
depth: display_depth,
|
||||
kind,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
dir_entries.sort_unstable_by(|a, b| a.3.name.cmp(&b.3.name));
|
||||
|
||||
for (entry_path, relative_path, kind, dir_entry) in dir_entries {
|
||||
if kind == DirEntryKind::Directory && remaining_depth > 1 {
|
||||
queue.push_back((entry_path, relative_path, remaining_depth - 1));
|
||||
}
|
||||
entries.push(dir_entry);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_entry_name(path: &Path) -> String {
|
||||
let normalized = path.to_string_lossy().replace("\\", "/");
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_component(name: &OsStr) -> String {
|
||||
let normalized = name.to_string_lossy();
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_line(entry: &DirEntry) -> String {
|
||||
let indent = " ".repeat(entry.depth * INDENTATION_SPACES);
|
||||
let mut name = entry.display_name.clone();
|
||||
match entry.kind {
|
||||
DirEntryKind::Directory => name.push('/'),
|
||||
DirEntryKind::Symlink => name.push('@'),
|
||||
DirEntryKind::Other => name.push('?'),
|
||||
DirEntryKind::File => {}
|
||||
}
|
||||
format!("{indent}{name}")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirEntry {
|
||||
name: String,
|
||||
display_name: String,
|
||||
depth: usize,
|
||||
kind: DirEntryKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum DirEntryKind {
|
||||
Directory,
|
||||
File,
|
||||
Symlink,
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<&FileType> for DirEntryKind {
|
||||
fn from(file_type: &FileType) -> Self {
|
||||
if file_type.is_symlink() {
|
||||
DirEntryKind::Symlink
|
||||
} else if file_type.is_dir() {
|
||||
DirEntryKind::Directory
|
||||
} else if file_type.is_file() {
|
||||
DirEntryKind::File
|
||||
} else {
|
||||
DirEntryKind::Other
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn lists_directory_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
let sub_dir = dir_path.join("nested");
|
||||
tokio::fs::create_dir(&sub_dir)
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let deeper_dir = sub_dir.join("deeper");
|
||||
tokio::fs::create_dir(&deeper_dir)
|
||||
.await
|
||||
.expect("create deeper dir");
|
||||
|
||||
tokio::fs::write(dir_path.join("entry.txt"), b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
tokio::fs::write(sub_dir.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write child");
|
||||
tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild")
|
||||
.await
|
||||
.expect("write grandchild");
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::symlink;
|
||||
let link_path = dir_path.join("link");
|
||||
symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 20, 3)
|
||||
.await
|
||||
.expect("list directory");
|
||||
|
||||
#[cfg(unix)]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"link@".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(entries, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_offset_exceeds_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::create_dir(dir_path.join("nested"))
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let err = list_dir_slice(dir_path, 10, 1, 2)
|
||||
.await
|
||||
.expect_err("offset exceeds entries");
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_depth_parameter() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await.expect("create nested");
|
||||
tokio::fs::create_dir(&deeper).await.expect("create deeper");
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root")
|
||||
.await
|
||||
.expect("write root");
|
||||
tokio::fs::write(nested.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write nested");
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep")
|
||||
.await
|
||||
.expect("write deeper");
|
||||
|
||||
let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1)
|
||||
.await
|
||||
.expect("list depth 1");
|
||||
assert_eq!(
|
||||
entries_depth_one,
|
||||
vec!["nested/".to_string(), "root.txt".to_string(),]
|
||||
);
|
||||
|
||||
let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2)
|
||||
.await
|
||||
.expect("list depth 2");
|
||||
assert_eq!(
|
||||
entries_depth_two,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3)
|
||||
.await
|
||||
.expect("list depth 3");
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handles_large_limit_without_overflow() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::write(dir_path.join("alpha.txt"), b"alpha")
|
||||
.await
|
||||
.expect("write alpha");
|
||||
tokio::fs::write(dir_path.join("beta.txt"), b"beta")
|
||||
.await
|
||||
.expect("write beta");
|
||||
tokio::fs::write(dir_path.join("gamma.txt"), b"gamma")
|
||||
.await
|
||||
.expect("write gamma");
|
||||
|
||||
let entries = list_dir_slice(dir_path, 2, usize::MAX, 1)
|
||||
.await
|
||||
.expect("list without overflow");
|
||||
assert_eq!(
|
||||
entries,
|
||||
vec!["beta.txt".to_string(), "gamma.txt".to_string(),]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indicates_truncated_results() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
for idx in 0..40 {
|
||||
let file = dir_path.join(format!("file_{idx:02}.txt"));
|
||||
tokio::fs::write(file, b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 25, 1)
|
||||
.await
|
||||
.expect("list directory");
|
||||
assert_eq!(entries.len(), 26);
|
||||
assert_eq!(
|
||||
entries.last(),
|
||||
Some(&"More than 25 entries found".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bfs_truncation() -> anyhow::Result<()> {
|
||||
let temp = tempdir()?;
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await?;
|
||||
tokio::fs::create_dir(&deeper).await?;
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root").await?;
|
||||
tokio::fs::write(nested.join("child.txt"), b"child").await?;
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?;
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?;
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
"More than 3 entries found".to_string()
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
|
||||
ToolKind::Mcp
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
|
||||
let arguments_str = raw_arguments;
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session,
|
||||
sub_id,
|
||||
session.as_ref(),
|
||||
&sub_id,
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
pub mod apply_patch;
|
||||
mod exec_stream;
|
||||
mod grep_files;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
mod view_image;
|
||||
|
||||
@@ -11,9 +14,12 @@ pub use plan::PLAN_TOOL;
|
||||
|
||||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use exec_stream::ExecStreamHandler;
|
||||
pub use grep_files::GrepFilesHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
pub use unified_exec::UnifiedExecHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
|
||||
@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecParams;
|
||||
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
|
||||
})
|
||||
}
|
||||
ToolPayload::LocalShell { params } => {
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct TestSyncHandler;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
|
||||
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||
|
||||
struct BarrierState {
|
||||
barrier: Arc<Barrier>,
|
||||
participants: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarrierArgs {
|
||||
id: String,
|
||||
participants: usize,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestSyncArgs {
|
||||
#[serde(default)]
|
||||
sleep_before_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
sleep_after_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
barrier: Option<BarrierArgs>,
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
DEFAULT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for TestSyncHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"test_sync_tool handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(delay) = args.sleep_before_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
if let Some(barrier) = args.barrier {
|
||||
wait_on_barrier(barrier).await?;
|
||||
}
|
||||
|
||||
if let Some(delay) = args.sleep_after_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||
if args.participants == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier participants must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.timeout_ms == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier timeout must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let barrier_id = args.id.clone();
|
||||
let barrier = {
|
||||
let mut map = barrier_map().lock().await;
|
||||
match map.entry(barrier_id.clone()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let state = entry.get();
|
||||
if state.participants != args.participants {
|
||||
let existing = state.participants;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"barrier {barrier_id} already registered with {existing} participants"
|
||||
)));
|
||||
}
|
||||
state.barrier.clone()
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let barrier = Arc::new(Barrier::new(args.participants));
|
||||
entry.insert(BarrierState {
|
||||
barrier: barrier.clone(),
|
||||
participants: args.participants,
|
||||
});
|
||||
barrier
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(args.timeout_ms);
|
||||
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||
})?;
|
||||
|
||||
if wait_result.is_leader() {
|
||||
let mut map = barrier_map().lock().await;
|
||||
if let Some(state) = map.get(&barrier_id)
|
||||
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||
{
|
||||
map.remove(&barrier_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session, payload, ..
|
||||
} = invocation;
|
||||
|
||||
@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod context;
|
||||
pub(crate) mod handlers;
|
||||
pub mod parallel;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
pub mod spec;
|
||||
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ApplyPatchCommandContext;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
pub use router::ToolRouter;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||
pub(crate) async fn handle_container_exec_with_params(
|
||||
tool_name: &str,
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
||||
match apply_patch::apply_patch(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&sub_id,
|
||||
&call_id,
|
||||
changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => return item,
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||
Some(apply_patch_exec)
|
||||
@@ -139,12 +149,13 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
|
||||
let output_result = sess
|
||||
.run_exec_with_events(
|
||||
turn_diff_tracker,
|
||||
turn_diff_tracker.clone(),
|
||||
prepared_exec,
|
||||
turn_context.approval_policy,
|
||||
)
|
||||
.await;
|
||||
|
||||
// always make sure to truncate the output if its length isn't controlled.
|
||||
match output_result {
|
||||
Ok(output) => {
|
||||
let ExecToolCallOutput { exit_code, .. } = &output;
|
||||
@@ -155,13 +166,16 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
Err(FunctionCallError::RespondToModel(content))
|
||||
}
|
||||
}
|
||||
Err(ExecError::Function(err)) => Err(err),
|
||||
Err(ExecError::Function(err)) => Err(truncate_function_error(err)),
|
||||
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err(
|
||||
FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)),
|
||||
),
|
||||
Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!(
|
||||
"execution error: {err:?}"
|
||||
))),
|
||||
Err(ExecError::Codex(err)) => {
|
||||
let message = format!("execution error: {err:?}");
|
||||
Err(FunctionCallError::RespondToModel(format_exec_output(
|
||||
&message,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,26 +220,42 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
aggregated_output, ..
|
||||
} = exec_output;
|
||||
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
|
||||
let mut s = &aggregated_output.text;
|
||||
let prefixed_str: String;
|
||||
let content = aggregated_output.text.as_str();
|
||||
|
||||
if exec_output.timed_out {
|
||||
prefixed_str = format!(
|
||||
"command timed out after {} milliseconds\n",
|
||||
let prefixed = format!(
|
||||
"command timed out after {} milliseconds\n{content}",
|
||||
exec_output.duration.as_millis()
|
||||
) + s;
|
||||
s = &prefixed_str;
|
||||
);
|
||||
return format_exec_output(&prefixed);
|
||||
}
|
||||
|
||||
let total_lines = s.lines().count();
|
||||
if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return s.to_string();
|
||||
}
|
||||
format_exec_output(content)
|
||||
}
|
||||
|
||||
let segments: Vec<&str> = s.split_inclusive('\n').collect();
|
||||
fn truncate_function_error(err: FunctionCallError) -> FunctionCallError {
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(msg) => {
|
||||
FunctionCallError::RespondToModel(format_exec_output(&msg))
|
||||
}
|
||||
FunctionCallError::Fatal(msg) => FunctionCallError::Fatal(format_exec_output(&msg)),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_exec_output(content: &str) -> String {
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
let total_lines = content.lines().count();
|
||||
if content.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return content.to_string();
|
||||
}
|
||||
let output = truncate_formatted_exec_output(content, total_lines);
|
||||
format!("Total output lines: {total_lines}\n\n{output}")
|
||||
}
|
||||
|
||||
fn truncate_formatted_exec_output(content: &str, total_lines: usize) -> String {
|
||||
let segments: Vec<&str> = content.split_inclusive('\n').collect();
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
|
||||
let omitted = segments.len().saturating_sub(head_take + tail_take);
|
||||
@@ -236,9 +266,9 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
.map(|segment| segment.len())
|
||||
.sum();
|
||||
let tail_slice_start: usize = if tail_take == 0 {
|
||||
s.len()
|
||||
content.len()
|
||||
} else {
|
||||
s.len()
|
||||
content.len()
|
||||
- segments
|
||||
.iter()
|
||||
.rev()
|
||||
@@ -260,9 +290,9 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len());
|
||||
}
|
||||
|
||||
let head_slice = &s[..head_slice_end];
|
||||
let head_slice = &content[..head_slice_end];
|
||||
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len()));
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(content.len()));
|
||||
|
||||
result.push_str(head_part);
|
||||
result.push_str(&marker);
|
||||
@@ -272,9 +302,86 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
return result;
|
||||
}
|
||||
|
||||
let tail_slice = &s[tail_slice_start..];
|
||||
let tail_slice = &content[tail_slice_start..];
|
||||
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
|
||||
result.push_str(tail_part);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use regex_lite::Regex;
|
||||
|
||||
fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) {
|
||||
let pattern = truncated_message_pattern(line, total_lines);
|
||||
let regex = Regex::new(&pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern}: {err}");
|
||||
});
|
||||
let captures = regex
|
||||
.captures(message)
|
||||
.unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {message}"));
|
||||
let body = captures
|
||||
.name("body")
|
||||
.expect("missing body capture")
|
||||
.as_str();
|
||||
assert!(
|
||||
body.len() <= MODEL_FORMAT_MAX_BYTES,
|
||||
"body exceeds byte limit: {} bytes",
|
||||
body.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn truncated_message_pattern(line: &str, total_lines: usize) -> String {
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(total_lines);
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(total_lines.saturating_sub(head_take));
|
||||
let omitted = total_lines.saturating_sub(head_take + tail_take);
|
||||
let escaped_line = regex_lite::escape(line);
|
||||
format!(
|
||||
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_formatted_exec_output_truncates_large_error() {
|
||||
let line = "very long execution error line that should trigger truncation\n";
|
||||
let large_error = line.repeat(2_500); // way beyond both byte and line limits
|
||||
|
||||
let truncated = format_exec_output(&large_error);
|
||||
|
||||
let total_lines = large_error.lines().count();
|
||||
assert_truncated_message_matches(&truncated, line, total_lines);
|
||||
assert_ne!(truncated, large_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_respond_to_model() {
|
||||
let line = "respond-to-model error that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::RespondToModel(huge));
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_fatal() {
|
||||
let line = "fatal error output that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::Fatal(huge));
|
||||
match err {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
80
codex-rs/core/src/tools/parallel.rs
Normal file
80
codex-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
parallel_execution: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
parallel_execution: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
}));
|
||||
|
||||
async move {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
|
||||
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
|
||||
Err(err) => Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to receive: {err:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||||
-> Result<ToolOutput, FunctionCallError>;
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
@@ -57,9 +56,9 @@ impl ToolRegistry {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch<'a>(
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation<'a>,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
@@ -137,9 +136,24 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfiguredToolSpec {
|
||||
pub spec: ToolSpec,
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ConfiguredToolSpec {
|
||||
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.specs.push(spec);
|
||||
self.push_spec_with_parallel_support(spec, false);
|
||||
}
|
||||
|
||||
pub fn push_spec_with_parallel_support(
|
||||
&mut self,
|
||||
spec: ToolSpec,
|
||||
supports_parallel_tool_calls: bool,
|
||||
) {
|
||||
self.specs
|
||||
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||||
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -24,7 +26,7 @@ pub struct ToolCall {
|
||||
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRouter {
|
||||
@@ -34,11 +36,22 @@ impl ToolRouter {
|
||||
) -> Self {
|
||||
let builder = build_specs(config, mcp_tools);
|
||||
let (specs, registry) = builder.build();
|
||||
|
||||
Self { registry, specs }
|
||||
}
|
||||
|
||||
pub fn specs(&self) -> &[ToolSpec] {
|
||||
&self.specs
|
||||
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||
self.specs
|
||||
.iter()
|
||||
.map(|config| config.spec.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
@@ -118,10 +131,10 @@ impl ToolRouter {
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
|
||||
@@ -28,6 +28,7 @@ pub(crate) struct ToolsConfig {
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
pub experimental_supported_tools: Vec<String>,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
@@ -78,6 +79,7 @@ impl ToolsConfig {
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
experimental_supported_tools: model_family.experimental_supported_tools.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,6 +258,118 @@ fn create_view_image_tool() -> ToolSpec {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_test_sync_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"sleep_before_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Optional delay in milliseconds before any other action".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"sleep_after_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Optional delay in milliseconds after completing the barrier".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
let mut barrier_properties = BTreeMap::new();
|
||||
barrier_properties.insert(
|
||||
"id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier shared by concurrent calls that should rendezvous".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"participants".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Number of tool calls that must arrive before the barrier opens".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
properties.insert(
|
||||
"barrier".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: barrier_properties,
|
||||
required: Some(vec!["id".to_string(), "participants".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_sync_tool".to_string(),
|
||||
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_grep_files_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"pattern".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Regular expression pattern to search for.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"include".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional glob that limits which files are searched (e.g. \"*.rs\" or \
|
||||
\"*.{ts,tsx}\")."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"path".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Directory or file path to search. Defaults to the session's working directory."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"limit".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum number of file paths to return (defaults to 100).".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "grep_files".to_string(),
|
||||
description: "Finds files whose contents match the pattern and lists them by modification \
|
||||
time."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["pattern".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_file_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
@@ -278,11 +392,72 @@ fn create_read_file_tool() -> ToolSpec {
|
||||
description: Some("The maximum number of lines to return.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"mode".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional mode selector: \"slice\" for simple ranges (default) or \"indentation\" \
|
||||
to expand around an anchor line."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
let mut indentation_properties = BTreeMap::new();
|
||||
indentation_properties.insert(
|
||||
"anchor_line".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Anchor line to center the indentation lookup on (defaults to offset).".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"max_levels".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"How many parent indentation levels (smaller indents) to include.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"include_siblings".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some(
|
||||
"When true, include additional blocks that share the anchor indentation."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"include_header".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some(
|
||||
"Include doc comments or attributes directly above the selected block.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
indentation_properties.insert(
|
||||
"max_lines".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Hard cap on the number of lines returned when using indentation mode.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"indentation".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: indentation_properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "read_file".to_string(),
|
||||
description:
|
||||
"Reads a local file with 1-indexed line numbers and returns up to the requested number of lines."
|
||||
"Reads a local file with 1-indexed line numbers, supporting slice and indentation-aware block modes."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -292,6 +467,51 @@ fn create_read_file_tool() -> ToolSpec {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_dir_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"dir_path".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Absolute path to the directory to list.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"offset".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"The entry number to start listing from. Must be 1 or greater.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"limit".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The maximum number of entries to return.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"depth".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"The maximum directory depth to traverse. Must be 1 or greater.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_dir".to_string(),
|
||||
description:
|
||||
"Lists entries in a local directory with 1-indexed entry numbers and simple type labels."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["dir_path".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
/// TODO(dylan): deprecate once we get rid of json tool
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
@@ -501,10 +721,13 @@ pub(crate) fn build_specs(
|
||||
use crate::exec_command::create_write_stdin_tool_for_responses_api;
|
||||
use crate::tools::handlers::ApplyPatchHandler;
|
||||
use crate::tools::handlers::ExecStreamHandler;
|
||||
use crate::tools::handlers::GrepFilesHandler;
|
||||
use crate::tools::handlers::ListDirHandler;
|
||||
use crate::tools::handlers::McpHandler;
|
||||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
use crate::tools::handlers::TestSyncHandler;
|
||||
use crate::tools::handlers::UnifiedExecHandler;
|
||||
use crate::tools::handlers::ViewImageHandler;
|
||||
use std::sync::Arc;
|
||||
@@ -515,7 +738,6 @@ pub(crate) fn build_specs(
|
||||
let exec_stream_handler = Arc::new(ExecStreamHandler);
|
||||
let unified_exec_handler = Arc::new(UnifiedExecHandler);
|
||||
let plan_handler = Arc::new(PlanHandler);
|
||||
let read_file_handler = Arc::new(ReadFileHandler);
|
||||
let apply_patch_handler = Arc::new(ApplyPatchHandler);
|
||||
let view_image_handler = Arc::new(ViewImageHandler);
|
||||
let mcp_handler = Arc::new(McpHandler);
|
||||
@@ -566,15 +788,49 @@ pub(crate) fn build_specs(
|
||||
builder.register_handler("apply_patch", apply_patch_handler);
|
||||
}
|
||||
|
||||
builder.push_spec(create_read_file_tool());
|
||||
builder.register_handler("read_file", read_file_handler);
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.contains(&"grep_files".to_string())
|
||||
{
|
||||
let grep_files_handler = Arc::new(GrepFilesHandler);
|
||||
builder.push_spec_with_parallel_support(create_grep_files_tool(), true);
|
||||
builder.register_handler("grep_files", grep_files_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.contains(&"read_file".to_string())
|
||||
{
|
||||
let read_file_handler = Arc::new(ReadFileHandler);
|
||||
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
|
||||
builder.register_handler("read_file", read_file_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "list_dir")
|
||||
{
|
||||
let list_dir_handler = Arc::new(ListDirHandler);
|
||||
builder.push_spec_with_parallel_support(create_list_dir_tool(), true);
|
||||
builder.register_handler("list_dir", list_dir_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.contains(&"test_sync_tool".to_string())
|
||||
{
|
||||
let test_sync_handler = Arc::new(TestSyncHandler);
|
||||
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
|
||||
builder.register_handler("test_sync_tool", test_sync_handler);
|
||||
}
|
||||
|
||||
if config.web_search_request {
|
||||
builder.push_spec(ToolSpec::WebSearch {});
|
||||
}
|
||||
|
||||
if config.include_view_image_tool {
|
||||
builder.push_spec(create_view_image_tool());
|
||||
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
|
||||
builder.register_handler("view_image", view_image_handler);
|
||||
}
|
||||
|
||||
@@ -602,20 +858,25 @@ pub(crate) fn build_specs(
|
||||
mod tests {
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use mcp_types::ToolInputSchema;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
|
||||
fn tool_name(tool: &ToolSpec) -> &str {
|
||||
match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
|
||||
let tool_names = tools
|
||||
.iter()
|
||||
.map(|tool| match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
})
|
||||
.map(|tool| tool_name(&tool.spec))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
@@ -631,6 +892,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn find_tool<'a>(
|
||||
tools: &'a [ConfiguredToolSpec],
|
||||
expected_name: &str,
|
||||
) -> &'a ConfiguredToolSpec {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool_name(&tool.spec) == expected_name)
|
||||
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs() {
|
||||
let model_family = find_family_for_model("codex-mini-latest")
|
||||
@@ -648,13 +919,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"update_plan",
|
||||
"read_file",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -674,16 +939,65 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"update_plan",
|
||||
"read_file",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_parallel_support_flags() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_test_model_family_includes_sync_tool() {
|
||||
let model_family = find_family_for_model("test-gpt-5-codex")
|
||||
.expect("test-gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: false,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
|
||||
);
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "read_file")
|
||||
);
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "grep_files")
|
||||
);
|
||||
assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs_mcp_tools() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
@@ -739,7 +1053,6 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -747,7 +1060,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[3].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -858,7 +1171,6 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -869,7 +1181,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_property_missing_type_defaults_to_string() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
@@ -908,7 +1221,7 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/search",
|
||||
@@ -916,7 +1229,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/search".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -937,7 +1250,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_integer_normalized_to_number() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
@@ -974,14 +1288,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/paginate",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/paginate".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1000,11 +1314,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_array_without_items_gets_default_string_items() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_apply_patch_tool: true,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
@@ -1037,14 +1352,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/tags",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/tags".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1066,7 +1381,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_anyof_defaults_to_string() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
@@ -1103,14 +1419,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/value",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/value".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1144,7 +1460,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
@@ -1206,7 +1523,7 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -1214,7 +1531,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
|
||||
@@ -110,11 +110,22 @@ impl ManagedUnifiedExecSession {
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(chunk) => {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
// If we lag behind the broadcast buffer, skip missed
|
||||
// messages but keep the task alive to continue streaming.
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
||||
continue;
|
||||
}
|
||||
// When the sender closes, exit the task.
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use assert_matches::assert_matches;
|
||||
use std::sync::Arc;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
@@ -178,7 +179,7 @@ async fn streams_text_without_reasoning() {
|
||||
other => panic!("expected terminal message, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -219,7 +220,7 @@ async fn streams_reasoning_from_string_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[4], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[4], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -266,7 +267,7 @@ async fn streams_reasoning_from_object_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[5], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[5], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -293,7 +294,7 @@ async fn streams_reasoning_from_final_message() {
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -337,7 +338,7 @@ async fn streams_reasoning_before_tool_call() {
|
||||
other => panic!("expected function call, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[3], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[3], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -10,6 +10,7 @@ path = "lib.rs"
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
|
||||
@@ -6,6 +6,7 @@ use codex_core::CodexConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use regex_lite::Regex;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
@@ -14,6 +15,16 @@ pub mod responses;
|
||||
pub mod test_codex;
|
||||
pub mod test_codex_exec;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> {
|
||||
let regex = Regex::new(pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern:?}: {err}");
|
||||
});
|
||||
regex
|
||||
.captures(actual)
|
||||
.unwrap_or_else(|| panic!("regex {pattern:?} did not match {actual:?}"))
|
||||
}
|
||||
|
||||
/// Returns a default `Config` whose on-disk state is confined to the provided
|
||||
/// temporary directory. Using a per-test directory keeps tests hermetic and
|
||||
/// avoids clobbering a developer’s real `~/.codex`.
|
||||
|
||||
@@ -1,11 +1,105 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use serde_json::Value;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockBuilder;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseMock {
|
||||
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
|
||||
}
|
||||
|
||||
impl ResponseMock {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn single_request(&self) -> ResponsesRequest {
|
||||
let requests = self.requests.lock().unwrap();
|
||||
if requests.len() != 1 {
|
||||
panic!("expected 1 request, got {}", requests.len());
|
||||
}
|
||||
requests.first().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponsesRequest(wiremock::Request);
|
||||
|
||||
impl ResponsesRequest {
|
||||
pub fn body_json(&self) -> Value {
|
||||
self.0.body_json().unwrap()
|
||||
}
|
||||
|
||||
pub fn input(&self) -> Vec<Value> {
|
||||
self.0.body_json::<Value>().unwrap()["input"]
|
||||
.as_array()
|
||||
.expect("input array not found in request")
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn function_call_output(&self, call_id: &str) -> Value {
|
||||
self.call_output(call_id, "function_call_output")
|
||||
}
|
||||
|
||||
pub fn custom_tool_call_output(&self, call_id: &str) -> Value {
|
||||
self.call_output(call_id, "custom_tool_call_output")
|
||||
}
|
||||
|
||||
pub fn call_output(&self, call_id: &str, call_type: &str) -> Value {
|
||||
self.input()
|
||||
.iter()
|
||||
.find(|item| {
|
||||
item.get("type").unwrap() == call_type && item.get("call_id").unwrap() == call_id
|
||||
})
|
||||
.cloned()
|
||||
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
||||
}
|
||||
|
||||
pub fn header(&self, name: &str) -> Option<String> {
|
||||
self.0
|
||||
.headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
pub fn path(&self) -> String {
|
||||
self.0.url.path().to_string()
|
||||
}
|
||||
|
||||
pub fn query_param(&self, name: &str) -> Option<String> {
|
||||
self.0
|
||||
.url
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == name)
|
||||
.map(|(_, v)| v.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Match for ResponseMock {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(ResponsesRequest(request.clone()));
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an SSE stream body from a list of JSON events.
|
||||
pub fn sse(events: Vec<Value>) -> String {
|
||||
@@ -34,6 +128,16 @@ pub fn ev_completed(id: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a created response with a specific id.
|
||||
pub fn ev_response_created(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": id,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ev_completed_with_tokens(id: &str, total_tokens: u64) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
@@ -135,40 +239,56 @@ pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sse_failed(id: &str, code: &str, message: &str) -> String {
|
||||
sse(vec![serde_json::json!({
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"error": {"code": code, "message": message}
|
||||
}
|
||||
})])
|
||||
}
|
||||
|
||||
pub fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String)
|
||||
fn base_mock() -> (MockBuilder, ResponseMock) {
|
||||
let response_mock = ResponseMock::new();
|
||||
let mock = Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.and(response_mock.clone());
|
||||
(mock, response_mock)
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
|
||||
where
|
||||
M: wiremock::Match + Send + Sync + 'static,
|
||||
{
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(matcher)
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.and(matcher)
|
||||
.respond_with(sse_response(body))
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once(server: &MockServer, body: String) {
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(sse_response(body))
|
||||
.expect(1)
|
||||
pub async fn mount_sse_once(server: &MockServer, body: String) -> ResponseMock {
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(sse_response(body))
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn mount_sse(server: &MockServer, body: String) {
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(sse_response(body))
|
||||
.mount(server)
|
||||
.await;
|
||||
pub async fn mount_sse(server: &MockServer, body: String) -> ResponseMock {
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(sse_response(body)).mount(server).await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn start_mock_server() -> MockServer {
|
||||
@@ -181,7 +301,7 @@ pub async fn start_mock_server() -> MockServer {
|
||||
/// Mounts a sequence of SSE response bodies and serves them in order for each
|
||||
/// POST to `/v1/responses`. Panics if more requests are received than bodies
|
||||
/// provided. Also asserts the exact number of expected calls.
|
||||
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
|
||||
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> ResponseMock {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -208,10 +328,11 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
|
||||
responses: bodies,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder)
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(server)
|
||||
.await;
|
||||
|
||||
response_mock
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use tempfile::TempDir;
|
||||
|
||||
use crate::load_default_config_for_test;
|
||||
|
||||
type ConfigMutator = dyn FnOnce(&mut Config);
|
||||
type ConfigMutator = dyn FnOnce(&mut Config) + Send;
|
||||
|
||||
pub struct TestCodexBuilder {
|
||||
config_mutators: Vec<Box<ConfigMutator>>,
|
||||
@@ -22,7 +22,7 @@ pub struct TestCodexBuilder {
|
||||
impl TestCodexBuilder {
|
||||
pub fn with_config<T>(mut self, mutator: T) -> Self
|
||||
where
|
||||
T: FnOnce(&mut Config) + 'static,
|
||||
T: FnOnce(&mut Config) + Send + 'static,
|
||||
{
|
||||
self.config_mutators.push(Box::new(mutator));
|
||||
self
|
||||
|
||||
@@ -3,14 +3,14 @@ use std::time::Duration;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
|
||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||
/// function call, then interrupt the session and expect TurnAborted.
|
||||
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
||||
let body = sse(vec![
|
||||
ev_function_call("call_sleep", "shell", &args),
|
||||
ev_completed("done"),
|
||||
]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await;
|
||||
mount_sse_once(&server, body).await;
|
||||
|
||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||
|
||||
|
||||
@@ -106,16 +106,12 @@ async fn exec_cli_applies_experimental_instructions_file() {
|
||||
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"r1\"}}\n\n"
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream"),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock = core_test_support::responses::mount_sse_once_match(
|
||||
&server,
|
||||
path("/v1/responses"),
|
||||
sse.to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create a temporary instructions file with a unique marker we can assert
|
||||
// appears in the outbound request payload.
|
||||
@@ -164,8 +160,8 @@ async fn exec_cli_applies_experimental_instructions_file() {
|
||||
|
||||
// Inspect the captured request and verify our custom base instructions were
|
||||
// included in the `instructions` field.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let body = request.body_json();
|
||||
let instructions = body
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
|
||||
@@ -14,6 +14,8 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -26,8 +28,10 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
@@ -37,6 +41,7 @@ use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
@@ -218,15 +223,9 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
|
||||
// Mock server that will receive the resumed request
|
||||
let server = MockServer::start().await;
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
// Configure Codex to resume from our file
|
||||
let model_provider = ModelProviderInfo {
|
||||
@@ -272,8 +271,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
@@ -367,18 +366,9 @@ async fn includes_base_instructions_override_in_request() {
|
||||
skip_if_no_network!();
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -409,8 +399,8 @@ async fn includes_base_instructions_override_in_request() {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
|
||||
assert!(
|
||||
request_body["instructions"]
|
||||
@@ -565,16 +555,9 @@ async fn includes_user_instructions_message_in_request() {
|
||||
skip_if_no_network!();
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -605,8 +588,8 @@ async fn includes_user_instructions_message_in_request() {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
|
||||
assert!(
|
||||
!request_body["instructions"]
|
||||
@@ -996,6 +979,100 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
let server = MockServer::start().await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("trigger context window"),
|
||||
responses::sse_failed(
|
||||
"resp_context_window",
|
||||
"context_length_exceeded",
|
||||
"Your input exceeds the context window of this model. Please adjust your input and try again.",
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("seed turn"),
|
||||
sse_completed("resp_seed"),
|
||||
)
|
||||
.await;
|
||||
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
.with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("known gpt-5 model family");
|
||||
config.model_context_window = Some(272_000);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "seed turn".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "trigger context window".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
let token_event = wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| {
|
||||
matches!(
|
||||
event,
|
||||
EventMsg::TokenCount(payload)
|
||||
if payload.info.as_ref().is_some_and(|info| {
|
||||
info.model_context_window == Some(info.total_token_usage.total_tokens)
|
||||
&& info.total_token_usage.total_tokens > 0
|
||||
})
|
||||
)
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::TokenCount(token_payload) = token_event else {
|
||||
unreachable!("wait_for_event_with_timeout returned unexpected event");
|
||||
};
|
||||
|
||||
let info = token_payload
|
||||
.info
|
||||
.expect("token usage info present when context window is exceeded");
|
||||
|
||||
assert_eq!(info.model_context_window, Some(272_000));
|
||||
assert_eq!(info.total_token_usage.total_tokens, 272_000);
|
||||
|
||||
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
|
||||
let expected_context_window_message = CodexErr::ContextWindowExceeded.to_string();
|
||||
assert!(
|
||||
matches!(
|
||||
error_event,
|
||||
EventMsg::Error(ref err) if err.message == expected_context_window_message
|
||||
),
|
||||
"expected context window error; got {error_event:?}"
|
||||
);
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
skip_if_no_network!();
|
||||
|
||||
@@ -13,12 +13,6 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
@@ -26,14 +20,11 @@ use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::sse_failed;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
pub(super) const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
@@ -48,6 +39,8 @@ const SECOND_LARGE_REPLY: &str = "SECOND_LARGE_REPLY";
|
||||
const FIRST_AUTO_SUMMARY: &str = "FIRST_AUTO_SUMMARY";
|
||||
const SECOND_AUTO_SUMMARY: &str = "SECOND_AUTO_SUMMARY";
|
||||
const FINAL_REPLY: &str = "FINAL_REPLY";
|
||||
const CONTEXT_LIMIT_MESSAGE: &str =
|
||||
"Your input exceeds the context window of this model. Please adjust your input and try again.";
|
||||
const DUMMY_FUNCTION_NAME: &str = "unsupported_tool";
|
||||
const DUMMY_CALL_ID: &str = "call-multi-auto";
|
||||
|
||||
@@ -295,12 +288,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -308,23 +296,13 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -455,12 +433,7 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -468,23 +441,13 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -582,35 +545,20 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1.clone()).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2.clone()).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
!body.contains("You have exceeded the maximum number of tokens")
|
||||
&& body.contains(SUMMARY_TEXT)
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3.clone()).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -677,6 +625,130 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn manual_compact_retries_after_context_window_error() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let user_turn = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed("r1"),
|
||||
]);
|
||||
let compact_failed = sse_failed(
|
||||
"resp-fail",
|
||||
"context_length_exceeded",
|
||||
CONTEXT_LIMIT_MESSAGE,
|
||||
);
|
||||
let compact_succeeds = sse(vec![
|
||||
ev_assistant_message("m2", SUMMARY_TEXT),
|
||||
ev_completed("r2"),
|
||||
]);
|
||||
|
||||
let request_log = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
user_turn.clone(),
|
||||
compact_failed.clone(),
|
||||
compact_succeeds.clone(),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200_000);
|
||||
let codex = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"))
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "first turn".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex.submit(Op::Compact).await.unwrap();
|
||||
|
||||
let EventMsg::BackgroundEvent(event) =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::BackgroundEvent(_))).await
|
||||
else {
|
||||
panic!("expected background event after compact retry");
|
||||
};
|
||||
assert!(
|
||||
event.message.contains("Trimmed 1 older conversation item"),
|
||||
"background event should mention trimmed item count: {}",
|
||||
event.message
|
||||
);
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = request_log.requests();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"expected user turn and two compact attempts"
|
||||
);
|
||||
|
||||
let compact_attempt = requests[1].body_json();
|
||||
let retry_attempt = requests[2].body_json();
|
||||
|
||||
let compact_input = compact_attempt["input"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("compact attempt missing input array: {compact_attempt}"));
|
||||
let retry_input = retry_attempt["input"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("retry attempt missing input array: {retry_attempt}"));
|
||||
assert_eq!(
|
||||
compact_input
|
||||
.last()
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(|text| text.as_str()),
|
||||
Some(SUMMARIZATION_PROMPT),
|
||||
"compact attempt should include summarization prompt"
|
||||
);
|
||||
assert_eq!(
|
||||
retry_input
|
||||
.last()
|
||||
.and_then(|item| item.get("content"))
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(|text| text.as_str()),
|
||||
Some(SUMMARIZATION_PROMPT),
|
||||
"retry attempt should include summarization prompt"
|
||||
);
|
||||
assert_eq!(
|
||||
retry_input.len(),
|
||||
compact_input.len().saturating_sub(1),
|
||||
"retry should drop exactly one history item (before {} vs after {})",
|
||||
compact_input.len(),
|
||||
retry_input.len()
|
||||
);
|
||||
if let (Some(first_before), Some(first_after)) = (compact_input.first(), retry_input.first()) {
|
||||
assert_ne!(
|
||||
first_before, first_after,
|
||||
"retry should drop the oldest conversation item"
|
||||
);
|
||||
} else {
|
||||
panic!("expected non-empty compact inputs");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_events() {
|
||||
skip_if_no_network!();
|
||||
@@ -708,49 +780,7 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
ev_completed_with_tokens("r6", 120),
|
||||
]);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SeqResponder {
|
||||
bodies: Arc<Vec<String>>,
|
||||
calls: Arc<AtomicUsize>,
|
||||
requests: Arc<Mutex<Vec<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl SeqResponder {
|
||||
fn new(bodies: Vec<String>) -> Self {
|
||||
Self {
|
||||
bodies: Arc::new(bodies),
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn recorded_requests(&self) -> Vec<Vec<u8>> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let idx = self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.requests.lock().unwrap().push(req.body.clone());
|
||||
let body = self
|
||||
.bodies
|
||||
.get(idx)
|
||||
.unwrap_or_else(|| panic!("unexpected request index {idx}"))
|
||||
.clone();
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
let responder = SeqResponder::new(vec![sse1, sse2, sse3, sse4, sse5, sse6]);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder.clone())
|
||||
.expect(6)
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_sequence(&server, vec![sse1, sse2, sse3, sse4, sse5, sse6]).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -801,10 +831,12 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
"auto compact should not emit task lifecycle events"
|
||||
);
|
||||
|
||||
let request_bodies: Vec<String> = responder
|
||||
.recorded_requests()
|
||||
let request_bodies: Vec<String> = server
|
||||
.received_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|body| String::from_utf8(body).unwrap_or_default())
|
||||
.map(|request| String::from_utf8(request.body).unwrap_or_default())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
request_bodies.len(),
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -131,9 +132,10 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let expected_model = OPENAI_DEFAULT_MODEL;
|
||||
let user_turn_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -182,7 +184,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let compact_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -251,7 +253,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let user_turn_2_after_compact = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -316,7 +318,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let usert_turn_3_after_resume = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -401,7 +403,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let user_turn_3_after_fork = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
|
||||
237
codex-rs/core/tests/suite/grep_files.rs
Normal file
237
codex-rs/core/tests/suite/grep_files.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::process::Command as StdCommand;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
const MODEL_WITH_TOOL: &str = "test-gpt-5-codex";
|
||||
|
||||
fn ripgrep_available() -> bool {
|
||||
StdCommand::new("rg")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
macro_rules! skip_if_ripgrep_missing {
|
||||
($ret:expr $(,)?) => {{
|
||||
if !ripgrep_available() {
|
||||
eprintln!("rg not available in PATH; skipping test");
|
||||
return $ret;
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn grep_files_tool_collects_matches() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_ripgrep_missing!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_test_codex(&server).await?;
|
||||
|
||||
let search_dir = test.cwd.path().join("src");
|
||||
std::fs::create_dir_all(&search_dir)?;
|
||||
let alpha = search_dir.join("alpha.rs");
|
||||
let beta = search_dir.join("beta.rs");
|
||||
let gamma = search_dir.join("gamma.txt");
|
||||
std::fs::write(&alpha, "alpha needle\n")?;
|
||||
std::fs::write(&beta, "beta needle\n")?;
|
||||
std::fs::write(&gamma, "needle in text but excluded\n")?;
|
||||
|
||||
let call_id = "grep-files-collect";
|
||||
let arguments = serde_json::json!({
|
||||
"pattern": "needle",
|
||||
"path": search_dir.to_string_lossy(),
|
||||
"include": "*.rs",
|
||||
})
|
||||
.to_string();
|
||||
|
||||
mount_tool_sequence(&server, call_id, &arguments, "grep_files").await;
|
||||
submit_turn(&test, "please find uses of needle").await?;
|
||||
|
||||
let bodies = recorded_bodies(&server).await?;
|
||||
let tool_output = find_tool_output(&bodies, call_id).expect("tool output present");
|
||||
let payload = tool_output.get("output").expect("output field present");
|
||||
let (content_opt, success_opt) = extract_content_and_success(payload);
|
||||
let content = content_opt.expect("content present");
|
||||
let success = success_opt.unwrap_or(true);
|
||||
assert!(success, "expected success for matches, got {payload:?}");
|
||||
|
||||
let entries = collect_file_names(content);
|
||||
assert_eq!(entries.len(), 2, "content: {content}");
|
||||
assert!(
|
||||
entries.contains("alpha.rs"),
|
||||
"missing alpha.rs in {entries:?}"
|
||||
);
|
||||
assert!(
|
||||
entries.contains("beta.rs"),
|
||||
"missing beta.rs in {entries:?}"
|
||||
);
|
||||
assert!(
|
||||
!entries.contains("gamma.txt"),
|
||||
"txt file should be filtered out: {entries:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn grep_files_tool_reports_empty_results() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_ripgrep_missing!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_test_codex(&server).await?;
|
||||
|
||||
let search_dir = test.cwd.path().join("logs");
|
||||
std::fs::create_dir_all(&search_dir)?;
|
||||
std::fs::write(search_dir.join("output.txt"), "no hits here")?;
|
||||
|
||||
let call_id = "grep-files-empty";
|
||||
let arguments = serde_json::json!({
|
||||
"pattern": "needle",
|
||||
"path": search_dir.to_string_lossy(),
|
||||
"limit": 5,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
mount_tool_sequence(&server, call_id, &arguments, "grep_files").await;
|
||||
submit_turn(&test, "search again").await?;
|
||||
|
||||
let bodies = recorded_bodies(&server).await?;
|
||||
let tool_output = find_tool_output(&bodies, call_id).expect("tool output present");
|
||||
let payload = tool_output.get("output").expect("output field present");
|
||||
let (content_opt, success_opt) = extract_content_and_success(payload);
|
||||
let content = content_opt.expect("content present");
|
||||
if let Some(success) = success_opt {
|
||||
assert!(!success, "expected success=false payload: {payload:?}");
|
||||
}
|
||||
assert_eq!(content, "No matches found.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn build_test_codex(server: &wiremock::MockServer) -> Result<TestCodex> {
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = MODEL_WITH_TOOL.to_string();
|
||||
config.model_family =
|
||||
find_family_for_model(MODEL_WITH_TOOL).expect("model family for test model");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mount_tool_sequence(
|
||||
server: &wiremock::MockServer,
|
||||
call_id: &str,
|
||||
arguments: &str,
|
||||
tool_name: &str,
|
||||
) {
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, tool_name, arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(server, any(), second_response).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn recorded_bodies(server: &wiremock::MockServer) -> Result<Vec<Value>> {
|
||||
let requests = server.received_requests().await.expect("requests recorded");
|
||||
Ok(requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn find_tool_output<'a>(requests: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
requests.iter().find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_file_names(content: &str) -> HashSet<String> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
if line.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
Path::new(line)
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().into_owned())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_content_and_success(value: &Value) -> (Option<&str>, Option<bool>) {
|
||||
match value {
|
||||
Value::String(text) => (Some(text.as_str()), None),
|
||||
Value::Object(obj) => (
|
||||
obj.get("content").and_then(Value::as_str),
|
||||
obj.get("success").and_then(Value::as_bool),
|
||||
),
|
||||
_ => (None, None),
|
||||
}
|
||||
}
|
||||
460
codex-rs/core/tests/suite/list_dir.rs
Normal file
460
codex-rs/core/tests/suite/list_dir.rs
Normal file
@@ -0,0 +1,460 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_returns_entries() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("sample_dir");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "first file")?;
|
||||
std::fs::create_dir(dir_path.join("nested"))?;
|
||||
let dir_path = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-call";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path,
|
||||
"offset": 1,
|
||||
"limit": 2,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(output_text, "E1: [file] alpha.txt\nE2: [dir] nested");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_one_omits_children() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_one");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
std::fs::create_dir(dir_path.join("nested"))?;
|
||||
std::fs::write(dir_path.join("nested").join("beta.txt"), "beta")?;
|
||||
let dir_path = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth1";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 1,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth one".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(output_text, "E1: [file] alpha.txt\nE2: [dir] nested");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_two_includes_children_only() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_two");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
let nested = dir_path.join("nested");
|
||||
std::fs::create_dir(&nested)?;
|
||||
std::fs::write(nested.join("beta.txt"), "beta")?;
|
||||
let deeper = nested.join("grand");
|
||||
std::fs::create_dir(&deeper)?;
|
||||
std::fs::write(deeper.join("gamma.txt"), "gamma")?;
|
||||
let dir_path_string = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth2";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path_string,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 2,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth two".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(
|
||||
output_text,
|
||||
"E1: [file] alpha.txt\nE2: [dir] nested\nE3: [file] nested/beta.txt\nE4: [dir] nested/grand"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_three_includes_grandchildren() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_three");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
let nested = dir_path.join("nested");
|
||||
std::fs::create_dir(&nested)?;
|
||||
std::fs::write(nested.join("beta.txt"), "beta")?;
|
||||
let deeper = nested.join("grand");
|
||||
std::fs::create_dir(&deeper)?;
|
||||
std::fs::write(deeper.join("gamma.txt"), "gamma")?;
|
||||
let dir_path_string = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth3";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path_string,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 3,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth three".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(
|
||||
output_text,
|
||||
"E1: [file] alpha.txt\nE2: [dir] nested\nE3: [file] nested/beta.txt\nE4: [dir] nested/grand\nE5: [file] nested/grand/gamma.txt"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -9,7 +9,9 @@ mod compact_resume_fork;
|
||||
mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod grep_files;
|
||||
mod json_result;
|
||||
mod list_dir;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod model_tools;
|
||||
@@ -20,9 +22,11 @@ mod review;
|
||||
mod rmcp_client;
|
||||
mod rollout_list_find;
|
||||
mod seatbelt;
|
||||
mod shell_serialization;
|
||||
mod stream_error_allows_next_turn;
|
||||
mod stream_no_completed;
|
||||
mod tool_harness;
|
||||
mod tool_parallelism;
|
||||
mod tools;
|
||||
mod unified_exec;
|
||||
mod user_notification;
|
||||
|
||||
@@ -10,14 +10,11 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
@@ -44,16 +41,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed(model);
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock = responses::mount_sse_once_match(&server, wiremock::matchers::any(), sse).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -93,13 +81,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for model {model}"
|
||||
);
|
||||
let body = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body = resp_mock.single_request().body_json();
|
||||
tool_identifiers(&body)
|
||||
}
|
||||
|
||||
@@ -111,14 +93,21 @@ async fn model_selects_expected_tools() {
|
||||
let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await;
|
||||
assert_eq!(
|
||||
codex_tools,
|
||||
vec!["local_shell".to_string(), "read_file".to_string()],
|
||||
vec!["local_shell".to_string()],
|
||||
"codex-mini-latest should expose the local shell tool",
|
||||
);
|
||||
|
||||
let o3_tools = collect_tool_identifiers_for_model("o3").await;
|
||||
assert_eq!(
|
||||
o3_tools,
|
||||
vec!["shell".to_string(), "read_file".to_string()],
|
||||
vec!["shell".to_string()],
|
||||
"o3 should expose the generic shell tool",
|
||||
);
|
||||
|
||||
let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await;
|
||||
assert_eq!(
|
||||
gpt5_codex_tools,
|
||||
vec!["shell".to_string(), "apply_patch".to_string(),],
|
||||
"gpt-5-codex should expose the apply_patch tool",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -18,6 +19,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use std::collections::HashMap;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
@@ -178,16 +180,16 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
config.include_apply_patch_tool = true;
|
||||
config.include_plan_tool = true;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let expected_instructions = config.model_family.base_instructions.clone();
|
||||
let base_instructions = config.model_family.base_instructions.clone();
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
@@ -219,14 +221,29 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
// our internal implementation is responsible for keeping tools in sync
|
||||
// with the OpenAI schema, so we just verify the tool presence here
|
||||
let expected_tools_names: &[&str] = &[
|
||||
"shell",
|
||||
"update_plan",
|
||||
"apply_patch",
|
||||
"read_file",
|
||||
"view_image",
|
||||
];
|
||||
let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([
|
||||
("gpt-5", vec!["shell", "update_plan", "view_image"]),
|
||||
(
|
||||
"gpt-5-codex",
|
||||
vec!["shell", "update_plan", "apply_patch", "view_image"],
|
||||
),
|
||||
]);
|
||||
let expected_tools_names = tools_by_model
|
||||
.get(OPENAI_DEFAULT_MODEL)
|
||||
.unwrap_or_else(|| panic!("expected tools to be defined for model {OPENAI_DEFAULT_MODEL}"))
|
||||
.as_slice();
|
||||
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
let expected_instructions = if expected_tools_names.contains(&"apply_patch") {
|
||||
base_instructions
|
||||
} else {
|
||||
[
|
||||
base_instructions.clone(),
|
||||
include_str!("../../../apply-patch/apply_patch_tool_instructions.md").to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
body0["instructions"],
|
||||
serde_json::json!(expected_instructions),
|
||||
|
||||
@@ -10,6 +10,7 @@ use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
@@ -21,6 +22,7 @@ use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable read_file tool"]
|
||||
async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -46,10 +48,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "read_file", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -59,7 +58,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -80,36 +79,12 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let tool_output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
|
||||
@@ -24,6 +24,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id_from_str;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -260,25 +261,28 @@ async fn review_does_not_emit_agent_message_on_structured_output() {
|
||||
.unwrap();
|
||||
|
||||
// Drain events until TaskComplete; ensure none are AgentMessage.
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
let mut saw_entered = false;
|
||||
let mut saw_exited = false;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(5), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
match ev.msg {
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| match event {
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
EventMsg::AgentMessage(_) => {
|
||||
panic!("unexpected AgentMessage during review with structured output")
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => saw_entered = true,
|
||||
EventMsg::ExitedReviewMode(_) => saw_exited = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => {
|
||||
saw_entered = true;
|
||||
false
|
||||
}
|
||||
EventMsg::ExitedReviewMode(_) => {
|
||||
saw_exited = true;
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
tokio::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
assert!(saw_entered && saw_exited, "missing review lifecycle events");
|
||||
|
||||
server.verify().await;
|
||||
@@ -441,7 +445,7 @@ async fn review_input_isolated_from_parent_history() {
|
||||
.await;
|
||||
let _complete = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Assert the request `input` contains the environment context followed by the review prompt.
|
||||
// Assert the request `input` contains the environment context followed by the user review prompt.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let input = body["input"].as_array().expect("input array");
|
||||
@@ -469,9 +473,14 @@ async fn review_input_isolated_from_parent_history() {
|
||||
assert_eq!(review_msg["role"].as_str().unwrap(), "user");
|
||||
assert_eq!(
|
||||
review_msg["content"][0]["text"].as_str().unwrap(),
|
||||
format!("{REVIEW_PROMPT}\n\n---\n\nNow, here's your task: Please review only this",)
|
||||
review_prompt,
|
||||
"user message should only contain the raw review prompt"
|
||||
);
|
||||
|
||||
// Ensure the REVIEW_PROMPT rubric is sent via instructions.
|
||||
let instructions = body["instructions"].as_str().expect("instructions string");
|
||||
assert_eq!(instructions, REVIEW_PROMPT);
|
||||
|
||||
// Also verify that a user interruption note was recorded in the rollout.
|
||||
codex.submit(Op::GetPath).await.unwrap();
|
||||
let history_event =
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
@@ -19,6 +24,8 @@ use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use escargot::CargoBuild;
|
||||
use serde_json::Value;
|
||||
use serial_test::serial;
|
||||
use tempfile::tempdir;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
@@ -40,10 +47,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -82,6 +86,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -177,10 +182,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -231,8 +233,9 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -328,6 +331,187 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This test writes to a fallback credentials file in CODEX_HOME.
|
||||
/// Ideally, we wouldn't need to serialize the test but it's much more cumbersome to wire CODEX_HOME through the code.
|
||||
#[serial(codex_home)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-789";
|
||||
let server_name = "rmcp_http_oauth";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message(
|
||||
"msg-1",
|
||||
"rmcp streamable http oauth echo tool completed successfully.",
|
||||
),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env-http-oauth";
|
||||
let expected_token = "initial-access-token";
|
||||
let client_id = "test-client-id";
|
||||
let refresh_token = "initial-refresh-token";
|
||||
let rmcp_http_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_streamable_http_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let server_url = format!("http://{bind_addr}/mcp");
|
||||
|
||||
let mut http_server_child = Command::new(&rmcp_http_server_bin)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_EXPECT_BEARER", expected_token)
|
||||
.env("MCP_TEST_VALUE", expected_env_value)
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
|
||||
.await?;
|
||||
|
||||
let temp_home = tempdir()?;
|
||||
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
|
||||
write_fallback_oauth_tokens(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
client_id,
|
||||
expected_token,
|
||||
refresh_token,
|
||||
)?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp streamable http oauth echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
match http_server_child.try_wait() {
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("failed to check streamable http oauth server status: {error}");
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
}
|
||||
if let Err(error) = http_server_child.wait().await {
|
||||
eprintln!("failed to await streamable http oauth server shutdown: {error}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
@@ -369,3 +553,60 @@ async fn wait_for_streamable_http_server(
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn write_fallback_oauth_tokens(
|
||||
home: &Path,
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
client_id: &str,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let expires_at = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(3600))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
|
||||
let store = serde_json::json!({
|
||||
"stub": {
|
||||
"server_name": server_name,
|
||||
"server_url": server_url,
|
||||
"client_id": client_id,
|
||||
"access_token": access_token,
|
||||
"expires_at": expires_at,
|
||||
"refresh_token": refresh_token,
|
||||
"scopes": ["profile"],
|
||||
}
|
||||
});
|
||||
|
||||
let file_path = home.join(".credentials.json");
|
||||
fs::write(&file_path, serde_json::to_vec(&store)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: &std::ffi::OsStr) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
match &self.original {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str, sandbox_policy: SandboxPolicy) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[wiremock::Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn find_function_call_output<'a>(bodies: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
{
|
||||
return Some(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_stays_json_without_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = false;
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-json";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "shell json"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the json shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item = find_function_call_output(&bodies, call_id).expect("shell output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("shell output string");
|
||||
|
||||
let parsed: Value = serde_json::from_str(output)?;
|
||||
assert_eq!(
|
||||
parsed
|
||||
.get("metadata")
|
||||
.and_then(|metadata| metadata.get("exit_code"))
|
||||
.and_then(Value::as_i64),
|
||||
Some(0),
|
||||
"expected zero exit code in unformatted JSON output",
|
||||
);
|
||||
let stdout = parsed
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert_regex_match(r"(?s)^shell json\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_is_structured_with_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-structured";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "freeform shell"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the structured shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("structured output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("structured output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected structured shell output to be plain text",
|
||||
);
|
||||
let expected_pattern = r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
freeform shell
|
||||
?$";
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_reserializes_truncated_content() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("gpt-5-codex").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-truncated";
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "seq 1 400"],
|
||||
"timeout_ms": 5_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the truncation shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("truncated output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("truncated output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected truncated shell output to be plain text",
|
||||
);
|
||||
let truncated_pattern = r#"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: 400
|
||||
Output:
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
.*
|
||||
\[\.{3} omitted \d+ of 400 lines \.{3}\]
|
||||
|
||||
.*
|
||||
396
|
||||
397
|
||||
398
|
||||
399
|
||||
400
|
||||
$"#;
|
||||
assert_regex_match(truncated_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -13,7 +13,7 @@ use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use tokio::time::timeout;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
@@ -102,13 +102,10 @@ async fn retries_on_early_close() {
|
||||
.unwrap();
|
||||
|
||||
// Wait until TaskComplete (should succeed after retry).
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(10), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| matches!(event, EventMsg::TaskComplete(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -7,31 +9,24 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::plan_tool::StepStatus;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
item.get("output").and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
@@ -40,12 +35,6 @@ fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
})
|
||||
}
|
||||
|
||||
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||
requests
|
||||
.iter()
|
||||
.find(|body| function_call_output(body).is_some())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -53,7 +42,8 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -65,10 +55,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
let call_id = "shell-tool-call";
|
||||
let command = vec!["/bin/echo", "tool harness"];
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_local_shell_call(call_id, "completed", command),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -78,7 +65,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
ev_assistant_message("msg-1", "all done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -97,32 +84,15 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
let exec_output: Value = serde_json::from_str(output_text)?;
|
||||
assert_eq!(exec_output["metadata"]["exit_code"], 0);
|
||||
let stdout = exec_output["output"].as_str().expect("stdout field");
|
||||
assert!(
|
||||
stdout.contains("tool harness"),
|
||||
"expected stdout to contain command output, got {stdout:?}"
|
||||
);
|
||||
assert_regex_match(r"(?s)^tool harness\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -154,10 +124,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &plan_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -167,7 +134,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "plan acknowledged"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -187,42 +154,31 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PlanUpdate(update) => {
|
||||
saw_plan_update = true;
|
||||
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||
assert_eq!(update.plan.len(), 2);
|
||||
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||
assert!(matches!(update.plan[0].status, StepStatus::InProgress));
|
||||
assert_eq!(update.plan[1].step, "Report results");
|
||||
assert!(matches!(update.plan[1].status, StepStatus::Pending));
|
||||
}
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(update) => {
|
||||
saw_plan_update = true;
|
||||
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||
assert_eq!(update.plan.len(), 2);
|
||||
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||
assert_matches!(update.plan[0].status, StepStatus::InProgress);
|
||||
assert_eq!(update.plan[1].step, "Report results");
|
||||
assert_matches!(update.plan[1].status, StepStatus::Pending);
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_plan_update, "expected PlanUpdate event");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
assert_eq!(output_text, "Plan updated");
|
||||
|
||||
Ok(())
|
||||
@@ -251,10 +207,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &invalid_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -264,7 +217,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "malformed plan payload"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -284,37 +237,28 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PlanUpdate(_) => saw_plan_update = true,
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(_) => {
|
||||
saw_plan_update = true;
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
!saw_plan_update,
|
||||
"did not expect PlanUpdate event for malformed payload"
|
||||
);
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
assert!(
|
||||
output_text.contains("failed to parse function arguments"),
|
||||
"expected parse error message in output text, got {output_text:?}"
|
||||
@@ -357,10 +301,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
*** End Patch"#;
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -370,7 +311,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
ev_assistant_message("msg-1", "patch complete"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -391,43 +332,33 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
|
||||
let mut saw_patch_begin = false;
|
||||
let mut patch_end_success = None;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PatchApplyBegin(begin) => {
|
||||
saw_patch_begin = true;
|
||||
assert_eq!(begin.call_id, call_id);
|
||||
}
|
||||
EventMsg::PatchApplyEnd(end) => {
|
||||
assert_eq!(end.call_id, call_id);
|
||||
patch_end_success = Some(end.success);
|
||||
}
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PatchApplyBegin(begin) => {
|
||||
saw_patch_begin = true;
|
||||
assert_eq!(begin.call_id, call_id);
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(end) => {
|
||||
assert_eq!(end.call_id, call_id);
|
||||
patch_end_success = Some(end.success);
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_patch_begin, "expected PatchApplyBegin event");
|
||||
let patch_end_success =
|
||||
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
|
||||
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
|
||||
let exit_code = exec_output["metadata"]["exit_code"]
|
||||
@@ -487,10 +418,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
*** End Patch";
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -500,7 +428,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "failed"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -519,29 +447,15 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
|
||||
assert!(
|
||||
output_text.contains("apply_patch verification failed"),
|
||||
|
||||
206
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
206
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::json;
|
||||
|
||||
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
|
||||
let start = Instant::now();
|
||||
run_turn(test, prompt).await?;
|
||||
Ok(start.elapsed())
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "test-gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
fn assert_parallel_duration(actual: Duration) {
|
||||
// Allow headroom for runtime overhead while still differentiating from serial execution.
|
||||
assert!(
|
||||
actual < Duration::from_millis(750),
|
||||
"expected parallel execution to finish quickly, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_serial_duration(actual: Duration) {
|
||||
assert!(
|
||||
actual >= Duration::from_millis(500),
|
||||
"expected serial execution to take longer, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let warmup_args = json!({
|
||||
"sleep_after_ms": 10,
|
||||
"barrier": {
|
||||
"id": "parallel-test-sync-warmup",
|
||||
"participants": 2,
|
||||
"timeout_ms": 1_000,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let parallel_args = json!({
|
||||
"sleep_after_ms": 300,
|
||||
"barrier": {
|
||||
"id": "parallel-test-sync",
|
||||
"participants": 2,
|
||||
"timeout_ms": 1_000,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let warmup_first = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-warm-1"}}),
|
||||
ev_function_call("warm-call-1", "test_sync_tool", &warmup_args),
|
||||
ev_function_call("warm-call-2", "test_sync_tool", &warmup_args),
|
||||
ev_completed("resp-warm-1"),
|
||||
]);
|
||||
let warmup_second = sse(vec![
|
||||
ev_assistant_message("warm-msg-1", "warmup complete"),
|
||||
ev_completed("resp-warm-2"),
|
||||
]);
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", ¶llel_args),
|
||||
ev_function_call("call-2", "test_sync_tool", ¶llel_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(
|
||||
&server,
|
||||
vec![warmup_first, warmup_second, first_response, second_response],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_turn(&test, "warm up parallel tool").await?;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
|
||||
assert_parallel_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = test_codex().build(&server).await?;
|
||||
|
||||
let shell_args = json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let args_one = serde_json::to_string(&shell_args)?;
|
||||
let args_two = serde_json::to_string(&shell_args)?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "shell", &args_one),
|
||||
ev_function_call("call-2", "shell", &args_two),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "run shell twice").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let sync_args = json!({
|
||||
"sleep_after_ms": 300
|
||||
})
|
||||
.to_string();
|
||||
let shell_args = serde_json::to_string(&json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
}))?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", &sync_args),
|
||||
ev_function_call("call-2", "shell", &shell_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "mix tools").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -2,25 +2,30 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::Request;
|
||||
|
||||
async fn submit_turn(
|
||||
test: &TestCodex,
|
||||
@@ -45,37 +50,14 @@ async fn submit_turn(
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = test.codex.next_event().await?;
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> {
|
||||
let mut out = Vec::new();
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some(ty) {
|
||||
out.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn tool_names(body: &Value) -> Vec<String> {
|
||||
body.get("tools")
|
||||
.and_then(Value::as_array)
|
||||
@@ -104,18 +86,23 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||
let call_id = "custom-unsupported";
|
||||
let tool_name = "unsupported_tool";
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_custom_tool_call(call_id, tool_name, "\"payload\""),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -125,13 +112,7 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let custom_items = collect_output_items(&bodies, "custom_tool_call_output");
|
||||
assert_eq!(custom_items.len(), 1, "expected single custom tool output");
|
||||
let item = custom_items[0];
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id));
|
||||
|
||||
let item = mock.single_request().custom_tool_call_output(call_id);
|
||||
let output = item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
@@ -147,7 +128,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let command = ["/bin/echo", "shell ok"];
|
||||
@@ -164,9 +148,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
call_id_blocked,
|
||||
"shell",
|
||||
@@ -174,8 +159,12 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
call_id_success,
|
||||
"shell",
|
||||
@@ -183,12 +172,16 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let third_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -198,46 +191,23 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
for item in &function_outputs {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
call_id == call_id_blocked || call_id == call_id_success,
|
||||
"unexpected call id {call_id}"
|
||||
);
|
||||
}
|
||||
|
||||
let policy = AskForApproval::Never;
|
||||
let expected_message = format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}"
|
||||
);
|
||||
|
||||
let blocked_outputs: Vec<&Value> = function_outputs
|
||||
.iter()
|
||||
.filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked))
|
||||
.copied()
|
||||
.collect();
|
||||
assert!(
|
||||
!blocked_outputs.is_empty(),
|
||||
"expected at least one rejection output for {call_id_blocked}"
|
||||
let blocked_item = second_mock
|
||||
.single_request()
|
||||
.function_call_output(call_id_blocked);
|
||||
assert_eq!(
|
||||
blocked_item.get("output").and_then(Value::as_str),
|
||||
Some(expected_message.as_str()),
|
||||
"unexpected rejection message"
|
||||
);
|
||||
for item in blocked_outputs {
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
Some(expected_message.as_str()),
|
||||
"unexpected rejection message"
|
||||
);
|
||||
}
|
||||
|
||||
let success_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success))
|
||||
.expect("success output present");
|
||||
let success_item = third_mock
|
||||
.single_request()
|
||||
.function_call_output(call_id_success);
|
||||
let output_json: Value = serde_json::from_str(
|
||||
success_item
|
||||
.get("output")
|
||||
@@ -250,10 +220,8 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
"expected exit code 0 after rerunning without escalation",
|
||||
);
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
stdout.contains("shell ok"),
|
||||
"expected stdout to include command output, got {stdout:?}"
|
||||
);
|
||||
let stdout_pattern = r"(?s)^shell ok\n?$";
|
||||
assert_regex_match(stdout_pattern, stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -278,18 +246,23 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
local_shell_event,
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -299,15 +272,7 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
assert_eq!(
|
||||
function_outputs.len(),
|
||||
1,
|
||||
"expected a single function output"
|
||||
);
|
||||
let item = function_outputs[0];
|
||||
let item = second_mock.single_request().function_call_output("");
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
@@ -321,11 +286,11 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let responses = vec![sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
let mock = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.use_experimental_unified_exec_tool = use_unified_exec;
|
||||
@@ -340,15 +305,8 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for tools collection"
|
||||
);
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let first_body = bodies.first().expect("request body present");
|
||||
Ok(tool_names(first_body))
|
||||
let first_body = mock.single_request().body_json();
|
||||
Ok(tool_names(&first_body))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -375,7 +333,10 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-timeout";
|
||||
@@ -385,18 +346,23 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
"timeout_ms": timeout_ms,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -406,45 +372,111 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
let timeout_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||
.expect("timeout output present");
|
||||
let timeout_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output_json: Value = serde_json::from_str(
|
||||
timeout_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("timeout output string"),
|
||||
)?;
|
||||
assert_eq!(
|
||||
output_json["metadata"]["exit_code"].as_i64(),
|
||||
Some(124),
|
||||
"expected timeout exit code 124",
|
||||
);
|
||||
let output_str = timeout_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("timeout output string");
|
||||
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
stdout.starts_with("command timed out after "),
|
||||
"expected timeout prefix, got {stdout:?}"
|
||||
);
|
||||
let first_line = stdout.lines().next().unwrap_or_default();
|
||||
let duration_ms = first_line
|
||||
.strip_prefix("command timed out after ")
|
||||
.and_then(|line| line.strip_suffix(" milliseconds"))
|
||||
.and_then(|value| value.parse::<u64>().ok())
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
duration_ms >= timeout_ms,
|
||||
"expected duration >= configured timeout, got {duration_ms} (timeout {timeout_ms})"
|
||||
);
|
||||
assert!(
|
||||
stdout.contains("[... omitted"),
|
||||
"expected truncated output marker, got {stdout:?}"
|
||||
);
|
||||
// The exec path can report a timeout in two ways depending on timing:
|
||||
// 1) Structured JSON with exit_code 124 and a timeout prefix (preferred), or
|
||||
// 2) A plain error string if the child is observed as killed by a signal first.
|
||||
if let Ok(output_json) = serde_json::from_str::<Value>(output_str) {
|
||||
assert_eq!(
|
||||
output_json["metadata"]["exit_code"].as_i64(),
|
||||
Some(124),
|
||||
"expected timeout exit code 124",
|
||||
);
|
||||
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
stdout.contains("command timed out"),
|
||||
"timeout output missing `command timed out`: {stdout}"
|
||||
);
|
||||
} else {
|
||||
// Fallback: accept the signal classification path to deflake the test.
|
||||
let signal_pattern = r"(?is)^execution error:.*signal.*$";
|
||||
assert_regex_match(signal_pattern, output_str);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|cfg| {
|
||||
cfg.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-spawn-failure";
|
||||
let bogus_component = "missing-bin-".repeat(700);
|
||||
let bogus_exe = test
|
||||
.cwd
|
||||
.path()
|
||||
.join(bogus_component)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let args = json!({
|
||||
"command": [bogus_exe],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"spawn a missing binary",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let failure_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output = failure_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("spawn failure output string");
|
||||
|
||||
let spawn_error_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
execution error: .*$"#;
|
||||
let spawn_truncated_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
execution error: .*$"#;
|
||||
let spawn_error_regex = Regex::new(spawn_error_pattern)?;
|
||||
let spawn_truncated_regex = Regex::new(spawn_truncated_pattern)?;
|
||||
if !spawn_error_regex.is_match(output) && !spawn_truncated_regex.is_match(output) {
|
||||
let fallback_pattern = r"(?s)^execution error: .*$";
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
assert!(output.len() <= 10 * 1024);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
@@ -19,6 +20,7 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::skip_if_sandbox;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
@@ -81,7 +83,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
@@ -90,7 +92,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
@@ -122,12 +124,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
@@ -170,6 +167,131 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_streams_after_lagged_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let script = r#"python3 - <<'PY'
|
||||
import sys
|
||||
import time
|
||||
|
||||
chunk = b'x' * (1 << 20)
|
||||
for _ in range(4):
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.flush()
|
||||
|
||||
time.sleep(0.2)
|
||||
for _ in range(5):
|
||||
sys.stdout.write("TAIL-MARKER\n")
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.05)
|
||||
|
||||
time.sleep(0.2)
|
||||
PY
|
||||
"#;
|
||||
|
||||
let first_call_id = "uexec-lag-start";
|
||||
let first_args = serde_json::json!({
|
||||
"input": ["/bin/sh", "-c", script],
|
||||
"timeout_ms": 25,
|
||||
});
|
||||
|
||||
let second_call_id = "uexec-lag-poll";
|
||||
let second_args = serde_json::json!({
|
||||
"input": Vec::<String>::new(),
|
||||
"session_id": "0",
|
||||
"timeout_ms": 2_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "lag handled"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "exercise lag handling".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let start_output = outputs
|
||||
.get(first_call_id)
|
||||
.expect("missing initial unified_exec output");
|
||||
let session_id = start_output["session_id"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
!session_id.is_empty(),
|
||||
"expected session id from initial unified_exec response"
|
||||
);
|
||||
|
||||
let poll_output = outputs
|
||||
.get(second_call_id)
|
||||
.expect("missing poll unified_exec output");
|
||||
let poll_text = poll_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
poll_text.contains("TAIL-MARKER"),
|
||||
"expected poll output to contain tail marker, got {poll_text:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -202,7 +324,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
@@ -211,7 +333,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user