mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
18 Commits
dev/cc/rel
...
models-not
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a07b59dc6 | ||
|
|
23579e2a76 | ||
|
|
bbc5675974 | ||
|
|
51865695e4 | ||
|
|
3a32716e1c | ||
|
|
5ceeaa96b8 | ||
|
|
b27c702e83 | ||
|
|
e290d48264 | ||
|
|
3d14da9728 | ||
|
|
b53889aed5 | ||
|
|
d7482510b1 | ||
|
|
021c9a60e5 | ||
|
|
c9f5b9a6df | ||
|
|
ae57e18947 | ||
|
|
cf44511e77 | ||
|
|
bef36f4ae7 | ||
|
|
f074e5706b | ||
|
|
b9d1a087ee |
212
.github/actions/macos-code-sign/action.yml
vendored
Normal file
212
.github/actions/macos-code-sign/action.yml
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
name: macos-code-sign
|
||||
description: Configure, sign, notarize, and clean up macOS code signing artifacts.
|
||||
inputs:
|
||||
target:
|
||||
description: Rust compilation target triple (e.g. aarch64-apple-darwin).
|
||||
required: true
|
||||
apple-certificate:
|
||||
description: Base64-encoded Apple signing certificate (P12).
|
||||
required: true
|
||||
apple-certificate-password:
|
||||
description: Password for the signing certificate.
|
||||
required: true
|
||||
apple-notarization-key-p8:
|
||||
description: Base64-encoded Apple notarization key (P8).
|
||||
required: true
|
||||
apple-notarization-key-id:
|
||||
description: Apple notarization key ID.
|
||||
required: true
|
||||
apple-notarization-issuer-id:
|
||||
description: Apple notarization issuer ID.
|
||||
required: true
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Configure Apple code signing
|
||||
shell: bash
|
||||
env:
|
||||
KEYCHAIN_PASSWORD: actions
|
||||
APPLE_CERTIFICATE: ${{ inputs.apple-certificate }}
|
||||
APPLE_CERTIFICATE_PASSWORD: ${{ inputs.apple-certificate-password }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE_PASSWORD:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE_PASSWORD is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cert_path="${RUNNER_TEMP}/apple_signing_certificate.p12"
|
||||
echo "$APPLE_CERTIFICATE" | base64 -d > "$cert_path"
|
||||
|
||||
keychain_path="${RUNNER_TEMP}/codex-signing.keychain-db"
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
security set-keychain-settings -lut 21600 "$keychain_path"
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
|
||||
keychain_args=()
|
||||
cleanup_keychain() {
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}" || true
|
||||
security default-keychain -s "${keychain_args[0]}" || true
|
||||
else
|
||||
security list-keychains -s || true
|
||||
fi
|
||||
if [[ -f "$keychain_path" ]]; then
|
||||
security delete-keychain "$keychain_path" || true
|
||||
fi
|
||||
}
|
||||
|
||||
while IFS= read -r keychain; do
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "$keychain_path" "${keychain_args[@]}"
|
||||
else
|
||||
security list-keychains -s "$keychain_path"
|
||||
fi
|
||||
|
||||
security default-keychain -s "$keychain_path"
|
||||
security import "$cert_path" -k "$keychain_path" -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign -T /usr/bin/security
|
||||
security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PASSWORD" "$keychain_path" > /dev/null
|
||||
|
||||
codesign_hashes=()
|
||||
while IFS= read -r hash; do
|
||||
[[ -n "$hash" ]] && codesign_hashes+=("$hash")
|
||||
done < <(security find-identity -v -p codesigning "$keychain_path" \
|
||||
| sed -n 's/.*\([0-9A-F]\{40\}\).*/\1/p' \
|
||||
| sort -u)
|
||||
|
||||
if ((${#codesign_hashes[@]} == 0)); then
|
||||
echo "No signing identities found in $keychain_path"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ((${#codesign_hashes[@]} > 1)); then
|
||||
echo "Multiple signing identities found in $keychain_path:"
|
||||
printf ' %s\n' "${codesign_hashes[@]}"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
APPLE_CODESIGN_IDENTITY="${codesign_hashes[0]}"
|
||||
|
||||
rm -f "$cert_path"
|
||||
|
||||
echo "APPLE_CODESIGN_IDENTITY=$APPLE_CODESIGN_IDENTITY" >> "$GITHUB_ENV"
|
||||
echo "APPLE_CODESIGN_KEYCHAIN=$keychain_path" >> "$GITHUB_ENV"
|
||||
echo "::add-mask::$APPLE_CODESIGN_IDENTITY"
|
||||
|
||||
- name: Sign macOS binaries
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CODESIGN_IDENTITY:-}" ]]; then
|
||||
echo "APPLE_CODESIGN_IDENTITY is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
keychain_args=()
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" && -f "${APPLE_CODESIGN_KEYCHAIN}" ]]; then
|
||||
keychain_args+=(--keychain "${APPLE_CODESIGN_KEYCHAIN}")
|
||||
fi
|
||||
|
||||
for binary in codex codex-responses-api-proxy; do
|
||||
path="codex-rs/target/${{ inputs.target }}/release/${binary}"
|
||||
codesign --force --options runtime --timestamp --sign "$APPLE_CODESIGN_IDENTITY" "${keychain_args[@]}" "$path"
|
||||
done
|
||||
|
||||
- name: Notarize macOS binaries
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_NOTARIZATION_KEY_P8: ${{ inputs.apple-notarization-key-p8 }}
|
||||
APPLE_NOTARIZATION_KEY_ID: ${{ inputs.apple-notarization-key-id }}
|
||||
APPLE_NOTARIZATION_ISSUER_ID: ${{ inputs.apple-notarization-issuer-id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
for var in APPLE_NOTARIZATION_KEY_P8 APPLE_NOTARIZATION_KEY_ID APPLE_NOTARIZATION_ISSUER_ID; do
|
||||
if [[ -z "${!var:-}" ]]; then
|
||||
echo "$var is required for notarization"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
notary_key_path="${RUNNER_TEMP}/notarytool.key.p8"
|
||||
echo "$APPLE_NOTARIZATION_KEY_P8" | base64 -d > "$notary_key_path"
|
||||
cleanup_notary() {
|
||||
rm -f "$notary_key_path"
|
||||
}
|
||||
trap cleanup_notary EXIT
|
||||
|
||||
notarize_binary() {
|
||||
local binary="$1"
|
||||
local source_path="codex-rs/target/${{ inputs.target }}/release/${binary}"
|
||||
local archive_path="${RUNNER_TEMP}/${binary}.zip"
|
||||
|
||||
if [[ ! -f "$source_path" ]]; then
|
||||
echo "Binary $source_path not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -f "$archive_path"
|
||||
ditto -c -k --keepParent "$source_path" "$archive_path"
|
||||
|
||||
submission_json=$(xcrun notarytool submit "$archive_path" \
|
||||
--key "$notary_key_path" \
|
||||
--key-id "$APPLE_NOTARIZATION_KEY_ID" \
|
||||
--issuer "$APPLE_NOTARIZATION_ISSUER_ID" \
|
||||
--output-format json \
|
||||
--wait)
|
||||
|
||||
status=$(printf '%s\n' "$submission_json" | jq -r '.status // "Unknown"')
|
||||
submission_id=$(printf '%s\n' "$submission_json" | jq -r '.id // ""')
|
||||
|
||||
if [[ -z "$submission_id" ]]; then
|
||||
echo "Failed to retrieve submission ID for $binary"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "::notice title=Notarization::$binary submission ${submission_id} completed with status ${status}"
|
||||
|
||||
if [[ "$status" != "Accepted" ]]; then
|
||||
echo "Notarization failed for ${binary} (submission ${submission_id}, status ${status})"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
notarize_binary "codex"
|
||||
notarize_binary "codex-responses-api-proxy"
|
||||
|
||||
- name: Remove signing keychain
|
||||
if: ${{ always() }}
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_CODESIGN_KEYCHAIN: ${{ env.APPLE_CODESIGN_KEYCHAIN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" ]]; then
|
||||
keychain_args=()
|
||||
while IFS= read -r keychain; do
|
||||
[[ "$keychain" == "$APPLE_CODESIGN_KEYCHAIN" ]] && continue
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}"
|
||||
security default-keychain -s "${keychain_args[0]}"
|
||||
fi
|
||||
|
||||
if [[ -f "$APPLE_CODESIGN_KEYCHAIN" ]]; then
|
||||
security delete-keychain "$APPLE_CODESIGN_KEYCHAIN"
|
||||
fi
|
||||
fi
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
|
||||
203
.github/workflows/rust-release.yml
vendored
203
.github/workflows/rust-release.yml
vendored
@@ -129,173 +129,15 @@ jobs:
|
||||
certificate-profile-name: ${{ secrets.AZURE_TRUSTED_SIGNING_CERTIFICATE_PROFILE_NAME }}
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-15-xlarge' }}
|
||||
name: Configure Apple code signing
|
||||
shell: bash
|
||||
env:
|
||||
KEYCHAIN_PASSWORD: actions
|
||||
APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE_P12 }}
|
||||
APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${APPLE_CERTIFICATE_PASSWORD:-}" ]]; then
|
||||
echo "APPLE_CERTIFICATE_PASSWORD is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cert_path="${RUNNER_TEMP}/apple_signing_certificate.p12"
|
||||
echo "$APPLE_CERTIFICATE" | base64 -d > "$cert_path"
|
||||
|
||||
keychain_path="${RUNNER_TEMP}/codex-signing.keychain-db"
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
security set-keychain-settings -lut 21600 "$keychain_path"
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" "$keychain_path"
|
||||
|
||||
keychain_args=()
|
||||
cleanup_keychain() {
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}" || true
|
||||
security default-keychain -s "${keychain_args[0]}" || true
|
||||
else
|
||||
security list-keychains -s || true
|
||||
fi
|
||||
if [[ -f "$keychain_path" ]]; then
|
||||
security delete-keychain "$keychain_path" || true
|
||||
fi
|
||||
}
|
||||
|
||||
while IFS= read -r keychain; do
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "$keychain_path" "${keychain_args[@]}"
|
||||
else
|
||||
security list-keychains -s "$keychain_path"
|
||||
fi
|
||||
|
||||
security default-keychain -s "$keychain_path"
|
||||
security import "$cert_path" -k "$keychain_path" -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign -T /usr/bin/security
|
||||
security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PASSWORD" "$keychain_path" > /dev/null
|
||||
|
||||
codesign_hashes=()
|
||||
while IFS= read -r hash; do
|
||||
[[ -n "$hash" ]] && codesign_hashes+=("$hash")
|
||||
done < <(security find-identity -v -p codesigning "$keychain_path" \
|
||||
| sed -n 's/.*\([0-9A-F]\{40\}\).*/\1/p' \
|
||||
| sort -u)
|
||||
|
||||
if ((${#codesign_hashes[@]} == 0)); then
|
||||
echo "No signing identities found in $keychain_path"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ((${#codesign_hashes[@]} > 1)); then
|
||||
echo "Multiple signing identities found in $keychain_path:"
|
||||
printf ' %s\n' "${codesign_hashes[@]}"
|
||||
cleanup_keychain
|
||||
rm -f "$cert_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
APPLE_CODESIGN_IDENTITY="${codesign_hashes[0]}"
|
||||
|
||||
rm -f "$cert_path"
|
||||
|
||||
echo "APPLE_CODESIGN_IDENTITY=$APPLE_CODESIGN_IDENTITY" >> "$GITHUB_ENV"
|
||||
echo "APPLE_CODESIGN_KEYCHAIN=$keychain_path" >> "$GITHUB_ENV"
|
||||
echo "::add-mask::$APPLE_CODESIGN_IDENTITY"
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-15-xlarge' }}
|
||||
name: Sign macOS binaries
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${APPLE_CODESIGN_IDENTITY:-}" ]]; then
|
||||
echo "APPLE_CODESIGN_IDENTITY is required for macOS signing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
keychain_args=()
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" && -f "${APPLE_CODESIGN_KEYCHAIN}" ]]; then
|
||||
keychain_args+=(--keychain "${APPLE_CODESIGN_KEYCHAIN}")
|
||||
fi
|
||||
|
||||
for binary in codex codex-responses-api-proxy; do
|
||||
path="target/${{ matrix.target }}/release/${binary}"
|
||||
codesign --force --options runtime --timestamp --sign "$APPLE_CODESIGN_IDENTITY" "${keychain_args[@]}" "$path"
|
||||
done
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-15-xlarge' }}
|
||||
name: Notarize macOS binaries
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_NOTARIZATION_KEY_P8: ${{ secrets.APPLE_NOTARIZATION_KEY_P8 }}
|
||||
APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
|
||||
APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
for var in APPLE_NOTARIZATION_KEY_P8 APPLE_NOTARIZATION_KEY_ID APPLE_NOTARIZATION_ISSUER_ID; do
|
||||
if [[ -z "${!var:-}" ]]; then
|
||||
echo "$var is required for notarization"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
notary_key_path="${RUNNER_TEMP}/notarytool.key.p8"
|
||||
echo "$APPLE_NOTARIZATION_KEY_P8" | base64 -d > "$notary_key_path"
|
||||
cleanup_notary() {
|
||||
rm -f "$notary_key_path"
|
||||
}
|
||||
trap cleanup_notary EXIT
|
||||
|
||||
notarize_binary() {
|
||||
local binary="$1"
|
||||
local source_path="target/${{ matrix.target }}/release/${binary}"
|
||||
local archive_path="${RUNNER_TEMP}/${binary}.zip"
|
||||
|
||||
if [[ ! -f "$source_path" ]]; then
|
||||
echo "Binary $source_path not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -f "$archive_path"
|
||||
ditto -c -k --keepParent "$source_path" "$archive_path"
|
||||
|
||||
submission_json=$(xcrun notarytool submit "$archive_path" \
|
||||
--key "$notary_key_path" \
|
||||
--key-id "$APPLE_NOTARIZATION_KEY_ID" \
|
||||
--issuer "$APPLE_NOTARIZATION_ISSUER_ID" \
|
||||
--output-format json \
|
||||
--wait)
|
||||
|
||||
status=$(printf '%s\n' "$submission_json" | jq -r '.status // "Unknown"')
|
||||
submission_id=$(printf '%s\n' "$submission_json" | jq -r '.id // ""')
|
||||
|
||||
if [[ -z "$submission_id" ]]; then
|
||||
echo "Failed to retrieve submission ID for $binary"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "::notice title=Notarization::$binary submission ${submission_id} completed with status ${status}"
|
||||
|
||||
if [[ "$status" != "Accepted" ]]; then
|
||||
echo "Notarization failed for ${binary} (submission ${submission_id}, status ${status})"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
notarize_binary "codex"
|
||||
notarize_binary "codex-responses-api-proxy"
|
||||
name: MacOS code signing
|
||||
uses: ./.github/actions/macos-code-sign
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
apple-certificate: ${{ secrets.APPLE_CERTIFICATE_P12 }}
|
||||
apple-certificate-password: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }}
|
||||
apple-notarization-key-p8: ${{ secrets.APPLE_NOTARIZATION_KEY_P8 }}
|
||||
apple-notarization-key-id: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
|
||||
apple-notarization-issuer-id: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
@@ -380,29 +222,6 @@ jobs:
|
||||
zstd "${zstd_args[@]}" "$dest/$base"
|
||||
done
|
||||
|
||||
- name: Remove signing keychain
|
||||
if: ${{ always() && matrix.runner == 'macos-15-xlarge' }}
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_CODESIGN_KEYCHAIN: ${{ env.APPLE_CODESIGN_KEYCHAIN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ -n "${APPLE_CODESIGN_KEYCHAIN:-}" ]]; then
|
||||
keychain_args=()
|
||||
while IFS= read -r keychain; do
|
||||
[[ "$keychain" == "$APPLE_CODESIGN_KEYCHAIN" ]] && continue
|
||||
[[ -n "$keychain" ]] && keychain_args+=("$keychain")
|
||||
done < <(security list-keychains | sed 's/^[[:space:]]*//;s/[[:space:]]*$//;s/"//g')
|
||||
if ((${#keychain_args[@]} > 0)); then
|
||||
security list-keychains -s "${keychain_args[@]}"
|
||||
security default-keychain -s "${keychain_args[0]}"
|
||||
fi
|
||||
|
||||
if [[ -f "$APPLE_CODESIGN_KEYCHAIN" ]]; then
|
||||
security delete-keychain "$APPLE_CODESIGN_KEYCHAIN"
|
||||
fi
|
||||
fi
|
||||
|
||||
- uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: ${{ matrix.target }}
|
||||
@@ -487,7 +306,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js for npm packaging
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
@@ -538,7 +357,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
|
||||
2
.github/workflows/sdk.yml
vendored
2
.github/workflows/sdk.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
|
||||
2
.github/workflows/shell-tool-mcp-ci.yml
vendored
2
.github/workflows/shell-tool-mcp-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
cache: "pnpm"
|
||||
|
||||
4
.github/workflows/shell-tool-mcp.yml
vendored
4
.github/workflows/shell-tool-mcp.yml
vendored
@@ -280,7 +280,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
|
||||
@@ -376,7 +376,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
registry-url: https://registry.npmjs.org
|
||||
|
||||
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -1701,7 +1701,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"arboard",
|
||||
"assert_matches",
|
||||
"async-stream",
|
||||
"base64",
|
||||
"chrono",
|
||||
"clap",
|
||||
@@ -6910,6 +6909,7 @@ dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -532,6 +532,7 @@ server_notification_definitions! {
|
||||
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
|
||||
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
|
||||
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
|
||||
ModelPresetsUpdated => "model/presets/updated" (v2::ModelPresetsUpdatedNotification),
|
||||
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
|
||||
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),
|
||||
ReasoningSummaryPartAdded => "item/reasoning/summaryPartAdded" (v2::ReasoningSummaryPartAddedNotification),
|
||||
|
||||
@@ -751,12 +751,7 @@ pub struct ReasoningEffortOption {
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ModelListResponse {
|
||||
pub data: Vec<Model>,
|
||||
/// Opaque cursor to pass to the next call to continue after the last item.
|
||||
/// If None, there are no more items to return.
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
pub struct ModelListResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -1074,6 +1069,13 @@ pub struct AccountUpdatedNotification {
|
||||
pub auth_mode: Option<AuthMode>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ModelPresetsUpdatedNotification {
|
||||
pub models: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
||||
@@ -64,7 +64,7 @@ Example (from OpenAI's official VSCode extension):
|
||||
- `turn/interrupt` — request cancellation of an in-flight turn by `(thread_id, turn_id)`; success is an empty `{}` response and the turn finishes with `status: "interrupted"`.
|
||||
- `review/start` — kick off Codex’s automated reviewer for a thread; responds like `turn/start` and emits `item/started`/`item/completed` notifications with `enteredReviewMode` and `exitedReviewMode` items, plus a final assistant `agentMessage` containing the review.
|
||||
- `command/exec` — run a single command under the server sandbox without starting a thread/turn (handy for utilities and validation).
|
||||
- `model/list` — list available models (with reasoning effort options).
|
||||
- `model/list` — request the available models; responds with `{}` and asynchronously emits `model/presets/updated` containing the catalog.
|
||||
- `skills/list` — list skills for one or more `cwd` values.
|
||||
- `mcpServer/oauth/login` — start an OAuth login for a configured MCP server; returns an `authorization_url` and later emits `mcpServer/oauthLogin/completed` once the browser flow finishes.
|
||||
- `mcpServers/list` — enumerate configured MCP servers with their tools, resources, resource templates, and auth status; supports cursor+limit pagination.
|
||||
|
||||
@@ -61,6 +61,7 @@ use codex_app_server_protocol::McpServerOauthLoginParams;
|
||||
use codex_app_server_protocol::McpServerOauthLoginResponse;
|
||||
use codex_app_server_protocol::ModelListParams;
|
||||
use codex_app_server_protocol::ModelListResponse;
|
||||
use codex_app_server_protocol::ModelPresetsUpdatedNotification;
|
||||
use codex_app_server_protocol::NewConversationParams;
|
||||
use codex_app_server_protocol::NewConversationResponse;
|
||||
use codex_app_server_protocol::RemoveConversationListenerParams;
|
||||
@@ -231,6 +232,19 @@ pub(crate) enum ApiVersion {
|
||||
V2,
|
||||
}
|
||||
|
||||
fn spawn_model_presets_notification(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
conversation_manager: Arc<ConversationManager>,
|
||||
config: Arc<Config>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let models = supported_models(conversation_manager, &config).await;
|
||||
let notification =
|
||||
ServerNotification::ModelPresetsUpdated(ModelPresetsUpdatedNotification { models });
|
||||
outgoing.send_server_notification(notification).await;
|
||||
});
|
||||
}
|
||||
|
||||
impl CodexMessageProcessor {
|
||||
async fn conversation_from_thread_id(
|
||||
&self,
|
||||
@@ -281,6 +295,14 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_model_presets_notification(&self) {
|
||||
spawn_model_presets_notification(
|
||||
self.outgoing.clone(),
|
||||
self.conversation_manager.clone(),
|
||||
self.config.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
async fn load_latest_config(&self) -> Result<Config, JSONRPCErrorError> {
|
||||
Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default())
|
||||
.await
|
||||
@@ -573,6 +595,7 @@ impl CodexMessageProcessor {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
.await;
|
||||
self.spawn_model_presets_notification();
|
||||
}
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -603,6 +626,7 @@ impl CodexMessageProcessor {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AccountUpdated(payload_v2))
|
||||
.await;
|
||||
self.spawn_model_presets_notification();
|
||||
}
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -659,6 +683,8 @@ impl CodexMessageProcessor {
|
||||
let outgoing_clone = self.outgoing.clone();
|
||||
let active_login = self.active_login.clone();
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let conversation_manager = self.conversation_manager.clone();
|
||||
let config = self.config.clone();
|
||||
let auth_url = server.auth_url.clone();
|
||||
tokio::spawn(async move {
|
||||
let (success, error_msg) = match tokio::time::timeout(
|
||||
@@ -699,6 +725,11 @@ impl CodexMessageProcessor {
|
||||
payload,
|
||||
))
|
||||
.await;
|
||||
spawn_model_presets_notification(
|
||||
outgoing_clone,
|
||||
conversation_manager,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
// Clear the active login if it matches this attempt. It may have been replaced or cancelled.
|
||||
@@ -749,6 +780,8 @@ impl CodexMessageProcessor {
|
||||
let outgoing_clone = self.outgoing.clone();
|
||||
let active_login = self.active_login.clone();
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let conversation_manager = self.conversation_manager.clone();
|
||||
let config = self.config.clone();
|
||||
let auth_url = server.auth_url.clone();
|
||||
tokio::spawn(async move {
|
||||
let (success, error_msg) = match tokio::time::timeout(
|
||||
@@ -789,6 +822,11 @@ impl CodexMessageProcessor {
|
||||
payload_v2,
|
||||
))
|
||||
.await;
|
||||
spawn_model_presets_notification(
|
||||
outgoing_clone,
|
||||
conversation_manager,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
// Clear the active login if it matches this attempt. It may have been replaced or cancelled.
|
||||
@@ -908,6 +946,7 @@ impl CodexMessageProcessor {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
.await;
|
||||
self.spawn_model_presets_notification();
|
||||
}
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -928,6 +967,7 @@ impl CodexMessageProcessor {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AccountUpdated(payload_v2))
|
||||
.await;
|
||||
self.spawn_model_presets_notification();
|
||||
}
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -1893,59 +1933,11 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn list_models(&self, request_id: RequestId, params: ModelListParams) {
|
||||
let ModelListParams { limit, cursor } = params;
|
||||
let models = supported_models(self.conversation_manager.clone(), &self.config).await;
|
||||
let total = models.len();
|
||||
|
||||
if total == 0 {
|
||||
let response = ModelListResponse {
|
||||
data: Vec::new(),
|
||||
next_cursor: None,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let effective_limit = limit.unwrap_or(total as u32).max(1) as usize;
|
||||
let effective_limit = effective_limit.min(total);
|
||||
let start = match cursor {
|
||||
Some(cursor) => match cursor.parse::<usize>() {
|
||||
Ok(idx) => idx,
|
||||
Err(_) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("invalid cursor: {cursor}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => 0,
|
||||
};
|
||||
|
||||
if start > total {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("cursor {start} exceeds total models {total}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let end = start.saturating_add(effective_limit).min(total);
|
||||
let items = models[start..end].to_vec();
|
||||
let next_cursor = if end < total {
|
||||
Some(end.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let response = ModelListResponse {
|
||||
data: items,
|
||||
next_cursor,
|
||||
};
|
||||
let _ = params;
|
||||
let response = ModelListResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
self.spawn_model_presets_notification();
|
||||
}
|
||||
|
||||
async fn mcp_server_oauth_login(
|
||||
|
||||
@@ -128,6 +128,8 @@ impl MessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
self.initialized = true;
|
||||
self.codex_message_processor
|
||||
.spawn_model_presets_notification();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::to_response;
|
||||
use app_test_support::write_models_cache;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::Model;
|
||||
use codex_app_server_protocol::ModelListParams;
|
||||
use codex_app_server_protocol::ModelListResponse;
|
||||
use codex_app_server_protocol::ReasoningEffortOption;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_protocol::openai_models::ReasoningEffort;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const INVALID_REQUEST_ERROR_CODE: i64 = -32600;
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_models_returns_all_models_with_large_limit() -> Result<()> {
|
||||
async fn list_models_returns_empty_response_and_notification() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_models_cache(codex_home.path())?;
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
@@ -30,8 +29,8 @@ async fn list_models_returns_all_models_with_large_limit() -> Result<()> {
|
||||
|
||||
let request_id = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(100),
|
||||
cursor: None,
|
||||
limit: Some(1),
|
||||
cursor: Some("ignored".to_string()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
@@ -41,12 +40,24 @@ async fn list_models_returns_all_models_with_large_limit() -> Result<()> {
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: items,
|
||||
next_cursor,
|
||||
} = to_response::<ModelListResponse>(response)?;
|
||||
let ModelListResponse {} = to_response::<ModelListResponse>(response)?;
|
||||
|
||||
let expected_models = vec![
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("model/presets/updated"),
|
||||
)
|
||||
.await??;
|
||||
let server_notification: ServerNotification = notification.try_into()?;
|
||||
let ServerNotification::ModelPresetsUpdated(payload) = server_notification else {
|
||||
unreachable!("expected model/presets/updated notification");
|
||||
};
|
||||
|
||||
assert_eq!(payload.models, expected_models());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expected_models() -> Vec<Model> {
|
||||
vec![
|
||||
Model {
|
||||
id: "gpt-5.1-codex-max".to_string(),
|
||||
model: "gpt-5.1-codex-max".to_string(),
|
||||
@@ -176,156 +187,5 @@ async fn list_models_returns_all_models_with_large_limit() -> Result<()> {
|
||||
default_reasoning_effort: ReasoningEffort::Medium,
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(items, expected_models);
|
||||
assert!(next_cursor.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_models_pagination_works() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_models_cache(codex_home.path())?;
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
|
||||
timeout(DEFAULT_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let first_request = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(1),
|
||||
cursor: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let first_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(first_request)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: first_items,
|
||||
next_cursor: first_cursor,
|
||||
} = to_response::<ModelListResponse>(first_response)?;
|
||||
|
||||
assert_eq!(first_items.len(), 1);
|
||||
assert_eq!(first_items[0].id, "gpt-5.1-codex-max");
|
||||
let next_cursor = first_cursor.ok_or_else(|| anyhow!("cursor for second page"))?;
|
||||
|
||||
let second_request = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(1),
|
||||
cursor: Some(next_cursor.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let second_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(second_request)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: second_items,
|
||||
next_cursor: second_cursor,
|
||||
} = to_response::<ModelListResponse>(second_response)?;
|
||||
|
||||
assert_eq!(second_items.len(), 1);
|
||||
assert_eq!(second_items[0].id, "gpt-5.1-codex");
|
||||
let third_cursor = second_cursor.ok_or_else(|| anyhow!("cursor for third page"))?;
|
||||
|
||||
let third_request = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(1),
|
||||
cursor: Some(third_cursor.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let third_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(third_request)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: third_items,
|
||||
next_cursor: third_cursor,
|
||||
} = to_response::<ModelListResponse>(third_response)?;
|
||||
|
||||
assert_eq!(third_items.len(), 1);
|
||||
assert_eq!(third_items[0].id, "gpt-5.1-codex-mini");
|
||||
let fourth_cursor = third_cursor.ok_or_else(|| anyhow!("cursor for fourth page"))?;
|
||||
|
||||
let fourth_request = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(1),
|
||||
cursor: Some(fourth_cursor.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fourth_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(fourth_request)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: fourth_items,
|
||||
next_cursor: fourth_cursor,
|
||||
} = to_response::<ModelListResponse>(fourth_response)?;
|
||||
|
||||
assert_eq!(fourth_items.len(), 1);
|
||||
assert_eq!(fourth_items[0].id, "gpt-5.2");
|
||||
let fifth_cursor = fourth_cursor.ok_or_else(|| anyhow!("cursor for fifth page"))?;
|
||||
|
||||
let fifth_request = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: Some(1),
|
||||
cursor: Some(fifth_cursor.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fifth_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(fifth_request)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let ModelListResponse {
|
||||
data: fifth_items,
|
||||
next_cursor: fifth_cursor,
|
||||
} = to_response::<ModelListResponse>(fifth_response)?;
|
||||
|
||||
assert_eq!(fifth_items.len(), 1);
|
||||
assert_eq!(fifth_items[0].id, "gpt-5.1");
|
||||
assert!(fifth_cursor.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_models_rejects_invalid_cursor() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_models_cache(codex_home.path())?;
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
|
||||
timeout(DEFAULT_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_list_models_request(ModelListParams {
|
||||
limit: None,
|
||||
cursor: Some("invalid".to_string()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let error: JSONRPCError = timeout(
|
||||
DEFAULT_TIMEOUT,
|
||||
mcp.read_stream_until_error_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
assert_eq!(error.id, RequestId::Integer(request_id));
|
||||
assert_eq!(error.error.code, INVALID_REQUEST_ERROR_CODE);
|
||||
assert_eq!(error.error.message, "invalid cursor: invalid");
|
||||
Ok(())
|
||||
]
|
||||
}
|
||||
|
||||
369
codex-rs/apply-patch/src/invocation.rs
Normal file
369
codex-rs/apply-patch/src/invocation.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Query;
|
||||
use tree_sitter::QueryCursor;
|
||||
use tree_sitter::StreamingIterator;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
use crate::ApplyPatchAction;
|
||||
use crate::ApplyPatchArgs;
|
||||
use crate::ApplyPatchError;
|
||||
use crate::ApplyPatchFileChange;
|
||||
use crate::ApplyPatchFileUpdate;
|
||||
use crate::IoError;
|
||||
use crate::MaybeApplyPatchVerified;
|
||||
use crate::parser::Hunk;
|
||||
use crate::parser::ParseError;
|
||||
use crate::parser::parse_patch;
|
||||
use crate::unified_diff_from_chunks;
|
||||
use std::str::Utf8Error;
|
||||
use tree_sitter::LanguageError;
|
||||
|
||||
const APPLY_PATCH_COMMANDS: [&str; 2] = ["apply_patch", "applypatch"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ApplyPatchShell {
|
||||
Unix,
|
||||
PowerShell,
|
||||
Cmd,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum MaybeApplyPatch {
|
||||
Body(ApplyPatchArgs),
|
||||
ShellParseError(ExtractHeredocError),
|
||||
PatchParseError(ParseError),
|
||||
NotApplyPatch,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ExtractHeredocError {
|
||||
CommandDidNotStartWithApplyPatch,
|
||||
FailedToLoadBashGrammar(LanguageError),
|
||||
HeredocNotUtf8(Utf8Error),
|
||||
FailedToParsePatchIntoAst,
|
||||
FailedToFindHeredocBody,
|
||||
}
|
||||
|
||||
fn classify_shell_name(shell: &str) -> Option<String> {
|
||||
std::path::Path::new(shell)
|
||||
.file_stem()
|
||||
.and_then(|name| name.to_str())
|
||||
.map(str::to_ascii_lowercase)
|
||||
}
|
||||
|
||||
fn classify_shell(shell: &str, flag: &str) -> Option<ApplyPatchShell> {
|
||||
classify_shell_name(shell).and_then(|name| match name.as_str() {
|
||||
"bash" | "zsh" | "sh" if matches!(flag, "-lc" | "-c") => Some(ApplyPatchShell::Unix),
|
||||
"pwsh" | "powershell" if flag.eq_ignore_ascii_case("-command") => {
|
||||
Some(ApplyPatchShell::PowerShell)
|
||||
}
|
||||
"cmd" if flag.eq_ignore_ascii_case("/c") => Some(ApplyPatchShell::Cmd),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn can_skip_flag(shell: &str, flag: &str) -> bool {
|
||||
classify_shell_name(shell).is_some_and(|name| {
|
||||
matches!(name.as_str(), "pwsh" | "powershell") && flag.eq_ignore_ascii_case("-noprofile")
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_shell_script(argv: &[String]) -> Option<(ApplyPatchShell, &str)> {
|
||||
match argv {
|
||||
[shell, flag, script] => classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
}),
|
||||
[shell, skip_flag, flag, script] if can_skip_flag(shell, skip_flag) => {
|
||||
classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_apply_patch_from_shell(
|
||||
shell: ApplyPatchShell,
|
||||
script: &str,
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
match shell {
|
||||
ApplyPatchShell::Unix | ApplyPatchShell::PowerShell | ApplyPatchShell::Cmd => {
|
||||
extract_apply_patch_from_bash(script)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make private once we remove tests in lib.rs
|
||||
pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
match argv {
|
||||
// Direct invocation: apply_patch <patch>
|
||||
[cmd, body] if APPLY_PATCH_COMMANDS.contains(&cmd.as_str()) => match parse_patch(body) {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
// Shell heredoc form: (optional `cd <path> &&`) apply_patch <<'EOF' ...
|
||||
_ => match parse_shell_script(argv) {
|
||||
Some((shell, script)) => match extract_apply_patch_from_shell(shell, script) {
|
||||
Ok((body, workdir)) => match parse_patch(&body) {
|
||||
Ok(mut source) => {
|
||||
source.workdir = workdir;
|
||||
MaybeApplyPatch::Body(source)
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch) => {
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::ShellParseError(e),
|
||||
},
|
||||
None => MaybeApplyPatch::NotApplyPatch,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// cwd must be an absolute path so that we can resolve relative paths in the
|
||||
/// patch.
|
||||
pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified {
|
||||
// Detect a raw patch body passed directly as the command or as the body of a shell
|
||||
// script. In these cases, report an explicit error rather than applying the patch.
|
||||
if let [body] = argv
|
||||
&& parse_patch(body).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
if let Some((_, script)) = parse_shell_script(argv)
|
||||
&& parse_patch(script).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
|
||||
match maybe_parse_apply_patch(argv) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs {
|
||||
patch,
|
||||
hunks,
|
||||
workdir,
|
||||
}) => {
|
||||
let effective_cwd = workdir
|
||||
.as_ref()
|
||||
.map(|dir| {
|
||||
let path = Path::new(dir);
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
cwd.join(path)
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| cwd.to_path_buf());
|
||||
let mut changes = HashMap::new();
|
||||
for hunk in hunks {
|
||||
let path = hunk.resolve_path(&effective_cwd);
|
||||
match hunk {
|
||||
Hunk::AddFile { contents, .. } => {
|
||||
changes.insert(path, ApplyPatchFileChange::Add { content: contents });
|
||||
}
|
||||
Hunk::DeleteFile { .. } => {
|
||||
let content = match std::fs::read_to_string(&path) {
|
||||
Ok(content) => content,
|
||||
Err(e) => {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::IoError(IoError {
|
||||
context: format!("Failed to read {}", path.display()),
|
||||
source: e,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
changes.insert(path, ApplyPatchFileChange::Delete { content });
|
||||
}
|
||||
Hunk::UpdateFile {
|
||||
move_path, chunks, ..
|
||||
} => {
|
||||
let ApplyPatchFileUpdate {
|
||||
unified_diff,
|
||||
content: contents,
|
||||
} = match unified_diff_from_chunks(&path, &chunks) {
|
||||
Ok(diff) => diff,
|
||||
Err(e) => {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(e);
|
||||
}
|
||||
};
|
||||
changes.insert(
|
||||
path,
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff,
|
||||
move_path: move_path.map(|p| effective_cwd.join(p)),
|
||||
new_content: contents,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
MaybeApplyPatchVerified::Body(ApplyPatchAction {
|
||||
changes,
|
||||
patch,
|
||||
cwd: effective_cwd,
|
||||
})
|
||||
}
|
||||
MaybeApplyPatch::ShellParseError(e) => MaybeApplyPatchVerified::ShellParseError(e),
|
||||
MaybeApplyPatch::PatchParseError(e) => MaybeApplyPatchVerified::CorrectnessError(e.into()),
|
||||
MaybeApplyPatch::NotApplyPatch => MaybeApplyPatchVerified::NotApplyPatch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the heredoc body (and optional `cd` workdir) from a `bash -lc` script
|
||||
/// that invokes the apply_patch tool using a heredoc.
|
||||
///
|
||||
/// Supported top‑level forms (must be the only top‑level statement):
|
||||
/// - `apply_patch <<'EOF'\n...\nEOF`
|
||||
/// - `cd <path> && apply_patch <<'EOF'\n...\nEOF`
|
||||
///
|
||||
/// Notes about matching:
|
||||
/// - Parsed with Tree‑sitter Bash and a strict query that uses anchors so the
|
||||
/// heredoc‑redirected statement is the only top‑level statement.
|
||||
/// - The connector between `cd` and `apply_patch` must be `&&` (not `|` or `||`).
|
||||
/// - Exactly one positional `word` argument is allowed for `cd` (no flags, no quoted
|
||||
/// strings, no second argument).
|
||||
/// - The apply command is validated in‑query via `#any-of?` to allow `apply_patch`
|
||||
/// or `applypatch`.
|
||||
/// - Preceding or trailing commands (e.g., `echo ...;` or `... && echo done`) do not match.
|
||||
///
|
||||
/// Returns `(heredoc_body, Some(path))` when the `cd` variant matches, or
|
||||
/// `(heredoc_body, None)` for the direct form. Errors are returned if the script
|
||||
/// cannot be parsed or does not match the allowed patterns.
|
||||
fn extract_apply_patch_from_bash(
|
||||
src: &str,
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
// This function uses a Tree-sitter query to recognize one of two
|
||||
// whole-script forms, each expressed as a single top-level statement:
|
||||
//
|
||||
// 1. apply_patch <<'EOF'\n...\nEOF
|
||||
// 2. cd <path> && apply_patch <<'EOF'\n...\nEOF
|
||||
//
|
||||
// Key ideas when reading the query:
|
||||
// - dots (`.`) between named nodes enforces adjacency among named children and
|
||||
// anchor to the start/end of the expression.
|
||||
// - we match a single redirected_statement directly under program with leading
|
||||
// and trailing anchors (`.`). This ensures it is the only top-level statement
|
||||
// (so prefixes like `echo ...;` or suffixes like `... && echo done` do not match).
|
||||
//
|
||||
// Overall, we want to be conservative and only match the intended forms, as other
|
||||
// forms are likely to be model errors, or incorrectly interpreted by later code.
|
||||
//
|
||||
// If you're editing this query, it's helpful to start by creating a debugging binary
|
||||
// which will let you see the AST of an arbitrary bash script passed in, and optionally
|
||||
// also run an arbitrary query against the AST. This is useful for understanding
|
||||
// how tree-sitter parses the script and whether the query syntax is correct. Be sure
|
||||
// to test both positive and negative cases.
|
||||
static APPLY_PATCH_QUERY: LazyLock<Query> = LazyLock::new(|| {
|
||||
let language = BASH.into();
|
||||
#[expect(clippy::expect_used)]
|
||||
Query::new(
|
||||
&language,
|
||||
r#"
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (command
|
||||
name: (command_name (word) @apply_name) .)
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (list
|
||||
. (command
|
||||
name: (command_name (word) @cd_name) .
|
||||
argument: [
|
||||
(word) @cd_path
|
||||
(string (string_content) @cd_path)
|
||||
(raw_string) @cd_raw_string
|
||||
] .)
|
||||
"&&"
|
||||
. (command
|
||||
name: (command_name (word) @apply_name))
|
||||
.)
|
||||
(#eq? @cd_name "cd")
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
"#,
|
||||
)
|
||||
.expect("valid bash query")
|
||||
});
|
||||
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
parser
|
||||
.set_language(&lang)
|
||||
.map_err(ExtractHeredocError::FailedToLoadBashGrammar)?;
|
||||
let tree = parser
|
||||
.parse(src, None)
|
||||
.ok_or(ExtractHeredocError::FailedToParsePatchIntoAst)?;
|
||||
|
||||
let bytes = src.as_bytes();
|
||||
let root = tree.root_node();
|
||||
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut matches = cursor.matches(&APPLY_PATCH_QUERY, root, bytes);
|
||||
while let Some(m) = matches.next() {
|
||||
let mut heredoc_text: Option<String> = None;
|
||||
let mut cd_path: Option<String> = None;
|
||||
|
||||
for capture in m.captures.iter() {
|
||||
let name = APPLY_PATCH_QUERY.capture_names()[capture.index as usize];
|
||||
match name {
|
||||
"heredoc" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.trim_end_matches('\n')
|
||||
.to_string();
|
||||
heredoc_text = Some(text);
|
||||
}
|
||||
"cd_path" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.to_string();
|
||||
cd_path = Some(text);
|
||||
}
|
||||
"cd_raw_string" => {
|
||||
let raw = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
let trimmed = raw
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''))
|
||||
.unwrap_or(raw);
|
||||
cd_path = Some(trimmed.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(heredoc) = heredoc_text {
|
||||
return Ok((heredoc, cd_path));
|
||||
}
|
||||
}
|
||||
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod invocation;
|
||||
mod parser;
|
||||
mod seek_sequence;
|
||||
mod standalone_executable;
|
||||
@@ -5,8 +6,6 @@ mod standalone_executable;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::str::Utf8Error;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
@@ -17,27 +16,15 @@ use parser::UpdateFileChunk;
|
||||
pub use parser::parse_patch;
|
||||
use similar::TextDiff;
|
||||
use thiserror::Error;
|
||||
use tree_sitter::LanguageError;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Query;
|
||||
use tree_sitter::QueryCursor;
|
||||
use tree_sitter::StreamingIterator;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
pub use invocation::maybe_parse_apply_patch_verified;
|
||||
pub use standalone_executable::main;
|
||||
|
||||
use crate::invocation::ExtractHeredocError;
|
||||
|
||||
/// Detailed instructions for gpt-4.1 on how to use the `apply_patch` tool.
|
||||
pub const APPLY_PATCH_TOOL_INSTRUCTIONS: &str = include_str!("../apply_patch_tool_instructions.md");
|
||||
|
||||
const APPLY_PATCH_COMMANDS: [&str; 2] = ["apply_patch", "applypatch"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ApplyPatchShell {
|
||||
Unix,
|
||||
PowerShell,
|
||||
Cmd,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, PartialEq)]
|
||||
pub enum ApplyPatchError {
|
||||
#[error(transparent)]
|
||||
@@ -86,14 +73,6 @@ impl PartialEq for IoError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum MaybeApplyPatch {
|
||||
Body(ApplyPatchArgs),
|
||||
ShellParseError(ExtractHeredocError),
|
||||
PatchParseError(ParseError),
|
||||
NotApplyPatch,
|
||||
}
|
||||
|
||||
/// Both the raw PATCH argument to `apply_patch` as well as the PATCH argument
|
||||
/// parsed into hunks.
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -103,84 +82,6 @@ pub struct ApplyPatchArgs {
|
||||
pub workdir: Option<String>,
|
||||
}
|
||||
|
||||
fn classify_shell_name(shell: &str) -> Option<String> {
|
||||
std::path::Path::new(shell)
|
||||
.file_stem()
|
||||
.and_then(|name| name.to_str())
|
||||
.map(str::to_ascii_lowercase)
|
||||
}
|
||||
|
||||
fn classify_shell(shell: &str, flag: &str) -> Option<ApplyPatchShell> {
|
||||
classify_shell_name(shell).and_then(|name| match name.as_str() {
|
||||
"bash" | "zsh" | "sh" if matches!(flag, "-lc" | "-c") => Some(ApplyPatchShell::Unix),
|
||||
"pwsh" | "powershell" if flag.eq_ignore_ascii_case("-command") => {
|
||||
Some(ApplyPatchShell::PowerShell)
|
||||
}
|
||||
"cmd" if flag.eq_ignore_ascii_case("/c") => Some(ApplyPatchShell::Cmd),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn can_skip_flag(shell: &str, flag: &str) -> bool {
|
||||
classify_shell_name(shell).is_some_and(|name| {
|
||||
matches!(name.as_str(), "pwsh" | "powershell") && flag.eq_ignore_ascii_case("-noprofile")
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_shell_script(argv: &[String]) -> Option<(ApplyPatchShell, &str)> {
|
||||
match argv {
|
||||
[shell, flag, script] => classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
}),
|
||||
[shell, skip_flag, flag, script] if can_skip_flag(shell, skip_flag) => {
|
||||
classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_apply_patch_from_shell(
|
||||
shell: ApplyPatchShell,
|
||||
script: &str,
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
match shell {
|
||||
ApplyPatchShell::Unix | ApplyPatchShell::PowerShell | ApplyPatchShell::Cmd => {
|
||||
extract_apply_patch_from_bash(script)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
match argv {
|
||||
// Direct invocation: apply_patch <patch>
|
||||
[cmd, body] if APPLY_PATCH_COMMANDS.contains(&cmd.as_str()) => match parse_patch(body) {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
// Shell heredoc form: (optional `cd <path> &&`) apply_patch <<'EOF' ...
|
||||
_ => match parse_shell_script(argv) {
|
||||
Some((shell, script)) => match extract_apply_patch_from_shell(shell, script) {
|
||||
Ok((body, workdir)) => match parse_patch(&body) {
|
||||
Ok(mut source) => {
|
||||
source.workdir = workdir;
|
||||
MaybeApplyPatch::Body(source)
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch) => {
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::ShellParseError(e),
|
||||
},
|
||||
None => MaybeApplyPatch::NotApplyPatch,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ApplyPatchFileChange {
|
||||
Add {
|
||||
@@ -269,256 +170,6 @@ impl ApplyPatchAction {
|
||||
}
|
||||
}
|
||||
|
||||
/// cwd must be an absolute path so that we can resolve relative paths in the
|
||||
/// patch.
|
||||
pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified {
|
||||
// Detect a raw patch body passed directly as the command or as the body of a shell
|
||||
// script. In these cases, report an explicit error rather than applying the patch.
|
||||
if let [body] = argv
|
||||
&& parse_patch(body).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
if let Some((_, script)) = parse_shell_script(argv)
|
||||
&& parse_patch(script).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
|
||||
match maybe_parse_apply_patch(argv) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs {
|
||||
patch,
|
||||
hunks,
|
||||
workdir,
|
||||
}) => {
|
||||
let effective_cwd = workdir
|
||||
.as_ref()
|
||||
.map(|dir| {
|
||||
let path = Path::new(dir);
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
cwd.join(path)
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| cwd.to_path_buf());
|
||||
let mut changes = HashMap::new();
|
||||
for hunk in hunks {
|
||||
let path = hunk.resolve_path(&effective_cwd);
|
||||
match hunk {
|
||||
Hunk::AddFile { contents, .. } => {
|
||||
changes.insert(path, ApplyPatchFileChange::Add { content: contents });
|
||||
}
|
||||
Hunk::DeleteFile { .. } => {
|
||||
let content = match std::fs::read_to_string(&path) {
|
||||
Ok(content) => content,
|
||||
Err(e) => {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::IoError(IoError {
|
||||
context: format!("Failed to read {}", path.display()),
|
||||
source: e,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
changes.insert(path, ApplyPatchFileChange::Delete { content });
|
||||
}
|
||||
Hunk::UpdateFile {
|
||||
move_path, chunks, ..
|
||||
} => {
|
||||
let ApplyPatchFileUpdate {
|
||||
unified_diff,
|
||||
content: contents,
|
||||
} = match unified_diff_from_chunks(&path, &chunks) {
|
||||
Ok(diff) => diff,
|
||||
Err(e) => {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(e);
|
||||
}
|
||||
};
|
||||
changes.insert(
|
||||
path,
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff,
|
||||
move_path: move_path.map(|p| effective_cwd.join(p)),
|
||||
new_content: contents,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
MaybeApplyPatchVerified::Body(ApplyPatchAction {
|
||||
changes,
|
||||
patch,
|
||||
cwd: effective_cwd,
|
||||
})
|
||||
}
|
||||
MaybeApplyPatch::ShellParseError(e) => MaybeApplyPatchVerified::ShellParseError(e),
|
||||
MaybeApplyPatch::PatchParseError(e) => MaybeApplyPatchVerified::CorrectnessError(e.into()),
|
||||
MaybeApplyPatch::NotApplyPatch => MaybeApplyPatchVerified::NotApplyPatch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the heredoc body (and optional `cd` workdir) from a `bash -lc` script
|
||||
/// that invokes the apply_patch tool using a heredoc.
|
||||
///
|
||||
/// Supported top‑level forms (must be the only top‑level statement):
|
||||
/// - `apply_patch <<'EOF'\n...\nEOF`
|
||||
/// - `cd <path> && apply_patch <<'EOF'\n...\nEOF`
|
||||
///
|
||||
/// Notes about matching:
|
||||
/// - Parsed with Tree‑sitter Bash and a strict query that uses anchors so the
|
||||
/// heredoc‑redirected statement is the only top‑level statement.
|
||||
/// - The connector between `cd` and `apply_patch` must be `&&` (not `|` or `||`).
|
||||
/// - Exactly one positional `word` argument is allowed for `cd` (no flags, no quoted
|
||||
/// strings, no second argument).
|
||||
/// - The apply command is validated in‑query via `#any-of?` to allow `apply_patch`
|
||||
/// or `applypatch`.
|
||||
/// - Preceding or trailing commands (e.g., `echo ...;` or `... && echo done`) do not match.
|
||||
///
|
||||
/// Returns `(heredoc_body, Some(path))` when the `cd` variant matches, or
|
||||
/// `(heredoc_body, None)` for the direct form. Errors are returned if the script
|
||||
/// cannot be parsed or does not match the allowed patterns.
|
||||
fn extract_apply_patch_from_bash(
|
||||
src: &str,
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
// This function uses a Tree-sitter query to recognize one of two
|
||||
// whole-script forms, each expressed as a single top-level statement:
|
||||
//
|
||||
// 1. apply_patch <<'EOF'\n...\nEOF
|
||||
// 2. cd <path> && apply_patch <<'EOF'\n...\nEOF
|
||||
//
|
||||
// Key ideas when reading the query:
|
||||
// - dots (`.`) between named nodes enforces adjacency among named children and
|
||||
// anchor to the start/end of the expression.
|
||||
// - we match a single redirected_statement directly under program with leading
|
||||
// and trailing anchors (`.`). This ensures it is the only top-level statement
|
||||
// (so prefixes like `echo ...;` or suffixes like `... && echo done` do not match).
|
||||
//
|
||||
// Overall, we want to be conservative and only match the intended forms, as other
|
||||
// forms are likely to be model errors, or incorrectly interpreted by later code.
|
||||
//
|
||||
// If you're editing this query, it's helpful to start by creating a debugging binary
|
||||
// which will let you see the AST of an arbitrary bash script passed in, and optionally
|
||||
// also run an arbitrary query against the AST. This is useful for understanding
|
||||
// how tree-sitter parses the script and whether the query syntax is correct. Be sure
|
||||
// to test both positive and negative cases.
|
||||
static APPLY_PATCH_QUERY: LazyLock<Query> = LazyLock::new(|| {
|
||||
let language = BASH.into();
|
||||
#[expect(clippy::expect_used)]
|
||||
Query::new(
|
||||
&language,
|
||||
r#"
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (command
|
||||
name: (command_name (word) @apply_name) .)
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
|
||||
(
|
||||
program
|
||||
. (redirected_statement
|
||||
body: (list
|
||||
. (command
|
||||
name: (command_name (word) @cd_name) .
|
||||
argument: [
|
||||
(word) @cd_path
|
||||
(string (string_content) @cd_path)
|
||||
(raw_string) @cd_raw_string
|
||||
] .)
|
||||
"&&"
|
||||
. (command
|
||||
name: (command_name (word) @apply_name))
|
||||
.)
|
||||
(#eq? @cd_name "cd")
|
||||
(#any-of? @apply_name "apply_patch" "applypatch")
|
||||
redirect: (heredoc_redirect
|
||||
. (heredoc_start)
|
||||
. (heredoc_body) @heredoc
|
||||
. (heredoc_end)
|
||||
.))
|
||||
.)
|
||||
"#,
|
||||
)
|
||||
.expect("valid bash query")
|
||||
});
|
||||
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
parser
|
||||
.set_language(&lang)
|
||||
.map_err(ExtractHeredocError::FailedToLoadBashGrammar)?;
|
||||
let tree = parser
|
||||
.parse(src, None)
|
||||
.ok_or(ExtractHeredocError::FailedToParsePatchIntoAst)?;
|
||||
|
||||
let bytes = src.as_bytes();
|
||||
let root = tree.root_node();
|
||||
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut matches = cursor.matches(&APPLY_PATCH_QUERY, root, bytes);
|
||||
while let Some(m) = matches.next() {
|
||||
let mut heredoc_text: Option<String> = None;
|
||||
let mut cd_path: Option<String> = None;
|
||||
|
||||
for capture in m.captures.iter() {
|
||||
let name = APPLY_PATCH_QUERY.capture_names()[capture.index as usize];
|
||||
match name {
|
||||
"heredoc" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.trim_end_matches('\n')
|
||||
.to_string();
|
||||
heredoc_text = Some(text);
|
||||
}
|
||||
"cd_path" => {
|
||||
let text = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?
|
||||
.to_string();
|
||||
cd_path = Some(text);
|
||||
}
|
||||
"cd_raw_string" => {
|
||||
let raw = capture
|
||||
.node
|
||||
.utf8_text(bytes)
|
||||
.map_err(ExtractHeredocError::HeredocNotUtf8)?;
|
||||
let trimmed = raw
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''))
|
||||
.unwrap_or(raw);
|
||||
cd_path = Some(trimmed.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(heredoc) = heredoc_text {
|
||||
return Ok((heredoc, cd_path));
|
||||
}
|
||||
}
|
||||
|
||||
Err(ExtractHeredocError::CommandDidNotStartWithApplyPatch)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ExtractHeredocError {
|
||||
CommandDidNotStartWithApplyPatch,
|
||||
FailedToLoadBashGrammar(LanguageError),
|
||||
HeredocNotUtf8(Utf8Error),
|
||||
FailedToParsePatchIntoAst,
|
||||
FailedToFindHeredocBody,
|
||||
}
|
||||
|
||||
/// Applies the patch and prints the result to stdout/stderr.
|
||||
pub fn apply_patch(
|
||||
patch: &str,
|
||||
@@ -893,6 +544,9 @@ pub fn print_summary(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::invocation::MaybeApplyPatch;
|
||||
use crate::invocation::maybe_parse_apply_patch;
|
||||
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
@@ -58,7 +58,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
|
||||
@@ -181,7 +181,7 @@ mod tests {
|
||||
use opentelemetry::trace::TracerProvider;
|
||||
use opentelemetry_sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry_sdk::trace::SdkTracerProvider;
|
||||
use tracing::info_span;
|
||||
use tracing::trace_span;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
@@ -195,7 +195,7 @@ mod tests {
|
||||
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
|
||||
let _guard = subscriber.set_default();
|
||||
|
||||
let span = info_span!("client_request");
|
||||
let span = trace_span!("client_request");
|
||||
let _entered = span.enter();
|
||||
let span_context = span.context().span().span_context().clone();
|
||||
|
||||
|
||||
@@ -66,8 +66,8 @@ use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::field;
|
||||
use tracing::info;
|
||||
use tracing::info_span;
|
||||
use tracing::instrument;
|
||||
use tracing::trace_span;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
@@ -2150,6 +2150,16 @@ pub(crate) async fn run_task(
|
||||
if input.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let auto_compact_limit = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.auto_compact_token_limit()
|
||||
.unwrap_or(i64::MAX);
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
if total_usage_tokens >= auto_compact_limit {
|
||||
run_auto_compact(&sess, &turn_context).await;
|
||||
}
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
@@ -2232,25 +2242,12 @@ pub(crate) async fn run_task(
|
||||
needs_follow_up,
|
||||
last_agent_message: turn_last_agent_message,
|
||||
} = turn_output;
|
||||
let limit = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.auto_compact_token_limit()
|
||||
.unwrap_or(i64::MAX);
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
let token_limit_reached = total_usage_tokens >= limit;
|
||||
let token_limit_reached = total_usage_tokens >= auto_compact_limit;
|
||||
|
||||
// as long as compaction works well in getting us way below the token limit, we shouldn't worry about being in an infinite loop.
|
||||
if token_limit_reached {
|
||||
if should_use_remote_compact_task(
|
||||
sess.as_ref(),
|
||||
&turn_context.client.get_provider(),
|
||||
) {
|
||||
run_inline_remote_auto_compact_task(sess.clone(), turn_context.clone())
|
||||
.await;
|
||||
} else {
|
||||
run_inline_auto_compact_task(sess.clone(), turn_context.clone()).await;
|
||||
}
|
||||
if token_limit_reached && needs_follow_up {
|
||||
run_auto_compact(&sess, &turn_context).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -2292,7 +2289,15 @@ pub(crate) async fn run_task(
|
||||
last_agent_message
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>) {
|
||||
if should_use_remote_compact_task(sess.as_ref(), &turn_context.client.get_provider()) {
|
||||
run_inline_remote_auto_compact_task(Arc::clone(sess), Arc::clone(turn_context)).await;
|
||||
} else {
|
||||
run_inline_auto_compact_task(Arc::clone(sess), Arc::clone(turn_context)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
@@ -2432,7 +2437,7 @@ async fn drain_in_flight(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(
|
||||
#[instrument(level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
@@ -2461,7 +2466,7 @@ async fn try_run_turn(
|
||||
.client
|
||||
.clone()
|
||||
.stream(prompt)
|
||||
.instrument(info_span!("stream_request"))
|
||||
.instrument(trace_span!("stream_request"))
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
@@ -2477,9 +2482,9 @@ async fn try_run_turn(
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
let mut active_item: Option<TurnItem> = None;
|
||||
let mut should_emit_turn_diff = false;
|
||||
let receiving_span = info_span!("receiving_stream");
|
||||
let receiving_span = trace_span!("receiving_stream");
|
||||
let outcome: CodexResult<TurnRunResult> = loop {
|
||||
let handle_responses = info_span!(
|
||||
let handle_responses = trace_span!(
|
||||
parent: &receiving_span,
|
||||
"handle_responses",
|
||||
otel.name = field::Empty,
|
||||
@@ -2489,7 +2494,7 @@ async fn try_run_turn(
|
||||
|
||||
let event = match stream
|
||||
.next()
|
||||
.instrument(info_span!(parent: &handle_responses, "receiving"))
|
||||
.instrument(trace_span!(parent: &handle_responses, "receiving"))
|
||||
.or_cancel(&cancellation_token)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -398,7 +398,7 @@ impl McpConnectionManager {
|
||||
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
#[instrument(skip_all)]
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
|
||||
let mut tools = HashMap::new();
|
||||
for managed_client in self.clients.values() {
|
||||
|
||||
@@ -166,30 +166,34 @@ mod tests {
|
||||
use super::create_seatbelt_command_args;
|
||||
use super::macos_dir_params;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_with_read_only_git_subpath() {
|
||||
fn create_seatbelt_args_with_read_only_git_and_codex_subpaths() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
// top-level .git and .codex directories and one without them.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
vulnerable_root,
|
||||
vulnerable_root_canonical,
|
||||
dot_git_canonical,
|
||||
dot_codex_canonical,
|
||||
empty_root,
|
||||
empty_root_canonical,
|
||||
} = populate_tmpdir(tmp.path());
|
||||
let cwd = tmp.path().join("cwd");
|
||||
fs::create_dir_all(&cwd).expect("create cwd");
|
||||
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git, root_without_git]
|
||||
writable_roots: vec![vulnerable_root, empty_root]
|
||||
.into_iter()
|
||||
.map(|p| p.try_into().unwrap())
|
||||
.collect(),
|
||||
@@ -198,23 +202,34 @@ mod tests {
|
||||
exclude_slash_tmp: true,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
&cwd,
|
||||
);
|
||||
// Create the Seatbelt command to wrap a shell command that tries to
|
||||
// write to .codex/config.toml in the vulnerable root.
|
||||
let shell_command: Vec<String> = [
|
||||
"bash",
|
||||
"-c",
|
||||
"echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"",
|
||||
"bash",
|
||||
dot_codex_canonical
|
||||
.join("config.toml")
|
||||
.to_string_lossy()
|
||||
.as_ref(),
|
||||
]
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
let args = create_seatbelt_command_args(shell_command.clone(), &policy, &cwd);
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")) (subpath (param "WRITABLE_ROOT_2"))
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_0_RO_1"))) ) (subpath (param "WRITABLE_ROOT_1")) (subpath (param "WRITABLE_ROOT_2"))
|
||||
)
|
||||
"#,
|
||||
);
|
||||
@@ -224,17 +239,26 @@ mod tests {
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
vulnerable_root_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
dot_git_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_1={}",
|
||||
dot_codex_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
root_without_git_canon.to_string_lossy()
|
||||
empty_root_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_2={}",
|
||||
cwd.canonicalize()
|
||||
.expect("canonicalize cwd")
|
||||
.to_string_lossy()
|
||||
),
|
||||
format!("-DWRITABLE_ROOT_2={}", cwd.to_string_lossy()),
|
||||
];
|
||||
|
||||
expected_args.extend(
|
||||
@@ -243,30 +267,119 @@ mod tests {
|
||||
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
|
||||
);
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
expected_args.push("--".to_string());
|
||||
expected_args.extend(shell_command);
|
||||
|
||||
assert_eq!(expected_args, args);
|
||||
|
||||
// Verify that .codex/config.toml cannot be modified under the generated
|
||||
// Seatbelt policy.
|
||||
let config_toml = dot_codex_canonical.join("config.toml");
|
||||
let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE)
|
||||
.args(&args)
|
||||
.current_dir(&cwd)
|
||||
.output()
|
||||
.expect("execute seatbelt command");
|
||||
assert_eq!(
|
||||
"sandbox_mode = \"read-only\"\n",
|
||||
String::from_utf8_lossy(&fs::read(&config_toml).expect("read config.toml")),
|
||||
"config.toml should contain its original contents because it should not have been modified"
|
||||
);
|
||||
assert!(
|
||||
!output.status.success(),
|
||||
"command to write {} should fail under seatbelt",
|
||||
&config_toml.display()
|
||||
);
|
||||
assert_eq!(
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
format!("bash: {}: Operation not permitted\n", config_toml.display()),
|
||||
);
|
||||
|
||||
// Create a similar Seatbelt command that tries to write to a file in
|
||||
// the .git folder, which should also be blocked.
|
||||
let pre_commit_hook = dot_git_canonical.join("hooks").join("pre-commit");
|
||||
let shell_command_git: Vec<String> = [
|
||||
"bash",
|
||||
"-c",
|
||||
"echo 'pwned!' > \"$1\"",
|
||||
"bash",
|
||||
pre_commit_hook.to_string_lossy().as_ref(),
|
||||
]
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
let write_hooks_file_args = create_seatbelt_command_args(shell_command_git, &policy, &cwd);
|
||||
let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE)
|
||||
.args(&write_hooks_file_args)
|
||||
.current_dir(&cwd)
|
||||
.output()
|
||||
.expect("execute seatbelt command");
|
||||
assert!(
|
||||
!fs::exists(&pre_commit_hook).expect("exists pre-commit hook"),
|
||||
"{} should not exist because it should not have been created",
|
||||
pre_commit_hook.display()
|
||||
);
|
||||
assert!(
|
||||
!output.status.success(),
|
||||
"command to write {} should fail under seatbelt",
|
||||
&pre_commit_hook.display()
|
||||
);
|
||||
assert_eq!(
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
format!(
|
||||
"bash: {}: Operation not permitted\n",
|
||||
pre_commit_hook.display()
|
||||
),
|
||||
);
|
||||
|
||||
// Verify that writing a file to the folder containing .git and .codex is allowed.
|
||||
let allowed_file = vulnerable_root_canonical.join("allowed.txt");
|
||||
let shell_command_allowed: Vec<String> = [
|
||||
"bash",
|
||||
"-c",
|
||||
"echo 'this is allowed' > \"$1\"",
|
||||
"bash",
|
||||
allowed_file.to_string_lossy().as_ref(),
|
||||
]
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
let write_allowed_file_args =
|
||||
create_seatbelt_command_args(shell_command_allowed, &policy, &cwd);
|
||||
let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE)
|
||||
.args(&write_allowed_file_args)
|
||||
.current_dir(&cwd)
|
||||
.output()
|
||||
.expect("execute seatbelt command");
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"command to write {} should succeed under seatbelt",
|
||||
&allowed_file.display()
|
||||
);
|
||||
assert_eq!(
|
||||
"this is allowed\n",
|
||||
String::from_utf8_lossy(&fs::read(&allowed_file).expect("read allowed.txt")),
|
||||
"{} should contain the written text",
|
||||
allowed_file.display()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_for_cwd_as_git_repo() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
// top-level .git and .codex directories and one without them.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
vulnerable_root,
|
||||
vulnerable_root_canonical,
|
||||
dot_git_canonical,
|
||||
dot_codex_canonical,
|
||||
..
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that does not specify any writable_roots, but does
|
||||
// use the default ones (cwd and TMPDIR) and verifies the `.git` check
|
||||
// is done properly for cwd.
|
||||
// use the default ones (cwd and TMPDIR) and verifies the `.git` and
|
||||
// `.codex` checks are done properly for cwd.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
@@ -274,11 +387,21 @@ mod tests {
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
root_with_git.as_path(),
|
||||
);
|
||||
let shell_command: Vec<String> = [
|
||||
"bash",
|
||||
"-c",
|
||||
"echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"",
|
||||
"bash",
|
||||
dot_codex_canonical
|
||||
.join("config.toml")
|
||||
.to_string_lossy()
|
||||
.as_ref(),
|
||||
]
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
let args =
|
||||
create_seatbelt_command_args(shell_command.clone(), &policy, vulnerable_root.as_path());
|
||||
|
||||
let tmpdir_env_var = std::env::var("TMPDIR")
|
||||
.ok()
|
||||
@@ -296,13 +419,13 @@ mod tests {
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry}
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_0_RO_1"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry}
|
||||
)
|
||||
"#,
|
||||
);
|
||||
@@ -312,11 +435,15 @@ mod tests {
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
vulnerable_root_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
dot_git_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_1={}",
|
||||
dot_codex_canonical.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
@@ -337,42 +464,68 @@ mod tests {
|
||||
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
|
||||
);
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
expected_args.push("--".to_string());
|
||||
expected_args.extend(shell_command);
|
||||
|
||||
assert_eq!(expected_args, args);
|
||||
}
|
||||
|
||||
struct PopulatedTmp {
|
||||
root_with_git: PathBuf,
|
||||
root_without_git: PathBuf,
|
||||
root_with_git_canon: PathBuf,
|
||||
root_with_git_git_canon: PathBuf,
|
||||
root_without_git_canon: PathBuf,
|
||||
/// Path containing a .git and .codex subfolder.
|
||||
/// For the purposes of this test, we consider this a "vulnerable" root
|
||||
/// because a bad actor could write to .git/hooks/pre-commit so an
|
||||
/// unsuspecting user would run code as privileged the next time they
|
||||
/// ran `git commit` themselves, or modified .codex/config.toml to
|
||||
/// contain `sandbox_mode = "danger-full-access"` so the agent would
|
||||
/// have full privileges the next time it ran in that repo.
|
||||
vulnerable_root: PathBuf,
|
||||
vulnerable_root_canonical: PathBuf,
|
||||
dot_git_canonical: PathBuf,
|
||||
dot_codex_canonical: PathBuf,
|
||||
|
||||
/// Path without .git or .codex subfolders.
|
||||
empty_root: PathBuf,
|
||||
/// Canonicalized version of `empty_root`.
|
||||
empty_root_canonical: PathBuf,
|
||||
}
|
||||
|
||||
fn populate_tmpdir(tmp: &Path) -> PopulatedTmp {
|
||||
let root_with_git = tmp.join("with_git");
|
||||
let root_without_git = tmp.join("no_git");
|
||||
fs::create_dir_all(&root_with_git).expect("create with_git");
|
||||
fs::create_dir_all(&root_without_git).expect("create no_git");
|
||||
fs::create_dir_all(root_with_git.join(".git")).expect("create .git");
|
||||
let vulnerable_root = tmp.join("vulnerable_root");
|
||||
fs::create_dir_all(&vulnerable_root).expect("create vulnerable_root");
|
||||
|
||||
// TODO(mbolin): Should also support the case where `.git` is a file
|
||||
// with a gitdir: ... line.
|
||||
Command::new("git")
|
||||
.arg("init")
|
||||
.arg(".")
|
||||
.current_dir(&vulnerable_root)
|
||||
.output()
|
||||
.expect("git init .");
|
||||
|
||||
fs::create_dir_all(vulnerable_root.join(".codex")).expect("create .codex");
|
||||
fs::write(
|
||||
vulnerable_root.join(".codex").join("config.toml"),
|
||||
"sandbox_mode = \"read-only\"\n",
|
||||
)
|
||||
.expect("write .codex/config.toml");
|
||||
|
||||
let empty_root = tmp.join("empty_root");
|
||||
fs::create_dir_all(&empty_root).expect("create empty_root");
|
||||
|
||||
// Ensure we have canonical paths for -D parameter matching.
|
||||
let root_with_git_canon = root_with_git.canonicalize().expect("canonicalize with_git");
|
||||
let root_with_git_git_canon = root_with_git_canon.join(".git");
|
||||
let root_without_git_canon = root_without_git
|
||||
let vulnerable_root_canonical = vulnerable_root
|
||||
.canonicalize()
|
||||
.expect("canonicalize no_git");
|
||||
.expect("canonicalize vulnerable_root");
|
||||
let dot_git_canonical = vulnerable_root_canonical.join(".git");
|
||||
let dot_codex_canonical = vulnerable_root_canonical.join(".codex");
|
||||
let empty_root_canonical = empty_root.canonicalize().expect("canonicalize empty_root");
|
||||
PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
vulnerable_root,
|
||||
vulnerable_root_canonical,
|
||||
dot_git_canonical,
|
||||
dot_codex_canonical,
|
||||
empty_root,
|
||||
empty_root_canonical,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ pub(crate) struct HandleOutputCtx {
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub(crate) async fn handle_output_item_done(
|
||||
ctx: &mut HandleOutputCtx,
|
||||
item: ResponseItem,
|
||||
|
||||
@@ -159,6 +159,7 @@ impl Session {
|
||||
for task in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(task, reason.clone()).await;
|
||||
}
|
||||
self.close_unified_exec_sessions().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
@@ -167,12 +168,18 @@ impl Session {
|
||||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
let should_close_sessions = if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
drop(active);
|
||||
if should_close_sessions {
|
||||
self.close_unified_exec_sessions().await;
|
||||
}
|
||||
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
|
||||
self.send_event(turn_context.as_ref(), event).await;
|
||||
}
|
||||
@@ -196,6 +203,13 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn close_unified_exec_sessions(&self) {
|
||||
self.services
|
||||
.unified_exec_manager
|
||||
.terminate_all_sessions()
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
|
||||
@@ -7,7 +7,7 @@ use async_trait::async_trait;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use tracing::info_span;
|
||||
use tracing::trace_span;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
@@ -30,7 +30,7 @@ impl SessionTask for RegularTask {
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
let run_task_span =
|
||||
info_span!(parent: sess.services.otel_manager.current_span(), "run_task");
|
||||
trace_span!(parent: sess.services.otel_manager.current_span(), "run_task");
|
||||
run_task(sess, ctx, input, cancellation_token)
|
||||
.instrument(run_task_span)
|
||||
.await
|
||||
|
||||
@@ -16,7 +16,6 @@ use tokio_util::sync::CancellationToken;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex_delegate::run_codex_conversation_one_shot;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::review_format::format_review_findings_block;
|
||||
use crate::review_format::render_review_output_text;
|
||||
use crate::state::TaskKind;
|
||||
@@ -78,7 +77,6 @@ async fn start_review_conversation(
|
||||
) -> Option<async_channel::Receiver<Event>> {
|
||||
let config = ctx.client.config();
|
||||
let mut sub_agent_config = config.as_ref().clone();
|
||||
sub_agent_config.sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
// Run with only reviewer rubric — drop outer user_instructions
|
||||
sub_agent_config.user_instructions = None;
|
||||
// Avoid loading project docs; reviewer only needs findings
|
||||
|
||||
@@ -70,9 +70,11 @@ pub fn format_exec_output_for_model_freeform(
|
||||
// round to 1 decimal place
|
||||
let duration_seconds = ((exec_output.duration.as_secs_f32()) * 10.0).round() / 10.0;
|
||||
|
||||
let total_lines = exec_output.aggregated_output.text.lines().count();
|
||||
let content = build_content_with_timeout(exec_output);
|
||||
|
||||
let formatted_output = truncate_text(&exec_output.aggregated_output.text, truncation_policy);
|
||||
let total_lines = content.lines().count();
|
||||
|
||||
let formatted_output = truncate_text(&content, truncation_policy);
|
||||
|
||||
let mut sections = Vec::new();
|
||||
|
||||
@@ -92,21 +94,21 @@ pub fn format_exec_output_str(
|
||||
exec_output: &ExecToolCallOutput,
|
||||
truncation_policy: TruncationPolicy,
|
||||
) -> String {
|
||||
let ExecToolCallOutput {
|
||||
aggregated_output, ..
|
||||
} = exec_output;
|
||||
|
||||
let content = aggregated_output.text.as_str();
|
||||
|
||||
let body = if exec_output.timed_out {
|
||||
format!(
|
||||
"command timed out after {} milliseconds\n{content}",
|
||||
exec_output.duration.as_millis()
|
||||
)
|
||||
} else {
|
||||
content.to_string()
|
||||
};
|
||||
let content = build_content_with_timeout(exec_output);
|
||||
|
||||
// Truncate for model consumption before serialization.
|
||||
formatted_truncate_text(&body, truncation_policy)
|
||||
formatted_truncate_text(&content, truncation_policy)
|
||||
}
|
||||
|
||||
/// Extracts exec output content and prepends a timeout message if the command timed out.
|
||||
fn build_content_with_timeout(exec_output: &ExecToolCallOutput) -> String {
|
||||
if exec_output.timed_out {
|
||||
format!(
|
||||
"command timed out after {} milliseconds\n{}",
|
||||
exec_output.duration.as_millis(),
|
||||
exec_output.aggregated_output.text
|
||||
)
|
||||
} else {
|
||||
exec_output.aggregated_output.text.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
use tracing::Instrument;
|
||||
use tracing::info_span;
|
||||
use tracing::instrument;
|
||||
use tracing::trace_span;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
@@ -45,7 +45,7 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(call = ?call))]
|
||||
#[instrument(level = "trace", skip_all, fields(call = ?call))]
|
||||
pub(crate) fn handle_tool_call(
|
||||
self,
|
||||
call: ToolCall,
|
||||
@@ -60,7 +60,7 @@ impl ToolCallRuntime {
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let started = Instant::now();
|
||||
|
||||
let dispatch_span = info_span!(
|
||||
let dispatch_span = trace_span!(
|
||||
"dispatch_tool_call",
|
||||
otel.name = call.tool_name.as_str(),
|
||||
tool_name = call.tool_name.as_str(),
|
||||
|
||||
@@ -55,7 +55,7 @@ impl ToolRouter {
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
pub async fn build_tool_call(
|
||||
session: &Session,
|
||||
item: ResponseItem,
|
||||
@@ -131,7 +131,7 @@ impl ToolRouter {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::path::PathBuf;
|
||||
#[cfg(target_os = "linux")]
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
|
||||
pub mod process;
|
||||
pub mod responses;
|
||||
pub mod streaming_sse;
|
||||
pub mod test_codex;
|
||||
|
||||
48
codex-rs/core/tests/common/process.rs
Normal file
48
codex-rs/core/tests/common/process.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use anyhow::Context;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
|
||||
let pid = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if let Ok(contents) = fs::read_to_string(path) {
|
||||
let trimmed = contents.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.context("timed out waiting for pid file")?;
|
||||
|
||||
Ok(pid)
|
||||
}
|
||||
|
||||
pub fn process_is_alive(pid: &str) -> anyhow::Result<bool> {
|
||||
let status = std::process::Command::new("kill")
|
||||
.args(["-0", pid])
|
||||
.status()
|
||||
.context("failed to probe process liveness with kill -0")?;
|
||||
Ok(status.success())
|
||||
}
|
||||
|
||||
async fn wait_for_process_exit_inner(pid: String) -> anyhow::Result<()> {
|
||||
loop {
|
||||
if !process_is_alive(&pid)? {
|
||||
return Ok(());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_for_process_exit(pid: &str) -> anyhow::Result<()> {
|
||||
let pid = pid.to_string();
|
||||
tokio::time::timeout(Duration::from_secs(2), wait_for_process_exit_inner(pid))
|
||||
.await
|
||||
.context("timed out waiting for process to exit")??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1009,7 +1009,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
ev_assistant_message("m3", AUTO_SUMMARY_TEXT),
|
||||
ev_completed_with_tokens("r3", 200),
|
||||
]);
|
||||
let sse_resume = sse(vec![ev_completed("r3-resume")]);
|
||||
let sse4 = sse(vec![
|
||||
ev_assistant_message("m4", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r4", 120),
|
||||
@@ -1038,15 +1037,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
};
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let resume_marker = prefixed_auto_summary;
|
||||
let resume_matcher = move |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(resume_marker)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body.contains(POST_AUTO_USER_MSG)
|
||||
};
|
||||
mount_sse_once_match(&server, resume_matcher, sse_resume).await;
|
||||
|
||||
let fourth_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
@@ -1106,8 +1096,8 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
let requests = get_responses_requests(&server).await;
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
5,
|
||||
"expected user turns, a compaction request, a resumed turn, and the follow-up turn; got {}",
|
||||
4,
|
||||
"expected user turns, a compaction request, and the follow-up turn; got {}",
|
||||
requests.len()
|
||||
);
|
||||
let is_auto_compact = |req: &wiremock::Request| {
|
||||
@@ -1131,19 +1121,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
"auto compact should add a third request"
|
||||
);
|
||||
|
||||
let resume_summary_marker = prefixed_auto_summary;
|
||||
let resume_index = requests
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, req)| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
(body.contains(resume_summary_marker)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body.contains(POST_AUTO_USER_MSG))
|
||||
.then_some(idx)
|
||||
})
|
||||
.expect("resume request missing after compaction");
|
||||
|
||||
let follow_up_index = requests
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1154,15 +1131,12 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
.then_some(idx)
|
||||
})
|
||||
.expect("follow-up request missing");
|
||||
assert_eq!(follow_up_index, 4, "follow-up request should be last");
|
||||
assert_eq!(follow_up_index, 3, "follow-up request should be last");
|
||||
|
||||
let body_first = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body_auto = requests[auto_compact_index]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body_resume = requests[resume_index]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body_follow_up = requests[follow_up_index]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
@@ -1201,23 +1175,6 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
"auto compact should send the summarization prompt as a user message",
|
||||
);
|
||||
|
||||
let input_resume = body_resume.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
assert!(
|
||||
input_resume.iter().any(|item| {
|
||||
item.get("type").and_then(|v| v.as_str()) == Some("message")
|
||||
&& item.get("role").and_then(|v| v.as_str()) == Some("user")
|
||||
&& item
|
||||
.get("content")
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|text| text.contains(prefixed_auto_summary))
|
||||
.unwrap_or(false)
|
||||
}),
|
||||
"resume request should include compacted history"
|
||||
);
|
||||
|
||||
let input_follow_up = body_follow_up
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
@@ -1276,6 +1233,10 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
ev_assistant_message("m3", &auto_summary_payload),
|
||||
ev_completed_with_tokens("r3", 200),
|
||||
]);
|
||||
let sse4 = sse(vec![
|
||||
ev_assistant_message("m4", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r4", 120),
|
||||
]);
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -1299,12 +1260,19 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
};
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let fourth_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, fourth_matcher, sse4).await;
|
||||
|
||||
let model_provider = non_openai_model_provider(&server);
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
set_test_compact_prompt(&mut config);
|
||||
config.model_auto_compact_token_limit = Some(200_000);
|
||||
let conversation_manager = ConversationManager::with_models_provider(
|
||||
CodexAuth::from_api_key("dummy"),
|
||||
config.model_provider.clone(),
|
||||
@@ -1335,6 +1303,16 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: POST_AUTO_USER_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex.submit(Op::Shutdown).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
@@ -1731,6 +1709,8 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
ev_assistant_message("m6", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r6", 120),
|
||||
]);
|
||||
let follow_up_user = "FOLLOW_UP_AUTO_COMPACT";
|
||||
let final_user = "FINAL_AUTO_COMPACT";
|
||||
|
||||
mount_sse_sequence(&server, vec![sse1, sse2, sse3, sse4, sse5, sse6]).await;
|
||||
|
||||
@@ -1751,31 +1731,31 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: MULTI_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut auto_compact_lifecycle_events = Vec::new();
|
||||
loop {
|
||||
let event = codex.next_event().await.unwrap();
|
||||
if event.id.starts_with("auto-compact-")
|
||||
&& matches!(
|
||||
event.msg,
|
||||
EventMsg::TaskStarted(_) | EventMsg::TaskComplete(_)
|
||||
)
|
||||
{
|
||||
auto_compact_lifecycle_events.push(event);
|
||||
continue;
|
||||
}
|
||||
if let EventMsg::TaskComplete(_) = &event.msg
|
||||
&& !event.id.starts_with("auto-compact-")
|
||||
{
|
||||
break;
|
||||
for user in [MULTI_AUTO_MSG, follow_up_user, final_user] {
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text { text: user.into() }],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.unwrap();
|
||||
if event.id.starts_with("auto-compact-")
|
||||
&& matches!(
|
||||
event.msg,
|
||||
EventMsg::TaskStarted(_) | EventMsg::TaskComplete(_)
|
||||
)
|
||||
{
|
||||
auto_compact_lifecycle_events.push(event);
|
||||
continue;
|
||||
}
|
||||
if let EventMsg::TaskComplete(_) = &event.msg
|
||||
&& !event.id.starts_with("auto-compact-")
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1821,6 +1801,7 @@ async fn auto_compact_triggers_after_function_call_over_95_percent_usage() {
|
||||
let context_window = 100;
|
||||
let limit = context_window * 90 / 100;
|
||||
let over_limit_tokens = context_window * 95 / 100 + 1;
|
||||
let follow_up_user = "FOLLOW_UP_AFTER_LIMIT";
|
||||
|
||||
let first_turn = sse(vec![
|
||||
ev_function_call(DUMMY_CALL_ID, DUMMY_FUNCTION_NAME, "{}"),
|
||||
@@ -1873,6 +1854,17 @@ async fn auto_compact_triggers_after_function_call_over_95_percent_usage() {
|
||||
|
||||
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: follow_up_user.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Assert first request captured expected user message that triggers function call.
|
||||
let first_request = first_turn_mock.single_request().input();
|
||||
assert!(
|
||||
@@ -1916,6 +1908,7 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
|
||||
let first_user = "COUNT_PRE_LAST_REASONING";
|
||||
let second_user = "TRIGGER_COMPACT_AT_LIMIT";
|
||||
let third_user = "AFTER_REMOTE_COMPACT";
|
||||
|
||||
let pre_last_reasoning_content = "a".repeat(2_400);
|
||||
let post_last_reasoning_content = "b".repeat(4_000);
|
||||
@@ -1928,7 +1921,7 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
ev_reasoning_item("post-reasoning", &["post"], &[&post_last_reasoning_content]),
|
||||
ev_completed_with_tokens("r2", 80),
|
||||
]);
|
||||
let resume_turn = sse(vec![
|
||||
let third_turn = sse(vec![
|
||||
ev_assistant_message("m4", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r4", 1),
|
||||
]);
|
||||
@@ -1940,8 +1933,8 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
first_turn,
|
||||
// Turn 2: reasoning after last user (should be ignored for compaction).
|
||||
second_turn,
|
||||
// Turn 3: resume after remote compaction.
|
||||
resume_turn,
|
||||
// Turn 3: next user turn after remote compaction.
|
||||
third_turn,
|
||||
],
|
||||
)
|
||||
.await;
|
||||
@@ -1973,7 +1966,10 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
.expect("build codex")
|
||||
.codex;
|
||||
|
||||
for (idx, user) in [first_user, second_user].into_iter().enumerate() {
|
||||
for (idx, user) in [first_user, second_user, third_user]
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
{
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text { text: user.into() }],
|
||||
@@ -1982,10 +1978,10 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
if idx == 0 {
|
||||
if idx < 2 {
|
||||
assert!(
|
||||
compact_mock.requests().is_empty(),
|
||||
"remote compaction should not run after the first turn"
|
||||
"remote compaction should not run before the next user turn"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -2006,20 +2002,21 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"conversation should include two user turns and a post-compaction resume"
|
||||
"conversation should include three user turns"
|
||||
);
|
||||
let second_request_body = requests[1].body_json().to_string();
|
||||
assert!(
|
||||
!second_request_body.contains("REMOTE_COMPACT_SUMMARY"),
|
||||
"second turn should not include compacted history"
|
||||
);
|
||||
let resume_body = requests[2].body_json().to_string();
|
||||
let third_request_body = requests[2].body_json().to_string();
|
||||
assert!(
|
||||
resume_body.contains("REMOTE_COMPACT_SUMMARY") || resume_body.contains(FINAL_REPLY),
|
||||
"resume request should follow remote compact and use compacted history"
|
||||
third_request_body.contains("REMOTE_COMPACT_SUMMARY")
|
||||
|| third_request_body.contains(FINAL_REPLY),
|
||||
"third turn should include compacted history"
|
||||
);
|
||||
assert!(
|
||||
resume_body.contains("ENCRYPTED_COMPACTION_SUMMARY"),
|
||||
"resume request should include compaction summary item"
|
||||
third_request_body.contains("ENCRYPTED_COMPACTION_SUMMARY"),
|
||||
"third turn should include compaction summary item"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use std::sync::Mutex;
|
||||
use tracing::Level;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use tracing_subscriber::fmt::format::FmtSpan;
|
||||
@@ -454,6 +455,7 @@ async fn handle_responses_span_records_response_kind_and_tool_name() {
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.with_level(true)
|
||||
.with_ansi(false)
|
||||
.with_max_level(Level::TRACE)
|
||||
.with_span_events(FmtSpan::FULL)
|
||||
.with_writer(MockWriter::new(buffer))
|
||||
.finish();
|
||||
@@ -517,6 +519,7 @@ async fn record_responses_sets_span_fields_for_response_events() {
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.with_level(true)
|
||||
.with_ansi(false)
|
||||
.with_max_level(Level::TRACE)
|
||||
.with_span_events(FmtSpan::FULL)
|
||||
.with_writer(MockWriter::new(buffer))
|
||||
.finish();
|
||||
|
||||
@@ -580,10 +580,6 @@ async fn review_input_isolated_from_parent_history() {
|
||||
review_prompt,
|
||||
"user message should only contain the raw review prompt"
|
||||
);
|
||||
assert!(
|
||||
env_text.contains("<sandbox_mode>read-only</sandbox_mode>"),
|
||||
"review environment context must run with read-only sandbox"
|
||||
);
|
||||
|
||||
// Ensure the REVIEW_PROMPT rubric is sent via instructions.
|
||||
let instructions = body["instructions"].as_str().expect("instructions string");
|
||||
|
||||
@@ -13,10 +13,15 @@ use core_test_support::test_codex::TestCodexHarness;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use serde_json::json;
|
||||
|
||||
fn shell_responses(call_id: &str, command: &str, login: Option<bool>) -> Vec<String> {
|
||||
fn shell_responses_with_timeout(
|
||||
call_id: &str,
|
||||
command: &str,
|
||||
login: Option<bool>,
|
||||
timeout_ms: i64,
|
||||
) -> Vec<String> {
|
||||
let args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 2_000,
|
||||
"timeout_ms": timeout_ms,
|
||||
"login": login,
|
||||
});
|
||||
|
||||
@@ -36,6 +41,10 @@ fn shell_responses(call_id: &str, command: &str, login: Option<bool>) -> Vec<Str
|
||||
]
|
||||
}
|
||||
|
||||
fn shell_responses(call_id: &str, command: &str, login: Option<bool>) -> Vec<String> {
|
||||
shell_responses_with_timeout(call_id, command, login, 2_000)
|
||||
}
|
||||
|
||||
async fn shell_command_harness_with(
|
||||
configure: impl FnOnce(TestCodexBuilder) -> TestCodexBuilder,
|
||||
) -> Result<TestCodexHarness> {
|
||||
@@ -54,6 +63,20 @@ async fn mount_shell_responses(
|
||||
mount_sse_sequence(harness.server(), shell_responses(call_id, command, login)).await;
|
||||
}
|
||||
|
||||
async fn mount_shell_responses_with_timeout(
|
||||
harness: &TestCodexHarness,
|
||||
call_id: &str,
|
||||
command: &str,
|
||||
login: Option<bool>,
|
||||
timeout_ms: i64,
|
||||
) {
|
||||
mount_sse_sequence(
|
||||
harness.server(),
|
||||
shell_responses_with_timeout(call_id, command, login, timeout_ms),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn assert_shell_command_output(output: &str, expected: &str) -> Result<()> {
|
||||
let normalized_output = output
|
||||
.replace("\r\n", "\n")
|
||||
@@ -172,3 +195,32 @@ async fn pipe_output_without_login() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_command_times_out_with_timeout_ms() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let harness = shell_command_harness_with(|builder| builder.with_model("gpt-5.1")).await?;
|
||||
|
||||
let call_id = "shell-command-timeout";
|
||||
let command = if cfg!(windows) {
|
||||
"timeout /t 5"
|
||||
} else {
|
||||
"sleep 5"
|
||||
};
|
||||
mount_shell_responses_with_timeout(&harness, call_id, command, None, 200).await;
|
||||
harness
|
||||
.submit("run a long command with a short timeout")
|
||||
.await?;
|
||||
|
||||
let output = harness.function_call_stdout(call_id).await;
|
||||
let normalized_output = output
|
||||
.replace("\r\n", "\n")
|
||||
.replace('\r', "\n")
|
||||
.trim_end_matches('\n')
|
||||
.to_string();
|
||||
let expected_pattern = r"(?s)^Exit code: 124\nWall time: [0-9]+(?:\.[0-9]+)? seconds\nOutput:\ncommand timed out after [0-9]+ milliseconds\n?$";
|
||||
assert_regex_match(expected_pattern, &normalized_output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::process::wait_for_pid_file;
|
||||
use core_test_support::process::wait_for_process_exit;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
@@ -31,6 +33,7 @@ use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use pretty_assertions::assert_eq;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
@@ -1640,6 +1643,111 @@ async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<()
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_closes_long_running_session_at_turn_end() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
skip_if_windows!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let pid_path = temp_dir.path().join("uexec_pid");
|
||||
let pid_path_str = pid_path.to_string_lossy();
|
||||
|
||||
let call_id = "uexec-long-running";
|
||||
let command = format!("printf '%s' $$ > '{pid_path_str}' && exec sleep 3000");
|
||||
let args = json!({
|
||||
"cmd": command,
|
||||
"yield_time_ms": 250,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "close unified exec sessions on turn end".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_match(&codex, |msg| match msg {
|
||||
EventMsg::ExecCommandBegin(ev) if ev.call_id == call_id => Some(ev.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let begin_process_id = begin_event
|
||||
.process_id
|
||||
.clone()
|
||||
.expect("expected process_id for long-running unified exec session");
|
||||
|
||||
let pid = wait_for_pid_file(&pid_path).await?;
|
||||
assert!(
|
||||
pid.chars().all(|ch| ch.is_ascii_digit()),
|
||||
"expected numeric pid, got {pid:?}"
|
||||
);
|
||||
|
||||
let mut end_event = None;
|
||||
let mut task_complete = false;
|
||||
loop {
|
||||
let msg = wait_for_event(&codex, |_| true).await;
|
||||
match msg {
|
||||
EventMsg::ExecCommandEnd(ev) if ev.call_id == call_id => end_event = Some(ev),
|
||||
EventMsg::TaskComplete(_) => task_complete = true,
|
||||
_ => {}
|
||||
}
|
||||
if task_complete && end_event.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let end_event = end_event.expect("expected ExecCommandEnd event for unified exec session");
|
||||
assert_eq!(end_event.call_id, call_id);
|
||||
let end_process_id = end_event
|
||||
.process_id
|
||||
.clone()
|
||||
.expect("expected process_id in unified exec end event");
|
||||
assert_eq!(end_process_id, begin_process_id);
|
||||
|
||||
wait_for_process_exit(&pid).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -20,6 +20,7 @@ At a glance:
|
||||
- Configuration and info
|
||||
- `getUserSavedConfig`, `setDefaultModel`, `getUserAgent`, `userInfo`
|
||||
- `model/list` → enumerate available models and reasoning options
|
||||
- notifications: `model/presets/updated`
|
||||
- Auth
|
||||
- `account/read`, `account/login/start`, `account/login/cancel`, `account/logout`, `account/rateLimits/read`
|
||||
- notifications: `account/login/completed`, `account/updated`, `account/rateLimits/updated`
|
||||
@@ -78,21 +79,21 @@ List/resume/archive: `listConversations`, `resumeConversation`, `archiveConversa
|
||||
|
||||
## Models
|
||||
|
||||
Fetch the catalog of models available in the current Codex build with `model/list`. The request accepts optional pagination inputs:
|
||||
Request the catalog of models available in the current Codex build with `model/list`. The request accepts optional pagination inputs (currently ignored by the server):
|
||||
|
||||
- `pageSize` – number of models to return (defaults to a server-selected value)
|
||||
- `pageSize` – number of models to return
|
||||
- `cursor` – opaque string from the previous response’s `nextCursor`
|
||||
|
||||
Each response yields:
|
||||
The response is an empty JSON object `{}`. The server asynchronously emits a
|
||||
`model/presets/updated` notification containing the full model catalog. The payload is:
|
||||
|
||||
- `items` – ordered list of models. A model includes:
|
||||
- `models` – the full list of available models. Each model includes:
|
||||
- `id`, `model`, `displayName`, `description`
|
||||
- `supportedReasoningEfforts` – array of objects with:
|
||||
- `reasoningEffort` – one of `minimal|low|medium|high`
|
||||
- `description` – human-friendly label for the effort
|
||||
- `defaultReasoningEffort` – suggested effort for the UI
|
||||
- `isDefault` – whether the model is recommended for most users
|
||||
- `nextCursor` – pass into the next request to continue paging (optional)
|
||||
|
||||
## Event stream
|
||||
|
||||
@@ -100,6 +101,7 @@ While a conversation runs, the server sends notifications:
|
||||
|
||||
- `codex/event` with the serialized Codex event payload. The shape matches `core/src/protocol.rs`’s `Event` and `EventMsg` types. Some notifications include a `_meta.requestId` to correlate with the originating request.
|
||||
- Auth notifications via method names `loginChatGptComplete` and `authStatusChange`.
|
||||
- Model catalog notifications via method name `model/presets/updated`.
|
||||
|
||||
Clients should render events and, when present, surface approval requests (see next section).
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ use std::time::Instant;
|
||||
use strum_macros::Display;
|
||||
use tokio::time::error::Elapsed;
|
||||
use tracing::Span;
|
||||
use tracing::info_span;
|
||||
use tracing::trace_span;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Display)]
|
||||
@@ -67,7 +67,7 @@ impl OtelManager {
|
||||
terminal_type: String,
|
||||
session_source: SessionSource,
|
||||
) -> OtelManager {
|
||||
let session_span = info_span!("new_session", conversation_id = %conversation_id, session_source = %session_source);
|
||||
let session_span = trace_span!("new_session", conversation_id = %conversation_id, session_source = %session_source);
|
||||
|
||||
if let Some(context) = traceparent_context_from_env() {
|
||||
session_span.set_parent(context);
|
||||
|
||||
@@ -134,7 +134,7 @@ impl OtelProvider {
|
||||
self.tracer.as_ref().map(|tracer| {
|
||||
tracing_opentelemetry::layer()
|
||||
.with_tracer(tracer.clone())
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(LevelFilter::TRACE)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -306,12 +306,14 @@ pub enum SandboxPolicy {
|
||||
|
||||
/// A writable root path accompanied by a list of subpaths that should remain
|
||||
/// read‑only even when the root is writable. This is primarily used to ensure
|
||||
/// top‑level VCS metadata directories (e.g. `.git`) under a writable root are
|
||||
/// not modified by the agent.
|
||||
/// that folders containing files that could be modified to escalate the
|
||||
/// privileges of the agent (e.g. `.codex`, `.git`, notably `.git/hooks`) under
|
||||
/// a writable root are not modified by the agent.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, JsonSchema)]
|
||||
pub struct WritableRoot {
|
||||
pub root: AbsolutePathBuf,
|
||||
|
||||
/// By construction, these subpaths are all under `root`.
|
||||
pub read_only_subpaths: Vec<AbsolutePathBuf>,
|
||||
}
|
||||
|
||||
@@ -458,6 +460,13 @@ impl SandboxPolicy {
|
||||
if top_level_git.as_path().is_dir() {
|
||||
subpaths.push(top_level_git);
|
||||
}
|
||||
#[allow(clippy::expect_used)]
|
||||
let top_level_codex = writable_root
|
||||
.join(".codex")
|
||||
.expect(".codex is a valid relative path");
|
||||
if top_level_codex.as_path().is_dir() {
|
||||
subpaths.push(top_level_codex);
|
||||
}
|
||||
WritableRoot {
|
||||
root: writable_root,
|
||||
read_only_subpaths: subpaths,
|
||||
|
||||
@@ -23,7 +23,6 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
@@ -81,7 +80,7 @@ tokio = { workspace = true, features = [
|
||||
"test-util",
|
||||
"time",
|
||||
] }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-stream = { workspace = true, features = ["sync"] }
|
||||
toml = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-appender = { workspace = true }
|
||||
|
||||
@@ -16,7 +16,6 @@ use crossterm::event::DisableBracketedPaste;
|
||||
use crossterm::event::DisableFocusChange;
|
||||
use crossterm::event::EnableBracketedPaste;
|
||||
use crossterm::event::EnableFocusChange;
|
||||
use crossterm::event::Event;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyboardEnhancementFlags;
|
||||
use crossterm::event::PopKeyboardEnhancementFlags;
|
||||
@@ -32,7 +31,6 @@ use ratatui::crossterm::terminal::enable_raw_mode;
|
||||
use ratatui::layout::Offset;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::text::Line;
|
||||
use tokio::select;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::Stream;
|
||||
|
||||
@@ -42,11 +40,12 @@ use crate::custom_terminal::Terminal as CustomTerminal;
|
||||
use crate::notifications::DesktopNotificationBackend;
|
||||
use crate::notifications::NotificationBackendKind;
|
||||
use crate::notifications::detect_backend;
|
||||
#[cfg(unix)]
|
||||
use crate::tui::job_control::SUSPEND_KEY;
|
||||
use crate::tui::event_stream::EventBroker;
|
||||
use crate::tui::event_stream::TuiEventStream;
|
||||
#[cfg(unix)]
|
||||
use crate::tui::job_control::SuspendContext;
|
||||
|
||||
mod event_stream;
|
||||
mod frame_requester;
|
||||
#[cfg(unix)]
|
||||
mod job_control;
|
||||
@@ -156,7 +155,7 @@ fn set_panic_hook() {
|
||||
}));
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TuiEvent {
|
||||
Key(KeyEvent),
|
||||
Paste(String),
|
||||
@@ -166,6 +165,7 @@ pub enum TuiEvent {
|
||||
pub struct Tui {
|
||||
frame_requester: FrameRequester,
|
||||
draw_tx: broadcast::Sender<()>,
|
||||
event_broker: Arc<EventBroker>,
|
||||
pub(crate) terminal: Terminal,
|
||||
pending_history_lines: Vec<Line<'static>>,
|
||||
alt_saved_viewport: Option<ratatui::layout::Rect>,
|
||||
@@ -194,6 +194,7 @@ impl Tui {
|
||||
Self {
|
||||
frame_requester,
|
||||
draw_tx,
|
||||
event_broker: Arc::new(EventBroker::new()),
|
||||
terminal,
|
||||
pending_history_lines: vec![],
|
||||
alt_saved_viewport: None,
|
||||
@@ -214,6 +215,18 @@ impl Tui {
|
||||
self.enhanced_keys_supported
|
||||
}
|
||||
|
||||
// todo(sayan) unused for now; intend to use to enable opening external editors
|
||||
#[allow(unused)]
|
||||
pub fn pause_events(&mut self) {
|
||||
self.event_broker.pause_events();
|
||||
}
|
||||
|
||||
// todo(sayan) unused for now; intend to use to enable opening external editors
|
||||
#[allow(unused)]
|
||||
pub fn resume_events(&mut self) {
|
||||
self.event_broker.resume_events();
|
||||
}
|
||||
|
||||
/// Emit a desktop notification now if the terminal is unfocused.
|
||||
/// Returns true if a notification was posted.
|
||||
pub fn notify(&mut self, message: impl AsRef<str>) -> bool {
|
||||
@@ -262,79 +275,21 @@ impl Tui {
|
||||
}
|
||||
|
||||
pub fn event_stream(&self) -> Pin<Box<dyn Stream<Item = TuiEvent> + Send + 'static>> {
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
let mut crossterm_events = crossterm::event::EventStream::new();
|
||||
let mut draw_rx = self.draw_tx.subscribe();
|
||||
|
||||
// State for tracking how we should resume from ^Z suspend.
|
||||
#[cfg(unix)]
|
||||
let suspend_context = self.suspend_context.clone();
|
||||
#[cfg(unix)]
|
||||
let alt_screen_active = self.alt_screen_active.clone();
|
||||
|
||||
let terminal_focused = self.terminal_focused.clone();
|
||||
let event_stream = async_stream::stream! {
|
||||
loop {
|
||||
select! {
|
||||
event_result = crossterm_events.next() => {
|
||||
match event_result {
|
||||
Some(Ok(event)) => {
|
||||
match event {
|
||||
Event::Key(key_event) => {
|
||||
#[cfg(unix)]
|
||||
if SUSPEND_KEY.is_press(key_event) {
|
||||
let _ = suspend_context.suspend(&alt_screen_active);
|
||||
// We continue here after resume.
|
||||
yield TuiEvent::Draw;
|
||||
continue;
|
||||
}
|
||||
yield TuiEvent::Key(key_event);
|
||||
}
|
||||
Event::Resize(_, _) => {
|
||||
yield TuiEvent::Draw;
|
||||
}
|
||||
Event::Paste(pasted) => {
|
||||
yield TuiEvent::Paste(pasted);
|
||||
}
|
||||
Event::FocusGained => {
|
||||
terminal_focused.store(true, Ordering::Relaxed);
|
||||
crate::terminal_palette::requery_default_colors();
|
||||
yield TuiEvent::Draw;
|
||||
}
|
||||
Event::FocusLost => {
|
||||
terminal_focused.store(false, Ordering::Relaxed);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Some(Err(_)) | None => {
|
||||
// Exit the loop in case of broken pipe as we will never
|
||||
// recover from it
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
result = draw_rx.recv() => {
|
||||
match result {
|
||||
Ok(_) => {
|
||||
yield TuiEvent::Draw;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
||||
// We dropped one or more draw notifications; coalesce to a single draw.
|
||||
yield TuiEvent::Draw;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
// Sender dropped. This stream likely outlived its owning `Tui`;
|
||||
// exit to avoid spinning on a permanently-closed receiver.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(event_stream)
|
||||
let stream = TuiEventStream::new(
|
||||
self.event_broker.clone(),
|
||||
self.draw_tx.subscribe(),
|
||||
self.terminal_focused.clone(),
|
||||
self.suspend_context.clone(),
|
||||
self.alt_screen_active.clone(),
|
||||
);
|
||||
#[cfg(not(unix))]
|
||||
let stream = TuiEventStream::new(
|
||||
self.event_broker.clone(),
|
||||
self.draw_tx.subscribe(),
|
||||
self.terminal_focused.clone(),
|
||||
);
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Enter alternate screen and expand the viewport to full terminal size, saving the current
|
||||
|
||||
511
codex-rs/tui/src/tui/event_stream.rs
Normal file
511
codex-rs/tui/src/tui/event_stream.rs
Normal file
@@ -0,0 +1,511 @@
|
||||
//! Event stream plumbing for the TUI.
|
||||
//!
|
||||
//! - [`EventBroker`] holds the shared crossterm stream so multiple callers reuse the same
|
||||
//! input source and can drop/recreate it on pause/resume without rebuilding consumers.
|
||||
//! - [`TuiEventStream`] wraps a draw event subscription plus the shared [`EventBroker`] and maps crossterm
|
||||
//! events into [`TuiEvent`].
|
||||
//! - [`EventSource`] abstracts the underlying event producer; the real implementation is
|
||||
//! [`CrosstermEventSource`] and tests can swap in [`FakeEventSource`].
|
||||
//!
|
||||
//! The motivation for dropping/recreating the crossterm event stream is to enable the TUI to fully relinquish stdin.
|
||||
//! If the stream is not dropped, it will continue to read from stdin even if it is not actively being polled
|
||||
//! (due to how crossterm's EventStream is implemented), potentially stealing input from other processes reading stdin,
|
||||
//! like terminal text editors. This race can cause missed input or capturing terminal query responses (for example, OSC palette/size queries)
|
||||
//! that the other process expects to read. Stopping polling, instead of dropping the stream, is only sufficient when the
|
||||
//! pause happens before the stream enters a pending state; otherwise the crossterm reader thread may keep reading
|
||||
//! from stdin, so the safer approach is to drop and recreate the event stream when we need to hand off the terminal.
|
||||
//!
|
||||
//! See https://ratatui.rs/recipes/apps/spawn-vim/ and https://www.reddit.com/r/rust/comments/1f3o33u/myterious_crossterm_input_after_running_vim for more details.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use crossterm::event::Event;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::watch;
|
||||
use tokio_stream::Stream;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use tokio_stream::wrappers::WatchStream;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
|
||||
use super::TuiEvent;
|
||||
|
||||
/// Result type produced by an event source.
|
||||
pub type EventResult = std::io::Result<Event>;
|
||||
|
||||
/// Abstraction over a source of terminal events. Allows swapping in a fake for tests.
|
||||
/// Value in production is [`CrosstermEventSource`].
|
||||
pub trait EventSource: Send + 'static {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<EventResult>>;
|
||||
}
|
||||
|
||||
/// Shared crossterm input state for all [`TuiEventStream`] instances. A single crossterm EventStream
|
||||
/// is reused so all streams still see the same input source.
|
||||
///
|
||||
/// This intermediate layer enables dropping/recreating the underlying EventStream (pause/resume) without rebuilding consumers.
|
||||
pub struct EventBroker<S: EventSource = CrosstermEventSource> {
|
||||
state: Mutex<EventBrokerState<S>>,
|
||||
resume_events_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
/// Tracks state of underlying [`EventSource`].
|
||||
enum EventBrokerState<S: EventSource> {
|
||||
Paused, // Underlying event source (i.e., crossterm EventStream) dropped
|
||||
Start, // A new event source will be created on next poll
|
||||
Running(S), // Event source is currently running
|
||||
}
|
||||
|
||||
impl<S: EventSource + Default> EventBrokerState<S> {
|
||||
/// Return the running event source, starting it if needed; None when paused.
|
||||
fn active_event_source_mut(&mut self) -> Option<&mut S> {
|
||||
match self {
|
||||
EventBrokerState::Paused => None,
|
||||
EventBrokerState::Start => {
|
||||
*self = EventBrokerState::Running(S::default());
|
||||
match self {
|
||||
EventBrokerState::Running(events) => Some(events),
|
||||
EventBrokerState::Paused | EventBrokerState::Start => unreachable!(),
|
||||
}
|
||||
}
|
||||
EventBrokerState::Running(events) => Some(events),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: EventSource + Default> EventBroker<S> {
|
||||
pub fn new() -> Self {
|
||||
let (resume_events_tx, _resume_events_rx) = watch::channel(());
|
||||
Self {
|
||||
state: Mutex::new(EventBrokerState::Start),
|
||||
resume_events_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop the underlying event source
|
||||
pub fn pause_events(&self) {
|
||||
let mut state = self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*state = EventBrokerState::Paused;
|
||||
}
|
||||
|
||||
/// Create a new instance of the underlying event source
|
||||
pub fn resume_events(&self) {
|
||||
let mut state = self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*state = EventBrokerState::Start;
|
||||
let _ = self.resume_events_tx.send(());
|
||||
}
|
||||
|
||||
/// Subscribe to a notification that fires whenever [`Self::resume_events`] is called.
|
||||
///
|
||||
/// This is used to wake `poll_crossterm_event` when it is paused and waiting for the
|
||||
/// underlying crossterm stream to be recreated.
|
||||
pub fn resume_events_rx(&self) -> watch::Receiver<()> {
|
||||
self.resume_events_tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
/// Real crossterm-backed event source.
|
||||
pub struct CrosstermEventSource(pub crossterm::event::EventStream);
|
||||
|
||||
impl Default for CrosstermEventSource {
|
||||
fn default() -> Self {
|
||||
Self(crossterm::event::EventStream::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl EventSource for CrosstermEventSource {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<EventResult>> {
|
||||
Pin::new(&mut self.get_mut().0).poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// TuiEventStream is a struct for reading TUI events (draws and user input).
|
||||
/// Each instance has its own draw subscription (the draw channel is broadcast, so
|
||||
/// multiple receivers are fine), while crossterm input is funneled through a
|
||||
/// single shared [`EventBroker`] because crossterm uses a global stdin reader and
|
||||
/// does not support fan-out. Multiple TuiEventStream instances can exist during the app lifetime
|
||||
/// (for nested or sequential screens), but only one should be polled at a time,
|
||||
/// otherwise one instance can consume ("steal") input events and the other will miss them.
|
||||
pub struct TuiEventStream<S: EventSource + Default + Unpin = CrosstermEventSource> {
|
||||
broker: Arc<EventBroker<S>>,
|
||||
draw_stream: BroadcastStream<()>,
|
||||
resume_stream: WatchStream<()>,
|
||||
terminal_focused: Arc<AtomicBool>,
|
||||
poll_draw_first: bool,
|
||||
#[cfg(unix)]
|
||||
suspend_context: crate::tui::job_control::SuspendContext,
|
||||
#[cfg(unix)]
|
||||
alt_screen_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<S: EventSource + Default + Unpin> TuiEventStream<S> {
|
||||
pub fn new(
|
||||
broker: Arc<EventBroker<S>>,
|
||||
draw_rx: broadcast::Receiver<()>,
|
||||
terminal_focused: Arc<AtomicBool>,
|
||||
#[cfg(unix)] suspend_context: crate::tui::job_control::SuspendContext,
|
||||
#[cfg(unix)] alt_screen_active: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let resume_stream = WatchStream::from_changes(broker.resume_events_rx());
|
||||
Self {
|
||||
broker,
|
||||
draw_stream: BroadcastStream::new(draw_rx),
|
||||
resume_stream,
|
||||
terminal_focused,
|
||||
poll_draw_first: false,
|
||||
#[cfg(unix)]
|
||||
suspend_context,
|
||||
#[cfg(unix)]
|
||||
alt_screen_active,
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll the shared crossterm stream for the next mapped `TuiEvent`.
|
||||
///
|
||||
/// This skips events we don't use (mouse events, etc.) and keeps polling until it yields
|
||||
/// a mapped event, hits `Pending`, or sees EOF/error. When the broker is paused, it drops
|
||||
/// the underlying stream and returns `Pending` to fully release stdin.
|
||||
pub fn poll_crossterm_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<TuiEvent>> {
|
||||
// Some crossterm events map to None (e.g. FocusLost, mouse); loop so we keep polling
|
||||
// until we return a mapped event, hit Pending, or see EOF/error.
|
||||
loop {
|
||||
let poll_result = {
|
||||
let mut state = self
|
||||
.broker
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let events = match state.active_event_source_mut() {
|
||||
Some(events) => events,
|
||||
None => {
|
||||
drop(state);
|
||||
// Poll resume_stream so resume_events wakes a stream paused here
|
||||
match Pin::new(&mut self.resume_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
};
|
||||
match Pin::new(events).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(event))) => Some(event),
|
||||
Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
|
||||
*state = EventBrokerState::Start;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Pending => {
|
||||
drop(state);
|
||||
// Poll resume_stream so resume_events can wake us even while waiting on stdin
|
||||
match Pin::new(&mut self.resume_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(mapped) = poll_result.and_then(|event| self.map_crossterm_event(event)) {
|
||||
return Poll::Ready(Some(mapped));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll the draw broadcast stream for the next draw event. Draw events are used to trigger a redraw of the TUI.
|
||||
pub fn poll_draw_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<TuiEvent>> {
|
||||
match Pin::new(&mut self.draw_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(()))) => Poll::Ready(Some(TuiEvent::Draw)),
|
||||
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
|
||||
Poll::Ready(Some(TuiEvent::Draw))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a crossterm event to a [`TuiEvent`], skipping events we don't use (mouse events, etc.).
|
||||
fn map_crossterm_event(&mut self, event: Event) -> Option<TuiEvent> {
|
||||
match event {
|
||||
Event::Key(key_event) => {
|
||||
#[cfg(unix)]
|
||||
if crate::tui::job_control::SUSPEND_KEY.is_press(key_event) {
|
||||
let _ = self.suspend_context.suspend(&self.alt_screen_active);
|
||||
return Some(TuiEvent::Draw);
|
||||
}
|
||||
Some(TuiEvent::Key(key_event))
|
||||
}
|
||||
Event::Resize(_, _) => Some(TuiEvent::Draw),
|
||||
Event::Paste(pasted) => Some(TuiEvent::Paste(pasted)),
|
||||
Event::FocusGained => {
|
||||
self.terminal_focused.store(true, Ordering::Relaxed);
|
||||
crate::terminal_palette::requery_default_colors();
|
||||
Some(TuiEvent::Draw)
|
||||
}
|
||||
Event::FocusLost => {
|
||||
self.terminal_focused.store(false, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: EventSource + Default + Unpin> Unpin for TuiEventStream<S> {}
|
||||
|
||||
impl<S: EventSource + Default + Unpin> Stream for TuiEventStream<S> {
|
||||
type Item = TuiEvent;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// approximate fairness + no starvation via round-robin.
|
||||
let draw_first = self.poll_draw_first;
|
||||
self.poll_draw_first = !self.poll_draw_first;
|
||||
|
||||
if draw_first {
|
||||
if let Poll::Ready(event) = self.poll_draw_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
}
|
||||
if let Poll::Ready(event) = self.poll_crossterm_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
}
|
||||
} else {
|
||||
if let Poll::Ready(event) = self.poll_crossterm_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
}
|
||||
if let Poll::Ready(event) = self.poll_draw_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crossterm::event::Event;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
/// Simple fake event source for tests; feed events via the handle.
|
||||
struct FakeEventSource {
|
||||
rx: mpsc::UnboundedReceiver<EventResult>,
|
||||
tx: mpsc::UnboundedSender<EventResult>,
|
||||
}
|
||||
|
||||
struct FakeEventSourceHandle {
|
||||
broker: Arc<EventBroker<FakeEventSource>>,
|
||||
}
|
||||
|
||||
impl FakeEventSource {
|
||||
fn new() -> Self {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
Self { rx, tx }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FakeEventSource {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEventSourceHandle {
|
||||
fn new(broker: Arc<EventBroker<FakeEventSource>>) -> Self {
|
||||
Self { broker }
|
||||
}
|
||||
|
||||
fn send(&self, event: EventResult) {
|
||||
let mut state = self
|
||||
.broker
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(source) = state.active_event_source_mut() else {
|
||||
return;
|
||||
};
|
||||
let _ = source.tx.send(event);
|
||||
}
|
||||
}
|
||||
|
||||
impl EventSource for FakeEventSource {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<EventResult>> {
|
||||
Pin::new(&mut self.get_mut().rx).poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_stream(
|
||||
broker: Arc<EventBroker<FakeEventSource>>,
|
||||
draw_rx: broadcast::Receiver<()>,
|
||||
terminal_focused: Arc<AtomicBool>,
|
||||
) -> TuiEventStream<FakeEventSource> {
|
||||
TuiEventStream::new(
|
||||
broker,
|
||||
draw_rx,
|
||||
terminal_focused,
|
||||
#[cfg(unix)]
|
||||
crate::tui::job_control::SuspendContext::new(),
|
||||
#[cfg(unix)]
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
)
|
||||
}
|
||||
|
||||
type SetupState = (
|
||||
Arc<EventBroker<FakeEventSource>>,
|
||||
FakeEventSourceHandle,
|
||||
broadcast::Sender<()>,
|
||||
broadcast::Receiver<()>,
|
||||
Arc<AtomicBool>,
|
||||
);
|
||||
|
||||
fn setup() -> SetupState {
|
||||
let source = FakeEventSource::new();
|
||||
let broker = Arc::new(EventBroker::new());
|
||||
*broker.state.lock().unwrap() = EventBrokerState::Running(source);
|
||||
let handle = FakeEventSourceHandle::new(broker.clone());
|
||||
|
||||
let (draw_tx, draw_rx) = broadcast::channel(1);
|
||||
let terminal_focused = Arc::new(AtomicBool::new(true));
|
||||
(broker, handle, draw_tx, draw_rx, terminal_focused)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn key_event_skips_unmapped() {
|
||||
let (broker, handle, _draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker, draw_rx, terminal_focused);
|
||||
|
||||
handle.send(Ok(Event::FocusLost));
|
||||
handle.send(Ok(Event::Key(KeyEvent::new(
|
||||
KeyCode::Char('a'),
|
||||
KeyModifiers::NONE,
|
||||
))));
|
||||
|
||||
let next = stream.next().await.unwrap();
|
||||
match next {
|
||||
TuiEvent::Key(key) => {
|
||||
assert_eq!(key, KeyEvent::new(KeyCode::Char('a'), KeyModifiers::NONE));
|
||||
}
|
||||
other => panic!("expected key event, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn draw_and_key_events_yield_both() {
|
||||
let (broker, handle, draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker, draw_rx, terminal_focused);
|
||||
|
||||
let expected_key = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::NONE);
|
||||
let _ = draw_tx.send(());
|
||||
handle.send(Ok(Event::Key(expected_key)));
|
||||
|
||||
let first = stream.next().await.unwrap();
|
||||
let second = stream.next().await.unwrap();
|
||||
|
||||
let mut saw_draw = false;
|
||||
let mut saw_key = false;
|
||||
for event in [first, second] {
|
||||
match event {
|
||||
TuiEvent::Draw => {
|
||||
saw_draw = true;
|
||||
}
|
||||
TuiEvent::Key(key) => {
|
||||
assert_eq!(key, expected_key);
|
||||
saw_key = true;
|
||||
}
|
||||
other => panic!("expected draw or key event, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(saw_draw && saw_key, "expected both draw and key events");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn lagged_draw_maps_to_draw() {
|
||||
let (broker, _handle, draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker, draw_rx.resubscribe(), terminal_focused);
|
||||
|
||||
// Fill channel to force Lagged on the receiver.
|
||||
let _ = draw_tx.send(());
|
||||
let _ = draw_tx.send(());
|
||||
|
||||
let first = stream.next().await;
|
||||
assert!(matches!(first, Some(TuiEvent::Draw)));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn error_or_eof_ends_stream() {
|
||||
let (broker, handle, _draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker, draw_rx, terminal_focused);
|
||||
|
||||
handle.send(Err(std::io::Error::other("boom")));
|
||||
|
||||
let next = stream.next().await;
|
||||
assert!(next.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn resume_wakes_paused_stream() {
|
||||
let (broker, handle, _draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker.clone(), draw_rx, terminal_focused);
|
||||
|
||||
broker.pause_events();
|
||||
|
||||
let task = tokio::spawn(async move { stream.next().await });
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
broker.resume_events();
|
||||
let expected_key = KeyEvent::new(KeyCode::Char('r'), KeyModifiers::NONE);
|
||||
handle.send(Ok(Event::Key(expected_key)));
|
||||
|
||||
let event = timeout(Duration::from_millis(100), task)
|
||||
.await
|
||||
.expect("timed out waiting for resumed event")
|
||||
.expect("join failed");
|
||||
match event {
|
||||
Some(TuiEvent::Key(key)) => assert_eq!(key, expected_key),
|
||||
other => panic!("expected key event, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn resume_wakes_pending_stream() {
|
||||
let (broker, handle, _draw_tx, draw_rx, terminal_focused) = setup();
|
||||
let mut stream = make_stream(broker.clone(), draw_rx, terminal_focused);
|
||||
|
||||
let task = tokio::spawn(async move { stream.next().await });
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
broker.pause_events();
|
||||
broker.resume_events();
|
||||
let expected_key = KeyEvent::new(KeyCode::Char('p'), KeyModifiers::NONE);
|
||||
handle.send(Ok(Event::Key(expected_key)));
|
||||
|
||||
let event = timeout(Duration::from_millis(100), task)
|
||||
.await
|
||||
.expect("timed out waiting for resumed event")
|
||||
.expect("join failed");
|
||||
match event {
|
||||
Some(TuiEvent::Key(key)) => assert_eq!(key, expected_key),
|
||||
other => panic!("expected key event, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,8 @@ use crate::skill_error_prompt::SkillErrorPromptOutcome;
|
||||
use crate::skill_error_prompt::run_skill_error_prompt;
|
||||
use crate::tui;
|
||||
use crate::tui::TuiEvent;
|
||||
use crate::tui::scrolling::TranscriptLineMeta;
|
||||
use crate::tui::scrolling::TranscriptScroll;
|
||||
use crate::update_action::UpdateAction;
|
||||
use crate::wrapping::RtOptions;
|
||||
use crate::wrapping::word_wrap_line;
|
||||
@@ -339,21 +341,6 @@ pub(crate) struct App {
|
||||
skip_world_writable_scan_once: bool,
|
||||
}
|
||||
|
||||
/// Scroll state for the inline transcript viewport.
|
||||
///
|
||||
/// This tracks whether the transcript is pinned to the latest line or anchored
|
||||
/// at a specific cell/line pair so later viewport changes can implement
|
||||
/// scrollback without losing the notion of "bottom".
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
enum TranscriptScroll {
|
||||
#[default]
|
||||
ToBottom,
|
||||
Scrolled {
|
||||
cell_index: usize,
|
||||
line_in_cell: usize,
|
||||
},
|
||||
}
|
||||
/// Content-relative selection within the inline transcript viewport.
|
||||
///
|
||||
/// Selection endpoints are expressed in terms of flattened, wrapped transcript
|
||||
@@ -494,7 +481,7 @@ impl App {
|
||||
file_search,
|
||||
enhanced_keys_supported,
|
||||
transcript_cells: Vec::new(),
|
||||
transcript_scroll: TranscriptScroll::ToBottom,
|
||||
transcript_scroll: TranscriptScroll::default(),
|
||||
transcript_selection: TranscriptSelection::default(),
|
||||
transcript_view_top: 0,
|
||||
transcript_total_lines: 0,
|
||||
@@ -562,13 +549,13 @@ impl App {
|
||||
let session_lines = if width == 0 {
|
||||
Vec::new()
|
||||
} else {
|
||||
let (lines, meta) = Self::build_transcript_lines(&app.transcript_cells, width);
|
||||
let (lines, line_meta) = Self::build_transcript_lines(&app.transcript_cells, width);
|
||||
let is_user_cell: Vec<bool> = app
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.map(|cell| cell.as_any().is::<UserHistoryCell>())
|
||||
.collect();
|
||||
Self::render_lines_to_ansi(&lines, &meta, &is_user_cell, width)
|
||||
Self::render_lines_to_ansi(&lines, &line_meta, &is_user_cell, width)
|
||||
};
|
||||
|
||||
tui.terminal.clear()?;
|
||||
@@ -676,7 +663,7 @@ impl App {
|
||||
) -> u16 {
|
||||
let area = frame.area();
|
||||
if area.width == 0 || area.height == 0 {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
self.transcript_scroll = TranscriptScroll::default();
|
||||
self.transcript_view_top = 0;
|
||||
self.transcript_total_lines = 0;
|
||||
return area.bottom().saturating_sub(chat_height);
|
||||
@@ -685,7 +672,7 @@ impl App {
|
||||
let chat_height = chat_height.min(area.height);
|
||||
let max_transcript_height = area.height.saturating_sub(chat_height);
|
||||
if max_transcript_height == 0 {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
self.transcript_scroll = TranscriptScroll::default();
|
||||
self.transcript_view_top = 0;
|
||||
self.transcript_total_lines = 0;
|
||||
return area.y;
|
||||
@@ -698,10 +685,10 @@ impl App {
|
||||
height: max_transcript_height,
|
||||
};
|
||||
|
||||
let (lines, meta) = Self::build_transcript_lines(cells, transcript_area.width);
|
||||
let (lines, line_meta) = Self::build_transcript_lines(cells, transcript_area.width);
|
||||
if lines.is_empty() {
|
||||
Clear.render_ref(transcript_area, frame.buffer);
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
self.transcript_scroll = TranscriptScroll::default();
|
||||
self.transcript_view_top = 0;
|
||||
self.transcript_total_lines = 0;
|
||||
return area.y;
|
||||
@@ -709,7 +696,7 @@ impl App {
|
||||
|
||||
let wrapped = word_wrap_lines_borrowed(&lines, transcript_area.width.max(1) as usize);
|
||||
if wrapped.is_empty() {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
self.transcript_scroll = TranscriptScroll::default();
|
||||
self.transcript_view_top = 0;
|
||||
self.transcript_total_lines = 0;
|
||||
return area.y;
|
||||
@@ -731,10 +718,10 @@ impl App {
|
||||
.initial_indent(base_opts.subsequent_indent.clone())
|
||||
};
|
||||
let seg_count = word_wrap_line(line, opts).len();
|
||||
let is_user_row = meta
|
||||
let is_user_row = line_meta
|
||||
.get(idx)
|
||||
.and_then(Option::as_ref)
|
||||
.map(|(cell_index, _)| is_user_cell.get(*cell_index).copied().unwrap_or(false))
|
||||
.and_then(TranscriptLineMeta::cell_index)
|
||||
.map(|cell_index| is_user_cell.get(cell_index).copied().unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
wrapped_is_user_row.extend(std::iter::repeat_n(is_user_row, seg_count));
|
||||
first = false;
|
||||
@@ -745,30 +732,8 @@ impl App {
|
||||
let max_visible = std::cmp::min(max_transcript_height as usize, total_lines);
|
||||
let max_start = total_lines.saturating_sub(max_visible);
|
||||
|
||||
let top_offset = match self.transcript_scroll {
|
||||
TranscriptScroll::ToBottom => max_start,
|
||||
TranscriptScroll::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
} => {
|
||||
let mut anchor = None;
|
||||
for (idx, entry) in meta.iter().enumerate() {
|
||||
if let Some((ci, li)) = entry
|
||||
&& *ci == cell_index
|
||||
&& *li == line_in_cell
|
||||
{
|
||||
anchor = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(idx) = anchor {
|
||||
idx.min(max_start)
|
||||
} else {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
max_start
|
||||
}
|
||||
}
|
||||
};
|
||||
let (scroll_state, top_offset) = self.transcript_scroll.resolve_top(&line_meta, max_start);
|
||||
self.transcript_scroll = scroll_state;
|
||||
self.transcript_view_top = top_offset;
|
||||
|
||||
let transcript_visible_height = max_visible as u16;
|
||||
@@ -974,69 +939,10 @@ impl App {
|
||||
return;
|
||||
}
|
||||
|
||||
let (lines, meta) = Self::build_transcript_lines(&self.transcript_cells, width);
|
||||
let total_lines = lines.len();
|
||||
if total_lines <= visible_lines {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
return;
|
||||
}
|
||||
|
||||
let max_start = total_lines.saturating_sub(visible_lines);
|
||||
|
||||
let current_top = match self.transcript_scroll {
|
||||
TranscriptScroll::ToBottom => max_start,
|
||||
TranscriptScroll::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
} => {
|
||||
let mut anchor = None;
|
||||
for (idx, entry) in meta.iter().enumerate() {
|
||||
if let Some((ci, li)) = entry
|
||||
&& *ci == cell_index
|
||||
&& *li == line_in_cell
|
||||
{
|
||||
anchor = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
anchor.unwrap_or(max_start).min(max_start)
|
||||
}
|
||||
};
|
||||
|
||||
if delta_lines == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let new_top = if delta_lines < 0 {
|
||||
current_top.saturating_sub(delta_lines.unsigned_abs() as usize)
|
||||
} else {
|
||||
current_top
|
||||
.saturating_add(delta_lines as usize)
|
||||
.min(max_start)
|
||||
};
|
||||
|
||||
if new_top == max_start {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
} else {
|
||||
let anchor = meta.iter().skip(new_top).find_map(|entry| *entry);
|
||||
if let Some((cell_index, line_in_cell)) = anchor {
|
||||
self.transcript_scroll = TranscriptScroll::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
};
|
||||
} else if let Some(prev_idx) = (0..=new_top).rfind(|&idx| meta[idx].is_some()) {
|
||||
if let Some((cell_index, line_in_cell)) = meta[prev_idx] {
|
||||
self.transcript_scroll = TranscriptScroll::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
};
|
||||
} else {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
}
|
||||
} else {
|
||||
self.transcript_scroll = TranscriptScroll::ToBottom;
|
||||
}
|
||||
}
|
||||
let (_, line_meta) = Self::build_transcript_lines(&self.transcript_cells, width);
|
||||
self.transcript_scroll =
|
||||
self.transcript_scroll
|
||||
.scrolled_by(delta_lines, &line_meta, visible_lines);
|
||||
|
||||
tui.frame_requester().schedule_frame();
|
||||
}
|
||||
@@ -1053,8 +959,8 @@ impl App {
|
||||
return;
|
||||
}
|
||||
|
||||
let (lines, meta) = Self::build_transcript_lines(&self.transcript_cells, width);
|
||||
if lines.is_empty() || meta.is_empty() {
|
||||
let (lines, line_meta) = Self::build_transcript_lines(&self.transcript_cells, width);
|
||||
if lines.is_empty() || line_meta.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1073,22 +979,8 @@ impl App {
|
||||
}
|
||||
};
|
||||
|
||||
let mut anchor = None;
|
||||
if let Some((cell_index, line_in_cell)) = meta.iter().skip(top_offset).flatten().next() {
|
||||
anchor = Some((*cell_index, *line_in_cell));
|
||||
}
|
||||
if anchor.is_none()
|
||||
&& let Some((cell_index, line_in_cell)) =
|
||||
meta[..top_offset].iter().rev().flatten().next()
|
||||
{
|
||||
anchor = Some((*cell_index, *line_in_cell));
|
||||
}
|
||||
|
||||
if let Some((cell_index, line_in_cell)) = anchor {
|
||||
self.transcript_scroll = TranscriptScroll::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
};
|
||||
if let Some(scroll_state) = TranscriptScroll::anchor_for(&line_meta, top_offset) {
|
||||
self.transcript_scroll = scroll_state;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1096,16 +988,17 @@ impl App {
|
||||
///
|
||||
/// Returns both the visible `Line` buffer and a parallel metadata vector
|
||||
/// that maps each line back to its originating `(cell_index, line_in_cell)`
|
||||
/// pair, or `None` for spacer lines. This allows the scroll state to anchor
|
||||
/// to a specific history cell even as new content arrives or the viewport
|
||||
/// size changes, and gives exit transcript renderers enough structure to
|
||||
/// style user rows differently from agent rows.
|
||||
/// pair (see `TranscriptLineMeta::CellLine`), or `TranscriptLineMeta::Spacer` for
|
||||
/// synthetic spacer rows inserted between cells. This allows the scroll state
|
||||
/// to anchor to a specific history cell even as new content arrives or the
|
||||
/// viewport size changes, and gives exit transcript renderers enough structure
|
||||
/// to style user rows differently from agent rows.
|
||||
fn build_transcript_lines(
|
||||
cells: &[Arc<dyn HistoryCell>],
|
||||
width: u16,
|
||||
) -> (Vec<Line<'static>>, Vec<Option<(usize, usize)>>) {
|
||||
) -> (Vec<Line<'static>>, Vec<TranscriptLineMeta>) {
|
||||
let mut lines: Vec<Line<'static>> = Vec::new();
|
||||
let mut meta: Vec<Option<(usize, usize)>> = Vec::new();
|
||||
let mut line_meta: Vec<TranscriptLineMeta> = Vec::new();
|
||||
let mut has_emitted_lines = false;
|
||||
|
||||
for (cell_index, cell) in cells.iter().enumerate() {
|
||||
@@ -1117,19 +1010,22 @@ impl App {
|
||||
if !cell.is_stream_continuation() {
|
||||
if has_emitted_lines {
|
||||
lines.push(Line::from(""));
|
||||
meta.push(None);
|
||||
line_meta.push(TranscriptLineMeta::Spacer);
|
||||
} else {
|
||||
has_emitted_lines = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (line_in_cell, line) in cell_lines.into_iter().enumerate() {
|
||||
meta.push(Some((cell_index, line_in_cell)));
|
||||
line_meta.push(TranscriptLineMeta::CellLine {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
});
|
||||
lines.push(line);
|
||||
}
|
||||
}
|
||||
|
||||
(lines, meta)
|
||||
(lines, line_meta)
|
||||
}
|
||||
|
||||
/// Render flattened transcript lines into ANSI strings suitable for
|
||||
@@ -1144,7 +1040,7 @@ impl App {
|
||||
/// and tools see consistent escape sequences.
|
||||
fn render_lines_to_ansi(
|
||||
lines: &[Line<'static>],
|
||||
meta: &[Option<(usize, usize)>],
|
||||
line_meta: &[TranscriptLineMeta],
|
||||
is_user_cell: &[bool],
|
||||
width: u16,
|
||||
) -> Vec<String> {
|
||||
@@ -1152,10 +1048,10 @@ impl App {
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, line)| {
|
||||
let is_user_row = meta
|
||||
let is_user_row = line_meta
|
||||
.get(idx)
|
||||
.and_then(|entry| entry.as_ref())
|
||||
.map(|(cell_index, _)| is_user_cell.get(*cell_index).copied().unwrap_or(false))
|
||||
.and_then(TranscriptLineMeta::cell_index)
|
||||
.map(|cell_index| is_user_cell.get(cell_index).copied().unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
let mut merged_spans: Vec<ratatui::text::Span<'static>> = line
|
||||
@@ -2262,7 +2158,7 @@ mod tests {
|
||||
active_profile: None,
|
||||
file_search,
|
||||
transcript_cells: Vec::new(),
|
||||
transcript_scroll: TranscriptScroll::ToBottom,
|
||||
transcript_scroll: TranscriptScroll::default(),
|
||||
transcript_selection: TranscriptSelection::default(),
|
||||
transcript_view_top: 0,
|
||||
transcript_total_lines: 0,
|
||||
@@ -2306,7 +2202,7 @@ mod tests {
|
||||
active_profile: None,
|
||||
file_search,
|
||||
transcript_cells: Vec::new(),
|
||||
transcript_scroll: TranscriptScroll::ToBottom,
|
||||
transcript_scroll: TranscriptScroll::default(),
|
||||
transcript_selection: TranscriptSelection::default(),
|
||||
transcript_view_top: 0,
|
||||
transcript_total_lines: 0,
|
||||
@@ -2576,11 +2472,14 @@ mod tests {
|
||||
fn render_lines_to_ansi_pads_user_rows_to_full_width() {
|
||||
let line: Line<'static> = Line::from("hi");
|
||||
let lines = vec![line];
|
||||
let meta = vec![Some((0usize, 0usize))];
|
||||
let line_meta = vec![TranscriptLineMeta::CellLine {
|
||||
cell_index: 0,
|
||||
line_in_cell: 0,
|
||||
}];
|
||||
let is_user_cell = vec![true];
|
||||
let width: u16 = 10;
|
||||
|
||||
let rendered = App::render_lines_to_ansi(&lines, &meta, &is_user_cell, width);
|
||||
let rendered = App::render_lines_to_ansi(&lines, &line_meta, &is_user_cell, width);
|
||||
assert_eq!(rendered.len(), 1);
|
||||
assert!(rendered[0].contains("hi"));
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ use crate::tui::job_control::SuspendContext;
|
||||
mod frame_requester;
|
||||
#[cfg(unix)]
|
||||
mod job_control;
|
||||
pub(crate) mod scrolling;
|
||||
|
||||
/// A type alias for the terminal type used in this application
|
||||
pub type Terminal = CustomTerminal<CrosstermBackend<Stdout>>;
|
||||
|
||||
366
codex-rs/tui2/src/tui/scrolling.rs
Normal file
366
codex-rs/tui2/src/tui/scrolling.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
//! Inline transcript scrolling primitives.
|
||||
//!
|
||||
//! The TUI renders the transcript as a list of logical *cells* (user prompts, agent responses,
|
||||
//! banners, etc.). Each frame flattens those cells into a sequence of visual lines (after wrapping)
|
||||
//! plus a parallel `line_meta` vector that maps each visual line back to its origin
|
||||
//! (`TranscriptLineMeta`) (see `App::build_transcript_lines` and the design notes in
|
||||
//! `codex-rs/tui2/docs/tui_viewport_and_history.md`).
|
||||
//!
|
||||
//! This module defines the scroll state for the inline transcript viewport and helpers to:
|
||||
//! - Resolve that state into a concrete top-row offset for the current frame.
|
||||
//! - Apply a scroll delta (mouse wheel / PgUp / PgDn) in terms of *visual lines*.
|
||||
//! - Convert a concrete top-row offset back into a stable anchor.
|
||||
//!
|
||||
//! Why anchors instead of a raw "top row" index?
|
||||
//! - When the transcript grows, a raw index drifts relative to the user's chosen content.
|
||||
//! - By anchoring to a particular `(cell_index, line_in_cell)`, we can re-find the same content in
|
||||
//! the newly flattened line list on the next frame.
|
||||
//!
|
||||
//! Spacer rows between non-continuation cells are represented as `TranscriptLineMeta::Spacer`.
|
||||
//! They are not valid anchors; `anchor_for` will pick the nearest non-spacer line when needed.
|
||||
|
||||
/// Per-flattened-line metadata for the transcript view.
|
||||
///
|
||||
/// Each rendered line in the flattened transcript has a corresponding `TranscriptLineMeta` entry
|
||||
/// describing where that visual line came from.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum TranscriptLineMeta {
|
||||
/// A visual line that belongs to a transcript cell.
|
||||
CellLine {
|
||||
cell_index: usize,
|
||||
line_in_cell: usize,
|
||||
},
|
||||
/// A synthetic spacer row inserted between non-continuation cells.
|
||||
Spacer,
|
||||
}
|
||||
|
||||
impl TranscriptLineMeta {
|
||||
pub(crate) fn cell_line(&self) -> Option<(usize, usize)> {
|
||||
match *self {
|
||||
Self::CellLine {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
} => Some((cell_index, line_in_cell)),
|
||||
Self::Spacer => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cell_index(&self) -> Option<usize> {
|
||||
match *self {
|
||||
Self::CellLine { cell_index, .. } => Some(cell_index),
|
||||
Self::Spacer => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll state for the inline transcript viewport.
|
||||
///
|
||||
/// This tracks whether the transcript is pinned to the latest line or anchored
|
||||
/// at a specific cell/line pair so later viewport changes can implement
|
||||
/// scrollback without losing the notion of "bottom".
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub(crate) enum TranscriptScroll {
|
||||
#[default]
|
||||
/// Follow the most recent line in the transcript.
|
||||
ToBottom,
|
||||
/// Anchor the viewport to a specific transcript cell and line.
|
||||
///
|
||||
/// `cell_index` indexes into the logical transcript cell list. `line_in_cell` is the 0-based
|
||||
/// visual line index within that cell as produced by the current wrapping/layout.
|
||||
Scrolled {
|
||||
cell_index: usize,
|
||||
line_in_cell: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl TranscriptScroll {
|
||||
/// Resolve the top row for the current scroll state.
|
||||
///
|
||||
/// `line_meta` is a line-parallel mapping of flattened transcript lines.
|
||||
///
|
||||
/// `max_start` is the maximum valid top-row offset for the current viewport height (i.e. the
|
||||
/// last scroll position that still yields a full viewport of content).
|
||||
///
|
||||
/// Returns the (possibly updated) scroll state plus the resolved top-row offset. If the current
|
||||
/// anchor can no longer be found in `line_meta` (for example because the transcript was
|
||||
/// truncated), this falls back to `ToBottom` so the UI stays usable.
|
||||
pub(crate) fn resolve_top(
|
||||
self,
|
||||
line_meta: &[TranscriptLineMeta],
|
||||
max_start: usize,
|
||||
) -> (Self, usize) {
|
||||
match self {
|
||||
Self::ToBottom => (Self::ToBottom, max_start),
|
||||
Self::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
} => {
|
||||
let anchor = anchor_index(line_meta, cell_index, line_in_cell);
|
||||
match anchor {
|
||||
Some(idx) => (self, idx.min(max_start)),
|
||||
None => (Self::ToBottom, max_start),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply a scroll delta and return the updated scroll state.
|
||||
///
|
||||
/// `delta_lines` is in *visual lines* (after wrapping): negative deltas scroll upward into
|
||||
/// scrollback, positive deltas scroll downward toward the latest content.
|
||||
///
|
||||
/// See `resolve_top` for `line_meta` semantics. `visible_lines` is the viewport height in rows.
|
||||
/// If all flattened lines fit in the viewport, this always returns `ToBottom`.
|
||||
pub(crate) fn scrolled_by(
|
||||
self,
|
||||
delta_lines: i32,
|
||||
line_meta: &[TranscriptLineMeta],
|
||||
visible_lines: usize,
|
||||
) -> Self {
|
||||
if delta_lines == 0 {
|
||||
return self;
|
||||
}
|
||||
|
||||
let total_lines = line_meta.len();
|
||||
if total_lines <= visible_lines {
|
||||
return Self::ToBottom;
|
||||
}
|
||||
|
||||
let max_start = total_lines.saturating_sub(visible_lines);
|
||||
let current_top = match self {
|
||||
Self::ToBottom => max_start,
|
||||
Self::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
} => anchor_index(line_meta, cell_index, line_in_cell)
|
||||
.unwrap_or(max_start)
|
||||
.min(max_start),
|
||||
};
|
||||
|
||||
let new_top = if delta_lines < 0 {
|
||||
current_top.saturating_sub(delta_lines.unsigned_abs() as usize)
|
||||
} else {
|
||||
current_top
|
||||
.saturating_add(delta_lines as usize)
|
||||
.min(max_start)
|
||||
};
|
||||
|
||||
if new_top == max_start {
|
||||
return Self::ToBottom;
|
||||
}
|
||||
|
||||
Self::anchor_for(line_meta, new_top).unwrap_or(Self::ToBottom)
|
||||
}
|
||||
|
||||
/// Anchor to the first available line at or near the given start offset.
|
||||
///
|
||||
/// This is the inverse of "resolving a scroll state to a top-row offset":
|
||||
/// given a concrete flattened line index, pick a stable `(cell_index, line_in_cell)` anchor.
|
||||
///
|
||||
/// See `resolve_top` for `line_meta` semantics. This prefers the nearest line at or after `start`
|
||||
/// (skipping spacer rows), falling back to the nearest line before it when needed.
|
||||
pub(crate) fn anchor_for(line_meta: &[TranscriptLineMeta], start: usize) -> Option<Self> {
|
||||
let anchor =
|
||||
anchor_at_or_after(line_meta, start).or_else(|| anchor_at_or_before(line_meta, start));
|
||||
anchor.map(|(cell_index, line_in_cell)| Self::Scrolled {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Locate the flattened line index for a specific transcript cell and line.
|
||||
///
|
||||
/// This scans `meta` for the exact `(cell_index, line_in_cell)` anchor. It returns `None` when the
|
||||
/// anchor is not present in the current frame's flattened line list (for example if a cell was
|
||||
/// removed or its displayed line count changed).
|
||||
fn anchor_index(
|
||||
line_meta: &[TranscriptLineMeta],
|
||||
cell_index: usize,
|
||||
line_in_cell: usize,
|
||||
) -> Option<usize> {
|
||||
line_meta
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, entry)| match *entry {
|
||||
TranscriptLineMeta::CellLine {
|
||||
cell_index: ci,
|
||||
line_in_cell: li,
|
||||
} if ci == cell_index && li == line_in_cell => Some(idx),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Find the first transcript line at or after the given flattened index.
|
||||
fn anchor_at_or_after(line_meta: &[TranscriptLineMeta], start: usize) -> Option<(usize, usize)> {
|
||||
if line_meta.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let start = start.min(line_meta.len().saturating_sub(1));
|
||||
line_meta
|
||||
.iter()
|
||||
.skip(start)
|
||||
.find_map(TranscriptLineMeta::cell_line)
|
||||
}
|
||||
|
||||
/// Find the nearest transcript line at or before the given flattened index.
|
||||
fn anchor_at_or_before(line_meta: &[TranscriptLineMeta], start: usize) -> Option<(usize, usize)> {
|
||||
if line_meta.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let start = start.min(line_meta.len().saturating_sub(1));
|
||||
line_meta[..=start]
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(TranscriptLineMeta::cell_line)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn meta(entries: &[TranscriptLineMeta]) -> Vec<TranscriptLineMeta> {
|
||||
entries.to_vec()
|
||||
}
|
||||
|
||||
fn cell_line(cell_index: usize, line_in_cell: usize) -> TranscriptLineMeta {
|
||||
TranscriptLineMeta::CellLine {
|
||||
cell_index,
|
||||
line_in_cell,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_top_to_bottom_clamps_to_max_start() {
|
||||
let meta = meta(&[
|
||||
cell_line(0, 0),
|
||||
cell_line(0, 1),
|
||||
TranscriptLineMeta::Spacer,
|
||||
cell_line(1, 0),
|
||||
]);
|
||||
|
||||
let (state, top) = TranscriptScroll::ToBottom.resolve_top(&meta, 3);
|
||||
|
||||
assert_eq!(state, TranscriptScroll::ToBottom);
|
||||
assert_eq!(top, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_top_scrolled_keeps_anchor_when_present() {
|
||||
let meta = meta(&[
|
||||
cell_line(0, 0),
|
||||
TranscriptLineMeta::Spacer,
|
||||
cell_line(1, 0),
|
||||
cell_line(1, 1),
|
||||
]);
|
||||
let scroll = TranscriptScroll::Scrolled {
|
||||
cell_index: 1,
|
||||
line_in_cell: 0,
|
||||
};
|
||||
|
||||
let (state, top) = scroll.resolve_top(&meta, 2);
|
||||
|
||||
assert_eq!(state, scroll);
|
||||
assert_eq!(top, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_top_scrolled_falls_back_when_anchor_missing() {
|
||||
let meta = meta(&[cell_line(0, 0), TranscriptLineMeta::Spacer, cell_line(1, 0)]);
|
||||
let scroll = TranscriptScroll::Scrolled {
|
||||
cell_index: 2,
|
||||
line_in_cell: 0,
|
||||
};
|
||||
|
||||
let (state, top) = scroll.resolve_top(&meta, 1);
|
||||
|
||||
assert_eq!(state, TranscriptScroll::ToBottom);
|
||||
assert_eq!(top, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrolled_by_moves_upward_and_anchors() {
|
||||
let meta = meta(&[
|
||||
cell_line(0, 0),
|
||||
cell_line(0, 1),
|
||||
cell_line(1, 0),
|
||||
TranscriptLineMeta::Spacer,
|
||||
cell_line(2, 0),
|
||||
cell_line(2, 1),
|
||||
]);
|
||||
|
||||
let state = TranscriptScroll::ToBottom.scrolled_by(-1, &meta, 3);
|
||||
|
||||
assert_eq!(
|
||||
state,
|
||||
TranscriptScroll::Scrolled {
|
||||
cell_index: 1,
|
||||
line_in_cell: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrolled_by_returns_to_bottom_when_scrolling_down() {
|
||||
let meta = meta(&[
|
||||
cell_line(0, 0),
|
||||
cell_line(0, 1),
|
||||
cell_line(1, 0),
|
||||
cell_line(2, 0),
|
||||
]);
|
||||
let scroll = TranscriptScroll::Scrolled {
|
||||
cell_index: 0,
|
||||
line_in_cell: 0,
|
||||
};
|
||||
|
||||
let state = scroll.scrolled_by(5, &meta, 2);
|
||||
|
||||
assert_eq!(state, TranscriptScroll::ToBottom);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrolled_by_to_bottom_when_all_lines_fit() {
|
||||
let meta = meta(&[cell_line(0, 0), cell_line(0, 1)]);
|
||||
|
||||
let state = TranscriptScroll::Scrolled {
|
||||
cell_index: 0,
|
||||
line_in_cell: 0,
|
||||
}
|
||||
.scrolled_by(-1, &meta, 5);
|
||||
|
||||
assert_eq!(state, TranscriptScroll::ToBottom);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn anchor_for_prefers_after_then_before() {
|
||||
let meta = meta(&[
|
||||
TranscriptLineMeta::Spacer,
|
||||
cell_line(0, 0),
|
||||
TranscriptLineMeta::Spacer,
|
||||
cell_line(1, 0),
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
TranscriptScroll::anchor_for(&meta, 0),
|
||||
Some(TranscriptScroll::Scrolled {
|
||||
cell_index: 0,
|
||||
line_in_cell: 0
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
TranscriptScroll::anchor_for(&meta, 2),
|
||||
Some(TranscriptScroll::Scrolled {
|
||||
cell_index: 1,
|
||||
line_in_cell: 0
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
TranscriptScroll::anchor_for(&meta, 3),
|
||||
Some(TranscriptScroll::Scrolled {
|
||||
cell_index: 1,
|
||||
line_in_cell: 0
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,7 @@ pub(crate) async fn run_update_prompt_if_needed(
|
||||
frame.render_widget_ref(&screen, frame.area());
|
||||
})?;
|
||||
}
|
||||
TuiEvent::Mouse(_) => {}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
|
||||
@@ -251,7 +251,7 @@ pub fn apply_capability_denies_for_world_writable(
|
||||
}
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
let cap_path = cap_sid_file(codex_home);
|
||||
let caps = load_or_create_cap_sids(codex_home);
|
||||
let caps = load_or_create_cap_sids(codex_home)?;
|
||||
std::fs::write(&cap_path, serde_json::to_string(&caps)?)?;
|
||||
let (active_sid, workspace_roots): (*mut c_void, Vec<PathBuf>) = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite { writable_roots, .. } => {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::RngCore;
|
||||
use rand::SeedableRng;
|
||||
@@ -26,25 +28,39 @@ fn make_random_cap_sid_string() -> String {
|
||||
format!("S-1-5-21-{}-{}-{}-{}", a, b, c, d)
|
||||
}
|
||||
|
||||
pub fn load_or_create_cap_sids(codex_home: &Path) -> CapSids {
|
||||
fn persist_caps(path: &Path, caps: &CapSids) -> Result<()> {
|
||||
if let Some(dir) = path.parent() {
|
||||
fs::create_dir_all(dir)
|
||||
.with_context(|| format!("create cap sid dir {}", dir.display()))?;
|
||||
}
|
||||
let json = serde_json::to_string(caps)?;
|
||||
fs::write(path, json).with_context(|| format!("write cap sid file {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_or_create_cap_sids(codex_home: &Path) -> Result<CapSids> {
|
||||
let path = cap_sid_file(codex_home);
|
||||
if path.exists() {
|
||||
if let Ok(txt) = fs::read_to_string(&path) {
|
||||
let t = txt.trim();
|
||||
if t.starts_with('{') && t.ends_with('}') {
|
||||
if let Ok(obj) = serde_json::from_str::<CapSids>(t) {
|
||||
return obj;
|
||||
}
|
||||
} else if !t.is_empty() {
|
||||
return CapSids {
|
||||
workspace: t.to_string(),
|
||||
readonly: make_random_cap_sid_string(),
|
||||
};
|
||||
let txt = fs::read_to_string(&path)
|
||||
.with_context(|| format!("read cap sid file {}", path.display()))?;
|
||||
let t = txt.trim();
|
||||
if t.starts_with('{') && t.ends_with('}') {
|
||||
if let Ok(obj) = serde_json::from_str::<CapSids>(t) {
|
||||
return Ok(obj);
|
||||
}
|
||||
} else if !t.is_empty() {
|
||||
let caps = CapSids {
|
||||
workspace: t.to_string(),
|
||||
readonly: make_random_cap_sid_string(),
|
||||
};
|
||||
persist_caps(&path, &caps)?;
|
||||
return Ok(caps);
|
||||
}
|
||||
}
|
||||
CapSids {
|
||||
let caps = CapSids {
|
||||
workspace: make_random_cap_sid_string(),
|
||||
readonly: make_random_cap_sid_string(),
|
||||
}
|
||||
};
|
||||
persist_caps(&path, &caps)?;
|
||||
Ok(caps)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ mod windows_impl {
|
||||
use crate::acl::allow_null_device;
|
||||
use crate::allow::compute_allow_paths;
|
||||
use crate::allow::AllowDenyPaths;
|
||||
use crate::cap::cap_sid_file;
|
||||
use crate::cap::load_or_create_cap_sids;
|
||||
use crate::env::ensure_non_interactive_pager;
|
||||
use crate::env::inherit_path_env;
|
||||
@@ -53,13 +52,6 @@ mod windows_impl {
|
||||
use windows_sys::Win32::System::Threading::STARTUPINFOW;
|
||||
|
||||
/// Ensures the parent directory of a path exists before writing to it.
|
||||
fn ensure_dir(p: &Path) -> Result<()> {
|
||||
if let Some(d) = p.parent() {
|
||||
std::fs::create_dir_all(d)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Walks upward from `start` to locate the git worktree root, following gitfile redirects.
|
||||
fn find_git_root(start: &Path) -> Option<PathBuf> {
|
||||
let mut cur = dunce::canonicalize(start).ok()?;
|
||||
@@ -246,44 +238,26 @@ mod windows_impl {
|
||||
let sandbox_creds =
|
||||
require_logon_sandbox_creds(&policy, sandbox_policy_cwd, cwd, &env_map, codex_home)?;
|
||||
log_note("cli creds ready", logs_base_dir);
|
||||
let cap_sid_path = cap_sid_file(codex_home);
|
||||
|
||||
// Build capability SID for ACL grants.
|
||||
if matches!(&policy, SandboxPolicy::DangerFullAccess) {
|
||||
anyhow::bail!("DangerFullAccess is not supported for sandboxing")
|
||||
}
|
||||
let caps = load_or_create_cap_sids(codex_home)?;
|
||||
let (psid_to_use, cap_sid_str) = match &policy {
|
||||
SandboxPolicy::ReadOnly => {
|
||||
let caps = load_or_create_cap_sids(codex_home);
|
||||
ensure_dir(&cap_sid_path)?;
|
||||
fs::write(&cap_sid_path, serde_json::to_string(&caps)?)?;
|
||||
(
|
||||
unsafe { convert_string_sid_to_sid(&caps.readonly).unwrap() },
|
||||
caps.readonly.clone(),
|
||||
)
|
||||
}
|
||||
SandboxPolicy::WorkspaceWrite { .. } => {
|
||||
let caps = load_or_create_cap_sids(codex_home);
|
||||
ensure_dir(&cap_sid_path)?;
|
||||
fs::write(&cap_sid_path, serde_json::to_string(&caps)?)?;
|
||||
(
|
||||
unsafe { convert_string_sid_to_sid(&caps.workspace).unwrap() },
|
||||
caps.workspace.clone(),
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
anyhow::bail!("DangerFullAccess is not supported for sandboxing")
|
||||
}
|
||||
SandboxPolicy::ReadOnly => (
|
||||
unsafe { convert_string_sid_to_sid(&caps.readonly).unwrap() },
|
||||
caps.readonly.clone(),
|
||||
),
|
||||
SandboxPolicy::WorkspaceWrite { .. } => (
|
||||
unsafe { convert_string_sid_to_sid(&caps.workspace).unwrap() },
|
||||
caps.workspace.clone(),
|
||||
),
|
||||
SandboxPolicy::DangerFullAccess => unreachable!("DangerFullAccess handled above"),
|
||||
};
|
||||
|
||||
let AllowDenyPaths { allow, deny } =
|
||||
let AllowDenyPaths { allow: _, deny: _ } =
|
||||
compute_allow_paths(&policy, sandbox_policy_cwd, ¤t_dir, &env_map);
|
||||
// Deny/allow ACEs are now applied during setup; avoid per-command churn.
|
||||
log_note(
|
||||
&format!(
|
||||
"cli skipping per-command ACL grants (allow_count={} deny_count={})",
|
||||
allow.len(),
|
||||
deny.len()
|
||||
),
|
||||
logs_base_dir,
|
||||
);
|
||||
unsafe {
|
||||
allow_null_device(psid_to_use);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,6 @@ mod windows_impl {
|
||||
use super::acl::revoke_ace;
|
||||
use super::allow::compute_allow_paths;
|
||||
use super::allow::AllowDenyPaths;
|
||||
use super::cap::cap_sid_file;
|
||||
use super::cap::load_or_create_cap_sids;
|
||||
use super::env::apply_no_network_to_env;
|
||||
use super::env::ensure_non_interactive_pager;
|
||||
@@ -104,7 +103,6 @@ mod windows_impl {
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
@@ -130,13 +128,6 @@ mod windows_impl {
|
||||
!policy.has_full_network_access()
|
||||
}
|
||||
|
||||
fn ensure_dir(p: &Path) -> Result<()> {
|
||||
if let Some(d) = p.parent() {
|
||||
std::fs::create_dir_all(d)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_codex_home_exists(p: &Path) -> Result<()> {
|
||||
std::fs::create_dir_all(p)?;
|
||||
Ok(())
|
||||
@@ -194,32 +185,28 @@ mod windows_impl {
|
||||
apply_no_network_to_env(&mut env_map)?;
|
||||
}
|
||||
ensure_codex_home_exists(codex_home)?;
|
||||
|
||||
let current_dir = cwd.to_path_buf();
|
||||
let logs_base_dir = Some(codex_home);
|
||||
let sandbox_base = codex_home.join(".sandbox");
|
||||
std::fs::create_dir_all(&sandbox_base)?;
|
||||
let logs_base_dir = Some(sandbox_base.as_path());
|
||||
log_start(&command, logs_base_dir);
|
||||
let cap_sid_path = cap_sid_file(codex_home);
|
||||
let is_workspace_write = matches!(&policy, SandboxPolicy::WorkspaceWrite { .. });
|
||||
|
||||
if matches!(&policy, SandboxPolicy::DangerFullAccess) {
|
||||
anyhow::bail!("DangerFullAccess is not supported for sandboxing")
|
||||
}
|
||||
let caps = load_or_create_cap_sids(codex_home)?;
|
||||
let (h_token, psid_to_use): (HANDLE, *mut c_void) = unsafe {
|
||||
match &policy {
|
||||
SandboxPolicy::ReadOnly => {
|
||||
let caps = load_or_create_cap_sids(codex_home);
|
||||
ensure_dir(&cap_sid_path)?;
|
||||
fs::write(&cap_sid_path, serde_json::to_string(&caps)?)?;
|
||||
let psid = convert_string_sid_to_sid(&caps.readonly).unwrap();
|
||||
super::token::create_readonly_token_with_cap(psid)?
|
||||
}
|
||||
SandboxPolicy::WorkspaceWrite { .. } => {
|
||||
let caps = load_or_create_cap_sids(codex_home);
|
||||
ensure_dir(&cap_sid_path)?;
|
||||
fs::write(&cap_sid_path, serde_json::to_string(&caps)?)?;
|
||||
let psid = convert_string_sid_to_sid(&caps.workspace).unwrap();
|
||||
super::token::create_workspace_write_token_with_cap(psid)?
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
anyhow::bail!("DangerFullAccess is not supported for sandboxing")
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => unreachable!("DangerFullAccess handled above"),
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ use rand::RngCore;
|
||||
use rand::SeedableRng;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::c_void;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::File;
|
||||
@@ -392,8 +393,8 @@ fn run_netsh_firewall(sid: &str, log: &mut File) -> Result<()> {
|
||||
log_line(
|
||||
log,
|
||||
&format!(
|
||||
"firewall rule configured via COM with LocalUserAuthorizedList={local_user_spec}"
|
||||
),
|
||||
"firewall rule configured via COM with LocalUserAuthorizedList={local_user_spec}"
|
||||
),
|
||||
)?;
|
||||
Ok(())
|
||||
})()
|
||||
@@ -647,7 +648,7 @@ fn run_setup(payload: &Payload, log: &mut File, sbx_dir: &Path) -> Result<()> {
|
||||
string_from_sid_bytes(&online_sid).map_err(anyhow::Error::msg)?
|
||||
),
|
||||
)?;
|
||||
let caps = load_or_create_cap_sids(&payload.codex_home);
|
||||
let caps = load_or_create_cap_sids(&payload.codex_home)?;
|
||||
let cap_psid = unsafe {
|
||||
convert_string_sid_to_sid(&caps.workspace)
|
||||
.ok_or_else(|| anyhow::anyhow!("convert capability SID failed"))?
|
||||
@@ -758,7 +759,19 @@ fn run_setup(payload: &Payload, log: &mut File, sbx_dir: &Path) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let cap_sid_str = caps.workspace.clone();
|
||||
let online_sid_str = string_from_sid_bytes(&online_sid).map_err(anyhow::Error::msg)?;
|
||||
let sid_strings = vec![offline_sid_str.clone(), online_sid_str, cap_sid_str];
|
||||
let write_mask =
|
||||
FILE_GENERIC_READ | FILE_GENERIC_WRITE | FILE_GENERIC_EXECUTE | DELETE | FILE_DELETE_CHILD;
|
||||
let mut grant_tasks: Vec<PathBuf> = Vec::new();
|
||||
|
||||
let mut seen_write_roots: HashSet<PathBuf> = HashSet::new();
|
||||
|
||||
for root in &payload.write_roots {
|
||||
if !seen_write_roots.insert(root.clone()) {
|
||||
continue;
|
||||
}
|
||||
if !root.exists() {
|
||||
log_line(
|
||||
log,
|
||||
@@ -766,12 +779,6 @@ fn run_setup(payload: &Payload, log: &mut File, sbx_dir: &Path) -> Result<()> {
|
||||
)?;
|
||||
continue;
|
||||
}
|
||||
let sids = vec![offline_psid, online_psid, cap_psid];
|
||||
let write_mask = FILE_GENERIC_READ
|
||||
| FILE_GENERIC_WRITE
|
||||
| FILE_GENERIC_EXECUTE
|
||||
| DELETE
|
||||
| FILE_DELETE_CHILD;
|
||||
let mut need_grant = false;
|
||||
for (label, psid) in [
|
||||
("offline", offline_psid),
|
||||
@@ -817,25 +824,7 @@ fn run_setup(payload: &Payload, log: &mut File, sbx_dir: &Path) -> Result<()> {
|
||||
root.display()
|
||||
),
|
||||
)?;
|
||||
match unsafe { ensure_allow_write_aces(root, &sids) } {
|
||||
Ok(res) => {
|
||||
log_line(
|
||||
log,
|
||||
&format!(
|
||||
"write ACE {} on {}",
|
||||
if res { "added" } else { "already present" },
|
||||
root.display()
|
||||
),
|
||||
)?;
|
||||
}
|
||||
Err(e) => {
|
||||
refresh_errors.push(format!("write ACE failed on {}: {}", root.display(), e));
|
||||
log_line(
|
||||
log,
|
||||
&format!("write ACE grant failed on {}: {}", root.display(), e),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
grant_tasks.push(root.clone());
|
||||
} else {
|
||||
log_line(
|
||||
log,
|
||||
@@ -847,6 +836,65 @@ fn run_setup(payload: &Payload, log: &mut File, sbx_dir: &Path) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<(PathBuf, Result<bool>)>();
|
||||
std::thread::scope(|scope| {
|
||||
for root in grant_tasks {
|
||||
let sid_strings = sid_strings.clone();
|
||||
let tx = tx.clone();
|
||||
scope.spawn(move || {
|
||||
// Convert SID strings to psids locally in this thread.
|
||||
let mut psids: Vec<*mut c_void> = Vec::new();
|
||||
for sid_str in &sid_strings {
|
||||
if let Some(psid) = unsafe { convert_string_sid_to_sid(sid_str) } {
|
||||
psids.push(psid);
|
||||
} else {
|
||||
let _ = tx.send((root.clone(), Err(anyhow::anyhow!("convert SID failed"))));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let res = unsafe { ensure_allow_write_aces(&root, &psids) };
|
||||
|
||||
for psid in psids {
|
||||
unsafe {
|
||||
LocalFree(psid as HLOCAL);
|
||||
}
|
||||
}
|
||||
let _ = tx.send((root, res));
|
||||
});
|
||||
}
|
||||
drop(tx);
|
||||
for (root, res) in rx {
|
||||
match res {
|
||||
Ok(added) => {
|
||||
if log_line(
|
||||
log,
|
||||
&format!(
|
||||
"write ACE {} on {}",
|
||||
if added { "added" } else { "already present" },
|
||||
root.display()
|
||||
),
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
// ignore log errors inside scoped thread
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
refresh_errors.push(format!("write ACE failed on {}: {}", root.display(), e));
|
||||
if log_line(
|
||||
log,
|
||||
&format!("write ACE grant failed on {}: {}", root.display(), e),
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
// ignore log errors inside scoped thread
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if refresh_only {
|
||||
log_line(
|
||||
log,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::c_void;
|
||||
use std::os::windows::process::CommandExt;
|
||||
use std::path::Path;
|
||||
@@ -54,13 +55,22 @@ pub fn run_setup_refresh(
|
||||
if matches!(policy, SandboxPolicy::DangerFullAccess) {
|
||||
return Ok(());
|
||||
}
|
||||
let (read_roots, write_roots) = build_payload_roots(
|
||||
policy,
|
||||
policy_cwd,
|
||||
command_cwd,
|
||||
env_map,
|
||||
codex_home,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let payload = ElevationPayload {
|
||||
version: SETUP_VERSION,
|
||||
offline_username: OFFLINE_USERNAME.to_string(),
|
||||
online_username: ONLINE_USERNAME.to_string(),
|
||||
codex_home: codex_home.to_path_buf(),
|
||||
read_roots: gather_read_roots(command_cwd, policy),
|
||||
write_roots: gather_write_roots(policy, policy_cwd, command_cwd, env_map),
|
||||
read_roots,
|
||||
write_roots,
|
||||
real_user: std::env::var("USERNAME").unwrap_or_else(|_| "Administrators".to_string()),
|
||||
refresh_only: true,
|
||||
};
|
||||
@@ -219,7 +229,14 @@ pub(crate) fn gather_write_roots(
|
||||
let AllowDenyPaths { allow, .. } =
|
||||
compute_allow_paths(policy, policy_cwd, command_cwd, env_map);
|
||||
roots.extend(allow);
|
||||
canonical_existing(&roots)
|
||||
let mut dedup: HashSet<PathBuf> = HashSet::new();
|
||||
let mut out: Vec<PathBuf> = Vec::new();
|
||||
for r in canonical_existing(&roots) {
|
||||
if dedup.insert(r.clone()) {
|
||||
out.push(r);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -355,19 +372,15 @@ pub fn run_elevated_setup(
|
||||
// Ensure the shared sandbox directory exists before we send it to the elevated helper.
|
||||
let sbx_dir = sandbox_dir(codex_home);
|
||||
std::fs::create_dir_all(&sbx_dir)?;
|
||||
let mut write_roots = if let Some(roots) = write_roots_override {
|
||||
roots
|
||||
} else {
|
||||
gather_write_roots(policy, policy_cwd, command_cwd, env_map)
|
||||
};
|
||||
if !write_roots.contains(&sbx_dir) {
|
||||
write_roots.push(sbx_dir.clone());
|
||||
}
|
||||
let read_roots = if let Some(roots) = read_roots_override {
|
||||
roots
|
||||
} else {
|
||||
gather_read_roots(command_cwd, policy)
|
||||
};
|
||||
let (read_roots, write_roots) = build_payload_roots(
|
||||
policy,
|
||||
policy_cwd,
|
||||
command_cwd,
|
||||
env_map,
|
||||
codex_home,
|
||||
read_roots_override,
|
||||
write_roots_override,
|
||||
);
|
||||
let payload = ElevationPayload {
|
||||
version: SETUP_VERSION,
|
||||
offline_username: OFFLINE_USERNAME.to_string(),
|
||||
@@ -381,3 +394,31 @@ pub fn run_elevated_setup(
|
||||
let needs_elevation = !is_elevated()?;
|
||||
run_setup_exe(&payload, needs_elevation)
|
||||
}
|
||||
|
||||
fn build_payload_roots(
|
||||
policy: &SandboxPolicy,
|
||||
policy_cwd: &Path,
|
||||
command_cwd: &Path,
|
||||
env_map: &HashMap<String, String>,
|
||||
codex_home: &Path,
|
||||
read_roots_override: Option<Vec<PathBuf>>,
|
||||
write_roots_override: Option<Vec<PathBuf>>,
|
||||
) -> (Vec<PathBuf>, Vec<PathBuf>) {
|
||||
let sbx_dir = sandbox_dir(codex_home);
|
||||
let mut write_roots = if let Some(roots) = write_roots_override {
|
||||
canonical_existing(&roots)
|
||||
} else {
|
||||
gather_write_roots(policy, policy_cwd, command_cwd, env_map)
|
||||
};
|
||||
if !write_roots.contains(&sbx_dir) {
|
||||
write_roots.push(sbx_dir.clone());
|
||||
}
|
||||
let mut read_roots = if let Some(roots) = read_roots_override {
|
||||
canonical_existing(&roots)
|
||||
} else {
|
||||
gather_read_roots(command_cwd, policy)
|
||||
};
|
||||
let write_root_set: HashSet<PathBuf> = write_roots.iter().cloned().collect();
|
||||
read_roots.retain(|root| !write_root_set.contains(root));
|
||||
(read_roots, write_roots)
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ disk, but attempts to write a file or access the network will be blocked.
|
||||
|
||||
A more relaxed policy is `workspace-write`. When specified, the current working directory for the Codex task will be writable (as well as `$TMPDIR` on macOS). Note that the CLI defaults to using the directory where it was spawned as `cwd`, though this can be overridden using `--cwd/-C`.
|
||||
|
||||
On macOS (and soon Linux), all writable roots (including `cwd`) that contain a `.git/` folder _as an immediate child_ will configure the `.git/` folder to be read-only while the rest of the Git repository will be writable. This means that commands like `git commit` will fail, by default (as it entails writing to `.git/`), and will require Codex to ask for permission.
|
||||
On macOS (and soon Linux), all writable roots (including `cwd`) that contain a `.git/` or `.codex/` folder _as an immediate child_ will configure those folders to be read-only while the rest of the root stays writable. This means that commands like `git commit` will fail, by default (as it entails writing to `.git/`), and will require Codex to ask for permission.
|
||||
|
||||
```toml
|
||||
# same as `--sandbox workspace-write`
|
||||
|
||||
@@ -2,7 +2,7 @@ export type ApprovalMode = "never" | "on-request" | "on-failure" | "untrusted";
|
||||
|
||||
export type SandboxMode = "read-only" | "workspace-write" | "danger-full-access";
|
||||
|
||||
export type ModelReasoningEffort = "minimal" | "low" | "medium" | "high";
|
||||
export type ModelReasoningEffort = "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
export type ThreadOptions = {
|
||||
model?: string;
|
||||
|
||||
Reference in New Issue
Block a user