mirror of
https://github.com/openai/codex.git
synced 2026-02-08 01:43:46 +00:00
Compare commits
71 Commits
subagents-
...
codex-stat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6016f5489 | ||
|
|
2721498ec9 | ||
|
|
47ef2cd9ca | ||
|
|
db70faab42 | ||
|
|
1e7796570a | ||
|
|
11b13914a8 | ||
|
|
567844fe05 | ||
|
|
6eb0474455 | ||
|
|
af1254fc4e | ||
|
|
764aff6753 | ||
|
|
e2a55921ec | ||
|
|
08b6d9ef1f | ||
|
|
bbf536847c | ||
|
|
0d0779d08a | ||
|
|
087e571198 | ||
|
|
28ff364c3a | ||
|
|
70b613be81 | ||
|
|
53ff941cf3 | ||
|
|
6b66534356 | ||
|
|
1938116a5d | ||
|
|
075d50677d | ||
|
|
5c40534e98 | ||
|
|
981e2f742d | ||
|
|
f05492fc94 | ||
|
|
44ff9fcb69 | ||
|
|
b4a1a500ec | ||
|
|
401f94ca31 | ||
|
|
865e225bde | ||
|
|
4502b1b263 | ||
|
|
2845e2c006 | ||
|
|
d8a7a63959 | ||
|
|
caf2749d5b | ||
|
|
9ba27cfa0a | ||
|
|
157a16cefa | ||
|
|
37d83e075e | ||
|
|
523b40a129 | ||
|
|
0dd822264a | ||
|
|
3308dc5e48 | ||
|
|
fc2ff624ac | ||
|
|
b897880378 | ||
|
|
99e5340c54 | ||
|
|
ec49b56874 | ||
|
|
4ed4c73d6b | ||
|
|
523dabc1ee | ||
|
|
3741f387e9 | ||
|
|
2c885edefa | ||
|
|
f6f4c8f5ee | ||
|
|
1875700220 | ||
|
|
207d94b0e7 | ||
|
|
481c84e118 | ||
|
|
6dd3606773 | ||
|
|
35850caff6 | ||
|
|
e8b6fb7937 | ||
|
|
be022be381 | ||
|
|
1e832b1438 | ||
|
|
c31663d745 | ||
|
|
35d89e820f | ||
|
|
486b1c4d9d | ||
|
|
b2cddec3d7 | ||
|
|
920239f272 | ||
|
|
99bcb90353 | ||
|
|
b519267d05 | ||
|
|
529eb4ff2a | ||
|
|
c6f68c9df8 | ||
|
|
af63e6eccc | ||
|
|
87b211709e | ||
|
|
67975ed33a | ||
|
|
7561a6aaf0 | ||
|
|
3ea33a0616 | ||
|
|
e8ef6d3c16 | ||
|
|
e52cc38dfd |
26
.github/workflows/cargo-deny.yml
vendored
Normal file
26
.github/workflows/cargo-deny.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: cargo-deny
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
cargo-deny:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./codex-rs
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Run cargo-deny
|
||||
uses: EmbarkStudios/cargo-deny-action@v1
|
||||
with:
|
||||
rust-version: stable
|
||||
manifest-path: ./codex-rs/Cargo.toml
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
||||
1
.github/workflows/cla.yml
vendored
1
.github/workflows/cla.yml
vendored
@@ -48,4 +48,5 @@ jobs:
|
||||
branch: cla-signatures
|
||||
allowlist: |
|
||||
codex
|
||||
dependabot
|
||||
dependabot[bot]
|
||||
|
||||
2
.github/workflows/codespell.yml
vendored
2
.github/workflows/codespell.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
- name: Annotate locations with typos
|
||||
uses: codespell-project/codespell-problem-matcher@b80729f885d32f78a716c2f107b4db1025001c42 # v1
|
||||
- name: Codespell
|
||||
|
||||
2
.github/workflows/issue-deduplicator.yml
vendored
2
.github/workflows/issue-deduplicator.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Prepare Codex inputs
|
||||
env:
|
||||
|
||||
2
.github/workflows/issue-labeler.yml
vendored
2
.github/workflows/issue-labeler.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
|
||||
48
.github/workflows/rust-ci.yml
vendored
48
.github/workflows/rust-ci.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
codex: ${{ steps.detect.outputs.codex }}
|
||||
workflows: ${{ steps.detect.outputs.workflows }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Detect changed paths (no external action)
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
components: rustfmt
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
- uses: taiki-e/install-action@44c6d64aa62cd779e873306675c7a58e86d6d532 # v2
|
||||
with:
|
||||
@@ -147,12 +147,21 @@ jobs:
|
||||
profile: release
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
components: clippy
|
||||
|
||||
- name: Compute lockfile hash
|
||||
id: lockhash
|
||||
working-directory: codex-rs
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "hash=$(sha256sum Cargo.lock | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
echo "toolchain_hash=$(sha256sum rust-toolchain.toml | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Explicit cache restore: split cargo home vs target, so we can
|
||||
# avoid caching the large target dir on the gnu-dev job.
|
||||
- name: Restore cargo home cache
|
||||
@@ -164,7 +173,7 @@ jobs:
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ hashFiles('codex-rs/rust-toolchain.toml') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }}
|
||||
restore-keys: |
|
||||
cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-
|
||||
|
||||
@@ -201,9 +210,9 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/.sccache/
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ github.run_id }}
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ github.run_id }}
|
||||
restore-keys: |
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
@@ -278,7 +287,7 @@ jobs:
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ hashFiles('codex-rs/rust-toolchain.toml') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }}
|
||||
|
||||
- name: Save sccache cache (fallback)
|
||||
if: always() && !cancelled() && env.USE_SCCACHE == 'true' && env.SCCACHE_GHA_ENABLED != 'true'
|
||||
@@ -286,7 +295,7 @@ jobs:
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/.sccache/
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ github.run_id }}
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ github.run_id }}
|
||||
|
||||
- name: sccache stats
|
||||
if: always() && env.USE_SCCACHE == 'true'
|
||||
@@ -359,11 +368,20 @@ jobs:
|
||||
profile: dev
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- name: Compute lockfile hash
|
||||
id: lockhash
|
||||
working-directory: codex-rs
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "hash=$(sha256sum Cargo.lock | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
echo "toolchain_hash=$(sha256sum rust-toolchain.toml | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Restore cargo home cache
|
||||
id: cache_cargo_home_restore
|
||||
uses: actions/cache/restore@v4
|
||||
@@ -373,7 +391,7 @@ jobs:
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ hashFiles('codex-rs/rust-toolchain.toml') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }}
|
||||
restore-keys: |
|
||||
cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-
|
||||
|
||||
@@ -409,9 +427,9 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/.sccache/
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ github.run_id }}
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ github.run_id }}
|
||||
restore-keys: |
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-
|
||||
sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-
|
||||
|
||||
- uses: taiki-e/install-action@44c6d64aa62cd779e873306675c7a58e86d6d532 # v2
|
||||
@@ -436,7 +454,7 @@ jobs:
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ hashFiles('codex-rs/rust-toolchain.toml') }}
|
||||
key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }}
|
||||
|
||||
- name: Save sccache cache (fallback)
|
||||
if: always() && !cancelled() && env.USE_SCCACHE == 'true' && env.SCCACHE_GHA_ENABLED != 'true'
|
||||
@@ -444,7 +462,7 @@ jobs:
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}/.sccache/
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ hashFiles('**/Cargo.lock') }}-${{ github.run_id }}
|
||||
key: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ github.run_id }}
|
||||
|
||||
- name: sccache stats
|
||||
if: always() && env.USE_SCCACHE == 'true'
|
||||
|
||||
9
.github/workflows/rust-release.yml
vendored
9
.github/workflows/rust-release.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
tag-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Validate tag matches Cargo.toml version
|
||||
shell: bash
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
target: aarch64-pc-windows-msvc
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
@@ -377,8 +377,7 @@ jobs:
|
||||
uses: ./.github/workflows/shell-tool-mcp.yml
|
||||
with:
|
||||
release-tag: ${{ github.ref_name }}
|
||||
# We are not ready to publish yet.
|
||||
publish: false
|
||||
publish: true
|
||||
secrets: inherit
|
||||
|
||||
release:
|
||||
@@ -398,7 +397,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
|
||||
2
.github/workflows/sdk.yml
vendored
2
.github/workflows/sdk.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
||||
2
.github/workflows/shell-tool-mcp-ci.yml
vendored
2
.github/workflows/shell-tool-mcp-ci.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
||||
23
.github/workflows/shell-tool-mcp.yml
vendored
23
.github/workflows/shell-tool-mcp.yml
vendored
@@ -91,7 +91,7 @@ jobs:
|
||||
install_musl: true
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
@@ -113,7 +113,7 @@ jobs:
|
||||
cp "target/${{ matrix.target }}/release/codex-exec-mcp-server" "$dest/"
|
||||
cp "target/${{ matrix.target }}/release/codex-execve-wrapper" "$dest/"
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: shell-tool-mcp-rust-${{ matrix.target }}
|
||||
path: artifacts/**
|
||||
@@ -138,10 +138,6 @@ jobs:
|
||||
target: x86_64-unknown-linux-musl
|
||||
variant: ubuntu-22.04
|
||||
image: ubuntu:22.04
|
||||
- runner: ubuntu-24.04
|
||||
target: x86_64-unknown-linux-musl
|
||||
variant: ubuntu-20.04
|
||||
image: ubuntu:20.04
|
||||
- runner: ubuntu-24.04
|
||||
target: x86_64-unknown-linux-musl
|
||||
variant: debian-12
|
||||
@@ -196,7 +192,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Build patched Bash
|
||||
shell: bash
|
||||
@@ -215,7 +211,7 @@ jobs:
|
||||
mkdir -p "$dest"
|
||||
cp bash "$dest/bash"
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: shell-tool-mcp-bash-${{ matrix.target }}-${{ matrix.variant }}
|
||||
path: artifacts/**
|
||||
@@ -236,12 +232,9 @@ jobs:
|
||||
- runner: macos-14
|
||||
target: aarch64-apple-darwin
|
||||
variant: macos-14
|
||||
- runner: macos-13
|
||||
target: x86_64-apple-darwin
|
||||
variant: macos-13
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Build patched Bash
|
||||
shell: bash
|
||||
@@ -260,7 +253,7 @@ jobs:
|
||||
mkdir -p "$dest"
|
||||
cp bash "$dest/bash"
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: shell-tool-mcp-bash-${{ matrix.target }}-${{ matrix.variant }}
|
||||
path: artifacts/**
|
||||
@@ -278,7 +271,7 @@ jobs:
|
||||
PACKAGE_VERSION: ${{ needs.metadata.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
@@ -359,7 +352,7 @@ jobs:
|
||||
filename=$(PACK_INFO="$pack_info" node -e 'const data = JSON.parse(process.env.PACK_INFO); console.log(data[0].filename);')
|
||||
mv "dist/npm/${filename}" "dist/npm/codex-shell-tool-mcp-npm-${PACKAGE_VERSION}.tgz"
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: codex-shell-tool-mcp-npm
|
||||
path: dist/npm/codex-shell-tool-mcp-npm-${{ env.PACKAGE_VERSION }}.tgz
|
||||
|
||||
34
README.md
34
README.md
@@ -69,38 +69,9 @@ Codex can access MCP servers. To configure them, refer to the [config docs](./do
|
||||
|
||||
Codex CLI supports a rich set of configuration options, with preferences stored in `~/.codex/config.toml`. For full configuration options, see [Configuration](./docs/config.md).
|
||||
|
||||
### Execpolicy Quickstart
|
||||
### Execpolicy
|
||||
|
||||
Codex can enforce your own rules-based execution policy before it runs shell commands.
|
||||
|
||||
1. Create a policy directory: `mkdir -p ~/.codex/policy`.
|
||||
2. Create one or more `.codexpolicy` files in that folder. Codex automatically loads every `.codexpolicy` file in there on startup.
|
||||
3. Write `prefix_rule` entries to describe the commands you want to allow, prompt, or block:
|
||||
|
||||
```starlark
|
||||
prefix_rule(
|
||||
pattern = ["git", ["push", "fetch"]],
|
||||
decision = "prompt", # allow | prompt | forbidden
|
||||
match = [["git", "push", "origin", "main"]], # examples that must match
|
||||
not_match = [["git", "status"]], # examples that must not match
|
||||
)
|
||||
```
|
||||
|
||||
- `pattern` is a list of shell tokens, evaluated from left to right; wrap tokens in a nested list to express alternatives (e.g., match both `push` and `fetch`).
|
||||
- `decision` sets the severity; Codex picks the strictest decision when multiple rules match (forbidden > prompt > allow).
|
||||
- `match` and `not_match` act as (optional) unit tests. Codex validates them when it loads your policy, so you get feedback if an example has unexpected behavior.
|
||||
|
||||
In this example rule, if Codex wants to run commands with the prefix `git push` or `git fetch`, it will first ask for user approval.
|
||||
|
||||
Use the `codex execpolicy check` subcommand to preview decisions before you save a rule (see the [`codex-execpolicy` README](./codex-rs/execpolicy/README.md) for syntax details):
|
||||
|
||||
```shell
|
||||
codex execpolicy check --policy ~/.codex/policy/default.codexpolicy git push origin main
|
||||
```
|
||||
|
||||
Pass multiple `--policy` flags to test how several files combine, and use `--pretty` for formatted JSON output. See the [`codex-rs/execpolicy` README](./codex-rs/execpolicy/README.md) for a more detailed walkthrough of the available syntax.
|
||||
|
||||
## Note: `execpolicy` commands are still in preview. The API may have breaking changes in the future.
|
||||
See the [Execpolicy quickstart](./docs/execpolicy.md) to set up rules that govern what commands Codex can execute.
|
||||
|
||||
### Docs & FAQ
|
||||
|
||||
@@ -114,6 +85,7 @@ Pass multiple `--policy` flags to test how several files combine, and use `--pre
|
||||
- [**Configuration**](./docs/config.md)
|
||||
- [Example config](./docs/example-config.md)
|
||||
- [**Sandbox & approvals**](./docs/sandbox.md)
|
||||
- [**Execpolicy quickstart**](./docs/execpolicy.md)
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
- [Auth methods](./docs/authentication.md#forcing-a-specific-auth-method-advanced)
|
||||
- [Login on a "Headless" machine](./docs/authentication.md#connecting-on-a-headless-machine)
|
||||
|
||||
6
codex-rs/.cargo/audit.toml
Normal file
6
codex-rs/.cargo/audit.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[advisories]
|
||||
ignore = [
|
||||
"RUSTSEC-2024-0388", # derivative 2.2.0 via starlark; upstream crate is unmaintained
|
||||
"RUSTSEC-2025-0057", # fxhash 0.2.1 via starlark_map; upstream crate is unmaintained
|
||||
"RUSTSEC-2024-0436", # paste 1.0.15 via starlark/ratatui; upstream crate is unmaintained
|
||||
]
|
||||
26
codex-rs/.github/workflows/cargo-audit.yml
vendored
Normal file
26
codex-rs/.github/workflows/cargo-audit.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Cargo audit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
audit:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- name: Install cargo-audit
|
||||
uses: taiki-e/install-action@v2
|
||||
with:
|
||||
tool: cargo-audit
|
||||
- name: Run cargo audit
|
||||
run: cargo audit --deny warnings
|
||||
109
codex-rs/Cargo.lock
generated
109
codex-rs/Cargo.lock
generated
@@ -212,7 +212,7 @@ dependencies = [
|
||||
"objc2-foundation",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
"wl-clipboard-rs",
|
||||
"x11rb",
|
||||
]
|
||||
@@ -843,6 +843,30 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-api"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_matches",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"codex-client",
|
||||
"codex-protocol",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"http",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-app-server"
|
||||
version = "0.0.0"
|
||||
@@ -870,6 +894,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"sha2",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
@@ -1028,6 +1053,23 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"http",
|
||||
"rand 0.9.2",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-cloud-tasks"
|
||||
version = "0.0.0"
|
||||
@@ -1095,13 +1137,15 @@ dependencies = [
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bytes",
|
||||
"chardetng",
|
||||
"chrono",
|
||||
"codex-api",
|
||||
"codex-app-server-protocol",
|
||||
"codex-apply-patch",
|
||||
"codex-arg0",
|
||||
"codex-async-utils",
|
||||
"codex-client",
|
||||
"codex-core",
|
||||
"codex-execpolicy",
|
||||
"codex-file-search",
|
||||
"codex-git",
|
||||
@@ -1637,7 +1681,7 @@ version = "0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "codex-windows-sandbox"
|
||||
version = "0.1.0"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"codex-protocol",
|
||||
@@ -1646,6 +1690,7 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
@@ -1777,6 +1822,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"base64",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
"notify",
|
||||
@@ -2491,7 +2537,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rustix 1.0.8",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3395,7 +3441,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3587,9 +3633,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.175"
|
||||
version = "0.2.177"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
|
||||
|
||||
[[package]]
|
||||
name = "libdbus-sys"
|
||||
@@ -4993,9 +5039,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.11.1"
|
||||
version = "1.12.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||
checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@@ -5005,9 +5051,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.9"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||
checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@@ -5097,10 +5143,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5947688160b56fb6c827e3c20a72c90392a1d7e9dec74749197aa1780ac42ca"
|
||||
version = "0.9.0"
|
||||
source = "git+https://github.com/bolinfest/rust-sdk?branch=pr556#4d9cc16f4c76c84486344f542ed9a3e9364019ba"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -5131,9 +5177,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01263441d3f8635c628e33856c468b96ebbce1af2d3699ea712ca71432d4ee7a"
|
||||
version = "0.9.0"
|
||||
source = "git+https://github.com/bolinfest/rust-sdk?branch=pr556#4d9cc16f4c76c84486344f542ed9a3e9364019ba"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
@@ -5173,7 +5218,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6553,18 +6598,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.7.0"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3"
|
||||
checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.23.4"
|
||||
version = "0.23.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7211ff1b8f0d3adae1663b7da9ffe396eabe1ca25f0b0bee42b0da29a9ddce93"
|
||||
checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d"
|
||||
dependencies = [
|
||||
"indexmap 2.12.0",
|
||||
"toml_datetime",
|
||||
@@ -6575,18 +6620,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.2"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b551886f449aa90d4fe2bdaa9f4a2577ad2dde302c61ecf262d80b116db95c10"
|
||||
checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.2"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64"
|
||||
checksum = "df8b2b54733674ad286d16267dcfc7a71ed5c776e4ac7aa3c3e2561f7c637bf2"
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
@@ -7256,9 +7301,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "webbrowser"
|
||||
version = "1.0.5"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aaf4f3c0ba838e82b4e5ccc4157003fb8c324ee24c058470ffb82820becbde98"
|
||||
checksum = "00f1243ef785213e3a32fa0396093424a3a6ea566f9948497e5a2309261a4c97"
|
||||
dependencies = [
|
||||
"core-foundation 0.10.1",
|
||||
"jni",
|
||||
@@ -7325,7 +7370,7 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7751,9 +7796,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.12"
|
||||
version = "0.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95"
|
||||
checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
@@ -41,6 +41,8 @@ members = [
|
||||
"utils/pty",
|
||||
"utils/readiness",
|
||||
"utils/string",
|
||||
"codex-client",
|
||||
"codex-api",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -51,6 +53,7 @@ version = "0.0.0"
|
||||
# crates created with `cargo new -w ...` automatically inherit the 2024
|
||||
# edition.
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
# Internal
|
||||
@@ -62,6 +65,8 @@ codex-apply-patch = { path = "apply-patch" }
|
||||
codex-arg0 = { path = "arg0" }
|
||||
codex-async-utils = { path = "async-utils" }
|
||||
codex-backend-client = { path = "backend-client" }
|
||||
codex-api = { path = "codex-api" }
|
||||
codex-client = { path = "codex-client" }
|
||||
codex-chatgpt = { path = "chatgpt" }
|
||||
codex-common = { path = "common" }
|
||||
codex-core = { path = "core" }
|
||||
@@ -108,8 +113,8 @@ async-trait = "0.1.89"
|
||||
axum = { version = "0.8", default-features = false }
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
chrono = "0.4.42"
|
||||
chardetng = "0.1.17"
|
||||
chrono = "0.4.42"
|
||||
clap = "4"
|
||||
clap_complete = "4"
|
||||
color-eyre = "0.6.3"
|
||||
@@ -120,9 +125,9 @@ diffy = "0.4.2"
|
||||
dirs = "6"
|
||||
dotenvy = "0.15.7"
|
||||
dunce = "1.0.4"
|
||||
encoding_rs = "0.8.35"
|
||||
env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
encoding_rs = "0.8.35"
|
||||
escargot = "0.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
@@ -138,7 +143,7 @@ itertools = "0.14.0"
|
||||
keyring = { version = "3.6", default-features = false }
|
||||
landlock = "0.4.1"
|
||||
lazy_static = "1"
|
||||
libc = "0.2.175"
|
||||
libc = "0.2.177"
|
||||
log = "0.4"
|
||||
lru = "0.12.5"
|
||||
maplit = "1.0.2"
|
||||
@@ -165,11 +170,12 @@ rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
ratatui-macros = "0.6.0"
|
||||
regex-lite = "0.1.7"
|
||||
regex = "1.11.1"
|
||||
regex = "1.12.2"
|
||||
reqwest = "0.12"
|
||||
rmcp = { version = "0.8.5", default-features = false }
|
||||
rmcp = { version = "0.9.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
sentry = "0.34.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
@@ -195,7 +201,7 @@ tokio-stream = "0.1.17"
|
||||
tokio-test = "0.4"
|
||||
tokio-util = "0.7.16"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.23.4"
|
||||
toml_edit = "0.23.5"
|
||||
tonic = "0.13.1"
|
||||
tracing = "0.1.41"
|
||||
tracing-appender = "0.2.3"
|
||||
@@ -261,11 +267,7 @@ unwrap_used = "deny"
|
||||
# cargo-shear cannot see the platform-specific openssl-sys usage, so we
|
||||
# silence the false positive here instead of deleting a real dependency.
|
||||
[workspace.metadata.cargo-shear]
|
||||
ignored = [
|
||||
"icu_provider",
|
||||
"openssl-sys",
|
||||
"codex-utils-readiness",
|
||||
]
|
||||
ignored = ["icu_provider", "openssl-sys", "codex-utils-readiness"]
|
||||
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
@@ -286,6 +288,7 @@ opt-level = 0
|
||||
# ratatui = { path = "../../ratatui" }
|
||||
crossterm = { git = "https://github.com/nornagon/crossterm", branch = "nornagon/color-query" }
|
||||
ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" }
|
||||
rmcp = { git = "https://github.com/bolinfest/rust-sdk", branch = "pr556" }
|
||||
|
||||
# Uncomment to debug local changes.
|
||||
# rmcp = { path = "../../rust-sdk/crates/rmcp" }
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-ansi-escape"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_ansi_escape"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-app-server-protocol"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_app_server_protocol"
|
||||
|
||||
@@ -164,6 +164,19 @@ client_request_definitions! {
|
||||
response: v2::FeedbackUploadResponse,
|
||||
},
|
||||
|
||||
ConfigRead => "config/read" {
|
||||
params: v2::ConfigReadParams,
|
||||
response: v2::ConfigReadResponse,
|
||||
},
|
||||
ConfigValueWrite => "config/value/write" {
|
||||
params: v2::ConfigValueWriteParams,
|
||||
response: v2::ConfigWriteResponse,
|
||||
},
|
||||
ConfigBatchWrite => "config/batchWrite" {
|
||||
params: v2::ConfigBatchWriteParams,
|
||||
response: v2::ConfigWriteResponse,
|
||||
},
|
||||
|
||||
GetAccount => "account/read" {
|
||||
params: v2::GetAccountParams,
|
||||
response: v2::GetAccountResponse,
|
||||
@@ -489,8 +502,10 @@ server_notification_definitions! {
|
||||
/// NEW NOTIFICATIONS
|
||||
Error => "error" (v2::ErrorNotification),
|
||||
ThreadStarted => "thread/started" (v2::ThreadStartedNotification),
|
||||
ThreadTokenUsageUpdated => "thread/tokenUsage/updated" (v2::ThreadTokenUsageUpdatedNotification),
|
||||
TurnStarted => "turn/started" (v2::TurnStartedNotification),
|
||||
TurnCompleted => "turn/completed" (v2::TurnCompletedNotification),
|
||||
TurnDiffUpdated => "turn/diff/updated" (v2::TurnDiffUpdatedNotification),
|
||||
ItemStarted => "item/started" (v2::ItemStartedNotification),
|
||||
ItemCompleted => "item/completed" (v2::ItemCompletedNotification),
|
||||
AgentMessageDelta => "item/agentMessage/delta" (v2::AgentMessageDeltaNotification),
|
||||
@@ -501,6 +516,7 @@ server_notification_definitions! {
|
||||
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),
|
||||
ReasoningSummaryPartAdded => "item/reasoning/summaryPartAdded" (v2::ReasoningSummaryPartAddedNotification),
|
||||
ReasoningTextDelta => "item/reasoning/textDelta" (v2::ReasoningTextDeltaNotification),
|
||||
ContextCompacted => "thread/compacted" (v2::ContextCompactedNotification),
|
||||
|
||||
/// Notifies the user of world-writable directories on Windows, which cannot be protected by the sandbox.
|
||||
WindowsWorldWritableWarning => "windows/worldWritableWarning" (v2::WindowsWorldWritableWarningNotification),
|
||||
|
||||
@@ -16,6 +16,8 @@ use codex_protocol::protocol::CreditsSnapshot as CoreCreditsSnapshot;
|
||||
use codex_protocol::protocol::RateLimitSnapshot as CoreRateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow as CoreRateLimitWindow;
|
||||
use codex_protocol::protocol::SessionSource as CoreSessionSource;
|
||||
use codex_protocol::protocol::TokenUsage as CoreTokenUsage;
|
||||
use codex_protocol::protocol::TokenUsageInfo as CoreTokenUsageInfo;
|
||||
use codex_protocol::user_input::UserInput as CoreUserInput;
|
||||
use mcp_types::ContentBlock as McpContentBlock;
|
||||
use schemars::JsonSchema;
|
||||
@@ -128,6 +130,127 @@ v2_enum_from_core!(
|
||||
}
|
||||
);
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub enum ConfigLayerName {
|
||||
Mdm,
|
||||
System,
|
||||
SessionFlags,
|
||||
User,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigLayerMetadata {
|
||||
pub name: ConfigLayerName,
|
||||
pub source: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigLayer {
|
||||
pub name: ConfigLayerName,
|
||||
pub source: String,
|
||||
pub version: String,
|
||||
pub config: JsonValue,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub enum MergeStrategy {
|
||||
Replace,
|
||||
Upsert,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub enum WriteStatus {
|
||||
Ok,
|
||||
OkOverridden,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct OverriddenMetadata {
|
||||
pub message: String,
|
||||
pub overriding_layer: ConfigLayerMetadata,
|
||||
pub effective_value: JsonValue,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigWriteResponse {
|
||||
pub status: WriteStatus,
|
||||
pub version: String,
|
||||
pub overridden_metadata: Option<OverriddenMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub enum ConfigWriteErrorCode {
|
||||
ConfigLayerReadonly,
|
||||
ConfigVersionConflict,
|
||||
ConfigValidationError,
|
||||
ConfigPathNotFound,
|
||||
ConfigSchemaUnknownKey,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigReadParams {
|
||||
#[serde(default)]
|
||||
pub include_layers: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigReadResponse {
|
||||
pub config: JsonValue,
|
||||
pub origins: HashMap<String, ConfigLayerMetadata>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub layers: Option<Vec<ConfigLayer>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigValueWriteParams {
|
||||
pub file_path: String,
|
||||
pub key_path: String,
|
||||
pub value: JsonValue,
|
||||
pub merge_strategy: MergeStrategy,
|
||||
pub expected_version: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigBatchWriteParams {
|
||||
pub file_path: String,
|
||||
pub edits: Vec<ConfigEdit>,
|
||||
pub expected_version: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigEdit {
|
||||
pub key_path: String,
|
||||
pub value: JsonValue,
|
||||
pub merge_strategy: MergeStrategy,
|
||||
}
|
||||
|
||||
v2_enum_from_core!(
|
||||
pub enum CommandRiskLevel from codex_protocol::approvals::SandboxRiskLevel {
|
||||
Low,
|
||||
@@ -659,6 +782,63 @@ pub struct AccountUpdatedNotification {
|
||||
pub auth_mode: Option<AuthMode>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadTokenUsageUpdatedNotification {
|
||||
pub thread_id: String,
|
||||
pub turn_id: String,
|
||||
pub token_usage: ThreadTokenUsage,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadTokenUsage {
|
||||
pub total: TokenUsageBreakdown,
|
||||
pub last: TokenUsageBreakdown,
|
||||
#[ts(type = "number | null")]
|
||||
pub model_context_window: Option<i64>,
|
||||
}
|
||||
|
||||
impl From<CoreTokenUsageInfo> for ThreadTokenUsage {
|
||||
fn from(value: CoreTokenUsageInfo) -> Self {
|
||||
Self {
|
||||
total: value.total_token_usage.into(),
|
||||
last: value.last_token_usage.into(),
|
||||
model_context_window: value.model_context_window,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct TokenUsageBreakdown {
|
||||
#[ts(type = "number")]
|
||||
pub total_tokens: i64,
|
||||
#[ts(type = "number")]
|
||||
pub input_tokens: i64,
|
||||
#[ts(type = "number")]
|
||||
pub cached_input_tokens: i64,
|
||||
#[ts(type = "number")]
|
||||
pub output_tokens: i64,
|
||||
#[ts(type = "number")]
|
||||
pub reasoning_output_tokens: i64,
|
||||
}
|
||||
|
||||
impl From<CoreTokenUsage> for TokenUsageBreakdown {
|
||||
fn from(value: CoreTokenUsage) -> Self {
|
||||
Self {
|
||||
total_tokens: value.total_tokens,
|
||||
input_tokens: value.input_tokens,
|
||||
cached_input_tokens: value.cached_input_tokens,
|
||||
output_tokens: value.output_tokens,
|
||||
reasoning_output_tokens: value.reasoning_output_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
@@ -686,6 +866,8 @@ pub struct TurnError {
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ErrorNotification {
|
||||
pub error: TurnError,
|
||||
pub thread_id: String,
|
||||
pub turn_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
@@ -838,6 +1020,8 @@ pub enum ThreadItem {
|
||||
command: String,
|
||||
/// The command's working directory.
|
||||
cwd: PathBuf,
|
||||
/// Identifier for the underlying PTY process (when available).
|
||||
process_id: Option<String>,
|
||||
status: CommandExecutionStatus,
|
||||
/// A best-effort parsing of the command to understand the action(s) it will perform.
|
||||
/// This returns a list of CommandAction objects because a single shell command may
|
||||
@@ -997,6 +1181,7 @@ pub struct ThreadStartedNotification {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct TurnStartedNotification {
|
||||
pub thread_id: String,
|
||||
pub turn: Turn,
|
||||
}
|
||||
|
||||
@@ -1013,14 +1198,27 @@ pub struct Usage {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct TurnCompletedNotification {
|
||||
pub thread_id: String,
|
||||
pub turn: Turn,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
/// Notification that the turn-level unified diff has changed.
|
||||
/// Contains the latest aggregated diff across all file changes in the turn.
|
||||
pub struct TurnDiffUpdatedNotification {
|
||||
pub turn_id: String,
|
||||
pub diff: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ItemStartedNotification {
|
||||
pub item: ThreadItem,
|
||||
pub thread_id: String,
|
||||
pub turn_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
@@ -1028,6 +1226,8 @@ pub struct ItemStartedNotification {
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ItemCompletedNotification {
|
||||
pub item: ThreadItem,
|
||||
pub thread_id: String,
|
||||
pub turn_id: String,
|
||||
}
|
||||
|
||||
// Item-specific progress notifications
|
||||
@@ -1090,6 +1290,14 @@ pub struct WindowsWorldWritableWarningNotification {
|
||||
pub failed_scan: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ContextCompactedNotification {
|
||||
pub thread_id: String,
|
||||
pub turn_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
name = "codex-app-server-test-client"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -101,6 +101,15 @@ enum CliCommand {
|
||||
/// Start a V2 turn that should not elicit an ExecCommand approval.
|
||||
#[command(name = "no-trigger-cmd-approval")]
|
||||
NoTriggerCmdApproval,
|
||||
/// Send two sequential V2 turns in the same thread to test follow-up behavior.
|
||||
SendFollowUpV2 {
|
||||
/// Initial user message for the first turn.
|
||||
#[arg()]
|
||||
first_message: String,
|
||||
/// Follow-up user message for the second turn.
|
||||
#[arg()]
|
||||
follow_up_message: String,
|
||||
},
|
||||
/// Trigger the ChatGPT login flow and wait for completion.
|
||||
TestLogin,
|
||||
/// Fetch the current account rate limits from the Codex app-server.
|
||||
@@ -120,6 +129,10 @@ fn main() -> Result<()> {
|
||||
trigger_patch_approval(codex_bin, user_message)
|
||||
}
|
||||
CliCommand::NoTriggerCmdApproval => no_trigger_cmd_approval(codex_bin),
|
||||
CliCommand::SendFollowUpV2 {
|
||||
first_message,
|
||||
follow_up_message,
|
||||
} => send_follow_up_v2(codex_bin, first_message, follow_up_message),
|
||||
CliCommand::TestLogin => test_login(codex_bin),
|
||||
CliCommand::GetAccountRateLimits => get_account_rate_limits(codex_bin),
|
||||
}
|
||||
@@ -209,6 +222,44 @@ fn send_message_v2_with_policies(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_follow_up_v2(
|
||||
codex_bin: String,
|
||||
first_message: String,
|
||||
follow_up_message: String,
|
||||
) -> Result<()> {
|
||||
let mut client = CodexClient::spawn(codex_bin)?;
|
||||
|
||||
let initialize = client.initialize()?;
|
||||
println!("< initialize response: {initialize:?}");
|
||||
|
||||
let thread_response = client.thread_start(ThreadStartParams::default())?;
|
||||
println!("< thread/start response: {thread_response:?}");
|
||||
|
||||
let first_turn_params = TurnStartParams {
|
||||
thread_id: thread_response.thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: first_message,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let first_turn_response = client.turn_start(first_turn_params)?;
|
||||
println!("< turn/start response (initial): {first_turn_response:?}");
|
||||
client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?;
|
||||
|
||||
let follow_up_params = TurnStartParams {
|
||||
thread_id: thread_response.thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: follow_up_message,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let follow_up_response = client.turn_start(follow_up_params)?;
|
||||
println!("< turn/start response (follow-up): {follow_up_response:?}");
|
||||
client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_login(codex_bin: String) -> Result<()> {
|
||||
let mut client = CodexClient::spawn(codex_bin)?;
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-app-server"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "codex-app-server"
|
||||
@@ -29,6 +30,9 @@ codex-utils-json-to-toml = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -50,7 +54,5 @@ mcp-types = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
|
||||
@@ -240,7 +240,7 @@ Event notifications are the server-initiated event stream for thread lifecycles,
|
||||
|
||||
### Turn events
|
||||
|
||||
The app-server streams JSON-RPC notifications while a turn is running. Each turn starts with `turn/started` (initial `turn`) and ends with `turn/completed` (final `turn` plus token `usage`), and clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`.
|
||||
The app-server streams JSON-RPC notifications while a turn is running. Each turn starts with `turn/started` (initial `turn`) and ends with `turn/completed` (final `turn` status). Token usage events stream separately via `thread/tokenUsage/updated`. Clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`.
|
||||
|
||||
- `turn/started` — `{ turn }` with the turn id, empty `items`, and `status: "inProgress"`.
|
||||
- `turn/completed` — `{ turn }` where `turn.status` is `completed`, `interrupted`, or `failed`; failures carry `{ error: { message, codexErrorInfo? } }`.
|
||||
@@ -257,6 +257,7 @@ Today both notifications carry an empty `items` array even when item events were
|
||||
- `fileChange` — `{id, changes, status}` describing proposed edits; `changes` list `{path, kind, diff}` and `status` is `inProgress`, `completed`, `failed`, or `declined`.
|
||||
- `mcpToolCall` — `{id, server, tool, status, arguments, result?, error?}` describing MCP calls; `status` is `inProgress`, `completed`, or `failed`.
|
||||
- `webSearch` — `{id, query}` for a web search request issued by the agent.
|
||||
- `compacted` - `{threadId, turnId}` when codex compacts the conversation history. This can happen automatically.
|
||||
|
||||
All items emit two shared lifecycle events:
|
||||
- `item/started` — emits the full `item` when a new unit of work begins so the UI can render it immediately; the `item.id` in this payload matches the `itemId` used by deltas.
|
||||
|
||||
@@ -14,6 +14,7 @@ use codex_app_server_protocol::CommandExecutionOutputDeltaNotification;
|
||||
use codex_app_server_protocol::CommandExecutionRequestApprovalParams;
|
||||
use codex_app_server_protocol::CommandExecutionRequestApprovalResponse;
|
||||
use codex_app_server_protocol::CommandExecutionStatus;
|
||||
use codex_app_server_protocol::ContextCompactedNotification;
|
||||
use codex_app_server_protocol::ErrorNotification;
|
||||
use codex_app_server_protocol::ExecCommandApprovalParams;
|
||||
use codex_app_server_protocol::ExecCommandApprovalResponse;
|
||||
@@ -35,8 +36,11 @@ use codex_app_server_protocol::SandboxCommandAssessment as V2SandboxCommandAsses
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadTokenUsage;
|
||||
use codex_app_server_protocol::ThreadTokenUsageUpdatedNotification;
|
||||
use codex_app_server_protocol::Turn;
|
||||
use codex_app_server_protocol::TurnCompletedNotification;
|
||||
use codex_app_server_protocol::TurnDiffUpdatedNotification;
|
||||
use codex_app_server_protocol::TurnError;
|
||||
use codex_app_server_protocol::TurnInterruptResponse;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
@@ -52,6 +56,8 @@ use codex_core::protocol::McpToolCallBeginEvent;
|
||||
use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_core::protocol::TokenCountEvent;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use codex_core::review_format::format_review_findings_block;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::ReviewOutputEvent;
|
||||
@@ -73,10 +79,19 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
api_version: ApiVersion,
|
||||
) {
|
||||
let Event { id: event_id, msg } = event;
|
||||
let Event {
|
||||
id: event_turn_id,
|
||||
msg,
|
||||
} = event;
|
||||
match msg {
|
||||
EventMsg::TaskComplete(_ev) => {
|
||||
handle_turn_complete(conversation_id, event_id, &outgoing, &turn_summary_store).await;
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
@@ -97,7 +112,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.send_request(ServerRequestPayload::ApplyPatchApproval(params))
|
||||
.await;
|
||||
tokio::spawn(async move {
|
||||
on_patch_approval_response(event_id, rx, conversation).await;
|
||||
on_patch_approval_response(event_turn_id, rx, conversation).await;
|
||||
});
|
||||
}
|
||||
ApiVersion::V2 => {
|
||||
@@ -117,7 +132,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
changes: patch_changes.clone(),
|
||||
status: PatchApplyStatus::InProgress,
|
||||
};
|
||||
let notification = ItemStartedNotification { item };
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemStarted(notification))
|
||||
.await;
|
||||
@@ -135,7 +154,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
tokio::spawn(async move {
|
||||
on_file_change_request_approval_response(
|
||||
event_id,
|
||||
event_turn_id,
|
||||
conversation_id,
|
||||
item_id,
|
||||
patch_changes,
|
||||
@@ -171,7 +190,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.send_request(ServerRequestPayload::ExecCommandApproval(params))
|
||||
.await;
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, rx, conversation).await;
|
||||
on_exec_approval_response(event_turn_id, rx, conversation).await;
|
||||
});
|
||||
}
|
||||
ApiVersion::V2 => {
|
||||
@@ -199,7 +218,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
tokio::spawn(async move {
|
||||
on_command_execution_request_approval_response(
|
||||
event_id,
|
||||
event_turn_id,
|
||||
conversation_id,
|
||||
item_id,
|
||||
command_string,
|
||||
cwd,
|
||||
@@ -214,13 +234,23 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
},
|
||||
// TODO(celia): properly construct McpToolCall TurnItem in core.
|
||||
EventMsg::McpToolCallBegin(begin_event) => {
|
||||
let notification = construct_mcp_tool_call_notification(begin_event).await;
|
||||
let notification = construct_mcp_tool_call_notification(
|
||||
begin_event,
|
||||
conversation_id.to_string(),
|
||||
event_turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemStarted(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::McpToolCallEnd(end_event) => {
|
||||
let notification = construct_mcp_tool_call_end_notification(end_event).await;
|
||||
let notification = construct_mcp_tool_call_end_notification(
|
||||
end_event,
|
||||
conversation_id.to_string(),
|
||||
event_turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
@@ -234,6 +264,15 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.send_server_notification(ServerNotification::AgentMessageDelta(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::ContextCompacted(..) => {
|
||||
let notification = ContextCompactedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ContextCompacted(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::ReasoningContentDelta(event) => {
|
||||
let notification = ReasoningSummaryTextDeltaNotification {
|
||||
item_id: event.item_id,
|
||||
@@ -268,15 +307,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.await;
|
||||
}
|
||||
EventMsg::TokenCount(token_count_event) => {
|
||||
if let Some(rate_limits) = token_count_event.rate_limits {
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::AccountRateLimitsUpdated(
|
||||
AccountRateLimitsUpdatedNotification {
|
||||
rate_limits: rate_limits.into(),
|
||||
},
|
||||
))
|
||||
.await;
|
||||
}
|
||||
handle_token_count_event(conversation_id, event_turn_id, token_count_event, &outgoing)
|
||||
.await;
|
||||
}
|
||||
EventMsg::Error(ev) => {
|
||||
let turn_error = TurnError {
|
||||
@@ -287,6 +319,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::Error(ErrorNotification {
|
||||
error: turn_error,
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
@@ -300,13 +334,17 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::Error(ErrorNotification {
|
||||
error: turn_error,
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
EventMsg::EnteredReviewMode(review_request) => {
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item: ThreadItem::CodeReview {
|
||||
id: event_id.clone(),
|
||||
id: event_turn_id.clone(),
|
||||
review: review_request.user_facing_hint,
|
||||
},
|
||||
};
|
||||
@@ -316,14 +354,22 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
}
|
||||
EventMsg::ItemStarted(item_started_event) => {
|
||||
let item: ThreadItem = item_started_event.item.clone().into();
|
||||
let notification = ItemStartedNotification { item };
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemStarted(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::ItemCompleted(item_completed_event) => {
|
||||
let item: ThreadItem = item_completed_event.item.clone().into();
|
||||
let notification = ItemCompletedNotification { item };
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
@@ -333,9 +379,12 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
Some(output) => render_review_output_text(&output),
|
||||
None => REVIEW_FALLBACK_MESSAGE.to_string(),
|
||||
};
|
||||
let review_item_id = event_turn_id.clone();
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item: ThreadItem::CodeReview {
|
||||
id: event_id,
|
||||
id: review_item_id,
|
||||
review: review_text,
|
||||
},
|
||||
};
|
||||
@@ -359,7 +408,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
changes: convert_patch_changes(&patch_begin_event.changes),
|
||||
status: PatchApplyStatus::InProgress,
|
||||
};
|
||||
let notification = ItemStartedNotification { item };
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemStarted(notification))
|
||||
.await;
|
||||
@@ -381,6 +434,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
item_id,
|
||||
changes,
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
)
|
||||
@@ -395,18 +449,24 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.collect::<Vec<_>>();
|
||||
let command = shlex_join(&exec_command_begin_event.command);
|
||||
let cwd = exec_command_begin_event.cwd;
|
||||
let process_id = exec_command_begin_event.process_id;
|
||||
|
||||
let item = ThreadItem::CommandExecution {
|
||||
id: item_id,
|
||||
command,
|
||||
cwd,
|
||||
process_id,
|
||||
status: CommandExecutionStatus::InProgress,
|
||||
command_actions,
|
||||
aggregated_output: None,
|
||||
exit_code: None,
|
||||
duration_ms: None,
|
||||
};
|
||||
let notification = ItemStartedNotification { item };
|
||||
let notification = ItemStartedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemStarted(notification))
|
||||
.await;
|
||||
@@ -428,6 +488,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
command,
|
||||
cwd,
|
||||
parsed_cmd,
|
||||
process_id,
|
||||
aggregated_output,
|
||||
exit_code,
|
||||
duration,
|
||||
@@ -456,6 +517,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
id: call_id,
|
||||
command: shlex_join(&command),
|
||||
cwd,
|
||||
process_id,
|
||||
status,
|
||||
command_actions,
|
||||
aggregated_output,
|
||||
@@ -463,7 +525,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
duration_ms: Some(duration_ms),
|
||||
};
|
||||
|
||||
let notification = ItemCompletedNotification { item };
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
@@ -491,22 +557,55 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
}
|
||||
}
|
||||
|
||||
handle_turn_interrupted(conversation_id, event_id, &outgoing, &turn_summary_store)
|
||||
.await;
|
||||
handle_turn_interrupted(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
EventMsg::TurnDiff(turn_diff_event) => {
|
||||
handle_turn_diff(
|
||||
&event_turn_id,
|
||||
turn_diff_event,
|
||||
api_version,
|
||||
outgoing.as_ref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_turn_diff(
|
||||
event_turn_id: &str,
|
||||
turn_diff_event: TurnDiffEvent,
|
||||
api_version: ApiVersion,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
if let ApiVersion::V2 = api_version {
|
||||
let notification = TurnDiffUpdatedNotification {
|
||||
turn_id: event_turn_id.to_string(),
|
||||
diff: turn_diff_event.unified_diff,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::TurnDiffUpdated(notification))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_turn_completed_with_status(
|
||||
event_id: String,
|
||||
conversation_id: ConversationId,
|
||||
event_turn_id: String,
|
||||
status: TurnStatus,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
let notification = TurnCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn: Turn {
|
||||
id: event_id,
|
||||
id: event_turn_id,
|
||||
items: vec![],
|
||||
status,
|
||||
},
|
||||
@@ -521,6 +620,7 @@ async fn complete_file_change_item(
|
||||
item_id: String,
|
||||
changes: Vec<FileUpdateChange>,
|
||||
status: PatchApplyStatus,
|
||||
turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
) {
|
||||
@@ -536,16 +636,24 @@ async fn complete_file_change_item(
|
||||
changes,
|
||||
status,
|
||||
};
|
||||
let notification = ItemCompletedNotification { item };
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id,
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn complete_command_execution_item(
|
||||
conversation_id: ConversationId,
|
||||
turn_id: String,
|
||||
item_id: String,
|
||||
command: String,
|
||||
cwd: PathBuf,
|
||||
process_id: Option<String>,
|
||||
command_actions: Vec<V2ParsedCommand>,
|
||||
status: CommandExecutionStatus,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
@@ -554,13 +662,18 @@ async fn complete_command_execution_item(
|
||||
id: item_id,
|
||||
command,
|
||||
cwd,
|
||||
process_id,
|
||||
status,
|
||||
command_actions,
|
||||
aggregated_output: None,
|
||||
exit_code: None,
|
||||
duration_ms: None,
|
||||
};
|
||||
let notification = ItemCompletedNotification { item };
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id,
|
||||
item,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
@@ -576,7 +689,7 @@ async fn find_and_remove_turn_summary(
|
||||
|
||||
async fn handle_turn_complete(
|
||||
conversation_id: ConversationId,
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
) {
|
||||
@@ -588,18 +701,52 @@ async fn handle_turn_complete(
|
||||
TurnStatus::Completed
|
||||
};
|
||||
|
||||
emit_turn_completed_with_status(event_id, status, outgoing).await;
|
||||
emit_turn_completed_with_status(conversation_id, event_turn_id, status, outgoing).await;
|
||||
}
|
||||
|
||||
async fn handle_turn_interrupted(
|
||||
conversation_id: ConversationId,
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
) {
|
||||
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
|
||||
|
||||
emit_turn_completed_with_status(event_id, TurnStatus::Interrupted, outgoing).await;
|
||||
emit_turn_completed_with_status(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
TurnStatus::Interrupted,
|
||||
outgoing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_token_count_event(
|
||||
conversation_id: ConversationId,
|
||||
turn_id: String,
|
||||
token_count_event: TokenCountEvent,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
let TokenCountEvent { info, rate_limits } = token_count_event;
|
||||
if let Some(token_usage) = info.map(ThreadTokenUsage::from) {
|
||||
let notification = ThreadTokenUsageUpdatedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id,
|
||||
token_usage,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::ThreadTokenUsageUpdated(notification))
|
||||
.await;
|
||||
}
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::AccountRateLimitsUpdated(
|
||||
AccountRateLimitsUpdatedNotification {
|
||||
rate_limits: rate_limits.into(),
|
||||
},
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
@@ -612,7 +759,7 @@ async fn handle_error(
|
||||
}
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
codex: Arc<CodexConversation>,
|
||||
) {
|
||||
@@ -623,7 +770,7 @@ async fn on_patch_approval_response(
|
||||
error!("request failed: {err:?}");
|
||||
if let Err(submit_err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id.clone(),
|
||||
id: event_turn_id.clone(),
|
||||
decision: ReviewDecision::Denied,
|
||||
})
|
||||
.await
|
||||
@@ -644,7 +791,7 @@ async fn on_patch_approval_response(
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id,
|
||||
id: event_turn_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
@@ -654,7 +801,7 @@ async fn on_patch_approval_response(
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
conversation: Arc<CodexConversation>,
|
||||
) {
|
||||
@@ -680,7 +827,7 @@ async fn on_exec_approval_response(
|
||||
|
||||
if let Err(err) = conversation
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
id: event_turn_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
@@ -753,7 +900,7 @@ fn format_file_change_diff(change: &CoreFileChange) -> String {
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn on_file_change_request_approval_response(
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
conversation_id: ConversationId,
|
||||
item_id: String,
|
||||
changes: Vec<FileUpdateChange>,
|
||||
@@ -798,6 +945,7 @@ async fn on_file_change_request_approval_response(
|
||||
item_id,
|
||||
changes,
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
)
|
||||
@@ -806,7 +954,7 @@ async fn on_file_change_request_approval_response(
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id,
|
||||
id: event_turn_id,
|
||||
decision,
|
||||
})
|
||||
.await
|
||||
@@ -817,7 +965,8 @@ async fn on_file_change_request_approval_response(
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn on_command_execution_request_approval_response(
|
||||
event_id: String,
|
||||
event_turn_id: String,
|
||||
conversation_id: ConversationId,
|
||||
item_id: String,
|
||||
command: String,
|
||||
cwd: PathBuf,
|
||||
@@ -867,9 +1016,12 @@ async fn on_command_execution_request_approval_response(
|
||||
|
||||
if let Some(status) = completion_status {
|
||||
complete_command_execution_item(
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
item_id.clone(),
|
||||
command.clone(),
|
||||
cwd.clone(),
|
||||
None,
|
||||
command_actions.clone(),
|
||||
status,
|
||||
outgoing.as_ref(),
|
||||
@@ -879,7 +1031,7 @@ async fn on_command_execution_request_approval_response(
|
||||
|
||||
if let Err(err) = conversation
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
id: event_turn_id,
|
||||
decision,
|
||||
})
|
||||
.await
|
||||
@@ -891,6 +1043,8 @@ async fn on_command_execution_request_approval_response(
|
||||
/// similar to handle_mcp_tool_call_begin in exec
|
||||
async fn construct_mcp_tool_call_notification(
|
||||
begin_event: McpToolCallBeginEvent,
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
) -> ItemStartedNotification {
|
||||
let item = ThreadItem::McpToolCall {
|
||||
id: begin_event.call_id,
|
||||
@@ -901,12 +1055,18 @@ async fn construct_mcp_tool_call_notification(
|
||||
result: None,
|
||||
error: None,
|
||||
};
|
||||
ItemStartedNotification { item }
|
||||
ItemStartedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item,
|
||||
}
|
||||
}
|
||||
|
||||
/// simiilar to handle_mcp_tool_call_end in exec
|
||||
async fn construct_mcp_tool_call_end_notification(
|
||||
end_event: McpToolCallEndEvent,
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
) -> ItemCompletedNotification {
|
||||
let status = if end_event.is_success() {
|
||||
McpToolCallStatus::Completed
|
||||
@@ -939,7 +1099,11 @@ async fn construct_mcp_tool_call_end_notification(
|
||||
result,
|
||||
error,
|
||||
};
|
||||
ItemCompletedNotification { item }
|
||||
ItemCompletedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -951,7 +1115,12 @@ mod tests {
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use codex_core::protocol::CreditsSnapshot;
|
||||
use codex_core::protocol::McpInvocation;
|
||||
use codex_core::protocol::RateLimitSnapshot;
|
||||
use codex_core::protocol::RateLimitWindow;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_core::protocol::TokenUsageInfo;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
@@ -995,14 +1164,14 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_complete_emits_completed_without_error() -> Result<()> {
|
||||
let conversation_id = ConversationId::new();
|
||||
let event_id = "complete1".to_string();
|
||||
let event_turn_id = "complete1".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_id.clone(),
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
@@ -1014,7 +1183,7 @@ mod tests {
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_id);
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
assert_eq!(n.turn.status, TurnStatus::Completed);
|
||||
}
|
||||
other => bail!("unexpected message: {other:?}"),
|
||||
@@ -1026,7 +1195,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
|
||||
let conversation_id = ConversationId::new();
|
||||
let event_id = "interrupt1".to_string();
|
||||
let event_turn_id = "interrupt1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
@@ -1042,7 +1211,7 @@ mod tests {
|
||||
|
||||
handle_turn_interrupted(
|
||||
conversation_id,
|
||||
event_id.clone(),
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
@@ -1054,7 +1223,7 @@ mod tests {
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_id);
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
assert_eq!(n.turn.status, TurnStatus::Interrupted);
|
||||
}
|
||||
other => bail!("unexpected message: {other:?}"),
|
||||
@@ -1066,7 +1235,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
|
||||
let conversation_id = ConversationId::new();
|
||||
let event_id = "complete_err1".to_string();
|
||||
let event_turn_id = "complete_err1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
@@ -1082,7 +1251,7 @@ mod tests {
|
||||
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_id.clone(),
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
@@ -1094,7 +1263,7 @@ mod tests {
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_id);
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
assert_eq!(
|
||||
n.turn.status,
|
||||
TurnStatus::Failed {
|
||||
@@ -1111,6 +1280,115 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_token_count_event_emits_usage_and_rate_limits() -> Result<()> {
|
||||
let conversation_id = ConversationId::new();
|
||||
let turn_id = "turn-123".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
|
||||
let info = TokenUsageInfo {
|
||||
total_token_usage: TokenUsage {
|
||||
input_tokens: 100,
|
||||
cached_input_tokens: 25,
|
||||
output_tokens: 50,
|
||||
reasoning_output_tokens: 9,
|
||||
total_tokens: 200,
|
||||
},
|
||||
last_token_usage: TokenUsage {
|
||||
input_tokens: 10,
|
||||
cached_input_tokens: 5,
|
||||
output_tokens: 7,
|
||||
reasoning_output_tokens: 1,
|
||||
total_tokens: 23,
|
||||
},
|
||||
model_context_window: Some(4096),
|
||||
};
|
||||
let rate_limits = RateLimitSnapshot {
|
||||
primary: Some(RateLimitWindow {
|
||||
used_percent: 42.5,
|
||||
window_minutes: Some(15),
|
||||
resets_at: Some(1700000000),
|
||||
}),
|
||||
secondary: None,
|
||||
credits: Some(CreditsSnapshot {
|
||||
has_credits: true,
|
||||
unlimited: false,
|
||||
balance: Some("5".to_string()),
|
||||
}),
|
||||
};
|
||||
|
||||
handle_token_count_event(
|
||||
conversation_id,
|
||||
turn_id.clone(),
|
||||
TokenCountEvent {
|
||||
info: Some(info),
|
||||
rate_limits: Some(rate_limits),
|
||||
},
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
|
||||
let first = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("expected usage notification"))?;
|
||||
match first {
|
||||
OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ThreadTokenUsageUpdated(payload),
|
||||
) => {
|
||||
assert_eq!(payload.thread_id, conversation_id.to_string());
|
||||
assert_eq!(payload.turn_id, turn_id);
|
||||
let usage = payload.token_usage;
|
||||
assert_eq!(usage.total.total_tokens, 200);
|
||||
assert_eq!(usage.total.cached_input_tokens, 25);
|
||||
assert_eq!(usage.last.output_tokens, 7);
|
||||
assert_eq!(usage.model_context_window, Some(4096));
|
||||
}
|
||||
other => bail!("unexpected notification: {other:?}"),
|
||||
}
|
||||
|
||||
let second = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("expected rate limit notification"))?;
|
||||
match second {
|
||||
OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::AccountRateLimitsUpdated(payload),
|
||||
) => {
|
||||
assert!(payload.rate_limits.primary.is_some());
|
||||
assert!(payload.rate_limits.credits.is_some());
|
||||
}
|
||||
other => bail!("unexpected notification: {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_token_count_event_without_usage_info() -> Result<()> {
|
||||
let conversation_id = ConversationId::new();
|
||||
let turn_id = "turn-456".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
|
||||
handle_token_count_event(
|
||||
conversation_id,
|
||||
turn_id.clone(),
|
||||
TokenCountEvent {
|
||||
info: None,
|
||||
rate_limits: None,
|
||||
},
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"no notifications should be emitted when token usage info is absent"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_construct_mcp_tool_call_begin_notification_with_args() {
|
||||
let begin_event = McpToolCallBeginEvent {
|
||||
@@ -1122,9 +1400,18 @@ mod tests {
|
||||
},
|
||||
};
|
||||
|
||||
let notification = construct_mcp_tool_call_notification(begin_event.clone()).await;
|
||||
let thread_id = ConversationId::new().to_string();
|
||||
let turn_id = "turn_1".to_string();
|
||||
let notification = construct_mcp_tool_call_notification(
|
||||
begin_event.clone(),
|
||||
thread_id.clone(),
|
||||
turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected = ItemStartedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item: ThreadItem::McpToolCall {
|
||||
id: begin_event.call_id,
|
||||
server: begin_event.invocation.server,
|
||||
@@ -1267,9 +1554,18 @@ mod tests {
|
||||
},
|
||||
};
|
||||
|
||||
let notification = construct_mcp_tool_call_notification(begin_event.clone()).await;
|
||||
let thread_id = ConversationId::new().to_string();
|
||||
let turn_id = "turn_2".to_string();
|
||||
let notification = construct_mcp_tool_call_notification(
|
||||
begin_event.clone(),
|
||||
thread_id.clone(),
|
||||
turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected = ItemStartedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item: ThreadItem::McpToolCall {
|
||||
id: begin_event.call_id,
|
||||
server: begin_event.invocation.server,
|
||||
@@ -1308,9 +1604,18 @@ mod tests {
|
||||
result: Ok(result),
|
||||
};
|
||||
|
||||
let notification = construct_mcp_tool_call_end_notification(end_event.clone()).await;
|
||||
let thread_id = ConversationId::new().to_string();
|
||||
let turn_id = "turn_3".to_string();
|
||||
let notification = construct_mcp_tool_call_end_notification(
|
||||
end_event.clone(),
|
||||
thread_id.clone(),
|
||||
turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected = ItemCompletedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item: ThreadItem::McpToolCall {
|
||||
id: end_event.call_id,
|
||||
server: end_event.invocation.server,
|
||||
@@ -1341,9 +1646,18 @@ mod tests {
|
||||
result: Err("boom".to_string()),
|
||||
};
|
||||
|
||||
let notification = construct_mcp_tool_call_end_notification(end_event.clone()).await;
|
||||
let thread_id = ConversationId::new().to_string();
|
||||
let turn_id = "turn_4".to_string();
|
||||
let notification = construct_mcp_tool_call_end_notification(
|
||||
end_event.clone(),
|
||||
thread_id.clone(),
|
||||
turn_id.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected = ItemCompletedNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item: ThreadItem::McpToolCall {
|
||||
id: end_event.call_id,
|
||||
server: end_event.invocation.server,
|
||||
@@ -1359,4 +1673,56 @@ mod tests {
|
||||
|
||||
assert_eq!(notification, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let unified_diff = "--- a\n+++ b\n".to_string();
|
||||
|
||||
handle_turn_diff(
|
||||
"turn-1",
|
||||
TurnDiffEvent {
|
||||
unified_diff: unified_diff.clone(),
|
||||
},
|
||||
ApiVersion::V2,
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnDiffUpdated(
|
||||
notification,
|
||||
)) => {
|
||||
assert_eq!(notification.turn_id, "turn-1");
|
||||
assert_eq!(notification.diff, unified_diff);
|
||||
}
|
||||
other => bail!("unexpected message: {other:?}"),
|
||||
}
|
||||
assert!(rx.try_recv().is_err(), "no extra messages expected");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_diff_is_noop_for_v1() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
|
||||
handle_turn_diff(
|
||||
"turn-1",
|
||||
TurnDiffEvent {
|
||||
unified_diff: "diff".to_string(),
|
||||
},
|
||||
ApiVersion::V1,
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(rx.try_recv().is_err(), "no messages expected");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +116,6 @@ use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::find_conversation_path_by_id_str;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::git_info::git_diff_to_remote;
|
||||
use codex_core::parse_cursor;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -473,6 +472,11 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ExecOneOffCommand { request_id, params } => {
|
||||
self.exec_one_off_command(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ConfigRead { .. }
|
||||
| ClientRequest::ConfigValueWrite { .. }
|
||||
| ClientRequest::ConfigBatchWrite { .. } => {
|
||||
warn!("Config request reached CodexMessageProcessor unexpectedly");
|
||||
}
|
||||
ClientRequest::GetAccountRateLimits {
|
||||
request_id,
|
||||
params: _,
|
||||
@@ -1182,13 +1186,6 @@ impl CodexMessageProcessor {
|
||||
.sandbox_policy
|
||||
.unwrap_or_else(|| self.config.sandbox_policy.clone());
|
||||
|
||||
let sandbox_type = match &effective_policy {
|
||||
codex_core::protocol::SandboxPolicy::DangerFullAccess => {
|
||||
codex_core::exec::SandboxType::None
|
||||
}
|
||||
_ => get_platform_sandbox().unwrap_or(codex_core::exec::SandboxType::None),
|
||||
};
|
||||
tracing::debug!("Sandbox type: {sandbox_type:?}");
|
||||
let codex_linux_sandbox_exe = self.config.codex_linux_sandbox_exe.clone();
|
||||
let outgoing = self.outgoing.clone();
|
||||
let req_id = request_id;
|
||||
@@ -1197,7 +1194,6 @@ impl CodexMessageProcessor {
|
||||
tokio::spawn(async move {
|
||||
match codex_core::exec::process_exec_tool_call(
|
||||
exec_params,
|
||||
sandbox_type,
|
||||
&effective_policy,
|
||||
sandbox_cwd.as_path(),
|
||||
&codex_linux_sandbox_exe,
|
||||
@@ -2482,7 +2478,10 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
// Emit v2 turn/started notification.
|
||||
let notif = TurnStartedNotification { turn };
|
||||
let notif = TurnStartedNotification {
|
||||
thread_id: params.thread_id,
|
||||
turn,
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::TurnStarted(notif))
|
||||
.await;
|
||||
@@ -2540,7 +2539,7 @@ impl CodexMessageProcessor {
|
||||
let response = TurnStartResponse { turn: turn.clone() };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
let notif = TurnStartedNotification { turn };
|
||||
let notif = TurnStartedNotification { thread_id, turn };
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::TurnStarted(notif))
|
||||
.await;
|
||||
@@ -2803,6 +2802,7 @@ impl CodexMessageProcessor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let session_source = self.conversation_manager.session_source();
|
||||
|
||||
let upload_result = tokio::task::spawn_blocking(move || {
|
||||
let rollout_path_ref = validated_rollout_path.as_deref();
|
||||
@@ -2811,6 +2811,7 @@ impl CodexMessageProcessor {
|
||||
reason.as_deref(),
|
||||
include_logs,
|
||||
rollout_path_ref,
|
||||
Some(session_source),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
974
codex-rs/app-server/src/config_api.rs
Normal file
974
codex-rs/app-server/src/config_api.rs
Normal file
@@ -0,0 +1,974 @@
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use anyhow::anyhow;
|
||||
use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigLayer;
|
||||
use codex_app_server_protocol::ConfigLayerMetadata;
|
||||
use codex_app_server_protocol::ConfigLayerName;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigReadResponse;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::ConfigWriteErrorCode;
|
||||
use codex_app_server_protocol::ConfigWriteResponse;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::MergeStrategy;
|
||||
use codex_app_server_protocol::OverriddenMetadata;
|
||||
use codex_app_server_protocol::WriteStatus;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config_loader::LoadedConfigLayers;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_core::config_loader::load_config_layers_with_overrides;
|
||||
use codex_core::config_loader::merge_toml_values;
|
||||
use serde_json::Value as JsonValue;
|
||||
use serde_json::json;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio::task;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
const SESSION_FLAGS_SOURCE: &str = "--config";
|
||||
const MDM_SOURCE: &str = "com.openai.codex/config_toml_base64";
|
||||
const CONFIG_FILE_NAME: &str = "config.toml";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConfigApi {
|
||||
codex_home: PathBuf,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
loader_overrides: LoaderOverrides,
|
||||
}
|
||||
|
||||
impl ConfigApi {
|
||||
pub(crate) fn new(codex_home: PathBuf, cli_overrides: Vec<(String, TomlValue)>) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
cli_overrides,
|
||||
loader_overrides: LoaderOverrides::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn with_overrides(
|
||||
codex_home: PathBuf,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
loader_overrides: LoaderOverrides,
|
||||
) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
cli_overrides,
|
||||
loader_overrides,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn read(
|
||||
&self,
|
||||
params: ConfigReadParams,
|
||||
) -> Result<ConfigReadResponse, JSONRPCErrorError> {
|
||||
let layers = self
|
||||
.load_layers_state()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read configuration layers", err))?;
|
||||
|
||||
let effective = layers.effective_config();
|
||||
validate_config(&effective).map_err(|err| internal_error("invalid configuration", err))?;
|
||||
|
||||
let response = ConfigReadResponse {
|
||||
config: to_json_value(&effective),
|
||||
origins: layers.origins(),
|
||||
layers: params.include_layers.then(|| layers.layers_high_to_low()),
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub(crate) async fn write_value(
|
||||
&self,
|
||||
params: ConfigValueWriteParams,
|
||||
) -> Result<ConfigWriteResponse, JSONRPCErrorError> {
|
||||
let edits = vec![(params.key_path, params.value, params.merge_strategy)];
|
||||
self.apply_edits(params.file_path, params.expected_version, edits)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn batch_write(
|
||||
&self,
|
||||
params: ConfigBatchWriteParams,
|
||||
) -> Result<ConfigWriteResponse, JSONRPCErrorError> {
|
||||
let edits = params
|
||||
.edits
|
||||
.into_iter()
|
||||
.map(|edit| (edit.key_path, edit.value, edit.merge_strategy))
|
||||
.collect();
|
||||
|
||||
self.apply_edits(params.file_path, params.expected_version, edits)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn apply_edits(
|
||||
&self,
|
||||
file_path: String,
|
||||
expected_version: Option<String>,
|
||||
edits: Vec<(String, JsonValue, MergeStrategy)>,
|
||||
) -> Result<ConfigWriteResponse, JSONRPCErrorError> {
|
||||
let allowed_path = self.codex_home.join(CONFIG_FILE_NAME);
|
||||
if !paths_match(&allowed_path, &file_path) {
|
||||
return Err(config_write_error(
|
||||
ConfigWriteErrorCode::ConfigLayerReadonly,
|
||||
"Only writes to the user config are allowed",
|
||||
));
|
||||
}
|
||||
|
||||
let layers = self
|
||||
.load_layers_state()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to load configuration", err))?;
|
||||
|
||||
if let Some(expected) = expected_version.as_deref()
|
||||
&& expected != layers.user.version
|
||||
{
|
||||
return Err(config_write_error(
|
||||
ConfigWriteErrorCode::ConfigVersionConflict,
|
||||
"Configuration was modified since last read. Fetch latest version and retry.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut user_config = layers.user.config.clone();
|
||||
let mut mutated = false;
|
||||
let mut parsed_segments = Vec::new();
|
||||
|
||||
for (key_path, value, strategy) in edits.into_iter() {
|
||||
let segments = parse_key_path(&key_path).map_err(|message| {
|
||||
config_write_error(ConfigWriteErrorCode::ConfigValidationError, message)
|
||||
})?;
|
||||
let parsed_value = parse_value(value).map_err(|message| {
|
||||
config_write_error(ConfigWriteErrorCode::ConfigValidationError, message)
|
||||
})?;
|
||||
|
||||
let changed = apply_merge(&mut user_config, &segments, parsed_value.as_ref(), strategy)
|
||||
.map_err(|err| match err {
|
||||
MergeError::PathNotFound => config_write_error(
|
||||
ConfigWriteErrorCode::ConfigPathNotFound,
|
||||
"Path not found",
|
||||
),
|
||||
MergeError::Validation(message) => {
|
||||
config_write_error(ConfigWriteErrorCode::ConfigValidationError, message)
|
||||
}
|
||||
})?;
|
||||
|
||||
mutated |= changed;
|
||||
parsed_segments.push(segments);
|
||||
}
|
||||
|
||||
validate_config(&user_config).map_err(|err| {
|
||||
config_write_error(
|
||||
ConfigWriteErrorCode::ConfigValidationError,
|
||||
format!("Invalid configuration: {err}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let updated_layers = layers.with_user_config(user_config.clone());
|
||||
let effective = updated_layers.effective_config();
|
||||
validate_config(&effective).map_err(|err| {
|
||||
config_write_error(
|
||||
ConfigWriteErrorCode::ConfigValidationError,
|
||||
format!("Invalid configuration: {err}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if mutated {
|
||||
self.persist_user_config(&user_config)
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to persist config.toml", err))?;
|
||||
}
|
||||
|
||||
let overridden = first_overridden_edit(&updated_layers, &effective, &parsed_segments);
|
||||
let status = overridden
|
||||
.as_ref()
|
||||
.map(|_| WriteStatus::OkOverridden)
|
||||
.unwrap_or(WriteStatus::Ok);
|
||||
|
||||
Ok(ConfigWriteResponse {
|
||||
status,
|
||||
version: updated_layers.user.version.clone(),
|
||||
overridden_metadata: overridden,
|
||||
})
|
||||
}
|
||||
|
||||
async fn load_layers_state(&self) -> std::io::Result<LayersState> {
|
||||
let LoadedConfigLayers {
|
||||
base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = load_config_layers_with_overrides(&self.codex_home, self.loader_overrides.clone())
|
||||
.await?;
|
||||
|
||||
let user = LayerState::new(
|
||||
ConfigLayerName::User,
|
||||
self.codex_home.join(CONFIG_FILE_NAME),
|
||||
base,
|
||||
);
|
||||
|
||||
let session_flags = LayerState::new(
|
||||
ConfigLayerName::SessionFlags,
|
||||
PathBuf::from(SESSION_FLAGS_SOURCE),
|
||||
{
|
||||
let mut root = TomlValue::Table(toml::map::Map::new());
|
||||
for (path, value) in self.cli_overrides.iter() {
|
||||
apply_override(&mut root, path, value.clone());
|
||||
}
|
||||
root
|
||||
},
|
||||
);
|
||||
|
||||
let system = managed_config.map(|cfg| {
|
||||
LayerState::new(
|
||||
ConfigLayerName::System,
|
||||
system_config_path(&self.codex_home),
|
||||
cfg,
|
||||
)
|
||||
});
|
||||
|
||||
let mdm = managed_preferences
|
||||
.map(|cfg| LayerState::new(ConfigLayerName::Mdm, PathBuf::from(MDM_SOURCE), cfg));
|
||||
|
||||
Ok(LayersState {
|
||||
user,
|
||||
session_flags,
|
||||
system,
|
||||
mdm,
|
||||
})
|
||||
}
|
||||
|
||||
async fn persist_user_config(&self, user_config: &TomlValue) -> anyhow::Result<()> {
|
||||
let codex_home = self.codex_home.clone();
|
||||
let serialized = toml::to_string_pretty(user_config)?;
|
||||
|
||||
task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
std::fs::create_dir_all(&codex_home)?;
|
||||
|
||||
let target = codex_home.join(CONFIG_FILE_NAME);
|
||||
let tmp = NamedTempFile::new_in(&codex_home)?;
|
||||
std::fs::write(tmp.path(), serialized.as_bytes())?;
|
||||
tmp.persist(&target)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(|err| anyhow!("config persistence task panicked: {err}"))??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_value(value: JsonValue) -> Result<Option<TomlValue>, String> {
|
||||
if value.is_null() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
serde_json::from_value::<TomlValue>(value)
|
||||
.map(Some)
|
||||
.map_err(|err| format!("invalid value: {err}"))
|
||||
}
|
||||
|
||||
fn parse_key_path(path: &str) -> Result<Vec<String>, String> {
|
||||
if path.trim().is_empty() {
|
||||
return Err("keyPath must not be empty".to_string());
|
||||
}
|
||||
Ok(path
|
||||
.split('.')
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn apply_override(target: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
use toml::value::Table;
|
||||
|
||||
let segments: Vec<&str> = path.split('.').collect();
|
||||
let mut current = target;
|
||||
|
||||
for (idx, segment) in segments.iter().enumerate() {
|
||||
let is_last = idx == segments.len() - 1;
|
||||
|
||||
if is_last {
|
||||
match current {
|
||||
TomlValue::Table(table) => {
|
||||
table.insert(segment.to_string(), value);
|
||||
}
|
||||
_ => {
|
||||
let mut table = Table::new();
|
||||
table.insert(segment.to_string(), value);
|
||||
*current = TomlValue::Table(table);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
match current {
|
||||
TomlValue::Table(table) => {
|
||||
current = table
|
||||
.entry((*segment).to_string())
|
||||
.or_insert_with(|| TomlValue::Table(Table::new()));
|
||||
}
|
||||
_ => {
|
||||
*current = TomlValue::Table(Table::new());
|
||||
if let TomlValue::Table(tbl) = current {
|
||||
current = tbl
|
||||
.entry((*segment).to_string())
|
||||
.or_insert_with(|| TomlValue::Table(Table::new()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum MergeError {
|
||||
PathNotFound,
|
||||
Validation(String),
|
||||
}
|
||||
|
||||
fn apply_merge(
|
||||
root: &mut TomlValue,
|
||||
segments: &[String],
|
||||
value: Option<&TomlValue>,
|
||||
strategy: MergeStrategy,
|
||||
) -> Result<bool, MergeError> {
|
||||
let Some(value) = value else {
|
||||
return clear_path(root, segments);
|
||||
};
|
||||
|
||||
let Some((last, parents)) = segments.split_last() else {
|
||||
return Err(MergeError::Validation(
|
||||
"keyPath must not be empty".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let mut current = root;
|
||||
|
||||
for segment in parents {
|
||||
match current {
|
||||
TomlValue::Table(table) => {
|
||||
current = table
|
||||
.entry(segment.clone())
|
||||
.or_insert_with(|| TomlValue::Table(toml::map::Map::new()));
|
||||
}
|
||||
_ => {
|
||||
*current = TomlValue::Table(toml::map::Map::new());
|
||||
if let TomlValue::Table(table) = current {
|
||||
current = table
|
||||
.entry(segment.clone())
|
||||
.or_insert_with(|| TomlValue::Table(toml::map::Map::new()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let table = current.as_table_mut().ok_or_else(|| {
|
||||
MergeError::Validation("cannot set value on non-table parent".to_string())
|
||||
})?;
|
||||
|
||||
if matches!(strategy, MergeStrategy::Upsert)
|
||||
&& let Some(existing) = table.get_mut(last)
|
||||
&& matches!(existing, TomlValue::Table(_))
|
||||
&& matches!(value, TomlValue::Table(_))
|
||||
{
|
||||
merge_toml_values(existing, value);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let changed = table
|
||||
.get(last)
|
||||
.map(|existing| Some(existing) != Some(value))
|
||||
.unwrap_or(true);
|
||||
table.insert(last.clone(), value.clone());
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
fn clear_path(root: &mut TomlValue, segments: &[String]) -> Result<bool, MergeError> {
|
||||
let Some((last, parents)) = segments.split_last() else {
|
||||
return Err(MergeError::Validation(
|
||||
"keyPath must not be empty".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let mut current = root;
|
||||
for segment in parents {
|
||||
match current {
|
||||
TomlValue::Table(table) => {
|
||||
current = table.get_mut(segment).ok_or(MergeError::PathNotFound)?;
|
||||
}
|
||||
_ => return Err(MergeError::PathNotFound),
|
||||
}
|
||||
}
|
||||
|
||||
let Some(parent) = current.as_table_mut() else {
|
||||
return Err(MergeError::PathNotFound);
|
||||
};
|
||||
|
||||
Ok(parent.remove(last).is_some())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LayerState {
|
||||
name: ConfigLayerName,
|
||||
source: PathBuf,
|
||||
config: TomlValue,
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl LayerState {
|
||||
fn new(name: ConfigLayerName, source: PathBuf, config: TomlValue) -> Self {
|
||||
let version = version_for_toml(&config);
|
||||
Self {
|
||||
name,
|
||||
source,
|
||||
config,
|
||||
version,
|
||||
}
|
||||
}
|
||||
|
||||
fn metadata(&self) -> ConfigLayerMetadata {
|
||||
ConfigLayerMetadata {
|
||||
name: self.name.clone(),
|
||||
source: self.source.display().to_string(),
|
||||
version: self.version.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_layer(&self) -> ConfigLayer {
|
||||
ConfigLayer {
|
||||
name: self.name.clone(),
|
||||
source: self.source.display().to_string(),
|
||||
version: self.version.clone(),
|
||||
config: to_json_value(&self.config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LayersState {
|
||||
user: LayerState,
|
||||
session_flags: LayerState,
|
||||
system: Option<LayerState>,
|
||||
mdm: Option<LayerState>,
|
||||
}
|
||||
|
||||
impl LayersState {
|
||||
fn with_user_config(self, user_config: TomlValue) -> Self {
|
||||
Self {
|
||||
user: LayerState::new(self.user.name, self.user.source, user_config),
|
||||
session_flags: self.session_flags,
|
||||
system: self.system,
|
||||
mdm: self.mdm,
|
||||
}
|
||||
}
|
||||
|
||||
fn effective_config(&self) -> TomlValue {
|
||||
let mut merged = self.user.config.clone();
|
||||
merge_toml_values(&mut merged, &self.session_flags.config);
|
||||
if let Some(system) = &self.system {
|
||||
merge_toml_values(&mut merged, &system.config);
|
||||
}
|
||||
if let Some(mdm) = &self.mdm {
|
||||
merge_toml_values(&mut merged, &mdm.config);
|
||||
}
|
||||
merged
|
||||
}
|
||||
|
||||
fn origins(&self) -> HashMap<String, ConfigLayerMetadata> {
|
||||
let mut origins = HashMap::new();
|
||||
let mut path = Vec::new();
|
||||
|
||||
record_origins(
|
||||
&self.user.config,
|
||||
&self.user.metadata(),
|
||||
&mut path,
|
||||
&mut origins,
|
||||
);
|
||||
record_origins(
|
||||
&self.session_flags.config,
|
||||
&self.session_flags.metadata(),
|
||||
&mut path,
|
||||
&mut origins,
|
||||
);
|
||||
if let Some(system) = &self.system {
|
||||
record_origins(&system.config, &system.metadata(), &mut path, &mut origins);
|
||||
}
|
||||
if let Some(mdm) = &self.mdm {
|
||||
record_origins(&mdm.config, &mdm.metadata(), &mut path, &mut origins);
|
||||
}
|
||||
|
||||
origins
|
||||
}
|
||||
|
||||
fn layers_high_to_low(&self) -> Vec<ConfigLayer> {
|
||||
let mut layers = Vec::new();
|
||||
if let Some(mdm) = &self.mdm {
|
||||
layers.push(mdm.as_layer());
|
||||
}
|
||||
if let Some(system) = &self.system {
|
||||
layers.push(system.as_layer());
|
||||
}
|
||||
layers.push(self.session_flags.as_layer());
|
||||
layers.push(self.user.as_layer());
|
||||
layers
|
||||
}
|
||||
}
|
||||
|
||||
fn record_origins(
|
||||
value: &TomlValue,
|
||||
meta: &ConfigLayerMetadata,
|
||||
path: &mut Vec<String>,
|
||||
origins: &mut HashMap<String, ConfigLayerMetadata>,
|
||||
) {
|
||||
match value {
|
||||
TomlValue::Table(table) => {
|
||||
for (key, val) in table {
|
||||
path.push(key.clone());
|
||||
record_origins(val, meta, path, origins);
|
||||
path.pop();
|
||||
}
|
||||
}
|
||||
TomlValue::Array(items) => {
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
path.push(idx.to_string());
|
||||
record_origins(item, meta, path, origins);
|
||||
path.pop();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if !path.is_empty() {
|
||||
origins.insert(path.join("."), meta.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_json_value(value: &TomlValue) -> JsonValue {
|
||||
serde_json::to_value(value).unwrap_or(JsonValue::Null)
|
||||
}
|
||||
|
||||
fn validate_config(value: &TomlValue) -> Result<(), toml::de::Error> {
|
||||
let _: ConfigToml = value.clone().try_into()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn version_for_toml(value: &TomlValue) -> String {
|
||||
let json = to_json_value(value);
|
||||
let canonical = canonical_json(&json);
|
||||
let serialized = serde_json::to_vec(&canonical).unwrap_or_default();
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(serialized);
|
||||
let hash = hasher.finalize();
|
||||
let hex = hash
|
||||
.iter()
|
||||
.map(|byte| format!("{byte:02x}"))
|
||||
.collect::<String>();
|
||||
format!("sha256:{hex}")
|
||||
}
|
||||
|
||||
fn canonical_json(value: &JsonValue) -> JsonValue {
|
||||
match value {
|
||||
JsonValue::Object(map) => {
|
||||
let mut sorted = serde_json::Map::new();
|
||||
let mut keys = map.keys().cloned().collect::<Vec<_>>();
|
||||
keys.sort();
|
||||
for key in keys {
|
||||
if let Some(val) = map.get(&key) {
|
||||
sorted.insert(key, canonical_json(val));
|
||||
}
|
||||
}
|
||||
JsonValue::Object(sorted)
|
||||
}
|
||||
JsonValue::Array(items) => JsonValue::Array(items.iter().map(canonical_json).collect()),
|
||||
other => other.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn paths_match(expected: &Path, provided: &str) -> bool {
|
||||
let provided_path = PathBuf::from(provided);
|
||||
if let (Ok(expanded_expected), Ok(expanded_provided)) =
|
||||
(expected.canonicalize(), provided_path.canonicalize())
|
||||
{
|
||||
return expanded_expected == expanded_provided;
|
||||
}
|
||||
|
||||
expected == provided_path
|
||||
}
|
||||
|
||||
fn value_at_path<'a>(root: &'a TomlValue, segments: &[String]) -> Option<&'a TomlValue> {
|
||||
let mut current = root;
|
||||
for segment in segments {
|
||||
match current {
|
||||
TomlValue::Table(table) => {
|
||||
current = table.get(segment)?;
|
||||
}
|
||||
TomlValue::Array(items) => {
|
||||
let idx: usize = segment.parse().ok()?;
|
||||
current = items.get(idx)?;
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
Some(current)
|
||||
}
|
||||
|
||||
fn override_message(layer: &ConfigLayerName) -> String {
|
||||
match layer {
|
||||
ConfigLayerName::Mdm => "Overridden by managed policy (mdm)".to_string(),
|
||||
ConfigLayerName::System => "Overridden by managed config (system)".to_string(),
|
||||
ConfigLayerName::SessionFlags => "Overridden by session flags".to_string(),
|
||||
ConfigLayerName::User => "Overridden by user config".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_override_metadata(
|
||||
layers: &LayersState,
|
||||
effective: &TomlValue,
|
||||
segments: &[String],
|
||||
) -> Option<OverriddenMetadata> {
|
||||
let user_value = value_at_path(&layers.user.config, segments);
|
||||
let effective_value = value_at_path(effective, segments);
|
||||
|
||||
if user_value.is_some() && user_value == effective_value {
|
||||
return None;
|
||||
}
|
||||
|
||||
if user_value.is_none() && effective_value.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let effective_layer = find_effective_layer(layers, segments);
|
||||
let overriding_layer = effective_layer.unwrap_or_else(|| layers.user.metadata());
|
||||
let message = override_message(&overriding_layer.name);
|
||||
|
||||
Some(OverriddenMetadata {
|
||||
message,
|
||||
overriding_layer,
|
||||
effective_value: effective_value
|
||||
.map(to_json_value)
|
||||
.unwrap_or(JsonValue::Null),
|
||||
})
|
||||
}
|
||||
|
||||
fn first_overridden_edit(
|
||||
layers: &LayersState,
|
||||
effective: &TomlValue,
|
||||
edits: &[Vec<String>],
|
||||
) -> Option<OverriddenMetadata> {
|
||||
for segments in edits {
|
||||
if let Some(meta) = compute_override_metadata(layers, effective, segments) {
|
||||
return Some(meta);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn find_effective_layer(layers: &LayersState, segments: &[String]) -> Option<ConfigLayerMetadata> {
|
||||
let check =
|
||||
|state: &LayerState| value_at_path(&state.config, segments).map(|_| state.metadata());
|
||||
|
||||
if let Some(mdm) = &layers.mdm
|
||||
&& let Some(meta) = check(mdm)
|
||||
{
|
||||
return Some(meta);
|
||||
}
|
||||
if let Some(system) = &layers.system
|
||||
&& let Some(meta) = check(system)
|
||||
{
|
||||
return Some(meta);
|
||||
}
|
||||
if let Some(meta) = check(&layers.session_flags) {
|
||||
return Some(meta);
|
||||
}
|
||||
check(&layers.user)
|
||||
}
|
||||
|
||||
fn system_config_path(codex_home: &Path) -> PathBuf {
|
||||
if let Ok(path) = std::env::var("CODEX_MANAGED_CONFIG_PATH") {
|
||||
return PathBuf::from(path);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let _ = codex_home;
|
||||
PathBuf::from("/etc/codex/managed_config.toml")
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
codex_home.join("managed_config.toml")
|
||||
}
|
||||
}
|
||||
|
||||
fn internal_error<E: std::fmt::Display>(context: &str, err: E) -> JSONRPCErrorError {
|
||||
JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("{context}: {err}"),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn config_write_error(code: ConfigWriteErrorCode, message: impl Into<String>) -> JSONRPCErrorError {
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: message.into(),
|
||||
data: Some(json!({
|
||||
"config_write_error_code": code,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_includes_origins_and_layers() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(tmp.path().join(CONFIG_FILE_NAME), "model = \"user\"").unwrap();
|
||||
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap();
|
||||
|
||||
let api = ConfigApi::with_overrides(
|
||||
tmp.path().to_path_buf(),
|
||||
vec![],
|
||||
LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
},
|
||||
);
|
||||
|
||||
let response = api
|
||||
.read(ConfigReadParams {
|
||||
include_layers: true,
|
||||
})
|
||||
.await
|
||||
.expect("response");
|
||||
|
||||
assert_eq!(
|
||||
response.config.get("approval_policy"),
|
||||
Some(&json!("never"))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
response
|
||||
.origins
|
||||
.get("approval_policy")
|
||||
.expect("origin")
|
||||
.name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
let layers = response.layers.expect("layers present");
|
||||
assert_eq!(layers.first().unwrap().name, ConfigLayerName::System);
|
||||
assert_eq!(layers.get(1).unwrap().name, ConfigLayerName::SessionFlags);
|
||||
assert_eq!(layers.last().unwrap().name, ConfigLayerName::User);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_value_reports_override() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_FILE_NAME),
|
||||
"approval_policy = \"on-request\"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap();
|
||||
|
||||
let api = ConfigApi::with_overrides(
|
||||
tmp.path().to_path_buf(),
|
||||
vec![],
|
||||
LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
},
|
||||
);
|
||||
|
||||
let result = api
|
||||
.write_value(ConfigValueWriteParams {
|
||||
file_path: tmp.path().join(CONFIG_FILE_NAME).display().to_string(),
|
||||
key_path: "approval_policy".to_string(),
|
||||
value: json!("never"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version: None,
|
||||
})
|
||||
.await
|
||||
.expect("result");
|
||||
|
||||
let read_after = api
|
||||
.read(ConfigReadParams {
|
||||
include_layers: true,
|
||||
})
|
||||
.await
|
||||
.expect("read");
|
||||
let config_object = read_after.config.as_object().expect("object");
|
||||
assert_eq!(config_object.get("approval_policy"), Some(&json!("never")));
|
||||
assert_eq!(
|
||||
read_after
|
||||
.origins
|
||||
.get("approval_policy")
|
||||
.expect("origin")
|
||||
.name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
assert_eq!(result.status, WriteStatus::Ok);
|
||||
assert!(result.overridden_metadata.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn version_conflict_rejected() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(tmp.path().join(CONFIG_FILE_NAME), "model = \"user\"").unwrap();
|
||||
|
||||
let api = ConfigApi::new(tmp.path().to_path_buf(), vec![]);
|
||||
let error = api
|
||||
.write_value(ConfigValueWriteParams {
|
||||
file_path: tmp.path().join(CONFIG_FILE_NAME).display().to_string(),
|
||||
key_path: "model".to_string(),
|
||||
value: json!("gpt-5"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version: Some("sha256:bogus".to_string()),
|
||||
})
|
||||
.await
|
||||
.expect_err("should fail");
|
||||
|
||||
assert_eq!(error.code, INVALID_REQUEST_ERROR_CODE);
|
||||
assert_eq!(
|
||||
error
|
||||
.data
|
||||
.as_ref()
|
||||
.and_then(|d| d.get("config_write_error_code"))
|
||||
.and_then(serde_json::Value::as_str),
|
||||
Some("configVersionConflict")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_user_value_rejected_even_if_overridden_by_managed() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(tmp.path().join(CONFIG_FILE_NAME), "model = \"user\"").unwrap();
|
||||
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap();
|
||||
|
||||
let api = ConfigApi::with_overrides(
|
||||
tmp.path().to_path_buf(),
|
||||
vec![],
|
||||
LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
},
|
||||
);
|
||||
|
||||
let error = api
|
||||
.write_value(ConfigValueWriteParams {
|
||||
file_path: tmp.path().join(CONFIG_FILE_NAME).display().to_string(),
|
||||
key_path: "approval_policy".to_string(),
|
||||
value: json!("bogus"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version: None,
|
||||
})
|
||||
.await
|
||||
.expect_err("should fail validation");
|
||||
|
||||
assert_eq!(error.code, INVALID_REQUEST_ERROR_CODE);
|
||||
assert_eq!(
|
||||
error
|
||||
.data
|
||||
.as_ref()
|
||||
.and_then(|d| d.get("config_write_error_code"))
|
||||
.and_then(serde_json::Value::as_str),
|
||||
Some("configValidationError")
|
||||
);
|
||||
|
||||
let contents =
|
||||
std::fs::read_to_string(tmp.path().join(CONFIG_FILE_NAME)).expect("read config");
|
||||
assert_eq!(contents.trim(), "model = \"user\"");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_reports_managed_overrides_user_and_session_flags() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(tmp.path().join(CONFIG_FILE_NAME), "model = \"user\"").unwrap();
|
||||
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
std::fs::write(&managed_path, "model = \"system\"").unwrap();
|
||||
|
||||
let cli_overrides = vec![(
|
||||
"model".to_string(),
|
||||
TomlValue::String("session".to_string()),
|
||||
)];
|
||||
|
||||
let api = ConfigApi::with_overrides(
|
||||
tmp.path().to_path_buf(),
|
||||
cli_overrides,
|
||||
LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
},
|
||||
);
|
||||
|
||||
let response = api
|
||||
.read(ConfigReadParams {
|
||||
include_layers: true,
|
||||
})
|
||||
.await
|
||||
.expect("response");
|
||||
|
||||
assert_eq!(response.config.get("model"), Some(&json!("system")));
|
||||
assert_eq!(
|
||||
response.origins.get("model").expect("origin").name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
let layers = response.layers.expect("layers");
|
||||
assert_eq!(layers.first().unwrap().name, ConfigLayerName::System);
|
||||
assert_eq!(layers.get(1).unwrap().name, ConfigLayerName::SessionFlags);
|
||||
assert_eq!(layers.get(2).unwrap().name, ConfigLayerName::User);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_value_reports_managed_override() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
std::fs::write(tmp.path().join(CONFIG_FILE_NAME), "").unwrap();
|
||||
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap();
|
||||
|
||||
let api = ConfigApi::with_overrides(
|
||||
tmp.path().to_path_buf(),
|
||||
vec![],
|
||||
LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
},
|
||||
);
|
||||
|
||||
let result = api
|
||||
.write_value(ConfigValueWriteParams {
|
||||
file_path: tmp.path().join(CONFIG_FILE_NAME).display().to_string(),
|
||||
key_path: "approval_policy".to_string(),
|
||||
value: json!("on-request"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version: None,
|
||||
})
|
||||
.await
|
||||
.expect("result");
|
||||
|
||||
assert_eq!(result.status, WriteStatus::OkOverridden);
|
||||
let overridden = result.overridden_metadata.expect("overridden metadata");
|
||||
assert_eq!(overridden.overriding_layer.name, ConfigLayerName::System);
|
||||
assert_eq!(overridden.effective_value, json!("never"));
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::{self};
|
||||
use tokio::sync::mpsc;
|
||||
use toml::Value as TomlValue;
|
||||
use tracing::Level;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
@@ -30,6 +31,7 @@ use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
mod bespoke_event_handling;
|
||||
mod codex_message_processor;
|
||||
mod config_api;
|
||||
mod error_code;
|
||||
mod fuzzy_file_search;
|
||||
mod message_processor;
|
||||
@@ -80,11 +82,12 @@ pub async fn run_main(
|
||||
format!("error parsing -c overrides: {e}"),
|
||||
)
|
||||
})?;
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}"))
|
||||
})?;
|
||||
let config =
|
||||
Config::load_with_cli_overrides(cli_kv_overrides.clone(), ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}"))
|
||||
})?;
|
||||
|
||||
let feedback = CodexFeedback::new();
|
||||
|
||||
@@ -121,10 +124,12 @@ pub async fn run_main(
|
||||
// Task: process incoming messages.
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone();
|
||||
let mut processor = MessageProcessor::new(
|
||||
outgoing_message_sender,
|
||||
codex_linux_sandbox_exe,
|
||||
std::sync::Arc::new(config),
|
||||
cli_overrides,
|
||||
feedback.clone(),
|
||||
);
|
||||
async move {
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex_message_processor::CodexMessageProcessor;
|
||||
use crate::config_api::ConfigApi;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::ClientRequest;
|
||||
use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::InitializeResponse;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::config::Config;
|
||||
@@ -18,11 +24,12 @@ use codex_core::default_client::USER_AGENT_SUFFIX;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_feedback::CodexFeedback;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use std::sync::Arc;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex_message_processor: CodexMessageProcessor,
|
||||
config_api: ConfigApi,
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
@@ -33,6 +40,7 @@ impl MessageProcessor {
|
||||
outgoing: OutgoingMessageSender,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
feedback: CodexFeedback,
|
||||
) -> Self {
|
||||
let outgoing = Arc::new(outgoing);
|
||||
@@ -50,13 +58,15 @@ impl MessageProcessor {
|
||||
conversation_manager,
|
||||
outgoing.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
config,
|
||||
Arc::clone(&config),
|
||||
feedback,
|
||||
);
|
||||
let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides);
|
||||
|
||||
Self {
|
||||
outgoing,
|
||||
codex_message_processor,
|
||||
config_api,
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
@@ -134,9 +144,20 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
self.codex_message_processor
|
||||
.process_request(codex_request)
|
||||
.await;
|
||||
match codex_request {
|
||||
ClientRequest::ConfigRead { request_id, params } => {
|
||||
self.handle_config_read(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ConfigValueWrite { request_id, params } => {
|
||||
self.handle_config_value_write(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ConfigBatchWrite { request_id, params } => {
|
||||
self.handle_config_batch_write(request_id, params).await;
|
||||
}
|
||||
other => {
|
||||
self.codex_message_processor.process_request(other).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_notification(&self, notification: JSONRPCNotification) {
|
||||
@@ -156,4 +177,33 @@ impl MessageProcessor {
|
||||
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
|
||||
tracing::error!("<- error: {:?}", err);
|
||||
}
|
||||
|
||||
async fn handle_config_read(&self, request_id: RequestId, params: ConfigReadParams) {
|
||||
match self.config_api.read(params).await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_config_value_write(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ConfigValueWriteParams,
|
||||
) {
|
||||
match self.config_api.write_value(params).await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_config_batch_write(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ConfigBatchWriteParams,
|
||||
) {
|
||||
match self.config_api.batch_write(params).await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "app_test_support"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
|
||||
@@ -15,6 +15,7 @@ pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use mock_model_server::create_mock_chat_completions_server_unchecked;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
pub use responses::create_exec_command_sse_response;
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_shell_command_sse_response;
|
||||
pub use rollout::create_fake_rollout;
|
||||
|
||||
@@ -18,6 +18,9 @@ use codex_app_server_protocol::CancelLoginAccountParams;
|
||||
use codex_app_server_protocol::CancelLoginChatGptParams;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::ClientNotification;
|
||||
use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::FeedbackUploadParams;
|
||||
use codex_app_server_protocol::GetAccountParams;
|
||||
use codex_app_server_protocol::GetAuthStatusParams;
|
||||
@@ -401,6 +404,30 @@ impl McpProcess {
|
||||
self.send_request("logoutChatGpt", None).await
|
||||
}
|
||||
|
||||
pub async fn send_config_read_request(
|
||||
&mut self,
|
||||
params: ConfigReadParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("config/read", params).await
|
||||
}
|
||||
|
||||
pub async fn send_config_value_write_request(
|
||||
&mut self,
|
||||
params: ConfigValueWriteParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("config/value/write", params).await
|
||||
}
|
||||
|
||||
pub async fn send_config_batch_write_request(
|
||||
&mut self,
|
||||
params: ConfigBatchWriteParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("config/batchWrite", params).await
|
||||
}
|
||||
|
||||
/// Send an `account/logout` JSON-RPC request.
|
||||
pub async fn send_logout_account_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("account/logout", None).await
|
||||
|
||||
@@ -94,3 +94,42 @@ pub fn create_apply_patch_sse_response(
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_exec_command_sse_response(call_id: &str) -> anyhow::Result<String> {
|
||||
let (cmd, args) = if cfg!(windows) {
|
||||
("cmd.exe", vec!["/d", "/c", "echo hi"])
|
||||
} else {
|
||||
("/bin/sh", vec!["-c", "echo hi"])
|
||||
};
|
||||
let command = std::iter::once(cmd.to_string())
|
||||
.chain(args.into_iter().map(str::to_string))
|
||||
.collect::<Vec<_>>();
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"cmd": command.join(" "),
|
||||
"yield_time_ms": 500
|
||||
}))?;
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "exec_command",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
@@ -272,40 +272,45 @@ async fn read_raw_response_item(
|
||||
mcp: &mut McpProcess,
|
||||
conversation_id: ConversationId,
|
||||
) -> ResponseItem {
|
||||
let raw_notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/raw_response_item"),
|
||||
)
|
||||
.await
|
||||
.expect("codex/event/raw_response_item notification timeout")
|
||||
.expect("codex/event/raw_response_item notification resp");
|
||||
loop {
|
||||
let raw_notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/raw_response_item"),
|
||||
)
|
||||
.await
|
||||
.expect("codex/event/raw_response_item notification timeout")
|
||||
.expect("codex/event/raw_response_item notification resp");
|
||||
|
||||
let serde_json::Value::Object(params) = raw_notification
|
||||
.params
|
||||
.expect("codex/event/raw_response_item should have params")
|
||||
else {
|
||||
panic!("codex/event/raw_response_item should have params");
|
||||
};
|
||||
let serde_json::Value::Object(params) = raw_notification
|
||||
.params
|
||||
.expect("codex/event/raw_response_item should have params")
|
||||
else {
|
||||
panic!("codex/event/raw_response_item should have params");
|
||||
};
|
||||
|
||||
let conversation_id_value = params
|
||||
.get("conversationId")
|
||||
.and_then(|value| value.as_str())
|
||||
.expect("raw response item should include conversationId");
|
||||
let conversation_id_value = params
|
||||
.get("conversationId")
|
||||
.and_then(|value| value.as_str())
|
||||
.expect("raw response item should include conversationId");
|
||||
|
||||
assert_eq!(
|
||||
conversation_id_value,
|
||||
conversation_id.to_string(),
|
||||
"raw response item conversation mismatch"
|
||||
);
|
||||
assert_eq!(
|
||||
conversation_id_value,
|
||||
conversation_id.to_string(),
|
||||
"raw response item conversation mismatch"
|
||||
);
|
||||
|
||||
let msg_value = params
|
||||
.get("msg")
|
||||
.cloned()
|
||||
.expect("raw response item should include msg payload");
|
||||
let msg_value = params
|
||||
.get("msg")
|
||||
.cloned()
|
||||
.expect("raw response item should include msg payload");
|
||||
|
||||
let event: RawResponseItemEvent =
|
||||
serde_json::from_value(msg_value).expect("deserialize raw response item");
|
||||
event.item
|
||||
// Ghost snapshots are produced concurrently and may arrive before the model reply.
|
||||
let event: RawResponseItemEvent =
|
||||
serde_json::from_value(msg_value).expect("deserialize raw response item");
|
||||
if !matches!(event.item, ResponseItem::GhostSnapshot { .. }) {
|
||||
return event.item;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_instructions_message(item: &ResponseItem) {
|
||||
|
||||
347
codex-rs/app-server/tests/suite/v2/config_rpc.rs
Normal file
347
codex-rs/app-server/tests/suite/v2/config_rpc.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigEdit;
|
||||
use codex_app_server_protocol::ConfigLayerName;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigReadResponse;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::ConfigWriteResponse;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::MergeStrategy;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::WriteStatus;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
fn write_config(codex_home: &TempDir, contents: &str) -> Result<()> {
|
||||
Ok(std::fs::write(
|
||||
codex_home.path().join("config.toml"),
|
||||
contents,
|
||||
)?)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_read_returns_effective_and_layers() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(
|
||||
&codex_home,
|
||||
r#"
|
||||
model = "gpt-user"
|
||||
sandbox_mode = "workspace-write"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: true,
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ConfigReadResponse {
|
||||
config,
|
||||
origins,
|
||||
layers,
|
||||
} = to_response(resp)?;
|
||||
|
||||
assert_eq!(config.get("model"), Some(&json!("gpt-user")));
|
||||
assert_eq!(
|
||||
origins.get("model").expect("origin").name,
|
||||
ConfigLayerName::User
|
||||
);
|
||||
let layers = layers.expect("layers present");
|
||||
assert_eq!(layers.len(), 2);
|
||||
assert_eq!(layers[0].name, ConfigLayerName::SessionFlags);
|
||||
assert_eq!(layers[1].name, ConfigLayerName::User);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_read_includes_system_layer_and_overrides() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(
|
||||
&codex_home,
|
||||
r#"
|
||||
model = "gpt-user"
|
||||
approval_policy = "on-request"
|
||||
sandbox_mode = "workspace-write"
|
||||
|
||||
[sandbox_workspace_write]
|
||||
writable_roots = ["/user"]
|
||||
network_access = true
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"
|
||||
model = "gpt-system"
|
||||
approval_policy = "never"
|
||||
|
||||
[sandbox_workspace_write]
|
||||
writable_roots = ["/system"]
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let managed_path_str = managed_path.display().to_string();
|
||||
|
||||
let mut mcp = McpProcess::new_with_env(
|
||||
codex_home.path(),
|
||||
&[("CODEX_MANAGED_CONFIG_PATH", Some(&managed_path_str))],
|
||||
)
|
||||
.await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: true,
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ConfigReadResponse {
|
||||
config,
|
||||
origins,
|
||||
layers,
|
||||
} = to_response(resp)?;
|
||||
|
||||
assert_eq!(config.get("model"), Some(&json!("gpt-system")));
|
||||
assert_eq!(
|
||||
origins.get("model").expect("origin").name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
|
||||
assert_eq!(config.get("approval_policy"), Some(&json!("never")));
|
||||
assert_eq!(
|
||||
origins.get("approval_policy").expect("origin").name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
|
||||
assert_eq!(config.get("sandbox_mode"), Some(&json!("workspace-write")));
|
||||
assert_eq!(
|
||||
origins.get("sandbox_mode").expect("origin").name,
|
||||
ConfigLayerName::User
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
config
|
||||
.get("sandbox_workspace_write")
|
||||
.and_then(|v| v.get("writable_roots")),
|
||||
Some(&json!(["/system"]))
|
||||
);
|
||||
assert_eq!(
|
||||
origins
|
||||
.get("sandbox_workspace_write.writable_roots.0")
|
||||
.expect("origin")
|
||||
.name,
|
||||
ConfigLayerName::System
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
config
|
||||
.get("sandbox_workspace_write")
|
||||
.and_then(|v| v.get("network_access")),
|
||||
Some(&json!(true))
|
||||
);
|
||||
assert_eq!(
|
||||
origins
|
||||
.get("sandbox_workspace_write.network_access")
|
||||
.expect("origin")
|
||||
.name,
|
||||
ConfigLayerName::User
|
||||
);
|
||||
|
||||
let layers = layers.expect("layers present");
|
||||
assert_eq!(layers.len(), 3);
|
||||
assert_eq!(layers[0].name, ConfigLayerName::System);
|
||||
assert_eq!(layers[1].name, ConfigLayerName::SessionFlags);
|
||||
assert_eq!(layers[2].name, ConfigLayerName::User);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_value_write_replaces_value() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(
|
||||
&codex_home,
|
||||
r#"
|
||||
model = "gpt-old"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let read_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: false,
|
||||
})
|
||||
.await?;
|
||||
let read_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(read_id)),
|
||||
)
|
||||
.await??;
|
||||
let read: ConfigReadResponse = to_response(read_resp)?;
|
||||
let expected_version = read.origins.get("model").map(|m| m.version.clone());
|
||||
|
||||
let write_id = mcp
|
||||
.send_config_value_write_request(ConfigValueWriteParams {
|
||||
file_path: codex_home.path().join("config.toml").display().to_string(),
|
||||
key_path: "model".to_string(),
|
||||
value: json!("gpt-new"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version,
|
||||
})
|
||||
.await?;
|
||||
let write_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(write_id)),
|
||||
)
|
||||
.await??;
|
||||
let write: ConfigWriteResponse = to_response(write_resp)?;
|
||||
|
||||
assert_eq!(write.status, WriteStatus::Ok);
|
||||
assert!(write.overridden_metadata.is_none());
|
||||
|
||||
let verify_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: false,
|
||||
})
|
||||
.await?;
|
||||
let verify_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(verify_id)),
|
||||
)
|
||||
.await??;
|
||||
let verify: ConfigReadResponse = to_response(verify_resp)?;
|
||||
assert_eq!(verify.config.get("model"), Some(&json!("gpt-new")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_value_write_rejects_version_conflict() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(
|
||||
&codex_home,
|
||||
r#"
|
||||
model = "gpt-old"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let write_id = mcp
|
||||
.send_config_value_write_request(ConfigValueWriteParams {
|
||||
file_path: codex_home.path().join("config.toml").display().to_string(),
|
||||
key_path: "model".to_string(),
|
||||
value: json!("gpt-new"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
expected_version: Some("sha256:stale".to_string()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let err: JSONRPCError = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_error_message(RequestId::Integer(write_id)),
|
||||
)
|
||||
.await??;
|
||||
let code = err
|
||||
.error
|
||||
.data
|
||||
.as_ref()
|
||||
.and_then(|d| d.get("config_write_error_code"))
|
||||
.and_then(|v| v.as_str());
|
||||
assert_eq!(code, Some("configVersionConflict"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_batch_write_applies_multiple_edits() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(&codex_home, "")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let batch_id = mcp
|
||||
.send_config_batch_write_request(ConfigBatchWriteParams {
|
||||
file_path: codex_home.path().join("config.toml").display().to_string(),
|
||||
edits: vec![
|
||||
ConfigEdit {
|
||||
key_path: "sandbox_mode".to_string(),
|
||||
value: json!("workspace-write"),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
},
|
||||
ConfigEdit {
|
||||
key_path: "sandbox_workspace_write".to_string(),
|
||||
value: json!({
|
||||
"writable_roots": ["/tmp"],
|
||||
"network_access": false
|
||||
}),
|
||||
merge_strategy: MergeStrategy::Replace,
|
||||
},
|
||||
],
|
||||
expected_version: None,
|
||||
})
|
||||
.await?;
|
||||
let batch_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(batch_id)),
|
||||
)
|
||||
.await??;
|
||||
let batch_write: ConfigWriteResponse = to_response(batch_resp)?;
|
||||
assert_eq!(batch_write.status, WriteStatus::Ok);
|
||||
|
||||
let read_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: false,
|
||||
})
|
||||
.await?;
|
||||
let read_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(read_id)),
|
||||
)
|
||||
.await??;
|
||||
let read: ConfigReadResponse = to_response(read_resp)?;
|
||||
assert_eq!(
|
||||
read.config.get("sandbox_mode"),
|
||||
Some(&json!("workspace-write"))
|
||||
);
|
||||
assert_eq!(
|
||||
read.config
|
||||
.get("sandbox_workspace_write")
|
||||
.and_then(|v| v.get("writable_roots")),
|
||||
Some(&json!(["/tmp"]))
|
||||
);
|
||||
assert_eq!(
|
||||
read.config
|
||||
.get("sandbox_workspace_write")
|
||||
.and_then(|v| v.get("network_access")),
|
||||
Some(&json!(false))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod account;
|
||||
mod config_rpc;
|
||||
mod model_list;
|
||||
mod rate_limits;
|
||||
mod review;
|
||||
|
||||
@@ -88,10 +88,11 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> {
|
||||
// Give the command a brief moment to start.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
let thread_id = thread.id.clone();
|
||||
// Interrupt the in-progress turn by id (v2 API).
|
||||
let interrupt_id = mcp
|
||||
.send_turn_interrupt_request(TurnInterruptParams {
|
||||
thread_id: thread.id,
|
||||
thread_id: thread_id.clone(),
|
||||
turn_id: turn.id,
|
||||
})
|
||||
.await?;
|
||||
@@ -112,6 +113,7 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> {
|
||||
.params
|
||||
.expect("turn/completed params must be present"),
|
||||
)?;
|
||||
assert_eq!(completed.thread_id, thread_id);
|
||||
assert_eq!(completed.turn.status, TurnStatus::Interrupted);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_apply_patch_sse_response;
|
||||
use app_test_support::create_exec_command_sse_response;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_chat_completions_server_unchecked;
|
||||
@@ -95,6 +96,7 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<(
|
||||
.await??;
|
||||
let started: TurnStartedNotification =
|
||||
serde_json::from_value(notif.params.expect("params must be present"))?;
|
||||
assert_eq!(started.thread_id, thread.id);
|
||||
assert_eq!(
|
||||
started.turn.status,
|
||||
codex_app_server_protocol::TurnStatus::InProgress
|
||||
@@ -138,6 +140,7 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<(
|
||||
.params
|
||||
.expect("turn/completed params must be present"),
|
||||
)?;
|
||||
assert_eq!(completed.thread_id, thread.id);
|
||||
assert_eq!(completed.turn.status, TurnStatus::Completed);
|
||||
|
||||
Ok(())
|
||||
@@ -614,10 +617,6 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn turn_start_file_change_approval_v2() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
if cfg!(windows) {
|
||||
// TODO apply_patch approvals are not parsed from powershell commands yet
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tmp = TempDir::new()?;
|
||||
let codex_home = tmp.path().join("codex_home");
|
||||
@@ -764,10 +763,6 @@ async fn turn_start_file_change_approval_v2() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn turn_start_file_change_approval_decline_v2() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
if cfg!(windows) {
|
||||
// TODO apply_patch approvals are not parsed from powershell commands yet
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tmp = TempDir::new()?;
|
||||
let codex_home = tmp.path().join("codex_home");
|
||||
@@ -913,6 +908,134 @@ async fn turn_start_file_change_approval_decline_v2() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg_attr(windows, ignore = "process id reporting differs on Windows")]
|
||||
async fn command_execution_notifications_include_process_id() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let responses = vec![
|
||||
create_exec_command_sse_response("uexec-1")?,
|
||||
create_final_assistant_message_sse_response("done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri(), "never")?;
|
||||
let config_toml = codex_home.path().join("config.toml");
|
||||
let mut config_contents = std::fs::read_to_string(&config_toml)?;
|
||||
config_contents.push_str(
|
||||
r#"
|
||||
[features]
|
||||
unified_exec = true
|
||||
"#,
|
||||
);
|
||||
std::fs::write(&config_toml, config_contents)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let start_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let start_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
|
||||
|
||||
let turn_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: "run a command".to_string(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let turn_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_id)),
|
||||
)
|
||||
.await??;
|
||||
let TurnStartResponse { turn: _turn } = to_response::<TurnStartResponse>(turn_resp)?;
|
||||
|
||||
let started_command = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notif = mcp
|
||||
.read_stream_until_notification_message("item/started")
|
||||
.await?;
|
||||
let started: ItemStartedNotification = serde_json::from_value(
|
||||
notif
|
||||
.params
|
||||
.clone()
|
||||
.expect("item/started should include params"),
|
||||
)?;
|
||||
if let ThreadItem::CommandExecution { .. } = started.item {
|
||||
return Ok::<ThreadItem, anyhow::Error>(started.item);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let ThreadItem::CommandExecution {
|
||||
id,
|
||||
process_id: started_process_id,
|
||||
status,
|
||||
..
|
||||
} = started_command
|
||||
else {
|
||||
unreachable!("loop ensures we break on command execution items");
|
||||
};
|
||||
assert_eq!(id, "uexec-1");
|
||||
assert_eq!(status, CommandExecutionStatus::InProgress);
|
||||
let started_process_id = started_process_id.expect("process id should be present");
|
||||
|
||||
let completed_command = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notif = mcp
|
||||
.read_stream_until_notification_message("item/completed")
|
||||
.await?;
|
||||
let completed: ItemCompletedNotification = serde_json::from_value(
|
||||
notif
|
||||
.params
|
||||
.clone()
|
||||
.expect("item/completed should include params"),
|
||||
)?;
|
||||
if let ThreadItem::CommandExecution { .. } = completed.item {
|
||||
return Ok::<ThreadItem, anyhow::Error>(completed.item);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let ThreadItem::CommandExecution {
|
||||
id: completed_id,
|
||||
process_id: completed_process_id,
|
||||
status: completed_status,
|
||||
exit_code,
|
||||
..
|
||||
} = completed_command
|
||||
else {
|
||||
unreachable!("loop ensures we break on command execution items");
|
||||
};
|
||||
assert_eq!(completed_id, "uexec-1");
|
||||
assert_eq!(completed_status, CommandExecutionStatus::Completed);
|
||||
assert_eq!(exit_code, Some(0));
|
||||
assert_eq!(
|
||||
completed_process_id.as_deref(),
|
||||
Some(started_process_id.as_str())
|
||||
);
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(
|
||||
codex_home: &Path,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-apply-patch"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_apply_patch"
|
||||
|
||||
@@ -30,7 +30,13 @@ pub use standalone_executable::main;
|
||||
pub const APPLY_PATCH_TOOL_INSTRUCTIONS: &str = include_str!("../apply_patch_tool_instructions.md");
|
||||
|
||||
const APPLY_PATCH_COMMANDS: [&str; 2] = ["apply_patch", "applypatch"];
|
||||
const APPLY_PATCH_SHELLS: [&str; 3] = ["bash", "zsh", "sh"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ApplyPatchShell {
|
||||
Unix,
|
||||
PowerShell,
|
||||
Cmd,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, PartialEq)]
|
||||
pub enum ApplyPatchError {
|
||||
@@ -97,11 +103,55 @@ pub struct ApplyPatchArgs {
|
||||
pub workdir: Option<String>,
|
||||
}
|
||||
|
||||
fn shell_supports_apply_patch(shell: &str) -> bool {
|
||||
fn classify_shell_name(shell: &str) -> Option<String> {
|
||||
std::path::Path::new(shell)
|
||||
.file_name()
|
||||
.file_stem()
|
||||
.and_then(|name| name.to_str())
|
||||
.is_some_and(|name| APPLY_PATCH_SHELLS.contains(&name))
|
||||
.map(str::to_ascii_lowercase)
|
||||
}
|
||||
|
||||
fn classify_shell(shell: &str, flag: &str) -> Option<ApplyPatchShell> {
|
||||
classify_shell_name(shell).and_then(|name| match name.as_str() {
|
||||
"bash" | "zsh" | "sh" if flag == "-lc" => Some(ApplyPatchShell::Unix),
|
||||
"pwsh" | "powershell" if flag.eq_ignore_ascii_case("-command") => {
|
||||
Some(ApplyPatchShell::PowerShell)
|
||||
}
|
||||
"cmd" if flag.eq_ignore_ascii_case("/c") => Some(ApplyPatchShell::Cmd),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn can_skip_flag(shell: &str, flag: &str) -> bool {
|
||||
classify_shell_name(shell).is_some_and(|name| {
|
||||
matches!(name.as_str(), "pwsh" | "powershell") && flag.eq_ignore_ascii_case("-noprofile")
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_shell_script(argv: &[String]) -> Option<(ApplyPatchShell, &str)> {
|
||||
match argv {
|
||||
[shell, flag, script] => classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
}),
|
||||
[shell, skip_flag, flag, script] if can_skip_flag(shell, skip_flag) => {
|
||||
classify_shell(shell, flag).map(|shell_type| {
|
||||
let script = script.as_str();
|
||||
(shell_type, script)
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_apply_patch_from_shell(
|
||||
shell: ApplyPatchShell,
|
||||
script: &str,
|
||||
) -> std::result::Result<(String, Option<String>), ExtractHeredocError> {
|
||||
match shell {
|
||||
ApplyPatchShell::Unix | ApplyPatchShell::PowerShell | ApplyPatchShell::Cmd => {
|
||||
extract_apply_patch_from_bash(script)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
@@ -111,9 +161,9 @@ pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
Ok(source) => MaybeApplyPatch::Body(source),
|
||||
Err(e) => MaybeApplyPatch::PatchParseError(e),
|
||||
},
|
||||
// Bash heredoc form: (optional `cd <path> &&`) apply_patch <<'EOF' ...
|
||||
[shell, flag, script] if shell_supports_apply_patch(shell) && flag == "-lc" => {
|
||||
match extract_apply_patch_from_bash(script) {
|
||||
// Shell heredoc form: (optional `cd <path> &&`) apply_patch <<'EOF' ...
|
||||
_ => match parse_shell_script(argv) {
|
||||
Some((shell, script)) => match extract_apply_patch_from_shell(shell, script) {
|
||||
Ok((body, workdir)) => match parse_patch(&body) {
|
||||
Ok(mut source) => {
|
||||
source.workdir = workdir;
|
||||
@@ -125,9 +175,9 @@ pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch {
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
}
|
||||
Err(e) => MaybeApplyPatch::ShellParseError(e),
|
||||
}
|
||||
}
|
||||
_ => MaybeApplyPatch::NotApplyPatch,
|
||||
},
|
||||
None => MaybeApplyPatch::NotApplyPatch,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,24 +272,17 @@ impl ApplyPatchAction {
|
||||
/// cwd must be an absolute path so that we can resolve relative paths in the
|
||||
/// patch.
|
||||
pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified {
|
||||
// Detect a raw patch body passed directly as the command or as the body of a bash -lc
|
||||
// Detect a raw patch body passed directly as the command or as the body of a shell
|
||||
// script. In these cases, report an explicit error rather than applying the patch.
|
||||
match argv {
|
||||
[body] => {
|
||||
if parse_patch(body).is_ok() {
|
||||
return MaybeApplyPatchVerified::CorrectnessError(
|
||||
ApplyPatchError::ImplicitInvocation,
|
||||
);
|
||||
}
|
||||
}
|
||||
[shell, flag, script]
|
||||
if shell_supports_apply_patch(shell)
|
||||
&& flag == "-lc"
|
||||
&& parse_patch(script).is_ok() =>
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
_ => {}
|
||||
if let [body] = argv
|
||||
&& parse_patch(body).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
if let Some((_, script)) = parse_shell_script(argv)
|
||||
&& parse_patch(script).is_ok()
|
||||
{
|
||||
return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation);
|
||||
}
|
||||
|
||||
match maybe_parse_apply_patch(argv) {
|
||||
@@ -871,6 +914,22 @@ mod tests {
|
||||
strs_to_strings(&["bash", "-lc", script])
|
||||
}
|
||||
|
||||
fn args_powershell(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["powershell.exe", "-Command", script])
|
||||
}
|
||||
|
||||
fn args_powershell_no_profile(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["powershell.exe", "-NoProfile", "-Command", script])
|
||||
}
|
||||
|
||||
fn args_pwsh(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["pwsh", "-NoProfile", "-Command", script])
|
||||
}
|
||||
|
||||
fn args_cmd(script: &str) -> Vec<String> {
|
||||
strs_to_strings(&["cmd.exe", "/c", script])
|
||||
}
|
||||
|
||||
fn heredoc_script(prefix: &str) -> String {
|
||||
format!(
|
||||
"{prefix}apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch\nPATCH"
|
||||
@@ -890,8 +949,7 @@ mod tests {
|
||||
}]
|
||||
}
|
||||
|
||||
fn assert_match(script: &str, expected_workdir: Option<&str>) {
|
||||
let args = args_bash(script);
|
||||
fn assert_match_args(args: Vec<String>, expected_workdir: Option<&str>) {
|
||||
match maybe_parse_apply_patch(&args) {
|
||||
MaybeApplyPatch::Body(ApplyPatchArgs { hunks, workdir, .. }) => {
|
||||
assert_eq!(workdir.as_deref(), expected_workdir);
|
||||
@@ -901,6 +959,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_match(script: &str, expected_workdir: Option<&str>) {
|
||||
let args = args_bash(script);
|
||||
assert_match_args(args, expected_workdir);
|
||||
}
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert_matches!(
|
||||
@@ -1014,6 +1077,28 @@ PATCH"#,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_powershell_heredoc() {
|
||||
let script = heredoc_script("");
|
||||
assert_match_args(args_powershell(&script), None);
|
||||
}
|
||||
#[test]
|
||||
fn test_powershell_heredoc_no_profile() {
|
||||
let script = heredoc_script("");
|
||||
assert_match_args(args_powershell_no_profile(&script), None);
|
||||
}
|
||||
#[test]
|
||||
fn test_pwsh_heredoc() {
|
||||
let script = heredoc_script("");
|
||||
assert_match_args(args_pwsh(&script), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cmd_heredoc_with_cd() {
|
||||
let script = heredoc_script("cd foo && ");
|
||||
assert_match_args(args_cmd(&script), Some("foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heredoc_with_leading_cd() {
|
||||
assert_match(&heredoc_script("cd foo && "), Some("foo"));
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-arg0"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_arg0"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
name = "codex-async-utils"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
name = "codex-backend-client"
|
||||
version = "0.0.0"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-chatgpt"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-cli"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "codex"
|
||||
|
||||
206
codex-rs/client.md
Normal file
206
codex-rs/client.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# Client Extraction Plan
|
||||
|
||||
## Goals
|
||||
- Split the HTTP transport/client code out of `codex-core` into a reusable crate that is agnostic of Codex/OpenAI business logic and API schemas.
|
||||
- Create a separate API library crate that houses typed requests/responses for well-known APIs (Responses, Chat Completions, Compact) and plugs into the transport crate via minimal traits.
|
||||
- Preserve current behaviour (auth headers, retries, SSE handling, rate-limit parsing, compaction, fixtures) while making the APIs symmetric and avoiding code duplication.
|
||||
- Keep existing consumers (`codex-core`, tests, and tools) stable by providing a small compatibility layer during the transition.
|
||||
|
||||
## Snapshot of Today
|
||||
- `core/src/client.rs (ModelClient)` owns config/auth/session state, chooses wire API, builds payloads, drives retries, parses SSE, compaction, and rate-limit headers.
|
||||
- `core/src/chat_completions.rs` implements the Chat Completions call + SSE parser + aggregation helper.
|
||||
- `core/src/client_common.rs` holds `Prompt`, tool specs, shared request structs (`ResponsesApiRequest`, `TextControls`), and `ResponseEvent`/`ResponseStream`.
|
||||
- `core/src/default_client.rs` wraps `reqwest` with Codex UA/originator defaults.
|
||||
- `core/src/model_provider_info.rs` models providers (base URL, headers, env keys, retry/timeout tuning) and builds `CodexRequestBuilder`s.
|
||||
- Current retry logic is co-located with API handling; streaming SSE parsing is duplicated across Responses/Chat.
|
||||
|
||||
## Target Crates (with interfaces)
|
||||
|
||||
- `codex-client` (generic transport)
|
||||
- Owns the generic HTTP machinery: a `CodexHttpClient`/`CodexRequestBuilder`-style wrapper, retry/backoff hooks, streaming connector (SSE framing + idle timeout), header injection, and optional telemetry callbacks.
|
||||
- Does **not** know about OpenAI/Codex-specific paths, headers, or error codes; it only exposes HTTP-level concepts (status, headers, bodies, connection errors).
|
||||
- Minimal surface:
|
||||
```rust
|
||||
pub trait HttpTransport {
|
||||
fn execute(&self, req: Request) -> Result<Response, TransportError>;
|
||||
fn stream(&self, req: Request) -> Result<ByteStream, TransportError>;
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub method: Method,
|
||||
pub url: String,
|
||||
pub headers: HeaderMap,
|
||||
pub body: Option<serde_json::Value>,
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
```
|
||||
- Generic client traits (request/response/chunk are abstract over the transport):
|
||||
```rust
|
||||
#[async_trait::async_trait]
|
||||
pub trait UnaryClient<Req, Resp> {
|
||||
async fn run(&self, req: Req) -> Result<Resp, TransportError>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait StreamClient<Req, Chunk> {
|
||||
async fn run(&self, req: Req) -> Result<ResponseStream<Chunk>, TransportError>;
|
||||
}
|
||||
|
||||
pub struct RetryPolicy {
|
||||
pub max_attempts: u64,
|
||||
pub base_delay: Duration,
|
||||
pub retry_on: RetryOn, // e.g., transport errors + 429/5xx
|
||||
}
|
||||
```
|
||||
- `RetryOn` lives in `codex-client` and captures HTTP status classes and transport failures that qualify for retry.
|
||||
- Implementations in `codex-api` plug in their own request types, parsers, and retry policies while reusing the transport’s backoff and error types.
|
||||
- Planned runtime helper:
|
||||
```rust
|
||||
pub async fn run_with_retry<T, F, Fut>(
|
||||
policy: RetryPolicy,
|
||||
make_req: impl Fn() -> Request,
|
||||
op: F,
|
||||
) -> Result<T, TransportError>
|
||||
where
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<T, TransportError>>,
|
||||
{
|
||||
for attempt in 0..=policy.max_attempts {
|
||||
let req = make_req();
|
||||
match op(req).await {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(err) if policy.retry_on.should_retry(&err, attempt) => {
|
||||
tokio::time::sleep(backoff(policy.base_delay, attempt + 1)).await;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
Err(TransportError::RetryLimit)
|
||||
}
|
||||
```
|
||||
- Unary clients wrap `transport.execute` with this helper and then deserialize.
|
||||
- Stream clients wrap the **initial** `transport.stream` call with this helper. Mid-stream disconnects are surfaced as `StreamError`s; automatic resume/reconnect can be added later on top of this primitive if we introduce cursor support.
|
||||
- Common helpers: `retry::backoff(attempt)`, `errors::{TransportError, StreamError}`.
|
||||
- Streaming utility (SSE framing only):
|
||||
```rust
|
||||
pub fn sse_stream<S>(
|
||||
bytes: S,
|
||||
idle_timeout: Duration,
|
||||
tx: mpsc::Sender<Result<String, StreamError>>,
|
||||
telemetry: Option<Box<dyn Telemetry>>,
|
||||
)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, TransportError>> + Unpin + Send + 'static;
|
||||
```
|
||||
- `sse_stream` is responsible for timeouts, connection-level errors, and emitting raw `data:` chunks as UTF-8 strings; parsing those strings into structured events is done in `codex-api`.
|
||||
|
||||
- `codex-api` (OpenAI/Codex API library)
|
||||
- Owns typed models for Responses/Chat/Compact plus shared helpers (`Prompt`, tool specs, text controls, `ResponsesApiRequest`, etc.).
|
||||
- Knows about OpenAI/Codex semantics:
|
||||
- URL shapes (`/v1/responses`, `/v1/chat/completions`, `/responses/compact`).
|
||||
- Provider configuration (`WireApi`, base URLs, query params, per-provider retry knobs).
|
||||
- Rate-limit headers (`x-codex-*`) and their mapping into `RateLimitSnapshot` / `CreditsSnapshot`.
|
||||
- Error body formats (`{ error: { type, code, message, plan_type, resets_at } }`) and how they become API errors (context window exceeded, quota/usage limit, etc.).
|
||||
- SSE event names (`response.output_item.done`, `response.completed`, `response.failed`, etc.) and their mapping into high-level events.
|
||||
- Provides a provider abstraction (conceptually similar to `ModelProviderInfo`):
|
||||
```rust
|
||||
pub struct Provider {
|
||||
pub name: String,
|
||||
pub base_url: String,
|
||||
pub wire: WireApi, // Responses | Chat
|
||||
pub headers: HeaderMap,
|
||||
pub retry: RetryConfig,
|
||||
pub stream_idle_timeout: Duration,
|
||||
}
|
||||
|
||||
pub trait AuthProvider {
|
||||
/// Returns a bearer token to use for this request (if any).
|
||||
/// Implementations are expected to be cheap and to surface already-refreshed tokens;
|
||||
/// higher layers (`codex-core`) remain responsible for token refresh flows.
|
||||
fn bearer_token(&self) -> Option<String>;
|
||||
|
||||
/// Optional ChatGPT account id header for Chat mode.
|
||||
fn account_id(&self) -> Option<String>;
|
||||
}
|
||||
```
|
||||
- Ready-made clients built on `HttpTransport`:
|
||||
```rust
|
||||
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> { /* ... */ }
|
||||
impl<T, A> ResponsesClient<T, A> {
|
||||
pub async fn stream(&self, prompt: &Prompt) -> ApiResult<ResponseStream<ApiEvent>>;
|
||||
pub async fn compact(&self, prompt: &Prompt) -> ApiResult<Vec<ResponseItem>>;
|
||||
}
|
||||
|
||||
pub struct ChatClient<T: HttpTransport, A: AuthProvider> { /* ... */ }
|
||||
impl<T, A> ChatClient<T, A> {
|
||||
pub async fn stream(&self, prompt: &Prompt) -> ApiResult<ResponseStream<ApiEvent>>;
|
||||
}
|
||||
|
||||
pub struct CompactClient<T: HttpTransport, A: AuthProvider> { /* ... */ }
|
||||
impl<T, A> CompactClient<T, A> {
|
||||
pub async fn compact(&self, prompt: &Prompt) -> ApiResult<Vec<ResponseItem>>;
|
||||
}
|
||||
```
|
||||
- Streaming events unified across wire APIs (this can closely mirror `ResponseEvent` today, and we may type-alias one to the other during migration):
|
||||
```rust
|
||||
pub enum ApiEvent {
|
||||
Created,
|
||||
OutputItemAdded(ResponseItem),
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputTextDelta(String),
|
||||
ReasoningContentDelta { delta: String, content_index: i64 },
|
||||
ReasoningSummaryDelta { delta: String, summary_index: i64 },
|
||||
RateLimits(RateLimitSnapshot),
|
||||
Completed { response_id: String, token_usage: Option<TokenUsage> },
|
||||
}
|
||||
```
|
||||
- Error layering:
|
||||
- `codex-client`: defines `TransportError` / `StreamError` (status codes, IO, timeouts).
|
||||
- `codex-api`: defines `ApiError` that wraps `TransportError` plus API-specific errors parsed from bodies and headers.
|
||||
- `codex-core`: maps `ApiError` into existing `CodexErr` variants so downstream callers remain unchanged.
|
||||
- Aggregation strategies (today’s `AggregateStreamExt`) live here as adapters (`Aggregated`, `Streaming`) that transform `ResponseStream<ApiEvent>` into the higher-level views used by `codex-core`.
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
1. **Create crates**: add `codex-client` and `codex-api` (names keep the `codex-` prefix). Stub lib files with feature flags/tests wired into the workspace; wire them into `Cargo.toml`.
|
||||
2. **Extract API-level SSE + rate limits into `codex-api`**:
|
||||
- Move the Responses SSE parser (`process_sse`), rate-limit parsing, and related tests from `core/src/client.rs` into `codex-api`, keeping the behavior identical.
|
||||
- Introduce `ApiEvent` (initially equivalent to `ResponseEvent`) and `ApiError`, and adjust the parser to emit those.
|
||||
- Provide test-only helpers for fixture streams (replacement for `CODEX_RS_SSE_FIXTURE`) in `codex-api`.
|
||||
3. **Lift transport layer into `codex-client`**:
|
||||
- Move `CodexHttpClient`/`CodexRequestBuilder`, UA/originator plumbing, and backoff helpers from `core/src/default_client.rs` into `codex-client` (or a thin wrapper on top of it).
|
||||
- Introduce `HttpTransport`, `Request`, `RetryPolicy`, `RetryOn`, and `run_with_retry` as described above.
|
||||
- Keep sandbox/no-proxy toggles behind injected configuration so `codex-client` stays generic and does not depend on Codex-specific env vars.
|
||||
4. **Model provider abstraction in `codex-api`**:
|
||||
- Relocate `ModelProviderInfo` (base URL, env/header resolution, retry knobs, wire API enum) into `codex-api`, expressed in terms of `Provider` and `AuthProvider`.
|
||||
- Ensure provider logic handles:
|
||||
- URL building for Responses/Chat/Compact (including Azure special cases).
|
||||
- Static and env-based headers.
|
||||
- Per-provider retry and idle-timeout settings that map cleanly into `RetryPolicy`/`RetryOn`.
|
||||
5. **API crate wiring**:
|
||||
- Move `Prompt`, tool specs, `ResponsesApiRequest`, `TextControls`, and `ResponseEvent/ResponseStream` into `codex-api` under modules (`common`, `responses`, `chat`, `compact`), keeping public types stable or re-exported through `codex-core` as needed.
|
||||
- Rebuild Responses and Chat clients on top of `HttpTransport` + `StreamClient`, reusing shared retry + SSE helpers; keep aggregation adapters as reusable strategies instead of `ModelClient`-local logic.
|
||||
- Implement Compact on top of `UnaryClient` and the unary `execute` path with JSON deserialization, sharing the same retry policy.
|
||||
- Keep request builders symmetric: each client prepares a `Request<serde_json::Value>`, attaches headers/auth via `AuthProvider`, and plugs in its parser (streaming clients) or deserializer (unary) while sharing retry/backoff configuration derived from `Provider`.
|
||||
6. **Core integration layer**:
|
||||
- Replace `core::ModelClient` internals with thin adapters that construct `codex-api` clients using `Config`, `AuthManager`, and `OtelEventManager`.
|
||||
- Keep the public `ModelClient` API and `ResponseEvent`/`ResponseStream` types stable by re-exporting `codex-api` types or providing type aliases.
|
||||
- Preserve existing auth flows (including ChatGPT token refresh) inside `codex-core` or a thin adapter, using `AuthProvider` to surface bearer tokens to `codex-api` and handling 401/refresh semantics at this layer.
|
||||
7. **Tests/migration**:
|
||||
- Move unit tests for SSE parsing, retry/backoff decisions, and provider/header behavior into the new crates; keep integration tests in `core` using the compatibility layer.
|
||||
- Update fixtures to be consumed via test-only adapters in `codex-api`.
|
||||
- Run targeted `just fmt`, `just fix -p` for the touched crates, and scoped `cargo test -p codex-client`, `-p codex-api`, and existing `codex-core` suites.
|
||||
|
||||
## Design Decisions
|
||||
|
||||
- **UA construction**
|
||||
- `codex-client` exposes an optional UA suffix/provider hook (tiny feature) and remains unaware of the CLI; `codex-core` / the CLI compute the full UA (including `terminal::user_agent()`) and pass the suffix or builder down.
|
||||
- **Config vs provider**
|
||||
- Most configuration stays in `codex-core`. `codex-api::Provider` only contains what is strictly required for HTTP (base URLs, query params, retry/timeout knobs, wire API), while higher-level knobs (reasoning defaults, verbosity flags, etc.) remain core concerns.
|
||||
- **Auth flow ownership**
|
||||
- Auth flows (including ChatGPT token refresh) remain in `codex-core`. `AuthProvider` simply exposes already-fresh tokens/account IDs; 401 handling and refresh retries stay in the existing auth layer.
|
||||
- **Error enums**
|
||||
- `codex-client` continues to define `TransportError` / `StreamError`. `codex-api` defines an `ApiError` (deriving `thiserror::Error`) that wraps `TransportError` and API-specific failures, and `codex-core` maps `ApiError` into existing `CodexErr` variants for callers.
|
||||
- **Streaming reconnection semantics**
|
||||
- For now, mid-stream SSE failures are surfaced as errors and only the initial connection is retried via `run_with_retry`. We will revisit mid-stream reconnect/resume once the underlying APIs support cursor/idempotent event semantics.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
name = "codex-cloud-tasks-client"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_cloud_tasks_client"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-cloud-tasks"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_cloud_tasks"
|
||||
|
||||
30
codex-rs/codex-api/Cargo.toml
Normal file
30
codex-rs/codex-api/Cargo.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "codex-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
codex-client = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
http = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] }
|
||||
tracing = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
32
codex-rs/codex-api/README.md
Normal file
32
codex-rs/codex-api/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# codex-api
|
||||
|
||||
Typed clients for Codex/OpenAI APIs built on top of the generic transport in `codex-client`.
|
||||
|
||||
- Hosts the request/response models and prompt helpers for Responses, Chat Completions, and Compact APIs.
|
||||
- Owns provider configuration (base URLs, headers, query params), auth header injection, retry tuning, and stream idle settings.
|
||||
- Parses SSE streams into `ResponseEvent`/`ResponseStream`, including rate-limit snapshots and API-specific error mapping.
|
||||
- Serves as the wire-level layer consumed by `codex-core`; higher layers handle auth refresh and business logic.
|
||||
|
||||
## Core interface
|
||||
|
||||
The public interface of this crate is intentionally small and uniform:
|
||||
|
||||
- **Prompted endpoints (Chat + Responses)**
|
||||
- Input: a single `Prompt` plus endpoint-specific options.
|
||||
- `Prompt` (re-exported as `codex_api::Prompt`) carries:
|
||||
- `instructions: String` – the fully-resolved system prompt for this turn.
|
||||
- `input: Vec<ResponseItem>` – conversation history and user/tool messages.
|
||||
- `tools: Vec<serde_json::Value>` – JSON tools compatible with the target API.
|
||||
- `parallel_tool_calls: bool`.
|
||||
- `output_schema: Option<Value>` – used to build `text.format` when present.
|
||||
- Output: a `ResponseStream` of `ResponseEvent` (both re-exported from `common`).
|
||||
|
||||
- **Compaction endpoint**
|
||||
- Input: `CompactionInput<'a>` (re-exported as `codex_api::CompactionInput`):
|
||||
- `model: &str`.
|
||||
- `input: &[ResponseItem]` – history to compact.
|
||||
- `instructions: &str` – fully-resolved compaction instructions.
|
||||
- Output: `Vec<ResponseItem>`.
|
||||
- `CompactClient::compact_input(&CompactionInput, extra_headers)` wraps the JSON encoding and retry/telemetry wiring.
|
||||
|
||||
All HTTP details (URLs, headers, retry/backoff policies, SSE framing) are encapsulated in `codex-api` and `codex-client`. Callers construct prompts/inputs using protocol types and work with typed streams of `ResponseEvent` or compacted `ResponseItem` values.
|
||||
27
codex-rs/codex-api/src/auth.rs
Normal file
27
codex-rs/codex-api/src/auth.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use codex_client::Request;
|
||||
|
||||
/// Provides bearer and account identity information for API requests.
|
||||
///
|
||||
/// Implementations should be cheap and non-blocking; any asynchronous
|
||||
/// refresh or I/O should be handled by higher layers before requests
|
||||
/// reach this interface.
|
||||
pub trait AuthProvider: Send + Sync {
|
||||
fn bearer_token(&self) -> Option<String>;
|
||||
fn account_id(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_auth_headers<A: AuthProvider>(auth: &A, mut req: Request) -> Request {
|
||||
if let Some(token) = auth.bearer_token()
|
||||
&& let Ok(header) = format!("Bearer {token}").parse()
|
||||
{
|
||||
let _ = req.headers.insert(http::header::AUTHORIZATION, header);
|
||||
}
|
||||
if let Some(account_id) = auth.account_id()
|
||||
&& let Ok(header) = account_id.parse()
|
||||
{
|
||||
let _ = req.headers.insert("ChatGPT-Account-ID", header);
|
||||
}
|
||||
req
|
||||
}
|
||||
167
codex-rs/codex-api/src/common.rs
Normal file
167
codex-rs/codex-api/src/common.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use crate::error::ApiError;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Canonical prompt input for Chat and Responses endpoints.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Fully-resolved system instructions for this turn.
|
||||
pub instructions: String,
|
||||
/// Conversation history and user/tool messages.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// JSON-encoded tool definitions compatible with the target API.
|
||||
// TODO(jif) have a proper type here
|
||||
pub tools: Vec<Value>,
|
||||
/// Whether parallel tool calls are permitted.
|
||||
pub parallel_tool_calls: bool,
|
||||
/// Optional output schema used to build the `text.format` controls.
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Canonical input payload for the compaction endpoint.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CompactionInput<'a> {
|
||||
pub model: &'a str,
|
||||
pub input: &'a [ResponseItem],
|
||||
pub instructions: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta {
|
||||
delta: String,
|
||||
summary_index: i64,
|
||||
},
|
||||
ReasoningContentDelta {
|
||||
delta: String,
|
||||
content_index: i64,
|
||||
},
|
||||
ReasoningSummaryPartAdded {
|
||||
summary_index: i64,
|
||||
},
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextFormat {
|
||||
/// Format type used by the OpenAI text controls.
|
||||
pub r#type: TextFormatType,
|
||||
/// When true, the server is expected to strictly validate responses.
|
||||
pub strict: bool,
|
||||
/// JSON schema for the desired output.
|
||||
pub schema: Value,
|
||||
/// Friendly name for the format, used in telemetry/debugging.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Controls the `text` field for the Responses API, combining verbosity and
|
||||
/// optional JSON schema output formatting.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResponsesApiRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub instructions: &'a str,
|
||||
pub input: &'a [ResponseItem],
|
||||
pub tools: &'a [serde_json::Value],
|
||||
pub tool_choice: &'static str,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub store: bool,
|
||||
pub stream: bool,
|
||||
pub include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
) -> Option<TextControls> {
|
||||
if verbosity.is_none() && output_schema.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
schema: schema.clone(),
|
||||
name: "codex_output_schema".to_string(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub rx_event: mpsc::Receiver<Result<ResponseEvent, ApiError>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent, ApiError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
266
codex-rs/codex-api/src/endpoint/chat.rs
Normal file
266
codex-rs/codex-api/src/endpoint/chat.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
use crate::ChatRequest;
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use futures::Stream;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
pub struct ChatClient<T: HttpTransport, A: AuthProvider> {
|
||||
streaming: StreamingClient<T, A>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
streaming: StreamingClient::new(transport, provider, auth),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(
|
||||
self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
streaming: self.streaming.with_telemetry(request, sse),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
use crate::requests::ChatRequestBuilder;
|
||||
|
||||
let request =
|
||||
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
|
||||
.conversation_id(conversation_id)
|
||||
.session_source(session_source)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Chat => "chat/completions",
|
||||
_ => "responses",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_chat_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
/// Stream adapter that merges token deltas into a single assistant message per turn.
|
||||
pub struct AggregatedStream {
|
||||
inner: ResponseStream,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl Stream for AggregatedStream {
|
||||
type Item = Result<ResponseEvent, ApiError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty() {
|
||||
let aggregated_reasoning = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
}))) => {
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
})));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AggregateStreamExt {
|
||||
fn aggregate(self) -> AggregatedStream;
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream;
|
||||
}
|
||||
|
||||
impl AggregateStreamExt for ResponseStream {
|
||||
fn aggregate(self) -> AggregatedStream {
|
||||
AggregatedStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregatedStream {
|
||||
fn new(inner: ResponseStream, mode: AggregateMode) -> Self {
|
||||
AggregatedStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
162
codex-rs/codex-api/src/endpoint/compact.rs
Normal file
162
codex-rs/codex-api/src/endpoint/compact.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::auth::add_auth_headers;
|
||||
use crate::common::CompactionInput;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use serde::Deserialize;
|
||||
use serde_json::to_value;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
provider,
|
||||
auth,
|
||||
request_telemetry: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
|
||||
self.request_telemetry = request;
|
||||
self
|
||||
}
|
||||
|
||||
fn path(&self) -> Result<&'static str, ApiError> {
|
||||
match self.provider.wire {
|
||||
WireApi::Compact | WireApi::Responses => Ok("responses/compact"),
|
||||
WireApi::Chat => Err(ApiError::Stream(
|
||||
"compact endpoint requires responses wire api".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn compact(
|
||||
&self,
|
||||
body: serde_json::Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Vec<ResponseItem>, ApiError> {
|
||||
let path = self.path()?;
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
let resp = run_with_request_telemetry(
|
||||
self.provider.retry.to_policy(),
|
||||
self.request_telemetry.clone(),
|
||||
builder,
|
||||
|req| self.transport.execute(req),
|
||||
)
|
||||
.await?;
|
||||
let parsed: CompactHistoryResponse =
|
||||
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
|
||||
Ok(parsed.output)
|
||||
}
|
||||
|
||||
pub async fn compact_input(
|
||||
&self,
|
||||
input: &CompactionInput<'_>,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Vec<ResponseItem>, ApiError> {
|
||||
let body = to_value(input)
|
||||
.map_err(|e| ApiError::Stream(format!("failed to encode compaction input: {e}")))?;
|
||||
self.compact(body, extra_headers).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CompactHistoryResponse {
|
||||
output: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use async_trait::async_trait;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use http::HeaderMap;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyTransport;
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for DummyTransport {
|
||||
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
||||
Err(TransportError::Build("execute should not run".to_string()))
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
Err(TransportError::Build("stream should not run".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyAuth;
|
||||
|
||||
impl AuthProvider for DummyAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(wire: WireApi) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://example.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_wire_is_chat() {
|
||||
let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth);
|
||||
let input = CompactionInput {
|
||||
model: "gpt-test",
|
||||
input: &[],
|
||||
instructions: "inst",
|
||||
};
|
||||
let err = client
|
||||
.compact_input(&input, HeaderMap::new())
|
||||
.await
|
||||
.expect_err("expected wire mismatch to fail");
|
||||
|
||||
match err {
|
||||
ApiError::Stream(msg) => {
|
||||
assert_eq!(msg, "compact endpoint requires responses wire api");
|
||||
}
|
||||
other => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
4
codex-rs/codex-api/src/endpoint/mod.rs
Normal file
4
codex-rs/codex-api/src/endpoint/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod compact;
|
||||
pub mod responses;
|
||||
mod streaming;
|
||||
107
codex-rs/codex-api/src/endpoint/responses.rs
Normal file
107
codex-rs/codex-api/src/endpoint/responses.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::Reasoning;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::common::TextControls;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use crate::sse::spawn_response_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
|
||||
streaming: StreamingClient<T, A>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ResponsesOptions {
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub include: Vec<String>,
|
||||
pub prompt_cache_key: Option<String>,
|
||||
pub text: Option<TextControls>,
|
||||
pub store_override: Option<bool>,
|
||||
pub conversation_id: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
streaming: StreamingClient::new(transport, provider, auth),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(
|
||||
self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
streaming: self.streaming.with_telemetry(request, sse),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
.tools(&prompt.tools)
|
||||
.parallel_tool_calls(prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Responses | WireApi::Compact => "responses",
|
||||
WireApi::Chat => "chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_response_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
82
codex-rs/codex-api/src/endpoint/streaming.rs
Normal file
82
codex-rs/codex-api/src/endpoint/streaming.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::auth::add_auth_headers;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
provider,
|
||||
auth,
|
||||
request_telemetry: None,
|
||||
sse_telemetry: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_telemetry(
|
||||
mut self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
self.request_telemetry = request;
|
||||
self.sse_telemetry = sse;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn provider(&self) -> &Provider {
|
||||
&self.provider
|
||||
}
|
||||
|
||||
pub(crate) async fn stream(
|
||||
&self,
|
||||
path: &str,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.headers.insert(
|
||||
http::header::ACCEPT,
|
||||
http::HeaderValue::from_static("text/event-stream"),
|
||||
);
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
let stream_response = run_with_request_telemetry(
|
||||
self.provider.retry.to_policy(),
|
||||
self.request_telemetry.clone(),
|
||||
builder,
|
||||
|req| self.transport.stream(req),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(spawner(
|
||||
stream_response,
|
||||
self.provider.stream_idle_timeout,
|
||||
self.sse_telemetry.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
34
codex-rs/codex-api/src/error.rs
Normal file
34
codex-rs/codex-api/src/error.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use crate::rate_limits::RateLimitError;
|
||||
use codex_client::TransportError;
|
||||
use http::StatusCode;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error(transparent)]
|
||||
Transport(#[from] TransportError),
|
||||
#[error("api error {status}: {message}")]
|
||||
Api { status: StatusCode, message: String },
|
||||
#[error("stream error: {0}")]
|
||||
Stream(String),
|
||||
#[error("context window exceeded")]
|
||||
ContextWindowExceeded,
|
||||
#[error("quota exceeded")]
|
||||
QuotaExceeded,
|
||||
#[error("usage not included")]
|
||||
UsageNotIncluded,
|
||||
#[error("retryable error: {message}")]
|
||||
Retryable {
|
||||
message: String,
|
||||
delay: Option<Duration>,
|
||||
},
|
||||
#[error("rate limit: {0}")]
|
||||
RateLimit(String),
|
||||
}
|
||||
|
||||
impl From<RateLimitError> for ApiError {
|
||||
fn from(err: RateLimitError) -> Self {
|
||||
Self::RateLimit(err.to_string())
|
||||
}
|
||||
}
|
||||
35
codex-rs/codex-api/src/lib.rs
Normal file
35
codex-rs/codex-api/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
pub mod auth;
|
||||
pub mod common;
|
||||
pub mod endpoint;
|
||||
pub mod error;
|
||||
pub mod provider;
|
||||
pub mod rate_limits;
|
||||
pub mod requests;
|
||||
pub mod sse;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use codex_client::RequestTelemetry;
|
||||
pub use codex_client::ReqwestTransport;
|
||||
pub use codex_client::TransportError;
|
||||
|
||||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::common::CompactionInput;
|
||||
pub use crate::common::Prompt;
|
||||
pub use crate::common::ResponseEvent;
|
||||
pub use crate::common::ResponseStream;
|
||||
pub use crate::common::ResponsesApiRequest;
|
||||
pub use crate::common::create_text_param_for_request;
|
||||
pub use crate::endpoint::chat::AggregateStreamExt;
|
||||
pub use crate::endpoint::chat::ChatClient;
|
||||
pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
pub use crate::requests::ChatRequest;
|
||||
pub use crate::requests::ChatRequestBuilder;
|
||||
pub use crate::requests::ResponsesRequest;
|
||||
pub use crate::requests::ResponsesRequestBuilder;
|
||||
pub use crate::sse::stream_from_fixture;
|
||||
pub use crate::telemetry::SseTelemetry;
|
||||
118
codex-rs/codex-api/src/provider.rs
Normal file
118
codex-rs/codex-api/src/provider.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use codex_client::Request;
|
||||
use codex_client::RetryOn;
|
||||
use codex_client::RetryPolicy;
|
||||
use http::Method;
|
||||
use http::header::HeaderMap;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Wire-level APIs supported by a `Provider`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum WireApi {
|
||||
Responses,
|
||||
Chat,
|
||||
Compact,
|
||||
}
|
||||
|
||||
/// High-level retry configuration for a provider.
|
||||
///
|
||||
/// This is converted into a `RetryPolicy` used by `codex-client` to drive
|
||||
/// transport-level retries for both unary and streaming calls.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
pub max_attempts: u64,
|
||||
pub base_delay: Duration,
|
||||
pub retry_429: bool,
|
||||
pub retry_5xx: bool,
|
||||
pub retry_transport: bool,
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
pub fn to_policy(&self) -> RetryPolicy {
|
||||
RetryPolicy {
|
||||
max_attempts: self.max_attempts,
|
||||
base_delay: self.base_delay,
|
||||
retry_on: RetryOn {
|
||||
retry_429: self.retry_429,
|
||||
retry_5xx: self.retry_5xx,
|
||||
retry_transport: self.retry_transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP endpoint configuration used to talk to a concrete API deployment.
|
||||
///
|
||||
/// Encapsulates base URL, default headers, query params, retry policy, and
|
||||
/// stream idle timeout, plus helper methods for building requests.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Provider {
|
||||
pub name: String,
|
||||
pub base_url: String,
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
pub wire: WireApi,
|
||||
pub headers: HeaderMap,
|
||||
pub retry: RetryConfig,
|
||||
pub stream_idle_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn url_for_path(&self, path: &str) -> String {
|
||||
let base = self.base_url.trim_end_matches('/');
|
||||
let path = path.trim_start_matches('/');
|
||||
let mut url = if path.is_empty() {
|
||||
base.to_string()
|
||||
} else {
|
||||
format!("{base}/{path}")
|
||||
};
|
||||
|
||||
if let Some(params) = &self.query_params
|
||||
&& !params.is_empty()
|
||||
{
|
||||
let qs = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
url.push('?');
|
||||
url.push_str(&qs);
|
||||
}
|
||||
|
||||
url
|
||||
}
|
||||
|
||||
pub fn build_request(&self, method: Method, path: &str) -> Request {
|
||||
Request {
|
||||
method,
|
||||
url: self.url_for_path(path),
|
||||
headers: self.headers.clone(),
|
||||
body: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.base_url.to_ascii_lowercase().contains("openai.azure.")
|
||||
|| matches_azure_responses_base_url(&self.base_url)
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
const AZURE_MARKERS: [&str; 5] = [
|
||||
"cognitiveservices.azure.",
|
||||
"aoai.azure.",
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
"windows.net/openai",
|
||||
];
|
||||
let base = base_url.to_ascii_lowercase();
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
}
|
||||
105
codex-rs/codex-api/src/rate_limits.rs
Normal file
105
codex-rs/codex-api/src/rate_limits.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use codex_protocol::protocol::CreditsSnapshot;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use http::HeaderMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitError {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl Display for RateLimitError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`.
|
||||
pub fn parse_rate_limit(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let primary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-primary-reset-at",
|
||||
);
|
||||
|
||||
let secondary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-secondary-reset-at",
|
||||
);
|
||||
|
||||
let credits = parse_credits_snapshot(headers);
|
||||
|
||||
Some(RateLimitSnapshot {
|
||||
primary,
|
||||
secondary,
|
||||
credits,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_rate_limit_window(
|
||||
headers: &HeaderMap,
|
||||
used_percent_header: &str,
|
||||
window_minutes_header: &str,
|
||||
resets_at_header: &str,
|
||||
) -> Option<RateLimitWindow> {
|
||||
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
|
||||
|
||||
used_percent.and_then(|used_percent| {
|
||||
let window_minutes = parse_header_i64(headers, window_minutes_header);
|
||||
let resets_at = parse_header_i64(headers, resets_at_header);
|
||||
|
||||
let has_data = used_percent != 0.0
|
||||
|| window_minutes.is_some_and(|minutes| minutes != 0)
|
||||
|| resets_at.is_some();
|
||||
|
||||
has_data.then_some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_credits_snapshot(headers: &HeaderMap) -> Option<CreditsSnapshot> {
|
||||
let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?;
|
||||
let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?;
|
||||
let balance = parse_header_str(headers, "x-codex-credits-balance")
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(std::string::ToString::to_string);
|
||||
Some(CreditsSnapshot {
|
||||
has_credits,
|
||||
unlimited,
|
||||
balance,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
parse_header_str(headers, name)?
|
||||
.parse::<f64>()
|
||||
.ok()
|
||||
.filter(|v| v.is_finite())
|
||||
}
|
||||
|
||||
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
parse_header_str(headers, name)?.parse::<i64>().ok()
|
||||
}
|
||||
|
||||
fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option<bool> {
|
||||
let raw = parse_header_str(headers, name)?;
|
||||
if raw.eq_ignore_ascii_case("true") || raw == "1" {
|
||||
Some(true)
|
||||
} else if raw.eq_ignore_ascii_case("false") || raw == "0" {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
|
||||
headers.get(name)?.to_str().ok()
|
||||
}
|
||||
388
codex-rs/codex-api/src/requests/chat.rs
Normal file
388
codex-rs/codex-api/src/requests/chat.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Assembled request body plus headers for Chat Completions streaming calls.
|
||||
pub struct ChatRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
pub struct ChatRequestBuilder<'a> {
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl<'a> ChatRequestBuilder<'a> {
|
||||
pub fn new(
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
) -> Self {
|
||||
Self {
|
||||
model,
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
conversation_id: None,
|
||||
session_source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conversation_id(mut self, id: Option<String>) -> Self {
|
||||
self.conversation_id = id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
|
||||
self.session_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, _provider: &Provider) -> Result<ChatRequest, ApiError> {
|
||||
let mut messages = Vec::<Value>::new();
|
||||
messages.push(json!({"role": "system", "content": self.instructions}));
|
||||
|
||||
let input = self.input;
|
||||
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
ResponseItem::CompactionSummary { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => {
|
||||
text.push_str(segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(
|
||||
json!({"type":"image_url","image_url": {"url": image_url}}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_default(),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::CompactionSummary { .. } => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let payload = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": self.tools,
|
||||
});
|
||||
|
||||
let mut headers = build_conversation_headers(self.conversation_id);
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ChatRequest {
|
||||
body: payload,
|
||||
headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
|
||||
fn provider() -> Provider {
|
||||
Provider {
|
||||
name: "openai".to_string(),
|
||||
base_url: "https://api.openai.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Chat,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(10),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attaches_conversation_and_subagent_headers() {
|
||||
let prompt_input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
}];
|
||||
let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[])
|
||||
.conversation_id(Some("conv-1".into()))
|
||||
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
|
||||
.build(&provider())
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(
|
||||
req.headers.get("conversation_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers.get("session_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers.get("x-openai-subagent"),
|
||||
Some(&HeaderValue::from_static("review"))
|
||||
);
|
||||
}
|
||||
}
|
||||
36
codex-rs/codex-api/src/requests/headers.rs
Normal file
36
codex-rs/codex-api/src/requests/headers.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
|
||||
pub(crate) fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(id) = conversation_id {
|
||||
insert_header(&mut headers, "conversation_id", &id);
|
||||
insert_header(&mut headers, "session_id", &id);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
pub(crate) fn subagent_header(source: &Option<SessionSource>) -> Option<String> {
|
||||
let SessionSource::SubAgent(sub) = source.as_ref()? else {
|
||||
return None;
|
||||
};
|
||||
match sub {
|
||||
codex_protocol::protocol::SubAgentSource::Other(label) => Some(label.clone()),
|
||||
other => Some(
|
||||
serde_json::to_value(other)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_header(headers: &mut HeaderMap, name: &str, value: &str) {
|
||||
if let (Ok(header_name), Ok(header_value)) = (
|
||||
name.parse::<http::HeaderName>(),
|
||||
HeaderValue::from_str(value),
|
||||
) {
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
8
codex-rs/codex-api/src/requests/mod.rs
Normal file
8
codex-rs/codex-api/src/requests/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod chat;
|
||||
pub(crate) mod headers;
|
||||
pub mod responses;
|
||||
|
||||
pub use chat::ChatRequest;
|
||||
pub use chat::ChatRequestBuilder;
|
||||
pub use responses::ResponsesRequest;
|
||||
pub use responses::ResponsesRequestBuilder;
|
||||
247
codex-rs/codex-api/src/requests/responses.rs
Normal file
247
codex-rs/codex-api/src/requests/responses.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use crate::common::Reasoning;
|
||||
use crate::common::ResponsesApiRequest;
|
||||
use crate::common::TextControls;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
|
||||
/// Assembled request body plus headers for a Responses stream request.
|
||||
pub struct ResponsesRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ResponsesRequestBuilder<'a> {
|
||||
model: Option<&'a str>,
|
||||
instructions: Option<&'a str>,
|
||||
input: Option<&'a [ResponseItem]>,
|
||||
tools: Option<&'a [Value]>,
|
||||
parallel_tool_calls: bool,
|
||||
reasoning: Option<Reasoning>,
|
||||
include: Vec<String>,
|
||||
prompt_cache_key: Option<String>,
|
||||
text: Option<TextControls>,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
store_override: Option<bool>,
|
||||
headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl<'a> ResponsesRequestBuilder<'a> {
|
||||
pub fn new(model: &'a str, instructions: &'a str, input: &'a [ResponseItem]) -> Self {
|
||||
Self {
|
||||
model: Some(model),
|
||||
instructions: Some(instructions),
|
||||
input: Some(input),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: &'a [Value]) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn parallel_tool_calls(mut self, enabled: bool) -> Self {
|
||||
self.parallel_tool_calls = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn reasoning(mut self, reasoning: Option<Reasoning>) -> Self {
|
||||
self.reasoning = reasoning;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn include(mut self, include: Vec<String>) -> Self {
|
||||
self.include = include;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn prompt_cache_key(mut self, key: Option<String>) -> Self {
|
||||
self.prompt_cache_key = key;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn text(mut self, text: Option<TextControls>) -> Self {
|
||||
self.text = text;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn conversation(mut self, conversation_id: Option<String>) -> Self {
|
||||
self.conversation_id = conversation_id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
|
||||
self.session_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn store_override(mut self, store: Option<bool>) -> Self {
|
||||
self.store_override = store;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn extra_headers(mut self, headers: HeaderMap) -> Self {
|
||||
self.headers = headers;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
|
||||
let model = self
|
||||
.model
|
||||
.ok_or_else(|| ApiError::Stream("missing model for responses request".into()))?;
|
||||
let instructions = self
|
||||
.instructions
|
||||
.ok_or_else(|| ApiError::Stream("missing instructions for responses request".into()))?;
|
||||
let input = self
|
||||
.input
|
||||
.ok_or_else(|| ApiError::Stream("missing input for responses request".into()))?;
|
||||
let tools = self.tools.unwrap_or_default();
|
||||
|
||||
let store = self
|
||||
.store_override
|
||||
.unwrap_or_else(|| provider.is_azure_responses_endpoint());
|
||||
|
||||
let req = ResponsesApiRequest {
|
||||
model,
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: self.parallel_tool_calls,
|
||||
reasoning: self.reasoning,
|
||||
store,
|
||||
stream: true,
|
||||
include: self.include,
|
||||
prompt_cache_key: self.prompt_cache_key,
|
||||
text: self.text,
|
||||
};
|
||||
|
||||
let mut body = serde_json::to_value(&req)
|
||||
.map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?;
|
||||
|
||||
if store && provider.is_azure_responses_endpoint() {
|
||||
attach_item_ids(&mut body, input);
|
||||
}
|
||||
|
||||
let mut headers = self.headers;
|
||||
headers.extend(build_conversation_headers(self.conversation_id));
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ResponsesRequest { body, headers })
|
||||
}
|
||||
}
|
||||
|
||||
fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
|
||||
let Some(input_value) = payload_json.get_mut("input") else {
|
||||
return;
|
||||
};
|
||||
let Value::Array(items) = input_value else {
|
||||
return;
|
||||
};
|
||||
|
||||
for (value, item) in items.iter_mut().zip(original_items.iter()) {
|
||||
if let ResponseItem::Reasoning { id, .. }
|
||||
| ResponseItem::Message { id: Some(id), .. }
|
||||
| ResponseItem::WebSearchCall { id: Some(id), .. }
|
||||
| ResponseItem::FunctionCall { id: Some(id), .. }
|
||||
| ResponseItem::LocalShellCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCall { id: Some(id), .. } = item
|
||||
{
|
||||
if id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(obj) = value.as_object_mut() {
|
||||
obj.insert("id".to_string(), Value::String(id.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
|
||||
fn provider(name: &str, base_url: &str) -> Provider {
|
||||
Provider {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Responses,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(50),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn azure_default_store_attaches_ids_and_headers() {
|
||||
let provider = provider("azure", "https://example.openai.azure.com/v1");
|
||||
let input = vec![
|
||||
ResponseItem::Message {
|
||||
id: Some("m1".into()),
|
||||
role: "assistant".into(),
|
||||
content: Vec::new(),
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".into(),
|
||||
content: Vec::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let request = ResponsesRequestBuilder::new("gpt-test", "inst", &input)
|
||||
.conversation(Some("conv-1".into()))
|
||||
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
|
||||
.build(&provider)
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
|
||||
|
||||
let ids: Vec<Option<String>> = request
|
||||
.body
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|item| item.get("id").and_then(|v| v.as_str().map(str::to_string)))
|
||||
.collect();
|
||||
assert_eq!(ids, vec![Some("m1".to_string()), None]);
|
||||
|
||||
assert_eq!(
|
||||
request.headers.get("conversation_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("session_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("x-openai-subagent"),
|
||||
Some(&HeaderValue::from_static("review"))
|
||||
);
|
||||
}
|
||||
}
|
||||
504
codex-rs/codex-api/src/sse/chat.rs
Normal file
504
codex-rs/codex-api/src/sse/chat.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
pub(crate) fn spawn_chat_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
) -> ResponseStream {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
pub async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
) where
|
||||
S: Stream<Item = Result<bytes::Bytes, codex_client::TransportError>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ToolCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
let mut tool_calls: HashMap<String, ToolCallState> = HashMap::new();
|
||||
let mut tool_call_order: Vec<String> = Vec::new();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
let mut completed_sent = false;
|
||||
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
t.on_sse_poll(&response, start.elapsed());
|
||||
}
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
if let Some(reasoning) = reasoning_item {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
if !completed_sent {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("SSE event: {}", sse.data);
|
||||
|
||||
if sse.data.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let value: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
debug!(
|
||||
"Failed to parse ChatCompletions SSE event: {err}, data: {}",
|
||||
&sse.data
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for choice in choices {
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
if let Some(reasoning) = delta.get("reasoning") {
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = delta.get("content") {
|
||||
if content.is_array() {
|
||||
for item in content.as_array().unwrap_or(&vec![]) {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
append_assistant_text(
|
||||
&tx_event,
|
||||
&mut assistant_item,
|
||||
text.to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = content.as_str() {
|
||||
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
|
||||
for tool_call in tool_call_values {
|
||||
let id = tool_call
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len()));
|
||||
|
||||
let call_state = tool_calls.entry(id.clone()).or_default();
|
||||
if !tool_call_order.contains(&id) {
|
||||
tool_call_order.push(id.clone());
|
||||
}
|
||||
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(fname) = func.get("name").and_then(|n| n.as_str()) {
|
||||
call_state.name = Some(fname.to_string());
|
||||
}
|
||||
if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
call_state.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = choice.get("message")
|
||||
&& let Some(reasoning) = message.get("reasoning")
|
||||
{
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str());
|
||||
if finish_reason == Some("stop") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
if !completed_sent {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
completed_sent = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if finish_reason == Some("length") {
|
||||
let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await;
|
||||
return;
|
||||
}
|
||||
|
||||
if finish_reason == Some("tool_calls") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
for call_id in tool_call_order.drain(..) {
|
||||
let state = tool_calls.remove(&call_id).unwrap_or_default();
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: state.name.unwrap_or_default(),
|
||||
arguments: state.arguments,
|
||||
call_id: call_id.clone(),
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
let content_index = content.len() as i64;
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta: text.clone(),
|
||||
content_index,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::TryStreamExt;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
fn build_body(events: &[serde_json::Value]) -> String {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
body.push_str(&format!("event: message\ndata: {e}\n\n"));
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
async fn collect_events(body: &str) -> Vec<ResponseEvent> {
|
||||
let reader = ReaderStream::new(std::io::Cursor::new(body.to_string()))
|
||||
.map_err(|err| codex_client::TransportError::Network(err.to_string()));
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
|
||||
tokio::spawn(process_chat_sse(
|
||||
reader,
|
||||
tx,
|
||||
Duration::from_millis(1000),
|
||||
None,
|
||||
));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
out.push(ev.expect("stream error"));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_multiple_tool_calls() {
|
||||
let delta_a = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{\"foo\":1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_b = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_b",
|
||||
"function": { "name": "do_b", "arguments": "{\"bar\":2}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_a, delta_b, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
|
||||
if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}"
|
||||
);
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
|
||||
if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}"
|
||||
);
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concatenates_tool_call_arguments_across_deltas() {
|
||||
let delta_name = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_1 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "arguments": "{ \"foo\":" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_2 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "arguments": "1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_tool_calls_even_when_content_and_reasoning_present() {
|
||||
let delta_content_and_tools = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"content": [{"text": "hi"}],
|
||||
"reasoning": "because",
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_content_and_tools, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::ReasoningContentDelta { .. },
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }),
|
||||
ResponseEvent::OutputTextDelta(delta),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if delta == "hi" && call_id == "call_a" && name == "do_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drops_partial_tool_calls_on_stop_finish_reason() {
|
||||
let delta_tool = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish_stop = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_tool, finish_stop]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert!(!events.iter().any(|ev| {
|
||||
matches!(
|
||||
ev,
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
|
||||
)
|
||||
}));
|
||||
assert_matches!(events.last(), Some(ResponseEvent::Completed { .. }));
|
||||
}
|
||||
}
|
||||
6
codex-rs/codex-api/src/sse/mod.rs
Normal file
6
codex-rs/codex-api/src/sse/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod chat;
|
||||
pub mod responses;
|
||||
|
||||
pub use responses::process_sse;
|
||||
pub use responses::spawn_response_stream;
|
||||
pub use responses::stream_from_fixture;
|
||||
672
codex-rs/codex-api/src/sse/responses.rs
Normal file
672
codex-rs/codex-api/src/sse/responses.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::rate_limits::parse_rate_limit;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::ByteStream;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Streams SSE events from an on-disk fixture for tests.
|
||||
pub fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
idle_timeout: Duration,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let file =
|
||||
std::fs::File::open(path.as_ref()).map_err(|err| ApiError::Stream(err.to_string()))?;
|
||||
let mut content = String::new();
|
||||
for line in std::io::BufReader::new(file).lines() {
|
||||
let line = line.map_err(|err| ApiError::Stream(err.to_string()))?;
|
||||
content.push_str(&line);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
|
||||
let reader = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
pub fn spawn_response_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> ResponseStream {
|
||||
let rate_limits = parse_rate_limit(&stream_response.headers);
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
if let Some(snapshot) = rate_limits {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
plan_type: Option<String>,
|
||||
resets_at: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
usage: Option<ResponseCompletedUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedUsage {
|
||||
input_tokens: i64,
|
||||
input_tokens_details: Option<ResponseCompletedInputTokensDetails>,
|
||||
output_tokens: i64,
|
||||
output_tokens_details: Option<ResponseCompletedOutputTokensDetails>,
|
||||
total_tokens: i64,
|
||||
}
|
||||
|
||||
impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedInputTokensDetails {
|
||||
cached_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
delta: Option<String>,
|
||||
summary_index: Option<i64>,
|
||||
content_index: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn process_sse(
|
||||
stream: ByteStream,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
) {
|
||||
let mut stream = stream.eventsource();
|
||||
let mut response_completed: Option<ResponseCompleted> = None;
|
||||
let mut response_error: Option<ApiError> = None;
|
||||
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
t.on_sse_poll(&response, start.elapsed());
|
||||
}
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
match response_completed.take() {
|
||||
Some(ResponseCompleted { id, usage }) => {
|
||||
let event = ResponseEvent::Completed {
|
||||
response_id: id,
|
||||
token_usage: usage.map(Into::into),
|
||||
};
|
||||
let _ = tx_event.send(Ok(event)).await;
|
||||
}
|
||||
None => {
|
||||
let error = response_error.unwrap_or(ApiError::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
let _ = tx_event.send(Err(error)).await;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let raw = sse.data.clone();
|
||||
trace!("SSE event: {raw}");
|
||||
|
||||
let event: SseEvent = match serde_json::from_str(&sse.data) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match event.kind.as_str() {
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemDone(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::OutputTextDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_text.delta" => {
|
||||
if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) {
|
||||
let event = ResponseEvent::ReasoningSummaryDelta {
|
||||
delta,
|
||||
summary_index,
|
||||
};
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) {
|
||||
let event = ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
};
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if event.response.is_some() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
response_error =
|
||||
Some(ApiError::Stream("response.failed event received".into()));
|
||||
|
||||
if let Some(error) = resp_val.get("error")
|
||||
&& let Ok(error) = serde_json::from_value::<Error>(error.clone())
|
||||
{
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(ApiError::ContextWindowExceeded);
|
||||
} else if is_quota_exceeded_error(&error) {
|
||||
response_error = Some(ApiError::QuotaExceeded);
|
||||
} else if is_usage_not_included(&error) {
|
||||
response_error = Some(ApiError::UsageNotIncluded);
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(ApiError::Retryable { message, delay });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
match serde_json::from_value::<ResponseCompleted>(resp_val) {
|
||||
Ok(r) => {
|
||||
response_completed = Some(r);
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(ApiError::Stream(error));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemAdded(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_part.added" => {
|
||||
if let Some(summary_index) = event.summary_index {
|
||||
let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index };
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
if err.code.as_deref() != Some("rate_limit_exceeded") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let re = rate_limit_regex();
|
||||
if let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str().to_ascii_lowercase();
|
||||
|
||||
if unit == "s" || unit.starts_with("second") {
|
||||
return Some(Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
fn is_quota_exceeded_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("insufficient_quota")
|
||||
}
|
||||
|
||||
fn is_usage_not_included(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("usage_not_included")
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static regex_lite::Regex {
|
||||
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
|
||||
#[expect(clippy::unwrap_used)]
|
||||
RE.get_or_init(|| {
|
||||
regex_lite::Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
|
||||
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent, ApiError>> {
|
||||
let mut builder = IoBuilder::new();
|
||||
for chunk in chunks {
|
||||
builder.read(chunk);
|
||||
}
|
||||
|
||||
let reader = builder.build();
|
||||
let stream =
|
||||
ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
events.push(ev);
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
let kind = e
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.expect("fixture event missing type");
|
||||
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
body.push_str(&format!("event: {kind}\n\n"));
|
||||
} else {
|
||||
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(8);
|
||||
let stream = ReaderStream::new(std::io::Cursor::new(body))
|
||||
.map_err(|err| TransportError::Network(err.to_string()));
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
out.push(ev.expect("channel closed"));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn idle_timeout() -> Duration {
|
||||
Duration::from_millis(1000)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_items_and_completed() {
|
||||
let item1 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let item2 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "World"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp1" }
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
|
||||
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
|
||||
match &events[2] {
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
assert_eq!(response_id, "resp1");
|
||||
assert!(token_usage.is_none());
|
||||
}
|
||||
other => panic!("unexpected third event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_when_missing_completed() {
|
||||
let item1 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
assert_matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
|
||||
|
||||
match &events[1] {
|
||||
Err(ApiError::Stream(msg)) => {
|
||||
assert_eq!(msg, "stream closed before response.completed")
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_when_error_event() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::Retryable { message, delay }) => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
|
||||
);
|
||||
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_with_newline_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_exceeded_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn table_driven_event_kinds() {
|
||||
struct TestCase {
|
||||
name: &'static str,
|
||||
event: serde_json::Value,
|
||||
expect_first: fn(&ResponseEvent) -> bool,
|
||||
expected_len: usize,
|
||||
}
|
||||
|
||||
fn is_created(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::Created)
|
||||
}
|
||||
fn is_output(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::OutputItemDone(_))
|
||||
}
|
||||
fn is_completed(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::Completed { .. })
|
||||
}
|
||||
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": "c",
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"input_tokens_details": null,
|
||||
"output_tokens": 0,
|
||||
"output_tokens_details": null,
|
||||
"total_tokens": 0
|
||||
},
|
||||
"output": []
|
||||
}
|
||||
});
|
||||
|
||||
let cases = vec![
|
||||
TestCase {
|
||||
name: "created",
|
||||
event: json!({"type": "response.created", "response": {}}),
|
||||
expect_first: is_created,
|
||||
expected_len: 2,
|
||||
},
|
||||
TestCase {
|
||||
name: "output_item.done",
|
||||
event: json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "hi"}
|
||||
]
|
||||
}
|
||||
}),
|
||||
expect_first: is_output,
|
||||
expected_len: 2,
|
||||
},
|
||||
TestCase {
|
||||
name: "unknown",
|
||||
event: json!({"type": "response.new_tool_event"}),
|
||||
expect_first: is_completed,
|
||||
expected_len: 1,
|
||||
},
|
||||
];
|
||||
|
||||
for case in cases {
|
||||
let mut evs = vec![case.event];
|
||||
evs.push(completed.clone());
|
||||
|
||||
let out = run_sse(evs).await;
|
||||
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
|
||||
assert!(
|
||||
(case.expect_first)(&out[0]),
|
||||
"first event mismatch in case {}",
|
||||
case.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_millis(28)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_no_delay() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5.1 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_azure() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs(35)));
|
||||
}
|
||||
}
|
||||
84
codex-rs/codex-api/src/telemetry.rs
Normal file
84
codex-rs/codex-api/src/telemetry.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use codex_client::Request;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::Response;
|
||||
use codex_client::RetryPolicy;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_client::run_with_retry;
|
||||
use http::StatusCode;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// Generic telemetry.
|
||||
pub trait SseTelemetry: Send + Sync {
|
||||
fn on_sse_poll(
|
||||
&self,
|
||||
result: &Result<
|
||||
Option<
|
||||
Result<
|
||||
eventsource_stream::Event,
|
||||
eventsource_stream::EventStreamError<TransportError>,
|
||||
>,
|
||||
>,
|
||||
tokio::time::error::Elapsed,
|
||||
>,
|
||||
duration: Duration,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) trait WithStatus {
|
||||
fn status(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
fn http_status(err: &TransportError) -> Option<StatusCode> {
|
||||
match err {
|
||||
TransportError::Http { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl WithStatus for Response {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
impl WithStatus for StreamResponse {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run_with_request_telemetry<T, F, Fut>(
|
||||
policy: RetryPolicy,
|
||||
telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
make_request: impl FnMut() -> Request,
|
||||
send: F,
|
||||
) -> Result<T, TransportError>
|
||||
where
|
||||
T: WithStatus,
|
||||
F: Clone + Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<T, TransportError>>,
|
||||
{
|
||||
// Wraps `run_with_retry` to attach per-attempt request telemetry for both
|
||||
// unary and streaming HTTP calls.
|
||||
run_with_retry(policy, make_request, move |req, attempt| {
|
||||
let telemetry = telemetry.clone();
|
||||
let send = send.clone();
|
||||
async move {
|
||||
let start = Instant::now();
|
||||
let result = send(req).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
let (status, err) = match &result {
|
||||
Ok(resp) => (Some(resp.status()), None),
|
||||
Err(err) => (http_status(err), Some(err)),
|
||||
};
|
||||
t.on_request(attempt, status, err, start.elapsed());
|
||||
}
|
||||
result
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
315
codex-rs/codex-api/tests/clients.rs
Normal file
315
codex-rs/codex-api/tests/clients.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use codex_api::AuthProvider;
|
||||
use codex_api::ChatClient;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesOptions;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use http::HeaderMap;
|
||||
use http::StatusCode;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
fn assert_path_ends_with(requests: &[Request], suffix: &str) {
|
||||
assert_eq!(requests.len(), 1);
|
||||
let url = &requests[0].url;
|
||||
assert!(
|
||||
url.ends_with(suffix),
|
||||
"expected url to end with {suffix}, got {url}"
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct RecordingState {
|
||||
stream_requests: Arc<Mutex<Vec<Request>>>,
|
||||
}
|
||||
|
||||
impl RecordingState {
|
||||
fn record(&self, req: Request) {
|
||||
let mut guard = self
|
||||
.stream_requests
|
||||
.lock()
|
||||
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
|
||||
guard.push(req);
|
||||
}
|
||||
|
||||
fn take_stream_requests(&self) -> Vec<Request> {
|
||||
let mut guard = self
|
||||
.stream_requests
|
||||
.lock()
|
||||
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
|
||||
std::mem::take(&mut *guard)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RecordingTransport {
|
||||
state: RecordingState,
|
||||
}
|
||||
|
||||
impl RecordingTransport {
|
||||
fn new(state: RecordingState) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for RecordingTransport {
|
||||
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
||||
Err(TransportError::Build("execute should not run".to_string()))
|
||||
}
|
||||
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
|
||||
self.state.record(req);
|
||||
|
||||
let stream = futures::stream::iter(Vec::<Result<Bytes, TransportError>>::new());
|
||||
Ok(StreamResponse {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
bytes: Box::pin(stream),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct NoAuth;
|
||||
|
||||
impl AuthProvider for NoAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StaticAuth {
|
||||
token: String,
|
||||
account_id: String,
|
||||
}
|
||||
|
||||
impl StaticAuth {
|
||||
fn new(token: &str, account_id: &str) -> Self {
|
||||
Self {
|
||||
token: token.to_string(),
|
||||
account_id: account_id.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthProvider for StaticAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
Some(self.token.clone())
|
||||
}
|
||||
|
||||
fn account_id(&self) -> Option<String> {
|
||||
Some(self.account_id.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(name: &str, wire: WireApi) -> Provider {
|
||||
Provider {
|
||||
name: name.to_string(),
|
||||
base_url: "https://example.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire,
|
||||
headers: HeaderMap::new(),
|
||||
retry: codex_api::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_millis(10),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FlakyTransport {
|
||||
state: Arc<Mutex<i64>>,
|
||||
}
|
||||
|
||||
impl Default for FlakyTransport {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl FlakyTransport {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn attempts(&self) -> i64 {
|
||||
*self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for FlakyTransport {
|
||||
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
||||
Err(TransportError::Build("execute should not run".to_string()))
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
let mut attempts = self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
|
||||
*attempts += 1;
|
||||
|
||||
if *attempts == 1 {
|
||||
return Err(TransportError::Network("first attempt fails".to_string()));
|
||||
}
|
||||
|
||||
let stream = futures::stream::iter(vec![Ok(Bytes::from(
|
||||
r#"event: message
|
||||
data: {"id":"resp-1","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"hi"}]}]}
|
||||
|
||||
"#,
|
||||
))]);
|
||||
|
||||
Ok(StreamResponse {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
bytes: Box::pin(stream),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn streaming_client_adds_auth_headers() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let auth = StaticAuth::new("secret-token", "acct-1");
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth);
|
||||
|
||||
let body = serde_json::json!({ "model": "gpt-test" });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_eq!(requests.len(), 1);
|
||||
let req = &requests[0];
|
||||
|
||||
let auth_header = req.headers.get(http::header::AUTHORIZATION);
|
||||
assert!(auth_header.is_some(), "missing auth header");
|
||||
assert_eq!(
|
||||
auth_header.unwrap().to_str().ok(),
|
||||
Some("Bearer secret-token")
|
||||
);
|
||||
|
||||
let account_header = req.headers.get("ChatGPT-Account-ID");
|
||||
assert!(account_header.is_some(), "missing account header");
|
||||
assert_eq!(account_header.unwrap().to_str().ok(), Some("acct-1"));
|
||||
|
||||
let accept_header = req.headers.get(http::header::ACCEPT);
|
||||
assert!(accept_header.is_some(), "missing Accept header");
|
||||
assert_eq!(
|
||||
accept_header.unwrap().to_str().ok(),
|
||||
Some("text/event-stream")
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn streaming_client_retries_on_transport_error() -> Result<()> {
|
||||
let transport = FlakyTransport::new();
|
||||
|
||||
let mut provider = provider("openai", WireApi::Responses);
|
||||
provider.retry.max_attempts = 2;
|
||||
|
||||
let client = ResponsesClient::new(transport.clone(), provider, NoAuth);
|
||||
|
||||
let prompt = codex_api::Prompt {
|
||||
instructions: "Say hi".to_string(),
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
}],
|
||||
tools: Vec::<Value>::new(),
|
||||
parallel_tool_calls: false,
|
||||
output_schema: None,
|
||||
};
|
||||
|
||||
let options = ResponsesOptions::default();
|
||||
|
||||
let _stream = client.stream_prompt("gpt-test", &prompt, options).await?;
|
||||
assert_eq!(transport.attempts(), 2);
|
||||
Ok(())
|
||||
}
|
||||
229
codex-rs/codex-api/tests/sse_end_to_end.rs
Normal file
229
codex-rs/codex-api/tests/sse_end_to_end.rs
Normal file
@@ -0,0 +1,229 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use codex_api::AggregateStreamExt;
|
||||
use codex_api::AuthProvider;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponseEvent;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::StatusCode;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FixtureSseTransport {
|
||||
body: String,
|
||||
}
|
||||
|
||||
impl FixtureSseTransport {
|
||||
fn new(body: String) -> Self {
|
||||
Self { body }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for FixtureSseTransport {
|
||||
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
||||
Err(TransportError::Build("execute should not run".to_string()))
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
let stream = futures::stream::iter(vec![Ok::<Bytes, TransportError>(Bytes::from(
|
||||
self.body.clone(),
|
||||
))]);
|
||||
Ok(StreamResponse {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
bytes: Box::pin(stream),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct NoAuth;
|
||||
|
||||
impl AuthProvider for NoAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(name: &str, wire: WireApi) -> Provider {
|
||||
Provider {
|
||||
name: name.to_string(),
|
||||
base_url: "https://example.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire,
|
||||
headers: HeaderMap::new(),
|
||||
retry: codex_api::provider::RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: false,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_responses_body(events: Vec<Value>) -> String {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
let kind = e
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_else(|| panic!("fixture event missing type in SSE fixture: {e}"));
|
||||
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
body.push_str(&format!("event: {kind}\n\n"));
|
||||
} else {
|
||||
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
|
||||
}
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> {
|
||||
let item1 = serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello"}]
|
||||
}
|
||||
});
|
||||
|
||||
let item2 = serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "World"}]
|
||||
}
|
||||
});
|
||||
|
||||
let completed = serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp1" }
|
||||
});
|
||||
|
||||
let body = build_responses_body(vec![item1, item2, completed]);
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let mut stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.await?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = stream.next().await {
|
||||
events.push(ev?);
|
||||
}
|
||||
|
||||
let events: Vec<ResponseEvent> = events
|
||||
.into_iter()
|
||||
.filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => {
|
||||
assert_eq!(role, "assistant");
|
||||
}
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => {
|
||||
assert_eq!(role, "assistant");
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
|
||||
match &events[2] {
|
||||
ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
} => {
|
||||
assert_eq!(response_id, "resp1");
|
||||
assert!(token_usage.is_none());
|
||||
}
|
||||
other => panic!("unexpected third event: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
|
||||
let delta1 = serde_json::json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": "Hello, "
|
||||
});
|
||||
|
||||
let delta2 = serde_json::json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": "world"
|
||||
});
|
||||
|
||||
let completed = serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp-agg" }
|
||||
});
|
||||
|
||||
let body = build_responses_body(vec![delta1, delta2, completed]);
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.await?;
|
||||
|
||||
let mut stream = stream.aggregate();
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = stream.next().await {
|
||||
events.push(ev?);
|
||||
}
|
||||
|
||||
let events: Vec<ResponseEvent> = events
|
||||
.into_iter()
|
||||
.filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
match &events[0] {
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||||
let mut aggregated = String::new();
|
||||
for item in content {
|
||||
if let ContentItem::OutputText { text } = item {
|
||||
aggregated.push_str(text);
|
||||
}
|
||||
}
|
||||
assert_eq!(aggregated, "Hello, world");
|
||||
}
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
|
||||
match &events[1] {
|
||||
ResponseEvent::Completed { response_id, .. } => {
|
||||
assert_eq!(response_id, "resp-agg");
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
name = "codex-backend-openapi-models"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_backend_openapi_models"
|
||||
|
||||
21
codex-rs/codex-client/Cargo.toml
Normal file
21
codex-rs/codex-client/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "codex-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
http = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "time", "sync"] }
|
||||
rand = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
8
codex-rs/codex-client/README.md
Normal file
8
codex-rs/codex-client/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# codex-client
|
||||
|
||||
Generic transport layer that wraps HTTP requests, retries, and streaming primitives without any Codex/OpenAI awareness.
|
||||
|
||||
- Defines `HttpTransport` and a default `ReqwestTransport` plus thin `Request`/`Response` types.
|
||||
- Provides retry utilities (`RetryPolicy`, `RetryOn`, `run_with_retry`, `backoff`) that callers plug into for unary and streaming calls.
|
||||
- Supplies the `sse_stream` helper to turn byte streams into raw SSE `data:` frames with idle timeouts and surfaced stream errors.
|
||||
- Consumed by higher-level crates like `codex-api`; it stays neutral on endpoints, headers, or API-specific error shapes.
|
||||
29
codex-rs/codex-client/src/error.rs
Normal file
29
codex-rs/codex-client/src/error.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use http::HeaderMap;
|
||||
use http::StatusCode;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TransportError {
|
||||
#[error("http {status}: {body:?}")]
|
||||
Http {
|
||||
status: StatusCode,
|
||||
headers: Option<HeaderMap>,
|
||||
body: Option<String>,
|
||||
},
|
||||
#[error("retry limit reached")]
|
||||
RetryLimit,
|
||||
#[error("timeout")]
|
||||
Timeout,
|
||||
#[error("network error: {0}")]
|
||||
Network(String),
|
||||
#[error("request build error: {0}")]
|
||||
Build(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum StreamError {
|
||||
#[error("stream failed: {0}")]
|
||||
Stream(String),
|
||||
#[error("timeout")]
|
||||
Timeout,
|
||||
}
|
||||
21
codex-rs/codex-client/src/lib.rs
Normal file
21
codex-rs/codex-client/src/lib.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
mod error;
|
||||
mod request;
|
||||
mod retry;
|
||||
mod sse;
|
||||
mod telemetry;
|
||||
mod transport;
|
||||
|
||||
pub use crate::error::StreamError;
|
||||
pub use crate::error::TransportError;
|
||||
pub use crate::request::Request;
|
||||
pub use crate::request::Response;
|
||||
pub use crate::retry::RetryOn;
|
||||
pub use crate::retry::RetryPolicy;
|
||||
pub use crate::retry::backoff;
|
||||
pub use crate::retry::run_with_retry;
|
||||
pub use crate::sse::sse_stream;
|
||||
pub use crate::telemetry::RequestTelemetry;
|
||||
pub use crate::transport::ByteStream;
|
||||
pub use crate::transport::HttpTransport;
|
||||
pub use crate::transport::ReqwestTransport;
|
||||
pub use crate::transport::StreamResponse;
|
||||
39
codex-rs/codex-client/src/request.rs
Normal file
39
codex-rs/codex-client/src/request.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Request {
|
||||
pub method: Method,
|
||||
pub url: String,
|
||||
pub headers: HeaderMap,
|
||||
pub body: Option<Value>,
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
pub fn new(method: Method, url: String) -> Self {
|
||||
Self {
|
||||
method,
|
||||
url,
|
||||
headers: HeaderMap::new(),
|
||||
body: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_json<T: Serialize>(mut self, body: &T) -> Self {
|
||||
self.body = serde_json::to_value(body).ok();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Response {
|
||||
pub status: http::StatusCode,
|
||||
pub headers: HeaderMap,
|
||||
pub body: Bytes,
|
||||
}
|
||||
73
codex-rs/codex-client/src/retry.rs
Normal file
73
codex-rs/codex-client/src/retry.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use crate::error::TransportError;
|
||||
use crate::request::Request;
|
||||
use rand::Rng;
|
||||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryPolicy {
|
||||
pub max_attempts: u64,
|
||||
pub base_delay: Duration,
|
||||
pub retry_on: RetryOn,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryOn {
|
||||
pub retry_429: bool,
|
||||
pub retry_5xx: bool,
|
||||
pub retry_transport: bool,
|
||||
}
|
||||
|
||||
impl RetryOn {
|
||||
pub fn should_retry(&self, err: &TransportError, attempt: u64, max_attempts: u64) -> bool {
|
||||
if attempt >= max_attempts {
|
||||
return false;
|
||||
}
|
||||
match err {
|
||||
TransportError::Http { status, .. } => {
|
||||
(self.retry_429 && status.as_u16() == 429)
|
||||
|| (self.retry_5xx && status.is_server_error())
|
||||
}
|
||||
TransportError::Timeout | TransportError::Network(_) => self.retry_transport,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backoff(base: Duration, attempt: u64) -> Duration {
|
||||
if attempt == 0 {
|
||||
return base;
|
||||
}
|
||||
let exp = 2u64.saturating_pow(attempt as u32 - 1);
|
||||
let millis = base.as_millis() as u64;
|
||||
let raw = millis.saturating_mul(exp);
|
||||
let jitter: f64 = rand::rng().random_range(0.9..1.1);
|
||||
Duration::from_millis((raw as f64 * jitter) as u64)
|
||||
}
|
||||
|
||||
pub async fn run_with_retry<T, F, Fut>(
|
||||
policy: RetryPolicy,
|
||||
mut make_req: impl FnMut() -> Request,
|
||||
op: F,
|
||||
) -> Result<T, TransportError>
|
||||
where
|
||||
F: Fn(Request, u64) -> Fut,
|
||||
Fut: Future<Output = Result<T, TransportError>>,
|
||||
{
|
||||
for attempt in 0..=policy.max_attempts {
|
||||
let req = make_req();
|
||||
match op(req, attempt).await {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(err)
|
||||
if policy
|
||||
.retry_on
|
||||
.should_retry(&err, attempt, policy.max_attempts) =>
|
||||
{
|
||||
sleep(backoff(policy.base_delay, attempt + 1)).await;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
Err(TransportError::RetryLimit)
|
||||
}
|
||||
48
codex-rs/codex-client/src/sse.rs
Normal file
48
codex-rs/codex-client/src/sse.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use crate::error::StreamError;
|
||||
use crate::transport::ByteStream;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
/// Minimal SSE helper that forwards raw `data:` frames as UTF-8 strings.
|
||||
///
|
||||
/// Errors and idle timeouts are sent as `Err(StreamError)` before the task exits.
|
||||
pub fn sse_stream(
|
||||
stream: ByteStream,
|
||||
idle_timeout: Duration,
|
||||
tx: mpsc::Sender<Result<String, StreamError>>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream
|
||||
.map(|res| res.map_err(|e| StreamError::Stream(e.to_string())))
|
||||
.eventsource();
|
||||
|
||||
loop {
|
||||
match timeout(idle_timeout, stream.next()).await {
|
||||
Ok(Some(Ok(ev))) => {
|
||||
if tx.send(Ok(ev.data.clone())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx.send(Err(StreamError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
let _ = tx
|
||||
.send(Err(StreamError::Stream(
|
||||
"stream closed before completion".into(),
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx.send(Err(StreamError::Timeout)).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
14
codex-rs/codex-client/src/telemetry.rs
Normal file
14
codex-rs/codex-client/src/telemetry.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::error::TransportError;
|
||||
use http::StatusCode;
|
||||
use std::time::Duration;
|
||||
|
||||
/// API specific telemetry.
|
||||
pub trait RequestTelemetry: Send + Sync {
|
||||
fn on_request(
|
||||
&self,
|
||||
attempt: u64,
|
||||
status: Option<StatusCode>,
|
||||
error: Option<&TransportError>,
|
||||
duration: Duration,
|
||||
);
|
||||
}
|
||||
107
codex-rs/codex-client/src/transport.rs
Normal file
107
codex-rs/codex-client/src/transport.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::error::TransportError;
|
||||
use crate::request::Request;
|
||||
use crate::request::Response;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::BoxStream;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use http::StatusCode;
|
||||
|
||||
pub type ByteStream = BoxStream<'static, Result<Bytes, TransportError>>;
|
||||
|
||||
pub struct StreamResponse {
|
||||
pub status: StatusCode,
|
||||
pub headers: HeaderMap,
|
||||
pub bytes: ByteStream,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait HttpTransport: Send + Sync {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError>;
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ReqwestTransport {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl ReqwestTransport {
|
||||
pub fn new(client: reqwest::Client) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
fn build(&self, req: Request) -> Result<reqwest::RequestBuilder, TransportError> {
|
||||
let mut builder = self
|
||||
.client
|
||||
.request(
|
||||
Method::from_bytes(req.method.as_str().as_bytes()).unwrap_or(Method::GET),
|
||||
&req.url,
|
||||
)
|
||||
.headers(req.headers);
|
||||
if let Some(timeout) = req.timeout {
|
||||
builder = builder.timeout(timeout);
|
||||
}
|
||||
if let Some(body) = req.body {
|
||||
builder = builder.json(&body);
|
||||
}
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn map_error(err: reqwest::Error) -> TransportError {
|
||||
if err.is_timeout() {
|
||||
TransportError::Timeout
|
||||
} else {
|
||||
TransportError::Network(err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for ReqwestTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
let builder = self.build(req)?;
|
||||
let resp = builder.send().await.map_err(Self::map_error)?;
|
||||
let status = resp.status();
|
||||
let headers = resp.headers().clone();
|
||||
let bytes = resp.bytes().await.map_err(Self::map_error)?;
|
||||
if !status.is_success() {
|
||||
let body = String::from_utf8(bytes.to_vec()).ok();
|
||||
return Err(TransportError::Http {
|
||||
status,
|
||||
headers: Some(headers),
|
||||
body,
|
||||
});
|
||||
}
|
||||
Ok(Response {
|
||||
status,
|
||||
headers,
|
||||
body: bytes,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
|
||||
let builder = self.build(req)?;
|
||||
let resp = builder.send().await.map_err(Self::map_error)?;
|
||||
let status = resp.status();
|
||||
let headers = resp.headers().clone();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.ok();
|
||||
return Err(TransportError::Http {
|
||||
status,
|
||||
headers: Some(headers),
|
||||
body,
|
||||
});
|
||||
}
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(Self::map_error));
|
||||
Ok(StreamResponse {
|
||||
status,
|
||||
headers,
|
||||
bytes: Box::pin(stream),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-common"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-core"
|
||||
version = { workspace = true }
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
@@ -17,12 +18,13 @@ askama = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
chardetng = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-api = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-client = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-git = { workspace = true }
|
||||
@@ -36,8 +38,8 @@ codex-utils-string = { workspace = true }
|
||||
codex-windows-sandbox = { package = "codex-windows-sandbox", path = "../windows-sandbox-rs" }
|
||||
dirs = { workspace = true }
|
||||
dunce = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
encoding_rs = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
http = { workspace = true }
|
||||
@@ -45,8 +47,10 @@ indexmap = { workspace = true }
|
||||
keyring = { workspace = true, features = ["crypto-rust"] }
|
||||
libc = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
@@ -56,9 +60,6 @@ sha2 = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
url = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
test-case = "3.3.1"
|
||||
test-log = { workspace = true }
|
||||
@@ -82,15 +83,19 @@ toml_edit = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tree-sitter = { workspace = true }
|
||||
tree-sitter-bash = { workspace = true }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v4", "v5"] }
|
||||
which = { workspace = true }
|
||||
wildmatch = { workspace = true }
|
||||
|
||||
[features]
|
||||
deterministic_process_ids = []
|
||||
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
keyring = { workspace = true, features = ["linux-native-async-persistent"] }
|
||||
landlock = { workspace = true }
|
||||
seccompiler = { workspace = true }
|
||||
keyring = { workspace = true, features = ["linux-native-async-persistent"] }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
core-foundation = "0.9"
|
||||
@@ -114,6 +119,7 @@ keyring = { workspace = true, features = ["sync-secret-service"] }
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
codex-arg0 = { workspace = true }
|
||||
codex-core = { path = ".", features = ["deterministic_process_ids"] }
|
||||
core_test_support = { workspace = true }
|
||||
ctor = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
|
||||
154
codex-rs/core/src/api_bridge.rs
Normal file
154
codex-rs/core/src/api_bridge.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use codex_api::AuthProvider as ApiAuthProvider;
|
||||
use codex_api::TransportError;
|
||||
use codex_api::error::ApiError;
|
||||
use codex_api::rate_limits::parse_rate_limit;
|
||||
use http::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::error::UsageLimitReachedError;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
|
||||
match err {
|
||||
ApiError::ContextWindowExceeded => CodexErr::ContextWindowExceeded,
|
||||
ApiError::QuotaExceeded => CodexErr::QuotaExceeded,
|
||||
ApiError::UsageNotIncluded => CodexErr::UsageNotIncluded,
|
||||
ApiError::Retryable { message, delay } => CodexErr::Stream(message, delay),
|
||||
ApiError::Stream(msg) => CodexErr::Stream(msg, None),
|
||||
ApiError::Api { status, message } => CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body: message,
|
||||
request_id: None,
|
||||
}),
|
||||
ApiError::Transport(transport) => match transport {
|
||||
TransportError::Http {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
} => {
|
||||
if status == http::StatusCode::INTERNAL_SERVER_ERROR {
|
||||
CodexErr::InternalServerError
|
||||
} else if status == http::StatusCode::TOO_MANY_REQUESTS {
|
||||
if let Some(body) = body
|
||||
&& let Ok(err) = serde_json::from_str::<UsageErrorResponse>(&body)
|
||||
{
|
||||
if err.error.error_type.as_deref() == Some("usage_limit_reached") {
|
||||
let rate_limits = headers.as_ref().and_then(parse_rate_limit);
|
||||
let resets_at = err
|
||||
.error
|
||||
.resets_at
|
||||
.and_then(|seconds| DateTime::<Utc>::from_timestamp(seconds, 0));
|
||||
return CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type: err.error.plan_type,
|
||||
resets_at,
|
||||
rate_limits,
|
||||
});
|
||||
} else if err.error.error_type.as_deref() == Some("usage_not_included") {
|
||||
return CodexErr::UsageNotIncluded;
|
||||
}
|
||||
}
|
||||
|
||||
CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status,
|
||||
request_id: extract_request_id(headers.as_ref()),
|
||||
})
|
||||
} else {
|
||||
CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body: body.unwrap_or_default(),
|
||||
request_id: extract_request_id(headers.as_ref()),
|
||||
})
|
||||
}
|
||||
}
|
||||
TransportError::RetryLimit => CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status: http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
request_id: None,
|
||||
}),
|
||||
TransportError::Timeout => CodexErr::Timeout,
|
||||
TransportError::Network(msg) | TransportError::Build(msg) => {
|
||||
CodexErr::Stream(msg, None)
|
||||
}
|
||||
},
|
||||
ApiError::RateLimit(msg) => CodexErr::Stream(msg, None),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_request_id(headers: Option<&HeaderMap>) -> Option<String> {
|
||||
headers.and_then(|map| {
|
||||
["cf-ray", "x-request-id", "x-oai-request-id"]
|
||||
.iter()
|
||||
.find_map(|name| {
|
||||
map.get(*name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn auth_provider_from_auth(
|
||||
auth: Option<CodexAuth>,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> crate::error::Result<CoreAuthProvider> {
|
||||
if let Some(api_key) = provider.api_key()? {
|
||||
return Ok(CoreAuthProvider {
|
||||
token: Some(api_key),
|
||||
account_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(token) = provider.experimental_bearer_token.clone() {
|
||||
return Ok(CoreAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(auth) = auth {
|
||||
let token = auth.get_token().await?;
|
||||
Ok(CoreAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: auth.get_account_id(),
|
||||
})
|
||||
} else {
|
||||
Ok(CoreAuthProvider {
|
||||
token: None,
|
||||
account_id: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsageErrorResponse {
|
||||
error: UsageErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsageErrorBody {
|
||||
#[serde(rename = "type")]
|
||||
error_type: Option<String>,
|
||||
plan_type: Option<PlanType>,
|
||||
resets_at: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct CoreAuthProvider {
|
||||
token: Option<String>,
|
||||
account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ApiAuthProvider for CoreAuthProvider {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
self.token.clone()
|
||||
}
|
||||
|
||||
fn account_id(&self) -> Option<String> {
|
||||
self.account_id.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,981 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model_family: &ModelFamily,
|
||||
client: &CodexHttpClient,
|
||||
provider: &ModelProviderInfo,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
session_source: &SessionSource,
|
||||
) -> Result<ResponseStream> {
|
||||
if prompt.output_schema.is_some() {
|
||||
return Err(CodexErr::UnsupportedOperation(
|
||||
"output_schema is not supported for Chat Completions API".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model_family);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
ResponseItem::CompactionSummary { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => text.push_str(segment),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Build content either as a plain string (typical for assistant text)
|
||||
// or as an array of content items when images are present (user/tool multimodal).
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
// For assistant messages, always send a plain string for compatibility.
|
||||
// For user messages, if an image is present, send an array of content items.
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_else(|| "".to_string()),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
// Prefer structured content items when available (e.g., images)
|
||||
// otherwise fall back to the legacy plain-string content.
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<serde_json::Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::CompactionSummary { .. } => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"POST to {}: {}",
|
||||
provider.get_full_url(&None),
|
||||
payload.to_string()
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = provider.create_request_builder(client, &None).await?;
|
||||
|
||||
// Include subagent header only for subagent sessions.
|
||||
if let SessionSource::SubAgent(sub) = session_source.clone() {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label
|
||||
} else {
|
||||
serde_json::to_value(&sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
let res = otel_event_manager
|
||||
.log_request(attempt, || {
|
||||
req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
})
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager.clone(),
|
||||
));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
let content_index = content.len() as i64;
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta: text.clone(),
|
||||
content_index,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
// keep collecting the pieces here and forward a single
|
||||
// `ResponseItem::FunctionCall` once the call is complete.
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
|
||||
loop {
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(e.to_string(), None)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
// Emit any finalized items before closing so downstream consumers receive
|
||||
// terminal events for both assistant content and raw reasoning.
|
||||
if let Some(item) = assistant_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
if let Some(item) = reasoning_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
trace!("chat_completions received SSE chunk: {chunk:?}");
|
||||
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens as streaming deltas.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
|
||||
}
|
||||
|
||||
// Forward any reasoning/thinking deltas if present.
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|tc| tc.as_array())
|
||||
&& let Some(tool_call) = tool_calls.first()
|
||||
{
|
||||
// Mark that we have an active function call in progress.
|
||||
fn_call_state.active = true;
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
}
|
||||
|
||||
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit end-of-turn when finish_reason signals completion.
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str())
|
||||
&& !finish_reason.is_empty()
|
||||
{
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// First, flush the terminal raw reasoning so UIs can finalize
|
||||
// the reasoning stream before any exec/tool events begin.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Then emit the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call. Emit the final assistant message
|
||||
// as a single OutputItemDone so non-delta consumers see the result.
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Prepare for potential next turn (should not happen in same stream).
|
||||
// fn_call_state = FunctionCallState::default();
|
||||
|
||||
return; // End processing for this SSE stream.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning =
|
||||
codex_protocol::models::ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![
|
||||
codex_protocol::models::ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
// Return the first pending event now.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
}))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
})));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let agg_stream = client.stream(&prompt).await?.aggregate();
|
||||
/// while let Some(event) = agg_stream.next().await {
|
||||
/// // event now contains cumulative text
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,11 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
pub use codex_api::common::ResponseEvent;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
@@ -184,104 +179,6 @@ fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {
|
||||
Some((remainder, total_lines))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta {
|
||||
delta: String,
|
||||
summary_index: i64,
|
||||
},
|
||||
ReasoningContentDelta {
|
||||
delta: String,
|
||||
content_index: i64,
|
||||
},
|
||||
ReasoningSummaryPartAdded {
|
||||
summary_index: i64,
|
||||
},
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextFormat {
|
||||
pub(crate) r#type: TextFormatType,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) schema: Value,
|
||||
pub(crate) name: String,
|
||||
}
|
||||
|
||||
/// Controls under the `text` field in the Responses API for GPT-5.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request object that is serialized as JSON and POST'ed when using the
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) model: &'a str,
|
||||
pub(crate) instructions: &'a str,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub(crate) mod tools {
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -341,25 +238,6 @@ pub(crate) mod tools {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
) -> Option<TextControls> {
|
||||
if verbosity.is_none() && output_schema.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
schema: schema.clone(),
|
||||
name: "codex_output_schema".to_string(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
@@ -375,6 +253,10 @@ impl Stream for ResponseStream {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
use codex_api::ResponsesApiRequest;
|
||||
use codex_api::common::OpenAiVerbosity;
|
||||
use codex_api::common::TextControls;
|
||||
use codex_api::create_text_param_for_request;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::SandboxState;
|
||||
use crate::client_common::REVIEW_PROMPT;
|
||||
use crate::compact;
|
||||
use crate::compact::run_inline_auto_compact_task;
|
||||
@@ -34,6 +35,7 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
@@ -44,6 +46,7 @@ use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -103,6 +106,8 @@ use crate::shell;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
use crate::state::SessionState;
|
||||
use crate::status::ComponentHealth;
|
||||
use crate::status::maybe_codex_status_warning;
|
||||
use crate::tasks::GhostSnapshotTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tasks::SessionTask;
|
||||
@@ -448,6 +453,16 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn replace_codex_backend_status(
|
||||
&self,
|
||||
status: ComponentHealth,
|
||||
) -> Option<ComponentHealth> {
|
||||
let mut guard = self.services.codex_backend_status.lock().await;
|
||||
let previous = *guard;
|
||||
*guard = Some(status);
|
||||
previous
|
||||
}
|
||||
|
||||
async fn new(
|
||||
session_configuration: SessionConfiguration,
|
||||
config: Arc<Config>,
|
||||
@@ -564,6 +579,7 @@ impl Session {
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
otel_event_manager,
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
codex_backend_status: Mutex::new(None),
|
||||
};
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
@@ -612,6 +628,22 @@ impl Session {
|
||||
)
|
||||
.await;
|
||||
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: session_configuration.sandbox_policy.clone(),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: session_configuration.cwd.clone(),
|
||||
};
|
||||
if let Err(e) = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.notify_sandbox_state_change(&sandbox_state)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to notify sandbox state change: {e}");
|
||||
}
|
||||
|
||||
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
|
||||
sess.record_initial_history(initial_history).await;
|
||||
|
||||
@@ -642,6 +674,11 @@ impl Session {
|
||||
format!("auto-compact-{id}")
|
||||
}
|
||||
|
||||
async fn get_total_token_usage(&self) -> i64 {
|
||||
let state = self.state.lock().await;
|
||||
state.get_total_token_usage()
|
||||
}
|
||||
|
||||
async fn record_initial_history(&self, conversation_history: InitialHistory) {
|
||||
let turn_context = self.new_turn(SessionSettingsUpdate::default()).await;
|
||||
match conversation_history {
|
||||
@@ -940,6 +977,19 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
&self,
|
||||
server_name: String,
|
||||
id: RequestId,
|
||||
response: ElicitationResponse,
|
||||
) -> anyhow::Result<()> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.resolve_elicitation(server_name, id, response)
|
||||
}
|
||||
|
||||
/// Records input items: always append to conversation history and
|
||||
/// persist these response items to rollout.
|
||||
pub(crate) async fn record_conversation_items(
|
||||
@@ -1413,6 +1463,13 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Op::ResolveElicitation {
|
||||
server_name,
|
||||
request_id,
|
||||
decision,
|
||||
} => {
|
||||
handlers::resolve_elicitation(&sess, server_name, request_id, decision).await;
|
||||
}
|
||||
Op::Shutdown => {
|
||||
if handlers::shutdown(&sess, sub.id.clone()).await {
|
||||
break;
|
||||
@@ -1452,6 +1509,9 @@ mod handlers {
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use mcp_types::RequestId;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
@@ -1535,6 +1595,32 @@ mod handlers {
|
||||
*previous_context = Some(turn_context);
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
sess: &Arc<Session>,
|
||||
server_name: String,
|
||||
request_id: RequestId,
|
||||
decision: codex_protocol::approvals::ElicitationAction,
|
||||
) {
|
||||
let action = match decision {
|
||||
codex_protocol::approvals::ElicitationAction::Accept => ElicitationAction::Accept,
|
||||
codex_protocol::approvals::ElicitationAction::Decline => ElicitationAction::Decline,
|
||||
codex_protocol::approvals::ElicitationAction::Cancel => ElicitationAction::Cancel,
|
||||
};
|
||||
let response = ElicitationResponse {
|
||||
action,
|
||||
content: None,
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.resolve_elicitation(server_name, request_id, response)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
error = %err,
|
||||
"failed to resolve elicitation request in session"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn exec_approval(sess: &Arc<Session>, id: String, decision: ReviewDecision) {
|
||||
match decision {
|
||||
ReviewDecision::Abort => {
|
||||
@@ -1672,6 +1758,10 @@ mod handlers {
|
||||
|
||||
pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
sess.services
|
||||
.unified_exec_manager
|
||||
.terminate_all_sessions()
|
||||
.await;
|
||||
info!("Shutting down Codex instance");
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
@@ -1890,20 +1980,13 @@ pub(crate) async fn run_task(
|
||||
.await
|
||||
{
|
||||
Ok(turn_output) => {
|
||||
let TurnRunResult {
|
||||
processed_items,
|
||||
total_token_usage,
|
||||
} = turn_output;
|
||||
let processed_items = turn_output;
|
||||
let limit = turn_context
|
||||
.client
|
||||
.get_auto_compact_token_limit()
|
||||
.unwrap_or(i64::MAX);
|
||||
let total_usage_tokens = total_token_usage
|
||||
.as_ref()
|
||||
.map(TokenUsage::tokens_in_context_window);
|
||||
let token_limit_reached = total_usage_tokens
|
||||
.map(|tokens| tokens >= limit)
|
||||
.unwrap_or(false);
|
||||
let total_usage_tokens = sess.get_total_token_usage().await;
|
||||
let token_limit_reached = total_usage_tokens >= limit;
|
||||
let (responses, items_to_record_in_conversation_history) =
|
||||
process_items(processed_items, &sess, &turn_context).await;
|
||||
|
||||
@@ -1960,7 +2043,7 @@ async fn run_turn(
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let mcp_tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
@@ -2091,12 +2174,6 @@ pub struct ProcessedResponseItem {
|
||||
pub response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TurnRunResult {
|
||||
processed_items: Vec<ProcessedResponseItem>,
|
||||
total_token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn try_run_turn(
|
||||
router: Arc<ToolRouter>,
|
||||
@@ -2105,7 +2182,7 @@ async fn try_run_turn(
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
@@ -2116,6 +2193,19 @@ async fn try_run_turn(
|
||||
});
|
||||
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let sess_clone = Arc::clone(&sess);
|
||||
let turn_context_clone = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move {
|
||||
if let Some(message) = maybe_codex_status_warning(sess_clone.as_ref()).await {
|
||||
sess_clone
|
||||
.send_event(
|
||||
&turn_context_clone,
|
||||
EventMsg::Warning(WarningEvent { message }),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
@@ -2267,12 +2357,7 @@ async fn try_run_turn(
|
||||
sess.send_event(&turn_context, msg).await;
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
processed_items,
|
||||
total_token_usage: token_usage.clone(),
|
||||
};
|
||||
|
||||
return Ok(result);
|
||||
return Ok(processed_items);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
// In review child threads, suppress assistant text deltas; the
|
||||
@@ -2287,7 +2372,7 @@ async fn try_run_turn(
|
||||
sess.send_event(&turn_context, EventMsg::AgentMessageContentDelta(event))
|
||||
.await;
|
||||
} else {
|
||||
error_or_panic("ReasoningSummaryDelta without active item".to_string());
|
||||
error_or_panic("OutputTextDelta without active item".to_string());
|
||||
}
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryDelta {
|
||||
@@ -2381,6 +2466,9 @@ use crate::features::Features;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context_with_rx;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -2632,6 +2720,7 @@ mod tests {
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
otel_event_manager: otel_event_manager.clone(),
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
codex_backend_status: Mutex::new(None),
|
||||
};
|
||||
|
||||
let turn_context = Session::make_turn_context(
|
||||
@@ -2657,7 +2746,7 @@ mod tests {
|
||||
|
||||
// Like make_session_and_context, but returns Arc<Session> and the event receiver
|
||||
// so tests can assert on emitted events.
|
||||
fn make_session_and_context_with_rx() -> (
|
||||
pub(crate) fn make_session_and_context_with_rx() -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
@@ -2710,6 +2799,7 @@ mod tests {
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
otel_event_manager: otel_event_manager.clone(),
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
codex_backend_status: Mutex::new(None),
|
||||
};
|
||||
|
||||
let turn_context = Arc::new(Session::make_turn_context(
|
||||
|
||||
@@ -13,6 +13,8 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::Submission;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::AuthManager;
|
||||
@@ -60,14 +62,13 @@ pub(crate) async fn run_codex_conversation_interactive(
|
||||
let parent_ctx_clone = Arc::clone(&parent_ctx);
|
||||
let codex_for_events = Arc::clone(&codex);
|
||||
tokio::spawn(async move {
|
||||
let _ = forward_events(
|
||||
forward_events(
|
||||
codex_for_events,
|
||||
tx_sub,
|
||||
parent_session_clone,
|
||||
parent_ctx_clone,
|
||||
cancel_token_events.clone(),
|
||||
cancel_token_events,
|
||||
)
|
||||
.or_cancel(&cancel_token_events)
|
||||
.await;
|
||||
});
|
||||
|
||||
@@ -156,53 +157,92 @@ async fn forward_events(
|
||||
parent_ctx: Arc<TurnContext>,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
while let Ok(event) = codex.next_event().await {
|
||||
match event {
|
||||
// ignore all legacy delta events
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::SessionConfigured(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ExecApprovalRequest(event),
|
||||
} => {
|
||||
// Initiate approval via parent session; do not surface to consumer.
|
||||
handle_exec_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
let cancelled = cancel_token.cancelled();
|
||||
tokio::pin!(cancelled);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut cancelled => {
|
||||
shutdown_delegate(&codex).await;
|
||||
break;
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(event),
|
||||
} => {
|
||||
handle_patch_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
let _ = tx_sub.send(other).await;
|
||||
event = codex.next_event() => {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(_) => break,
|
||||
};
|
||||
match event {
|
||||
// ignore all legacy delta events
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
|
||||
} => {}
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::SessionConfigured(_),
|
||||
} => {}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ExecApprovalRequest(event),
|
||||
} => {
|
||||
// Initiate approval via parent session; do not surface to consumer.
|
||||
handle_exec_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(event),
|
||||
} => {
|
||||
handle_patch_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
match tx_sub.send(other).or_cancel(&cancel_token).await {
|
||||
Ok(Ok(())) => {}
|
||||
_ => {
|
||||
shutdown_delegate(&codex).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ask the delegate to stop and drain its events so background sends do not hit a closed channel.
|
||||
async fn shutdown_delegate(codex: &Codex) {
|
||||
let _ = codex.submit(Op::Interrupt).await;
|
||||
let _ = codex.submit(Op::Shutdown {}).await;
|
||||
|
||||
let _ = timeout(Duration::from_millis(500), async {
|
||||
while let Ok(event) = codex.next_event().await {
|
||||
if matches!(
|
||||
event.msg,
|
||||
EventMsg::TurnAborted(_) | EventMsg::TaskComplete(_)
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Forward ops from a caller to a sub-agent, respecting cancellation.
|
||||
async fn forward_ops(
|
||||
codex: Arc<Codex>,
|
||||
@@ -298,3 +338,85 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_channel::bounded;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RawResponseItemEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() {
|
||||
let (tx_events, rx_events) = bounded(1);
|
||||
let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let codex = Arc::new(Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
rx_event: rx_events,
|
||||
});
|
||||
|
||||
let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx();
|
||||
|
||||
let (tx_out, rx_out) = bounded(1);
|
||||
tx_out
|
||||
.send(Event {
|
||||
id: "full".to_string(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
reason: TurnAbortReason::Interrupted,
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let forward = tokio::spawn(forward_events(
|
||||
Arc::clone(&codex),
|
||||
tx_out.clone(),
|
||||
session,
|
||||
ctx,
|
||||
cancel.clone(),
|
||||
));
|
||||
|
||||
tx_events
|
||||
.send(Event {
|
||||
id: "evt".to_string(),
|
||||
msg: EventMsg::RawResponseItem(RawResponseItemEvent {
|
||||
item: ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
drop(tx_events);
|
||||
cancel.cancel();
|
||||
timeout(std::time::Duration::from_millis(1000), forward)
|
||||
.await
|
||||
.expect("forward_events hung")
|
||||
.expect("forward_events join error");
|
||||
|
||||
let received = rx_out.recv().await.expect("prefilled event missing");
|
||||
assert_eq!("full", received.id);
|
||||
let mut ops = Vec::new();
|
||||
while let Ok(sub) = rx_sub.try_recv() {
|
||||
ops.push(sub.op);
|
||||
}
|
||||
assert!(
|
||||
ops.iter().any(|op| matches!(op, Op::Interrupt)),
|
||||
"expected Interrupt op after cancellation"
|
||||
);
|
||||
assert!(
|
||||
ops.iter().any(|op| matches!(op, Op::Shutdown)),
|
||||
"expected Shutdown op after cancellation"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ use crate::codex::get_last_assistant_message_from_turn;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::features::Feature;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::ContextCompactedEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
@@ -175,9 +175,7 @@ async fn run_compact_task_inner(
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
});
|
||||
let event = EventMsg::ContextCompacted(ContextCompactedEvent {});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
|
||||
let warning = EventMsg::Warning(WarningEvent {
|
||||
|
||||
@@ -4,8 +4,8 @@ use crate::Prompt;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::ContextCompactedEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
@@ -74,9 +74,7 @@ async fn run_remote_compact_task_inner_impl(
|
||||
sess.persist_rollout_items(&[RolloutItem::Compacted(compacted_item)])
|
||||
.await;
|
||||
|
||||
let event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
});
|
||||
let event = EventMsg::ContextCompacted(ContextCompactedEvent {});
|
||||
sess.send_event(turn_context, event).await;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -70,7 +70,7 @@ pub const GPT_5_CODEX_MEDIUM_MODEL: &str = "gpt-5.1-codex";
|
||||
/// the context window.
|
||||
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
|
||||
|
||||
pub(crate) const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
pub const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
/// Application configuration loaded from disk and merged with overrides.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -267,6 +267,11 @@ pub struct Config {
|
||||
/// Collection of various notices we show the user
|
||||
pub notices: Notice,
|
||||
|
||||
/// When `true`, checks for Codex updates on startup and surfaces update prompts.
|
||||
/// Set to `false` only if your Codex updates are centrally managed.
|
||||
/// Defaults to `true`.
|
||||
pub check_for_update_on_startup: bool,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -685,6 +690,11 @@ pub struct ConfigToml {
|
||||
#[serde(default)]
|
||||
pub features: Option<FeaturesToml>,
|
||||
|
||||
/// When `true`, checks for Codex updates on startup and surfaces update prompts.
|
||||
/// Set to `false` only if your Codex updates are centrally managed.
|
||||
/// Defaults to `true`.
|
||||
pub check_for_update_on_startup: Option<bool>,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -1162,6 +1172,8 @@ impl Config {
|
||||
.or(cfg.review_model)
|
||||
.unwrap_or_else(default_review_model);
|
||||
|
||||
let check_for_update_on_startup = cfg.check_for_update_on_startup.unwrap_or(true);
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
review_model,
|
||||
@@ -1238,6 +1250,7 @@ impl Config {
|
||||
active_project,
|
||||
windows_wsl_setup_acknowledged: cfg.windows_wsl_setup_acknowledged.unwrap_or(false),
|
||||
notices: cfg.notice.unwrap_or_default(),
|
||||
check_for_update_on_startup,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
.tui
|
||||
@@ -2992,6 +3005,7 @@ model_verbosity = "high"
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
check_for_update_on_startup: true,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
animations: true,
|
||||
@@ -3064,6 +3078,7 @@ model_verbosity = "high"
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
check_for_update_on_startup: true,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
animations: true,
|
||||
@@ -3151,6 +3166,7 @@ model_verbosity = "high"
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
check_for_update_on_startup: true,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
animations: true,
|
||||
@@ -3224,6 +3240,7 @@ model_verbosity = "high"
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
check_for_update_on_startup: true,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
animations: true,
|
||||
|
||||
@@ -11,15 +11,15 @@ use toml::Value as TomlValue;
|
||||
#[cfg(unix)]
|
||||
const CODEX_MANAGED_CONFIG_SYSTEM_PATH: &str = "/etc/codex/managed_config.toml";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LoadedConfigLayers {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadedConfigLayers {
|
||||
pub base: TomlValue,
|
||||
pub managed_config: Option<TomlValue>,
|
||||
pub managed_preferences: Option<TomlValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct LoaderOverrides {
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct LoaderOverrides {
|
||||
pub managed_config_path: Option<PathBuf>,
|
||||
#[cfg(target_os = "macos")]
|
||||
pub managed_preferences_base64: Option<String>,
|
||||
@@ -47,11 +47,15 @@ pub async fn load_config_as_toml(codex_home: &Path) -> io::Result<TomlValue> {
|
||||
load_config_as_toml_with_overrides(codex_home, LoaderOverrides::default()).await
|
||||
}
|
||||
|
||||
pub async fn load_config_layers(codex_home: &Path) -> io::Result<LoadedConfigLayers> {
|
||||
load_config_layers_with_overrides(codex_home, LoaderOverrides::default()).await
|
||||
}
|
||||
|
||||
fn default_empty_table() -> TomlValue {
|
||||
TomlValue::Table(Default::default())
|
||||
}
|
||||
|
||||
pub(crate) async fn load_config_layers_with_overrides(
|
||||
pub async fn load_config_layers_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
@@ -130,7 +134,7 @@ async fn read_config_from_path(
|
||||
}
|
||||
|
||||
/// Merge config `overlay` into `base`, giving `overlay` precedence.
|
||||
pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
pub fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
if let TomlValue::Table(overlay_table) = overlay
|
||||
&& let TomlValue::Table(base_table) = base
|
||||
{
|
||||
@@ -147,6 +151,10 @@ pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
}
|
||||
|
||||
fn managed_config_default_path(codex_home: &Path) -> PathBuf {
|
||||
if let Ok(path) = std::env::var("CODEX_MANAGED_CONFIG_PATH") {
|
||||
return PathBuf::from(path);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let _ = codex_home;
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::codex::TurnContext;
|
||||
use crate::context_manager::normalize;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::truncate::approx_token_count;
|
||||
use crate::truncate::approx_tokens_from_byte_count;
|
||||
use crate::truncate::truncate_function_output_items_with_policy;
|
||||
use crate::truncate::truncate_text;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
@@ -83,9 +84,19 @@ impl ContextManager {
|
||||
.unwrap_or(i64::MAX);
|
||||
|
||||
let items_tokens = self.items.iter().fold(0i64, |acc, item| {
|
||||
let serialized = serde_json::to_string(item).unwrap_or_default();
|
||||
let item_tokens = i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX);
|
||||
acc.saturating_add(item_tokens)
|
||||
acc + match item {
|
||||
ResponseItem::Reasoning {
|
||||
encrypted_content: Some(content),
|
||||
..
|
||||
}
|
||||
| ResponseItem::CompactionSummary {
|
||||
encrypted_content: content,
|
||||
} => estimate_reasoning_length(content.len()) as i64,
|
||||
item => {
|
||||
let serialized = serde_json::to_string(item).unwrap_or_default();
|
||||
i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Some(base_tokens.saturating_add(items_tokens))
|
||||
@@ -119,6 +130,46 @@ impl ContextManager {
|
||||
);
|
||||
}
|
||||
|
||||
fn get_non_last_reasoning_items_tokens(&self) -> usize {
|
||||
// get reasoning items excluding all the ones after the last user message
|
||||
let Some(last_user_index) = self
|
||||
.items
|
||||
.iter()
|
||||
.rposition(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user"))
|
||||
else {
|
||||
return 0usize;
|
||||
};
|
||||
|
||||
let total_reasoning_bytes = self
|
||||
.items
|
||||
.iter()
|
||||
.take(last_user_index)
|
||||
.filter_map(|item| {
|
||||
if let ResponseItem::Reasoning {
|
||||
encrypted_content: Some(content),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
Some(content.len())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(estimate_reasoning_length)
|
||||
.fold(0usize, usize::saturating_add);
|
||||
|
||||
let token_estimate = approx_tokens_from_byte_count(total_reasoning_bytes);
|
||||
token_estimate as usize
|
||||
}
|
||||
|
||||
pub(crate) fn get_total_token_usage(&self) -> i64 {
|
||||
self.token_info
|
||||
.as_ref()
|
||||
.map(|info| info.last_token_usage.total_tokens)
|
||||
.unwrap_or(0)
|
||||
.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
|
||||
}
|
||||
|
||||
/// This function enforces a couple of invariants on the in-memory history:
|
||||
/// 1. every call (function/custom) has a corresponding output entry
|
||||
/// 2. every output has a corresponding call entry
|
||||
@@ -198,6 +249,14 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_reasoning_length(encoded_len: usize) -> usize {
|
||||
encoded_len
|
||||
.saturating_mul(3)
|
||||
.checked_div(4)
|
||||
.unwrap_or(0)
|
||||
.saturating_sub(650)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "history_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -56,6 +56,17 @@ fn reasoning_msg(text: &str) -> ResponseItem {
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_with_encrypted_content(len: usize) -> ResponseItem {
|
||||
ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "summary".to_string(),
|
||||
}],
|
||||
content: None,
|
||||
encrypted_content: Some("a".repeat(len)),
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_exec_output(content: &str) -> String {
|
||||
truncate::truncate_text(content, TruncationPolicy::Tokens(EXEC_FORMAT_MAX_TOKENS))
|
||||
}
|
||||
@@ -112,6 +123,28 @@ fn filters_non_api_messages() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_last_reasoning_tokens_return_zero_when_no_user_messages() {
|
||||
let history = create_history_with_items(vec![reasoning_with_encrypted_content(800)]);
|
||||
|
||||
assert_eq!(history.get_non_last_reasoning_items_tokens(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_last_reasoning_tokens_ignore_entries_after_last_user() {
|
||||
let history = create_history_with_items(vec![
|
||||
reasoning_with_encrypted_content(900),
|
||||
user_msg("first"),
|
||||
reasoning_with_encrypted_content(1_000),
|
||||
user_msg("second"),
|
||||
reasoning_with_encrypted_content(2_000),
|
||||
]);
|
||||
// first: (900 * 0.75 - 650) / 4 = 6.25 tokens
|
||||
// second: (1000 * 0.75 - 650) / 4 = 25 tokens
|
||||
// first + second = 62.5
|
||||
assert_eq!(history.get_non_last_reasoning_items_tokens(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_history_for_prompt_drops_ghost_commits() {
|
||||
let items = vec![ResponseItem::GhostSnapshot {
|
||||
|
||||
@@ -56,6 +56,10 @@ impl ConversationManager {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn session_source(&self) -> SessionSource {
|
||||
self.session_source.clone()
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
self.spawn_conversation(config, self.auth_manager.clone())
|
||||
.await
|
||||
|
||||
@@ -258,6 +258,11 @@ fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
|
||||
|
||||
/// Create an HTTP client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client() -> CodexHttpClient {
|
||||
let inner = build_reqwest_client();
|
||||
CodexHttpClient::new(inner)
|
||||
}
|
||||
|
||||
pub fn build_reqwest_client() -> reqwest::Client {
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -272,8 +277,7 @@ pub fn create_client() -> CodexHttpClient {
|
||||
builder = builder.no_proxy();
|
||||
}
|
||||
|
||||
let inner = builder.build().unwrap_or_else(|_| reqwest::Client::new());
|
||||
CodexHttpClient::new(inner)
|
||||
builder.build().unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
fn is_sandboxed() -> bool {
|
||||
|
||||
@@ -19,6 +19,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::get_platform_sandbox;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
@@ -127,12 +128,17 @@ pub struct StdoutStream {
|
||||
|
||||
pub async fn process_exec_tool_call(
|
||||
params: ExecParams,
|
||||
sandbox_type: SandboxType,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_cwd: &Path,
|
||||
codex_linux_sandbox_exe: &Option<PathBuf>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let sandbox_type = match &sandbox_policy {
|
||||
SandboxPolicy::DangerFullAccess => SandboxType::None,
|
||||
_ => get_platform_sandbox().unwrap_or(SandboxType::None),
|
||||
};
|
||||
tracing::debug!("Sandbox type: {sandbox_type:?}");
|
||||
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
@@ -893,7 +899,6 @@ mod tests {
|
||||
});
|
||||
let result = process_exec_tool_call(
|
||||
params,
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
cwd.as_path(),
|
||||
&None,
|
||||
|
||||
@@ -107,7 +107,9 @@ fn evaluate_with_policy(
|
||||
})
|
||||
}
|
||||
}
|
||||
Decision::Allow => Some(ApprovalRequirement::Skip),
|
||||
Decision::Allow => Some(ApprovalRequirement::Skip {
|
||||
bypass_sandbox: true,
|
||||
}),
|
||||
},
|
||||
Evaluation::NoMatch { .. } => None,
|
||||
}
|
||||
@@ -132,7 +134,9 @@ pub(crate) fn create_approval_requirement_for_command(
|
||||
) {
|
||||
ApprovalRequirement::NeedsApproval { reason: None }
|
||||
} else {
|
||||
ApprovalRequirement::Skip
|
||||
ApprovalRequirement::Skip {
|
||||
bypass_sandbox: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
// the TUI or the tracing stack).
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
pub mod api_bridge;
|
||||
mod apply_patch;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
@@ -32,6 +32,9 @@ pub mod git_info;
|
||||
pub mod landlock;
|
||||
pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_NOTIFICATION;
|
||||
pub use mcp_connection_manager::SandboxState;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
@@ -39,6 +42,7 @@ pub mod parse_command;
|
||||
pub mod powershell;
|
||||
mod response_processing;
|
||||
pub mod sandboxing;
|
||||
pub mod status;
|
||||
mod text_encoding;
|
||||
pub mod token_data;
|
||||
mod truncate;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user