mirror of
https://github.com/openai/codex.git
synced 2026-02-04 16:03:46 +00:00
Compare commits
139 Commits
daniel/rel
...
patch-tool
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
570639cf98 | ||
|
|
1c50fbb8a7 | ||
|
|
3316d04ed4 | ||
|
|
67a8566f59 | ||
|
|
2d36621f48 | ||
|
|
0a70810fc0 | ||
|
|
b5cf9e09ff | ||
|
|
b2067c73d9 | ||
|
|
13e8771ee9 | ||
|
|
bba567cee9 | ||
|
|
6577197fa4 | ||
|
|
fd1e12f34e | ||
|
|
ba6af23cb6 | ||
|
|
f805d17930 | ||
|
|
9580603fed | ||
|
|
da38a8f56a | ||
|
|
552a438cc9 | ||
|
|
a36a273d4e | ||
|
|
6884c6ccf6 | ||
|
|
1e5a613c55 | ||
|
|
90965fbc84 | ||
|
|
4fee2ca3fd | ||
|
|
3318cf9369 | ||
|
|
5ba0bcf035 | ||
|
|
6d55ef62f9 | ||
|
|
cecf3a82a6 | ||
|
|
c172e8e997 | ||
|
|
9bbeb75361 | ||
|
|
6ccd32c601 | ||
|
|
3b5a5412bb | ||
|
|
44bb53df1e | ||
|
|
9a7266a33f | ||
|
|
2abad8fece | ||
|
|
0d4a25b981 | ||
|
|
8453915e02 | ||
|
|
44587c2443 | ||
|
|
8f7b22b652 | ||
|
|
027944c64e | ||
|
|
bec51f6c05 | ||
|
|
66967500bb | ||
|
|
167b4f0e25 | ||
|
|
167154178b | ||
|
|
674e3d3c90 | ||
|
|
114ce9ff4d | ||
|
|
e13b35ecb0 | ||
|
|
377af75730 | ||
|
|
86e0f31a7e | ||
|
|
8f837f1093 | ||
|
|
162e1235a8 | ||
|
|
c09ed74a16 | ||
|
|
65f3528cad | ||
|
|
44262d8fd8 | ||
|
|
95a9938d3a | ||
|
|
f69f07b028 | ||
|
|
8d766088e6 | ||
|
|
87654ec0b7 | ||
|
|
51d9e05de7 | ||
|
|
8068cc75f8 | ||
|
|
acb28bf914 | ||
|
|
97338de578 | ||
|
|
5200b7a95d | ||
|
|
64e6c4afbb | ||
|
|
39db113cc9 | ||
|
|
45bd5ca4b9 | ||
|
|
c13c3dadbf | ||
|
|
8636bff46d | ||
|
|
43809a454e | ||
|
|
5c48600bb3 | ||
|
|
de6559f2ab | ||
|
|
5bcc9d8b77 | ||
|
|
5eab4c7ab4 | ||
|
|
f656e192bf | ||
|
|
ee5ecae7c0 | ||
|
|
58bb2048ac | ||
|
|
ac8a3155d6 | ||
|
|
ace14e8d36 | ||
|
|
2a76a08a9e | ||
|
|
16309d6b68 | ||
|
|
62bd0e3d9d | ||
|
|
a9c68ea270 | ||
|
|
ac58749bd3 | ||
|
|
79cbd2ab1b | ||
|
|
5eaaf307e1 | ||
|
|
18330c2362 | ||
|
|
4c46490e53 | ||
|
|
5c1416d99b | ||
|
|
0525b48baa | ||
|
|
1f4f9cde8e | ||
|
|
cad37009e1 | ||
|
|
e2b3053b2b | ||
|
|
e47bd33689 | ||
|
|
6b878bea01 | ||
|
|
ca46510fd3 | ||
|
|
6efb52e545 | ||
|
|
d84a799ec0 | ||
|
|
c8fab51372 | ||
|
|
58d77ca4e7 | ||
|
|
0269096229 | ||
|
|
70a6d4b1b4 | ||
|
|
b1d5f7c0bd | ||
|
|
066c6cce02 | ||
|
|
bd65f81e54 | ||
|
|
ba9620aea7 | ||
|
|
45c3b20041 | ||
|
|
6cfc012e9d | ||
|
|
17a80d43c8 | ||
|
|
c11696f6b1 | ||
|
|
5775174ec2 | ||
|
|
ba631e7928 | ||
|
|
db3834733a | ||
|
|
d6182becbe | ||
|
|
323a5cb7e7 | ||
|
|
3f40fbc0a8 | ||
|
|
742feaf40f | ||
|
|
907d3dd348 | ||
|
|
7df9e9c664 | ||
|
|
b795fbe244 | ||
|
|
82ed7bd285 | ||
|
|
1c04e1314d | ||
|
|
bef7ed0ccc | ||
|
|
be23fe1353 | ||
|
|
2073fa7139 | ||
|
|
e60a44cbab | ||
|
|
075e385969 | ||
|
|
aa083b795d | ||
|
|
91708bb031 | ||
|
|
82dfec5b10 | ||
|
|
1e82bf9d98 | ||
|
|
0a83db5512 | ||
|
|
bd4fa85507 | ||
|
|
234c0a0469 | ||
|
|
0f4ae1b5b0 | ||
|
|
2b96f9f569 | ||
|
|
f2036572b6 | ||
|
|
bea64569c1 | ||
|
|
e83c5f429c | ||
|
|
ed0d23d560 | ||
|
|
4ae45a6c8d | ||
|
|
6b83c1c3f3 |
23
.github/workflows/ci.yml
vendored
23
.github/workflows/ci.yml
vendored
@@ -14,33 +14,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.8.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "store_path=$(pnpm store path --silent)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v4
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v5
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.store_path }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
node-version: 22
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Run all tasks using workspace filters
|
||||
|
||||
|
||||
35
.github/workflows/rust-ci.yml
vendored
35
.github/workflows/rust-ci.yml
vendored
@@ -62,6 +62,26 @@ jobs:
|
||||
components: rustfmt
|
||||
- name: cargo fmt
|
||||
run: cargo fmt -- --config imports_granularity=Item --check
|
||||
- name: Verify codegen for mcp-types
|
||||
run: ./mcp-types/check_lib_rs.py
|
||||
|
||||
cargo_shear:
|
||||
name: cargo shear
|
||||
runs-on: ubuntu-24.04
|
||||
needs: changed
|
||||
if: ${{ needs.changed.outputs.codex == 'true' || needs.changed.outputs.workflows == 'true' || github.event_name == 'push' }}
|
||||
defaults:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: cargo-shear
|
||||
version: 1.5.1
|
||||
- name: cargo shear
|
||||
run: cargo shear
|
||||
|
||||
# --- CI to validate on different os/targets --------------------------------
|
||||
lint_build_test:
|
||||
@@ -160,12 +180,17 @@ jobs:
|
||||
find . -name Cargo.toml -mindepth 2 -maxdepth 2 -print0 \
|
||||
| xargs -0 -n1 -I{} bash -c 'cd "$(dirname "{}")" && cargo check --profile ${{ matrix.profile }}'
|
||||
|
||||
- name: cargo test
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: nextest
|
||||
version: 0.9.103
|
||||
|
||||
- name: tests
|
||||
id: test
|
||||
# `cargo test` takes too long for release builds to run them on every PR
|
||||
# Tests take too long for release builds to run them on every PR.
|
||||
if: ${{ matrix.profile != 'release' }}
|
||||
continue-on-error: true
|
||||
run: cargo test --all-features --target ${{ matrix.target }} --profile ${{ matrix.profile }}
|
||||
run: cargo nextest run --all-features --no-fail-fast --target ${{ matrix.target }}
|
||||
env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
@@ -182,7 +207,7 @@ jobs:
|
||||
# --- Gatherer job that you mark as the ONLY required status -----------------
|
||||
results:
|
||||
name: CI results (required)
|
||||
needs: [changed, general, lint_build_test]
|
||||
needs: [changed, general, cargo_shear, lint_build_test]
|
||||
if: always()
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
@@ -190,6 +215,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "general: ${{ needs.general.result }}"
|
||||
echo "shear : ${{ needs.cargo_shear.result }}"
|
||||
echo "matrix : ${{ needs.lint_build_test.result }}"
|
||||
|
||||
# If nothing relevant changed (PR touching only root README, etc.),
|
||||
@@ -201,4 +227,5 @@ jobs:
|
||||
|
||||
# Otherwise require the jobs to have succeeded
|
||||
[[ '${{ needs.general.result }}' == 'success' ]] || { echo 'general failed'; exit 1; }
|
||||
[[ '${{ needs.cargo_shear.result }}' == 'success' ]] || { echo 'cargo_shear failed'; exit 1; }
|
||||
[[ '${{ needs.lint_build_test.result }}' == 'success' ]] || { echo 'matrix failed'; exit 1; }
|
||||
|
||||
19
.github/workflows/rust-release.yml
vendored
19
.github/workflows/rust-release.yml
vendored
@@ -219,3 +219,22 @@ jobs:
|
||||
with:
|
||||
tag: ${{ github.ref_name }}
|
||||
config: .github/dotslash-config.json
|
||||
|
||||
update-branch:
|
||||
name: Update latest-alpha-cli branch
|
||||
permissions:
|
||||
contents: write
|
||||
needs: release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Update latest-alpha-cli branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh api \
|
||||
repos/${GITHUB_REPOSITORY}/git/refs/heads/latest-alpha-cli \
|
||||
-X PATCH \
|
||||
-f sha="${GITHUB_SHA}" \
|
||||
-F force=true
|
||||
|
||||
6
.vscode/extensions.json
vendored
6
.vscode/extensions.json
vendored
@@ -1,5 +1,11 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"rust-lang.rust-analyzer",
|
||||
"tamasfe.even-better-toml",
|
||||
"vadimcn.vscode-lldb",
|
||||
|
||||
// Useful if touching files in .github/workflows, though most
|
||||
// contributors will not be doing that?
|
||||
// "github.vscode-github-actions",
|
||||
]
|
||||
}
|
||||
|
||||
16
AGENTS.md
16
AGENTS.md
@@ -11,7 +11,7 @@ In the codex-rs folder where the rust code lives:
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
1. Run the test for the specific project that was changed. For example, if changes were made in `codex-rs/tui`, run `cargo test -p codex-tui`.
|
||||
2. Once those pass, if any changes were made in common, core, or protocol, run the complete test suite with `cargo test --all-features`.
|
||||
When running interactively, ask the user before running `just fix` and tests to finalize; `just fmt` does not require approval.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
|
||||
## TUI style conventions
|
||||
|
||||
@@ -37,7 +37,15 @@ See `codex-rs/tui/styles.md`.
|
||||
- Avoid churn: don’t refactor between equivalent forms (Span::styled ↔ set_style, Line::from ↔ .into()) without a clear readability or functional gain; follow file‑local conventions and do not introduce type annotations solely to satisfy .into().
|
||||
- Compactness: prefer the form that stays on one line after rustfmt; if only one of Line::from(vec![…]) or vec![…].into() avoids wrapping, choose that. If both wrap, pick the one with fewer wrapped lines.
|
||||
|
||||
## Snapshot tests
|
||||
### Text wrapping
|
||||
- Always use textwrap::wrap to wrap plain strings.
|
||||
- If you have a ratatui Line and you want to wrap it, use the helpers in tui/src/wrapping.rs, e.g. word_wrap_lines / word_wrap_line.
|
||||
- If you need to indent wrapped lines, use the initial_indent / subsequent_indent options from RtOptions if you can, rather than writing custom logic.
|
||||
- If you have a list of lines and you need to prefix them all with some prefix (optionally different on the first vs subsequent lines), use the `prefix_lines` helper from line_utils.
|
||||
|
||||
## Tests
|
||||
|
||||
### Snapshot tests
|
||||
|
||||
This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to validate rendered output. When UI or text output changes intentionally, update the snapshots as follows:
|
||||
|
||||
@@ -52,3 +60,7 @@ This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to va
|
||||
|
||||
If you don’t have the tool:
|
||||
- `cargo install cargo-insta`
|
||||
|
||||
### Test assertions
|
||||
|
||||
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, see <a href="https://chatgpt.com/codex">chatgpt.com/codex</a>.</p>
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
</br>
|
||||
</br>If you want Codex in your code editor (VS Code, Cursor, Windsurf), <a href="https://developers.openai.com/codex/ide">install in your IDE</a>
|
||||
</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, go to <a href="https://chatgpt.com/codex">chatgpt.com/codex</a></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="./.github/codex-cli-splash.png" alt="Codex CLI splash" width="80%" />
|
||||
@@ -75,7 +78,7 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [CLI usage](./docs/getting-started.md#cli-usage)
|
||||
- [Running with a prompt as input](./docs/getting-started.md#running-with-a-prompt-as-input)
|
||||
- [Example prompts](./docs/getting-started.md#example-prompts)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory--project-docs)
|
||||
- [Memory with AGENTS.md](./docs/getting-started.md#memory-with-agentsmd)
|
||||
- [Configuration](./docs/config.md)
|
||||
- [**Sandbox & approvals**](./docs/sandbox.md)
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
|
||||
985
codex-rs/Cargo.lock
generated
985
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ rust = {}
|
||||
|
||||
[workspace.lints.clippy]
|
||||
expect_used = "deny"
|
||||
redundant_clone = "deny"
|
||||
uninlined_format_args = "deny"
|
||||
unwrap_used = "deny"
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ npx @modelcontextprotocol/inspector codex mcp
|
||||
|
||||
You can enable notifications by configuring a script that is run whenever the agent finishes a turn. The [notify documentation](../docs/config.md#notify) includes a detailed example that explains how to get desktop notifications via [terminal-notifier](https://github.com/julienXX/terminal-notifier) on macOS.
|
||||
|
||||
### `codex exec` to run Codex programmatially/non-interactively
|
||||
### `codex exec` to run Codex programmatically/non-interactively
|
||||
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ workspace = true
|
||||
anyhow = "1"
|
||||
similar = "2.7.0"
|
||||
thiserror = "2.0.16"
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
once_cell = "1"
|
||||
|
||||
|
||||
@@ -726,13 +726,15 @@ fn compute_replacements(
|
||||
line_index = start_idx + pattern.len();
|
||||
} else {
|
||||
return Err(ApplyPatchError::ComputeReplacements(format!(
|
||||
"Failed to find expected lines {:?} in {}",
|
||||
chunk.old_lines,
|
||||
path.display()
|
||||
"Failed to find expected lines in {}:\n{}",
|
||||
path.display(),
|
||||
chunk.old_lines.join("\n"),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
replacements.sort_by(|(lhs_idx, _, _), (rhs_idx, _, _)| lhs_idx.cmp(rhs_idx));
|
||||
|
||||
Ok(replacements)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1218,33 @@ PATCH"#,
|
||||
assert_eq!(contents, "a\nB\nc\nd\nE\nf\ng\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pure_addition_chunk_followed_by_removal() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("panic.txt");
|
||||
fs::write(&path, "line1\nline2\nline3\n").unwrap();
|
||||
let patch = wrap_patch(&format!(
|
||||
r#"*** Update File: {}
|
||||
@@
|
||||
+after-context
|
||||
+second-line
|
||||
@@
|
||||
line1
|
||||
-line2
|
||||
-line3
|
||||
+line2-replacement"#,
|
||||
path.display()
|
||||
));
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
apply_patch(&patch, &mut stdout, &mut stderr).unwrap();
|
||||
let contents = fs::read_to_string(path).unwrap();
|
||||
assert_eq!(
|
||||
contents,
|
||||
"line1\nline2-replacement\nafter-context\nsecond-line\n"
|
||||
);
|
||||
}
|
||||
|
||||
/// Ensure that patches authored with ASCII characters can update lines that
|
||||
/// contain typographic Unicode punctuation (e.g. EN DASH, NON-BREAKING
|
||||
/// HYPHEN). Historically `git apply` succeeds in such scenarios but our
|
||||
|
||||
@@ -617,7 +617,7 @@ fn test_parse_patch_lenient() {
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
hunks: expected_patch,
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
@@ -637,7 +637,7 @@ fn test_parse_patch_lenient() {
|
||||
"<<EOF\n*** Begin Patch\n*** Update File: file2.py\nEOF\n".to_string();
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Strict),
|
||||
Err(expected_error.clone())
|
||||
Err(expected_error)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Lenient),
|
||||
|
||||
@@ -21,8 +21,7 @@ const MISSPELLED_APPLY_PATCH_ARG0: &str = "applypatch";
|
||||
/// `codex-linux-sandbox` we *directly* execute
|
||||
/// [`codex_linux_sandbox::run_main`] (which never returns). Otherwise we:
|
||||
///
|
||||
/// 1. Use [`dotenvy::from_path`] and [`dotenvy::dotenv`] to modify the
|
||||
/// environment before creating any threads.
|
||||
/// 1. Load `.env` values from `~/.codex/.env` before creating any threads.
|
||||
/// 2. Construct a Tokio multi-thread runtime.
|
||||
/// 3. Derive the path to the current executable (so children can re-invoke the
|
||||
/// sandbox) when running on Linux.
|
||||
@@ -106,7 +105,7 @@ where
|
||||
|
||||
const ILLEGAL_ENV_VAR_PREFIX: &str = "CODEX_";
|
||||
|
||||
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
|
||||
/// Load env vars from ~/.codex/.env.
|
||||
///
|
||||
/// Security: Do not allow `.env` files to create or modify any variables
|
||||
/// with names starting with `CODEX_`.
|
||||
@@ -116,10 +115,6 @@ fn load_dotenv() {
|
||||
{
|
||||
set_filtered(iter);
|
||||
}
|
||||
|
||||
if let Ok(iter) = dotenvy::dotenv_iter() {
|
||||
set_filtered(iter);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to set vars from a dotenvy iterator while filtering out `CODEX_` keys.
|
||||
|
||||
@@ -11,8 +11,6 @@ anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -31,7 +31,7 @@ pub async fn run_apply_command(
|
||||
ConfigOverrides::default(),
|
||||
)?;
|
||||
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
let task_response = get_task(&config, apply_cli.task_id).await?;
|
||||
apply_diff_from_task(task_response, cwd).await
|
||||
|
||||
@@ -13,10 +13,10 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
path: String,
|
||||
) -> anyhow::Result<T> {
|
||||
let chatgpt_base_url = &config.chatgpt_base_url;
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
// Make direct HTTP request to ChatGPT backend API with the token
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
let client = create_client();
|
||||
let url = format!("{chatgpt_base_url}{path}");
|
||||
|
||||
let token =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
@@ -19,11 +18,8 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
||||
}
|
||||
|
||||
/// Initialize the ChatGPT token from auth.json file
|
||||
pub async fn init_chatgpt_token_from_auth(
|
||||
codex_home: &Path,
|
||||
originator: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT, originator)?;
|
||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::logout;
|
||||
use codex_core::config::Config;
|
||||
@@ -9,7 +8,6 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::run_login_server;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
@@ -60,23 +58,11 @@ pub async fn run_login_with_api_key(
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match CodexAuth::from_codex_home(
|
||||
&config.codex_home,
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
Ok(api_key) => {
|
||||
eprintln!("Logged in using an API key - {}", safe_format_key(&api_key));
|
||||
|
||||
if let Ok(env_api_key) = env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
&& env_api_key == api_key
|
||||
{
|
||||
eprintln!(
|
||||
" API loaded from OPENAI_API_KEY environment variable or .env file"
|
||||
);
|
||||
}
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -37,11 +37,8 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
// Use conversation_manager API to start a conversation
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Returns a string representing the elapsed time since `start_time` like
|
||||
/// "1m15s" or "1.50s".
|
||||
/// "1m 15s" or "1.50s".
|
||||
pub fn format_elapsed(start_time: Instant) -> String {
|
||||
format_duration(start_time.elapsed())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ pub fn format_elapsed(start_time: Instant) -> String {
|
||||
/// Formatting rules:
|
||||
/// * < 1 s -> "{milli}ms"
|
||||
/// * < 60 s -> "{sec:.2}s" (two decimal places)
|
||||
/// * >= 60 s -> "{min}m{sec:02}s"
|
||||
/// * >= 60 s -> "{min}m {sec:02}s"
|
||||
pub fn format_duration(duration: Duration) -> String {
|
||||
let millis = duration.as_millis() as i64;
|
||||
format_elapsed_millis(millis)
|
||||
@@ -26,7 +26,7 @@ fn format_elapsed_millis(millis: i64) -> String {
|
||||
} else {
|
||||
let minutes = millis / 60_000;
|
||||
let seconds = (millis % 60_000) / 1000;
|
||||
format!("{minutes}m{seconds:02}s")
|
||||
format!("{minutes}m {seconds:02}s")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,12 +61,18 @@ mod tests {
|
||||
fn test_format_duration_minutes() {
|
||||
// Durations ≥ 1 minute should be printed mmss.
|
||||
let dur = Duration::from_millis(75_000); // 1m15s
|
||||
assert_eq!(format_duration(dur), "1m15s");
|
||||
assert_eq!(format_duration(dur), "1m 15s");
|
||||
|
||||
let dur_exact = Duration::from_millis(60_000); // 1m0s
|
||||
assert_eq!(format_duration(dur_exact), "1m00s");
|
||||
assert_eq!(format_duration(dur_exact), "1m 00s");
|
||||
|
||||
let dur_long = Duration::from_millis(3_601_000);
|
||||
assert_eq!(format_duration(dur_long), "60m01s");
|
||||
assert_eq!(format_duration(dur_long), "60m 01s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration_one_hour_has_space() {
|
||||
let dur_hour = Duration::from_millis(3_600_000);
|
||||
assert_eq!(format_duration(dur_hour), "60m 00s");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ pub fn builtin_model_presets() -> &'static [ModelPreset] {
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::High,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-high-new",
|
||||
label: "gpt-5 high new",
|
||||
description: "— our latest release tuned to rely on the model's built-in reasoning defaults",
|
||||
model: "gpt-5-high-new",
|
||||
effort: ReasoningEffort::Medium,
|
||||
},
|
||||
];
|
||||
PRESETS
|
||||
}
|
||||
|
||||
@@ -26,14 +26,12 @@ eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
libc = "0.2.175"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0"
|
||||
os_info = "3.12.0"
|
||||
portable-pty = "0.9.0"
|
||||
rand = "0.9"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
@@ -53,10 +51,10 @@ tokio-util = "0.7.16"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.23.4"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
whoami = "1.6.1"
|
||||
which = "6"
|
||||
wildmatch = "2.4.0"
|
||||
|
||||
|
||||
@@ -72,9 +70,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
|
||||
[target.aarch64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { version = "*", features = ["vendored"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
which = "6"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
core_test_support = { path = "tests/common" }
|
||||
@@ -85,3 +80,6 @@ tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
walkdir = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
@@ -14,6 +14,18 @@ Within this context, Codex refers to the open-source agentic coding interface (n
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
@@ -228,7 +240,6 @@ You are producing plain text that will later be styled by the CLI. Follow these
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Bold the keyword, then colon + concise description.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::time::Duration;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
use crate::token_data::PlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
@@ -70,14 +71,9 @@ impl CodexAuth {
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json or
|
||||
/// OPENAI_API_KEY environment variable.
|
||||
pub fn from_codex_home(
|
||||
codex_home: &Path,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home, true, preferred_auth_method, originator)
|
||||
/// Loads the available auth information from the auth.json.
|
||||
pub fn from_codex_home(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
@@ -136,13 +132,12 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.account_id.clone())
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub fn get_plan_type(&self) -> Option<String> {
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type.as_ref().map(|p| p.as_string()))
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
@@ -151,7 +146,7 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens.clone())
|
||||
self.get_current_auth_json().and_then(|t| t.tokens)
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
@@ -173,7 +168,7 @@ impl CodexAuth {
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json,
|
||||
client: crate::default_client::create_client("codex_cli_rs"),
|
||||
client: crate::default_client::create_client(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,19 +183,17 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
pub fn from_api_key(api_key: &str) -> Self {
|
||||
Self::from_api_key_with_client(
|
||||
api_key,
|
||||
crate::default_client::create_client(crate::default_client::DEFAULT_ORIGINATOR),
|
||||
)
|
||||
Self::from_api_key_with_client(api_key, crate::default_client::create_client())
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
fn read_openai_api_key_from_env() -> Option<String> {
|
||||
pub fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
@@ -218,7 +211,7 @@ pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an `auth.json` that contains only the API key. Intended for CLI use.
|
||||
/// Writes an `auth.json` that contains only the API key.
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
@@ -228,29 +221,11 @@ pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<(
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
fn load_auth(
|
||||
codex_home: &Path,
|
||||
include_env_var: bool,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
// First, check to see if there is a valid auth.json file. If not, we fall
|
||||
// back to AuthMode::ApiKey using the OPENAI_API_KEY environment variable
|
||||
// (if it is set).
|
||||
fn load_auth(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let client = crate::default_client::create_client(originator);
|
||||
let client = crate::default_client::create_client();
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
// If auth.json does not exist, try to read the OPENAI_API_KEY from the
|
||||
// environment variable.
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound && include_env_var => {
|
||||
return match read_openai_api_key_from_env() {
|
||||
Some(api_key) => Ok(Some(CodexAuth::from_api_key_with_client(&api_key, client))),
|
||||
None => Ok(None),
|
||||
};
|
||||
}
|
||||
// Though if auth.json exists but is malformed, do not fall back to the
|
||||
// env var because the user may be expecting to use AuthMode::ChatGPT.
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
@@ -262,32 +237,11 @@ fn load_auth(
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// If the auth.json has an API key AND does not appear to be on a plan that
|
||||
// should prefer AuthMode::ChatGPT, use AuthMode::ApiKey.
|
||||
// Prefer AuthMode.ApiKey if it's set in the auth.json.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
// Should any of these be AuthMode::ChatGPT with the api_key set?
|
||||
// Does AuthMode::ChatGPT indicate that there is an auth.json that is
|
||||
// "refreshable" even if we are using the API key for auth?
|
||||
match &tokens {
|
||||
Some(tokens) => {
|
||||
if tokens.should_use_api_key(preferred_auth_method, tokens.is_openai_email()) {
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
} else {
|
||||
// Ignore the API key and fall through to ChatGPT auth.
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// We have an API key but no tokens in the auth.json file.
|
||||
// Perhaps the user ran `codex login --api-key <KEY>` or updated
|
||||
// auth.json by hand. Either way, let's assume they are trying
|
||||
// to use their API key.
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
}
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
|
||||
// For the AuthMode::ChatGPT variant, perhaps neither api_key nor
|
||||
// openai_api_key should exist?
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
@@ -337,10 +291,10 @@ async fn update_tokens(
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token.to_string();
|
||||
tokens.access_token = access_token;
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token.to_string();
|
||||
tokens.refresh_token = refresh_token;
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
@@ -417,7 +371,6 @@ use std::sync::RwLock;
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
preferred_auth_mode: AuthMode,
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
@@ -473,9 +426,7 @@ mod tests {
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
} = super::load_auth(codex_home.path()).unwrap().unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
@@ -504,88 +455,6 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
/// Even if the OPENAI_API_KEY is set in auth.json, if the plan is not in
|
||||
/// [`TokenData::is_plan_that_should_use_api_key`], it should use
|
||||
/// [`AuthMode::ChatGPT`].
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_api_key_still_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// If the OPENAI_API_KEY is set in auth.json and it is an enterprise
|
||||
/// account, then it should use [`AuthMode::ApiKey`].
|
||||
#[tokio::test]
|
||||
async fn enterprise_account_with_api_key_uses_apikey_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "enterprise".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(Some("sk-test-key".to_string()), api_key);
|
||||
assert_eq!(AuthMode::ApiKey, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().expect("should unwrap");
|
||||
assert!(guard.is_none(), "auth_dot_json should be None");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
@@ -596,9 +465,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = super::load_auth(dir.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let auth = super::load_auth(dir.path()).unwrap().unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
@@ -680,7 +547,6 @@ mod tests {
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
originator: String,
|
||||
inner: RwLock<CachedAuth>,
|
||||
}
|
||||
|
||||
@@ -689,30 +555,19 @@ impl AuthManager {
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode, originator: String) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home, preferred_auth_mode, &originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
pub fn new(codex_home: PathBuf) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home).ok().flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
originator,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth,
|
||||
}),
|
||||
inner: RwLock::new(CachedAuth { auth }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let preferred_auth_mode = auth.mode;
|
||||
let cached = CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth: Some(auth),
|
||||
};
|
||||
let cached = CachedAuth { auth: Some(auth) };
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
originator: "codex_cli_rs".to_string(),
|
||||
inner: RwLock::new(cached),
|
||||
})
|
||||
}
|
||||
@@ -722,21 +577,10 @@ impl AuthManager {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Preferred auth method used when (re)loading.
|
||||
pub fn preferred_auth_method(&self) -> AuthMode {
|
||||
self.inner
|
||||
.read()
|
||||
.map(|c| c.preferred_auth_mode)
|
||||
.unwrap_or(AuthMode::ApiKey)
|
||||
}
|
||||
|
||||
/// Force a reload using the existing preferred auth method. Returns
|
||||
/// Force a reload of the auth information from auth.json. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let preferred = self.preferred_auth_method();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home, preferred, &self.originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home).ok().flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
@@ -755,12 +599,8 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(
|
||||
codex_home: PathBuf,
|
||||
preferred_auth_mode: AuthMode,
|
||||
originator: String,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home, preferred_auth_mode, originator))
|
||||
pub fn shared(codex_home: PathBuf) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::time::Duration;
|
||||
use crate::AuthManager;
|
||||
use bytes::Bytes;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::prelude::*;
|
||||
use regex_lite::Regex;
|
||||
@@ -19,7 +20,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -41,6 +41,7 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -60,7 +61,7 @@ struct Error {
|
||||
message: Option<String>,
|
||||
|
||||
// Optional fields available on "usage_limit_reached" and "usage_not_included" errors
|
||||
plan_type: Option<String>,
|
||||
plan_type: Option<PlanType>,
|
||||
resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -70,7 +71,7 @@ pub struct ModelClient {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
@@ -82,16 +83,16 @@ impl ModelClient {
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> Self {
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
let client = create_client();
|
||||
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
client,
|
||||
provider,
|
||||
session_id,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -157,14 +158,6 @@ impl ModelClient {
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
|
||||
let auth_mode = auth_manager
|
||||
.as_ref()
|
||||
.and_then(|m| m.auth())
|
||||
.as_ref()
|
||||
.map(|a| a.mode);
|
||||
|
||||
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
|
||||
let tools_json = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
let reasoning = create_reasoning_param_for_request(
|
||||
@@ -173,9 +166,7 @@ impl ModelClient {
|
||||
self.summary,
|
||||
);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
let include: Vec<String> = if !store && reasoning.is_some() {
|
||||
let include: Vec<String> = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
@@ -204,10 +195,10 @@ impl ModelClient {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
store,
|
||||
store: false,
|
||||
stream: true,
|
||||
include,
|
||||
prompt_cache_key: Some(self.session_id.to_string()),
|
||||
prompt_cache_key: Some(self.conversation_id.to_string()),
|
||||
text,
|
||||
};
|
||||
|
||||
@@ -233,7 +224,9 @@ impl ModelClient {
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
// Send session_id for compatibility.
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
@@ -247,10 +240,10 @@ impl ModelClient {
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
@@ -312,7 +305,7 @@ impl ModelClient {
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.and_then(|a| a.get_plan_type()));
|
||||
.or_else(|| auth.as_ref().and_then(|a| a.get_plan_type()));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
@@ -408,9 +401,15 @@ impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val.input_tokens_details.map(|d| d.cached_tokens),
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val.output_tokens_details.map(|d| d.reasoning_tokens),
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
@@ -1039,4 +1038,37 @@ mod tests {
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_known_plan_type_and_serializes_back() {
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json = r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_in_seconds":3600}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_unknown_plan_type_and_serializes_back() {
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json =
|
||||
r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_in_seconds":60}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
@@ -20,19 +19,12 @@ use tokio::sync::mpsc;
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
|
||||
/// wraps user instructions message in a tag for the model to parse more easily.
|
||||
const USER_INSTRUCTIONS_START: &str = "<user_instructions>\n\n";
|
||||
const USER_INSTRUCTIONS_END: &str = "\n\n</user_instructions>";
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<OpenAiTool>,
|
||||
@@ -68,17 +60,6 @@ impl Prompt {
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
}
|
||||
|
||||
/// Creates a formatted user instructions message from a string
|
||||
pub(crate) fn format_user_instructions_message(ui: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -144,7 +125,6 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
@@ -215,7 +195,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
@@ -245,7 +225,7 @@ mod tests {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use crate::config_types::History;
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::config_types::SandboxWorkspaceWrite;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::config_types::ShellEnvironmentPolicyToml;
|
||||
@@ -14,11 +15,13 @@ use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use anyhow::Context;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::Tools;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@@ -29,15 +32,14 @@ use toml::Value as TomlValue;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
pub const GPT5_HIGH_MODEL: &str = "gpt-5-high";
|
||||
|
||||
/// Maximum number of bytes of the documentation that will be embedded. Larger
|
||||
/// files are *silently truncated* to this size so we do not take up too much of
|
||||
/// the context window.
|
||||
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
|
||||
|
||||
const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
const DEFAULT_RESPONSES_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
pub(crate) const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
/// Application configuration loaded from disk and merged with overrides.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -75,11 +77,6 @@ pub struct Config {
|
||||
/// Defaults to `false`.
|
||||
pub show_raw_agent_reasoning: bool,
|
||||
|
||||
/// Disable server-side response storage (sends the full conversation
|
||||
/// context with every request). Currently necessary for OpenAI customers
|
||||
/// who have opted into Zero Data Retention (ZDR).
|
||||
pub disable_response_storage: bool,
|
||||
|
||||
/// User-provided instructions from AGENTS.md.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
@@ -133,9 +130,6 @@ pub struct Config {
|
||||
/// output will be hyperlinked using the specified URI scheme.
|
||||
pub file_opener: UriBasedFileOpener,
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
pub tui: Tui,
|
||||
|
||||
/// Path to the `codex-linux-sandbox` executable. This must be set if
|
||||
/// [`crate::exec::SandboxType::LinuxSeccomp`] is used. Note that this
|
||||
/// cannot be set in the config file: it must be set in code via
|
||||
@@ -171,22 +165,21 @@ pub struct Config {
|
||||
|
||||
pub tools_web_search_request: bool,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub responses_originator_header: String,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: AuthMode,
|
||||
|
||||
pub use_experimental_streamable_shell_tool: bool,
|
||||
|
||||
/// If set to `true`, used only the experimental unified exec tool.
|
||||
pub use_experimental_unified_exec_tool: bool,
|
||||
|
||||
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
||||
pub include_view_image_tool: bool,
|
||||
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
pub disable_paste_burst: bool,
|
||||
|
||||
pub use_experimental_reasoning_summary: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -266,17 +259,7 @@ pub fn load_config_as_toml(codex_home: &Path) -> std::io::Result<TomlValue> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
fn set_project_trusted_inner(doc: &mut DocumentMut, project_path: &Path) -> anyhow::Result<()> {
|
||||
// Ensure we render a human-friendly structure:
|
||||
//
|
||||
// [projects]
|
||||
@@ -292,14 +275,26 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
// Ensure top-level `projects` exists as a non-inline, explicit table. If it
|
||||
// exists but was previously represented as a non-table (e.g., inline),
|
||||
// replace it with an explicit table.
|
||||
let mut created_projects_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("projects")
|
||||
|| root.get("projects").and_then(|i| i.as_table()).is_none();
|
||||
if needs_table {
|
||||
root.insert("projects", toml_edit::table());
|
||||
created_projects_table = true;
|
||||
// If `projects` exists but isn't a standard table (e.g., it's an inline table),
|
||||
// convert it to an explicit table while preserving existing entries.
|
||||
let existing_projects = root.get("projects").cloned();
|
||||
if existing_projects.as_ref().is_none_or(|i| !i.is_table()) {
|
||||
let mut projects_tbl = toml_edit::Table::new();
|
||||
projects_tbl.set_implicit(true);
|
||||
|
||||
// If there was an existing inline table, migrate its entries to explicit tables.
|
||||
if let Some(inline_tbl) = existing_projects.as_ref().and_then(|i| i.as_inline_table()) {
|
||||
for (k, v) in inline_tbl.iter() {
|
||||
if let Some(inner_tbl) = v.as_inline_table() {
|
||||
let new_tbl = inner_tbl.clone().into_table();
|
||||
projects_tbl.insert(k, toml_edit::Item::Table(new_tbl));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
root.insert("projects", toml_edit::Item::Table(projects_tbl));
|
||||
}
|
||||
}
|
||||
let Some(projects_tbl) = doc["projects"].as_table_mut() else {
|
||||
@@ -308,12 +303,6 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
));
|
||||
};
|
||||
|
||||
// If we created the `projects` table ourselves, keep it implicit so we
|
||||
// don't render a standalone `[projects]` header.
|
||||
if created_projects_table {
|
||||
projects_tbl.set_implicit(true);
|
||||
}
|
||||
|
||||
// Ensure the per-project entry is its own explicit table. If it exists but
|
||||
// is not a table (e.g., an inline table), replace it with an explicit table.
|
||||
let needs_proj_table = !projects_tbl.contains_key(project_key.as_str())
|
||||
@@ -332,6 +321,21 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
};
|
||||
proj_tbl.set_implicit(false);
|
||||
proj_tbl["trust_level"] = toml_edit::value("trusted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
set_project_trusted_inner(&mut doc, project_path)?;
|
||||
|
||||
// ensure codex_home exists
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
@@ -346,6 +350,107 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
) -> anyhow::Result<&'a mut toml_edit::Table> {
|
||||
let mut created_profiles_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("profiles")
|
||||
|| root
|
||||
.get("profiles")
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_table {
|
||||
root.insert("profiles", toml_edit::table());
|
||||
created_profiles_table = true;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(profiles_table) = doc["profiles"].as_table_mut() else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"profiles table missing after initialization"
|
||||
));
|
||||
};
|
||||
|
||||
if created_profiles_table {
|
||||
profiles_table.set_implicit(true);
|
||||
}
|
||||
|
||||
let needs_profile_table = !profiles_table.contains_key(profile_name)
|
||||
|| profiles_table
|
||||
.get(profile_name)
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_profile_table {
|
||||
profiles_table.insert(profile_name, toml_edit::table());
|
||||
}
|
||||
|
||||
let Some(profile_table) = profiles_table
|
||||
.get_mut(profile_name)
|
||||
.and_then(|item| item.as_table_mut())
|
||||
else {
|
||||
return Err(anyhow::anyhow!(format!(
|
||||
"profile table missing for {profile_name}"
|
||||
)));
|
||||
};
|
||||
|
||||
profile_table.set_implicit(false);
|
||||
Ok(profile_table)
|
||||
}
|
||||
|
||||
// TODO(jif) refactor config persistence.
|
||||
pub async fn persist_model_selection(
|
||||
codex_home: &Path,
|
||||
active_profile: Option<&str>,
|
||||
model: &str,
|
||||
effort: Option<ReasoningEffort>,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let serialized = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => String::new(),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
let mut doc = if serialized.is_empty() {
|
||||
DocumentMut::new()
|
||||
} else {
|
||||
serialized.parse::<DocumentMut>()?
|
||||
};
|
||||
|
||||
if let Some(profile_name) = active_profile {
|
||||
let profile_table = ensure_profile_table(&mut doc, profile_name)?;
|
||||
profile_table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
profile_table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
} else {
|
||||
let table = doc.as_table_mut();
|
||||
table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) refactor the home creation
|
||||
tokio::fs::create_dir_all(codex_home)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to create Codex home directory at {}",
|
||||
codex_home.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
tokio::fs::write(&config_path, doc.to_string())
|
||||
.await
|
||||
.with_context(|| format!("failed to persist config.toml at {}", config_path.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply a single dotted-path override onto a TOML value.
|
||||
fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
use toml::value::Table;
|
||||
@@ -390,7 +495,7 @@ fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
}
|
||||
|
||||
/// Base config deserialized from ~/.codex/config.toml.
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ConfigToml {
|
||||
/// Optional override of model selection.
|
||||
pub model: Option<String>,
|
||||
@@ -416,11 +521,6 @@ pub struct ConfigToml {
|
||||
/// Sandbox configuration to apply if `sandbox` is `WorkspaceWrite`.
|
||||
pub sandbox_workspace_write: Option<SandboxWorkspaceWrite>,
|
||||
|
||||
/// Disable server-side response storage (sends the full conversation
|
||||
/// context with every request). Currently necessary for OpenAI customers
|
||||
/// who have opted into Zero Data Retention (ZDR).
|
||||
pub disable_response_storage: Option<bool>,
|
||||
|
||||
/// Optional external command to spawn for end-user notifications.
|
||||
#[serde(default)]
|
||||
pub notify: Option<Vec<String>>,
|
||||
@@ -473,6 +573,9 @@ pub struct ConfigToml {
|
||||
/// Override to force-enable reasoning summaries for the configured model.
|
||||
pub model_supports_reasoning_summaries: Option<bool>,
|
||||
|
||||
/// Override to force reasoning summary format for the configured model.
|
||||
pub model_reasoning_summary_format: Option<ReasoningSummaryFormat>,
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
@@ -483,17 +586,10 @@ pub struct ConfigToml {
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
|
||||
pub use_experimental_reasoning_summary: Option<bool>,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub responses_originator_header_internal_override: Option<String>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: Option<AuthMode>,
|
||||
|
||||
/// Nested tools section for feature toggles
|
||||
pub tools: Option<ToolsToml>,
|
||||
|
||||
@@ -503,12 +599,35 @@ pub struct ConfigToml {
|
||||
pub disable_paste_burst: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ConfigToml> for UserSavedConfig {
|
||||
fn from(config_toml: ConfigToml) -> Self {
|
||||
let profiles = config_toml
|
||||
.profiles
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
approval_policy: config_toml.approval_policy,
|
||||
sandbox_mode: config_toml.sandbox_mode,
|
||||
sandbox_settings: config_toml.sandbox_workspace_write.map(From::from),
|
||||
model: config_toml.model,
|
||||
model_reasoning_effort: config_toml.model_reasoning_effort,
|
||||
model_reasoning_summary: config_toml.model_reasoning_summary,
|
||||
model_verbosity: config_toml.model_verbosity,
|
||||
tools: config_toml.tools.map(From::from),
|
||||
profile: config_toml.profile,
|
||||
profiles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ProjectConfig {
|
||||
pub trust_level: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ToolsToml {
|
||||
#[serde(default, alias = "web_search_request")]
|
||||
pub web_search: Option<bool>,
|
||||
@@ -518,6 +637,15 @@ pub struct ToolsToml {
|
||||
pub view_image: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ToolsToml> for Tools {
|
||||
fn from(tools_toml: ToolsToml) -> Self {
|
||||
Self {
|
||||
web_search: tools_toml.web_search,
|
||||
view_image: tools_toml.view_image,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
/// Derive the effective sandbox policy from the configuration.
|
||||
fn derive_sandbox_policy(&self, sandbox_mode_override: Option<SandboxMode>) -> SandboxPolicy {
|
||||
@@ -606,7 +734,6 @@ pub struct ConfigOverrides {
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub include_view_image_tool: Option<bool>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub show_raw_agent_reasoning: Option<bool>,
|
||||
pub tools_web_search_request: Option<bool>,
|
||||
}
|
||||
@@ -634,12 +761,15 @@ impl Config {
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_view_image_tool,
|
||||
disable_response_storage,
|
||||
show_raw_agent_reasoning,
|
||||
tools_web_search_request: override_tools_web_search_request,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
let active_profile_name = config_profile_key
|
||||
.as_ref()
|
||||
.or(cfg.profile.as_ref())
|
||||
.cloned();
|
||||
let config_profile = match active_profile_name.as_ref() {
|
||||
Some(key) => cfg
|
||||
.profiles
|
||||
.get(key)
|
||||
@@ -710,19 +840,24 @@ impl Config {
|
||||
.or(config_profile.model)
|
||||
.or(cfg.model)
|
||||
.unwrap_or_else(default_model);
|
||||
let model_family = find_family_for_model(&model).unwrap_or_else(|| {
|
||||
let supports_reasoning_summaries =
|
||||
cfg.model_supports_reasoning_summaries.unwrap_or(false);
|
||||
ModelFamily {
|
||||
slug: model.clone(),
|
||||
family: model.clone(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
}
|
||||
|
||||
let mut model_family = find_family_for_model(&model).unwrap_or_else(|| ModelFamily {
|
||||
slug: model.clone(),
|
||||
family: model.clone(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
});
|
||||
|
||||
if let Some(supports_reasoning_summaries) = cfg.model_supports_reasoning_summaries {
|
||||
model_family.supports_reasoning_summaries = supports_reasoning_summaries;
|
||||
}
|
||||
if let Some(model_reasoning_summary_format) = cfg.model_reasoning_summary_format {
|
||||
model_family.reasoning_summary_format = model_reasoning_summary_format;
|
||||
}
|
||||
|
||||
let openai_model_info = get_model_info(&model_family);
|
||||
let model_context_window = cfg
|
||||
.model_context_window
|
||||
@@ -746,10 +881,6 @@ impl Config {
|
||||
Self::get_base_instructions(experimental_instructions_path, &resolved_cwd)?;
|
||||
let base_instructions = base_instructions.or(file_base_instructions);
|
||||
|
||||
let responses_originator_header: String = cfg
|
||||
.responses_originator_header_internal_override
|
||||
.unwrap_or(DEFAULT_RESPONSES_ORIGINATOR_HEADER.to_owned());
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_family,
|
||||
@@ -764,11 +895,6 @@ impl Config {
|
||||
.unwrap_or_else(AskForApproval::default),
|
||||
sandbox_policy,
|
||||
shell_environment_policy,
|
||||
disable_response_storage: config_profile
|
||||
.disable_response_storage
|
||||
.or(cfg.disable_response_storage)
|
||||
.or(disable_response_storage)
|
||||
.unwrap_or(false),
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -778,7 +904,6 @@ impl Config {
|
||||
codex_home,
|
||||
history,
|
||||
file_opener: cfg.file_opener.unwrap_or(UriBasedFileOpener::VsCode),
|
||||
tui: cfg.tui.unwrap_or_default(),
|
||||
codex_linux_sandbox_exe,
|
||||
|
||||
hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false),
|
||||
@@ -804,16 +929,15 @@ impl Config {
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
responses_originator_header,
|
||||
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
use_experimental_reasoning_summary: cfg
|
||||
.use_experimental_reasoning_summary
|
||||
use_experimental_unified_exec_tool: cfg
|
||||
.experimental_use_unified_exec_tool
|
||||
.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -923,6 +1047,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
@@ -1013,6 +1138,145 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_overwrites_existing_model() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.dev]
|
||||
model = "gpt-4.1"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"o4-mini",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("o4-mini"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-4.1"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::Low),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
let profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("profile should be created");
|
||||
|
||||
assert_eq!(profile.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(profile.model_reasoning_effort, Some(ReasoningEffort::Low));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_existing_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[profiles.dev]
|
||||
model = "gpt-4"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.prod]
|
||||
model = "gpt-5"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"o4-high",
|
||||
Some(ReasoningEffort::Medium),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
let dev_profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("dev profile should survive updates");
|
||||
assert_eq!(dev_profile.model.as_deref(), Some("o4-high"));
|
||||
assert_eq!(
|
||||
dev_profile.model_reasoning_effort,
|
||||
Some(ReasoningEffort::Medium)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("prod")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-5"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct PrecedenceTestFixture {
|
||||
cwd: TempDir,
|
||||
codex_home: TempDir,
|
||||
@@ -1036,7 +1300,6 @@ exclude_slash_tmp = true
|
||||
let toml = r#"
|
||||
model = "o3"
|
||||
approval_policy = "untrusted"
|
||||
disable_response_storage = false
|
||||
|
||||
# Can be used to determine which profile to use if not specified by
|
||||
# `ConfigOverrides`.
|
||||
@@ -1066,7 +1329,6 @@ model_provider = "openai-chat-completions"
|
||||
model = "o3"
|
||||
model_provider = "openai"
|
||||
approval_policy = "on-failure"
|
||||
disable_response_storage = true
|
||||
|
||||
[profiles.gpt5]
|
||||
model = "gpt-5"
|
||||
@@ -1164,7 +1426,6 @@ model_verbosity = "high"
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1174,7 +1435,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1187,12 +1447,11 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
use_experimental_reasoning_summary: false,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -1223,7 +1482,6 @@ model_verbosity = "high"
|
||||
approval_policy: AskForApproval::UnlessTrusted,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1233,7 +1491,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1246,12 +1503,11 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
use_experimental_reasoning_summary: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -1297,7 +1553,6 @@ model_verbosity = "high"
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: true,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1307,7 +1562,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1320,12 +1574,11 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
disable_paste_burst: false,
|
||||
use_experimental_reasoning_summary: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
@@ -1350,14 +1603,13 @@ model_verbosity = "high"
|
||||
let expected_gpt5_profile_config = Config {
|
||||
model: "gpt-5".to_string(),
|
||||
model_family: find_family_for_model("gpt-5").expect("known model slug"),
|
||||
model_context_window: Some(400_000),
|
||||
model_context_window: Some(272_000),
|
||||
model_max_output_tokens: Some(128_000),
|
||||
model_provider_id: "openai".to_string(),
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
@@ -1367,7 +1619,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1380,12 +1631,11 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
disable_paste_burst: false,
|
||||
use_experimental_reasoning_summary: false,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt5_profile_config, gpt5_profile_config);
|
||||
@@ -1395,17 +1645,14 @@ model_verbosity = "high"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_writes_explicit_tables() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
let mut doc = DocumentMut::new();
|
||||
|
||||
// Call the function under test
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
// Read back the generated config.toml and assert exact contents
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1423,12 +1670,10 @@ trust_level = "trusted"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_converts_inline_to_explicit() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
|
||||
// Seed config.toml with an inline project entry under [projects]
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1440,13 +1685,12 @@ trust_level = "trusted"
|
||||
{path_str} = {{ trust_level = "untrusted" }}
|
||||
"#
|
||||
);
|
||||
std::fs::create_dir_all(codex_home.path())?;
|
||||
std::fs::write(&config_path, initial)?;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Run the function; it should convert to explicit tables and set trusted
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Assert exact output after conversion to explicit table
|
||||
let expected = format!(
|
||||
@@ -1461,5 +1705,37 @@ trust_level = "trusted"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// No test enforcing the presence of a standalone [projects] header.
|
||||
#[test]
|
||||
fn test_set_project_trusted_migrates_top_level_inline_projects_preserving_entries()
|
||||
-> anyhow::Result<()> {
|
||||
let initial = r#"toplevel = "baz"
|
||||
projects = { "/Users/mbolin/code/codex4" = { trust_level = "trusted", foo = "bar" } , "/Users/mbolin/code/codex3" = { trust_level = "trusted" } }
|
||||
model = "foo""#;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Approve a new directory
|
||||
let new_project = Path::new("/Users/mbolin/code/codex2");
|
||||
set_project_trusted_inner(&mut doc, new_project)?;
|
||||
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Since we created the [projects] table as part of migration, it is kept implicit.
|
||||
// Expect explicit per-project tables, preserving prior entries and appending the new one.
|
||||
let expected = r#"toplevel = "baz"
|
||||
model = "foo"
|
||||
|
||||
[projects."/Users/mbolin/code/codex4"]
|
||||
trust_level = "trusted"
|
||||
foo = "bar"
|
||||
|
||||
[projects."/Users/mbolin/code/codex3"]
|
||||
trust_level = "trusted"
|
||||
|
||||
[projects."/Users/mbolin/code/codex2"]
|
||||
trust_level = "trusted"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
582
codex-rs/core/src/config_edit.rs
Normal file
582
codex-rs/core/src/config_edit.rs
Normal file
@@ -0,0 +1,582 @@
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use tempfile::NamedTempFile;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
pub const CONFIG_KEY_MODEL: &str = "model";
|
||||
pub const CONFIG_KEY_EFFORT: &str = "model_reasoning_effort";
|
||||
|
||||
/// Persist overrides into `config.toml` using explicit key segments per
|
||||
/// override. This avoids ambiguity with keys that contain dots or spaces.
|
||||
pub async fn persist_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], &str)],
|
||||
) -> Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
let mut doc = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
tokio::fs::create_dir_all(codex_home).await?;
|
||||
DocumentMut::new()
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let effective_profile = if let Some(p) = profile {
|
||||
Some(p.to_owned())
|
||||
} else {
|
||||
doc.get("profile")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(|s| s.to_string())
|
||||
};
|
||||
|
||||
for (segments, val) in overrides.iter().copied() {
|
||||
let value = toml_edit::value(val);
|
||||
if let Some(ref name) = effective_profile {
|
||||
if segments.first().copied() == Some("profiles") {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
} else {
|
||||
let mut seg_buf: Vec<&str> = Vec::with_capacity(2 + segments.len());
|
||||
seg_buf.push("profiles");
|
||||
seg_buf.push(name.as_str());
|
||||
seg_buf.extend_from_slice(segments);
|
||||
apply_toml_edit_override_segments(&mut doc, &seg_buf, value);
|
||||
}
|
||||
} else {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
}
|
||||
}
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
tokio::fs::write(tmp_file.path(), doc.to_string()).await?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist overrides where values may be optional. Any entries with `None`
|
||||
/// values are skipped. If all values are `None`, this becomes a no-op and
|
||||
/// returns `Ok(())` without touching the file.
|
||||
pub async fn persist_non_null_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
) -> Result<()> {
|
||||
let filtered: Vec<(&[&str], &str)> = overrides
|
||||
.iter()
|
||||
.filter_map(|(k, v)| v.map(|vv| (*k, vv)))
|
||||
.collect();
|
||||
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
persist_overrides(codex_home, profile, &filtered).await
|
||||
}
|
||||
|
||||
/// Apply a single override onto a `toml_edit` document while preserving
|
||||
/// existing formatting/comments.
|
||||
/// The key is expressed as explicit segments to correctly handle keys that
|
||||
/// contain dots or spaces.
|
||||
fn apply_toml_edit_override_segments(
|
||||
doc: &mut DocumentMut,
|
||||
segments: &[&str],
|
||||
value: toml_edit::Item,
|
||||
) {
|
||||
use toml_edit::Item;
|
||||
|
||||
if segments.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut current = doc.as_table_mut();
|
||||
for seg in &segments[..segments.len() - 1] {
|
||||
if !current.contains_key(seg) {
|
||||
current[*seg] = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = current[*seg].as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let maybe_item = current.get_mut(seg);
|
||||
let Some(item) = maybe_item else { return };
|
||||
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = item.as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tbl) = item.as_table_mut() else {
|
||||
return;
|
||||
};
|
||||
current = tbl;
|
||||
}
|
||||
|
||||
let last = segments[segments.len() - 1];
|
||||
current[last] = value;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Verifies model and effort are written at top-level when no profile is set.
|
||||
#[tokio::test]
|
||||
async fn set_default_model_and_effort_top_level_when_no_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "high"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies values are written under the active profile when `profile` is set.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_when_profile_set() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile selection but without profiles table
|
||||
let seed = "profile = \"o3\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "o3"
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies profile names with dots/spaces are preserved via explicit segments.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_with_dot_and_space() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile name that contains a dot and a space
|
||||
let seed = "profile = \"my.team name\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "my.team name"
|
||||
|
||||
[profiles."my.team name"]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies explicit profile override writes under that profile even without active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_when_profile_override_supplied() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No profile key in config.toml
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), "")
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Persist with an explicit profile override
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("o3"),
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies nested tables are created as needed when applying overrides.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_creates_nested_tables() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&["a", "b", "c"], "v"),
|
||||
(&["x"], "y"),
|
||||
(&["profiles", "p1", CONFIG_KEY_MODEL], "gpt-5"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"x = "y"
|
||||
|
||||
[a.b]
|
||||
c = "v"
|
||||
|
||||
[profiles.p1]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies a scalar key becomes a table when nested keys are written.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_replaces_scalar_with_table() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
let seed = "foo = \"bar\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&["foo", "bar", "baz"], "ok")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[foo.bar]
|
||||
baz = "ok"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing under active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config with comments and spacing we expect to preserve
|
||||
let seed = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Apply defaults; since profile is set, it should write under [profiles.o3]
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing at top level.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_global_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config WITHOUT a profile, containing comments and spacing
|
||||
let seed = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Since there is no profile, the defaults should be written at top-level
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies errors on invalid TOML propagate and file is not clobbered.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_errors_on_parse_failure() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Write an intentionally invalid TOML file
|
||||
let invalid = "invalid = [unclosed";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), invalid)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Attempting to persist should return an error and must not clobber the file.
|
||||
let res = persist_overrides(codex_home, None, &[(&["x"], "y")]).await;
|
||||
assert!(res.is_err(), "expected parse error to propagate");
|
||||
|
||||
// File should be unchanged
|
||||
let contents = read_config(codex_home).await;
|
||||
assert_eq!(contents, invalid);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_existing_effort_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an effort value only
|
||||
let seed = "model_reasoning_effort = \"minimal\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the model
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o3")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model_reasoning_effort = "minimal"
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_existing_model_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with a model value only
|
||||
let seed = "model = \"gpt-5\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the effort
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_EFFORT], "high")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort in active profile.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_effort_in_active_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an active profile and an existing effort under that profile
|
||||
let seed = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o4-mini")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
model = "o4-mini"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model in a profile override.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_model_in_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No active profile key; we'll target an explicit override
|
||||
let seed = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[(&[CONFIG_KEY_EFFORT], "minimal")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies `persist_non_null_overrides` skips `None` entries and writes only present values at top-level.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_skips_none_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("gpt-5")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = "model = \"gpt-5\"\n";
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies no-op behavior when all provided overrides are `None` (no file created/modified).
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_noop_when_all_none() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&["a"], None), (&["profiles", "p", "x"], None)],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
// Should not create config.toml on a pure no-op
|
||||
assert!(!codex_home.join(CONFIG_TOML_FILE).exists());
|
||||
}
|
||||
|
||||
/// Verifies entries are written under the specified profile and `None` entries are skipped.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_respects_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("o3")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
// Test helper moved to bottom per review guidance.
|
||||
async fn read_config(codex_home: &Path) -> String {
|
||||
let p = codex_home.join(CONFIG_TOML_FILE);
|
||||
tokio::fs::read_to_string(p).await.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
@@ -15,10 +15,23 @@ pub struct ConfigProfile {
|
||||
/// [`ModelProviderInfo`] to use.
|
||||
pub model_provider: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl From<ConfigProfile> for codex_protocol::mcp_protocol::Profile {
|
||||
fn from(config_profile: ConfigProfile) -> Self {
|
||||
Self {
|
||||
model: config_profile.model,
|
||||
model_provider: config_profile.model_provider,
|
||||
approval_policy: config_profile.approval_policy,
|
||||
model_reasoning_effort: config_profile.model_reasoning_effort,
|
||||
model_reasoning_summary: config_profile.model_reasoning_summary,
|
||||
model_verbosity: config_profile.model_verbosity,
|
||||
chatgpt_base_url: config_profile.chatgpt_base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,10 @@ pub struct McpServerConfig {
|
||||
|
||||
#[serde(default)]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
|
||||
/// Startup timeout in milliseconds for initializing MCP server & initially listing tools.
|
||||
#[serde(default)]
|
||||
pub startup_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||
@@ -88,6 +92,17 @@ pub struct SandboxWorkspaceWrite {
|
||||
pub exclude_slash_tmp: bool,
|
||||
}
|
||||
|
||||
impl From<SandboxWorkspaceWrite> for codex_protocol::mcp_protocol::SandboxSettings {
|
||||
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
|
||||
Self {
|
||||
writable_roots: sandbox_workspace_write.writable_roots,
|
||||
network_access: Some(sandbox_workspace_write.network_access),
|
||||
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
|
||||
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ShellEnvironmentPolicyInherit {
|
||||
@@ -183,3 +198,11 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReasoningSummaryFormat {
|
||||
#[default]
|
||||
None,
|
||||
Experimental,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
/// Transcript of conversation history
|
||||
@@ -59,6 +60,26 @@ impl ConversationHistory {
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
}
|
||||
|
||||
pub(crate) fn last_agent_message(&self) -> String {
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item
|
||||
&& role == "assistant"
|
||||
{
|
||||
return content
|
||||
.iter()
|
||||
.find_map(|ci| {
|
||||
if let ContentItem::OutputText { text } = ci {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Anything that is not a system message or "reasoning" message is considered
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::CodexSpawnOk;
|
||||
use crate::codex::INITIAL_SUBMIT_ID;
|
||||
@@ -18,18 +11,19 @@ use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum InitialHistory {
|
||||
New,
|
||||
Resumed(Vec<ResponseItem>),
|
||||
}
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
pub conversation_id: Uuid,
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
@@ -37,7 +31,7 @@ pub struct NewConversation {
|
||||
/// [`ConversationManager`] is responsible for creating conversations and
|
||||
/// maintaining them in memory.
|
||||
pub struct ConversationManager {
|
||||
conversations: Arc<RwLock<HashMap<Uuid, Arc<CodexConversation>>>>,
|
||||
conversations: Arc<RwLock<HashMap<ConversationId, Arc<CodexConversation>>>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
@@ -70,14 +64,14 @@ impl ConversationManager {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(resume_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
} else {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = { Codex::spawn(config, auth_manager, InitialHistory::New).await? };
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, InitialHistory::New).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
@@ -85,7 +79,7 @@ impl ConversationManager {
|
||||
async fn finalize_spawn(
|
||||
&self,
|
||||
codex: Codex,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
@@ -116,7 +110,7 @@ impl ConversationManager {
|
||||
|
||||
pub async fn get_conversation(
|
||||
&self,
|
||||
conversation_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<Arc<CodexConversation>> {
|
||||
let conversations = self.conversations.read().await;
|
||||
conversations
|
||||
@@ -134,13 +128,20 @@ impl ConversationManager {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
||||
self.conversations.write().await.remove(&conversation_id);
|
||||
/// Removes the conversation from the manager's internal map, though the
|
||||
/// conversation is stored as `Arc<CodexConversation>`, it is possible that
|
||||
/// other references to it exist elsewhere. Returns the conversation if the
|
||||
/// conversation was found and removed.
|
||||
pub async fn remove_conversation(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
) -> Option<Arc<CodexConversation>> {
|
||||
self.conversations.write().await.remove(conversation_id)
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by dropping the last `drop_last_messages`
|
||||
@@ -149,19 +150,19 @@ impl ConversationManager {
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
conversation_history: Vec<ResponseItem>,
|
||||
num_messages_to_drop: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let history =
|
||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_after_dropping_last_messages(history, num_messages_to_drop);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, history).await?;
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
@@ -170,31 +171,37 @@ impl ConversationManager {
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||
fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return InitialHistory::Resumed(items);
|
||||
return InitialHistory::Forked(history.get_rollout_items());
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
let mut count = 0usize;
|
||||
let mut cut_index = 0usize;
|
||||
for (idx, item) in items.iter().enumerate().rev() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
// Work directly on rollout items, and cut the vector at the nth-from-last user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, .. }) = item
|
||||
&& role == "user"
|
||||
{
|
||||
count += 1;
|
||||
if count == n {
|
||||
// Cut everything from this user message to the end.
|
||||
cut_index = idx;
|
||||
break;
|
||||
}
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
if cut_index == 0 {
|
||||
// No prefix remains after dropping; start a new conversation.
|
||||
|
||||
// If fewer than n user messages exist, treat as empty.
|
||||
if user_positions.len() < n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth-from-last user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[user_positions.len() - n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
} else {
|
||||
InitialHistory::Resumed(items.into_iter().take(cut_index).collect())
|
||||
InitialHistory::Forked(rolled)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,13 +256,30 @@ mod tests {
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
// Wrap as InitialHistory::Forked with response items only.
|
||||
let initial: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
let expected_items = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
assert_eq!(
|
||||
truncated,
|
||||
InitialHistory::Resumed(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected_items).unwrap()
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert_eq!(truncated2, InitialHistory::New);
|
||||
let initial2: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_after_dropping_last_messages(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,38 +1,123 @@
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
use reqwest::header::HeaderValue;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub fn get_codex_user_agent(originator: Option<&str>) -> String {
|
||||
/// Set this to add a suffix to the User-Agent string.
|
||||
///
|
||||
/// It is not ideal that we're using a global singleton for this.
|
||||
/// This is primarily designed to differentiate MCP clients from each other.
|
||||
/// Because there can only be one MCP server per process, it should be safe for this to be a global static.
|
||||
/// However, future users of this should use this with caution as a result.
|
||||
/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a
|
||||
/// lot of wiring and it's easy to miss code paths by doing so.
|
||||
/// See https://github.com/openai/codex/pull/3388/files for an example of what that would look like.
|
||||
/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var
|
||||
/// or having to set data that they already specified in the mcp initialize request somewhere else.
|
||||
///
|
||||
/// A space is automatically added between the suffix and the rest of the User-Agent string.
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
pub header_value: HeaderValue,
|
||||
}
|
||||
|
||||
pub static ORIGINATOR: LazyLock<Originator> = LazyLock::new(|| {
|
||||
let default = "codex_cli_rs";
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.unwrap_or_else(|_| default.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
value,
|
||||
header_value,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: default.to_string(),
|
||||
header_value: HeaderValue::from_static(default),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
format!(
|
||||
let prefix = format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator.unwrap_or(DEFAULT_ORIGINATOR),
|
||||
ORIGINATOR.value.as_str(),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
)
|
||||
);
|
||||
let suffix = USER_AGENT_SUFFIX
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|guard| guard.clone());
|
||||
let suffix = suffix
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map_or_else(String::new, |value| format!(" ({value})"));
|
||||
|
||||
let candidate = format!("{prefix}{suffix}");
|
||||
sanitize_user_agent(candidate, &prefix)
|
||||
}
|
||||
|
||||
/// Sanitize the user agent string.
|
||||
///
|
||||
/// Invalid characters are replaced with an underscore.
|
||||
///
|
||||
/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR.
|
||||
fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
|
||||
if HeaderValue::from_str(candidate.as_str()).is_ok() {
|
||||
return candidate;
|
||||
}
|
||||
|
||||
let sanitized: String = candidate
|
||||
.chars()
|
||||
.map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' })
|
||||
.collect();
|
||||
if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() {
|
||||
tracing::warn!(
|
||||
"Sanitized Codex user agent because provided suffix contained invalid header characters"
|
||||
);
|
||||
sanitized
|
||||
} else if HeaderValue::from_str(fallback).is_ok() {
|
||||
tracing::warn!(
|
||||
"Falling back to base Codex user agent because provided suffix could not be sanitized"
|
||||
);
|
||||
fallback.to_string()
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Falling back to default Codex originator because base user agent string is invalid"
|
||||
);
|
||||
ORIGINATOR.value.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a reqwest client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client(originator: &str) -> reqwest::Client {
|
||||
pub fn create_client() -> reqwest::Client {
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
let originator_value = HeaderValue::from_str(originator)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_ORIGINATOR));
|
||||
headers.insert("originator", originator_value);
|
||||
let ua = get_codex_user_agent(Some(originator));
|
||||
headers.insert("originator", ORIGINATOR.header_value.clone());
|
||||
let ua = get_codex_user_agent();
|
||||
|
||||
match reqwest::Client::builder()
|
||||
reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
{
|
||||
Ok(client) => client,
|
||||
Err(_) => reqwest::Client::new(),
|
||||
}
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -41,7 +126,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let user_agent = get_codex_user_agent();
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
@@ -53,8 +138,7 @@ mod tests {
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let originator = "test_originator";
|
||||
let client = create_client(originator);
|
||||
let client = create_client();
|
||||
|
||||
// Spin up a local mock server and capture a request.
|
||||
let server = MockServer::start().await;
|
||||
@@ -82,21 +166,43 @@ mod tests {
|
||||
let originator_header = headers
|
||||
.get("originator")
|
||||
.expect("originator header missing");
|
||||
assert_eq!(originator_header.to_str().unwrap(), originator);
|
||||
assert_eq!(originator_header.to_str().unwrap(), "codex_cli_rs");
|
||||
|
||||
// User-Agent matches the computed Codex UA for that originator
|
||||
let expected_ua = get_codex_user_agent(Some(originator));
|
||||
let expected_ua = get_codex_user_agent();
|
||||
let ua_header = headers
|
||||
.get("user-agent")
|
||||
.expect("user-agent header missing");
|
||||
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\rsuffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized2() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\0suffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let user_agent = get_codex_user_agent();
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
|
||||
@@ -8,12 +8,10 @@ use crate::shell::Shell;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_CLOSE_TAG;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>";
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_END: &str = "</environment_context>";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
@@ -28,6 +26,7 @@ pub(crate) struct EnvironmentContext {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
}
|
||||
|
||||
@@ -59,6 +58,16 @@ impl EnvironmentContext {
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
writable_roots: match sandbox_policy {
|
||||
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
|
||||
if writable_roots.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(writable_roots)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
}
|
||||
}
|
||||
@@ -74,12 +83,13 @@ impl EnvironmentContext {
|
||||
/// <cwd>...</cwd>
|
||||
/// <approval_policy>...</approval_policy>
|
||||
/// <sandbox_mode>...</sandbox_mode>
|
||||
/// <writable_roots>...</writable_roots>
|
||||
/// <network_access>...</network_access>
|
||||
/// <shell>...</shell>
|
||||
/// </environment_context>
|
||||
/// ```
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_START.to_string()];
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()];
|
||||
if let Some(cwd) = self.cwd {
|
||||
lines.push(format!(" <cwd>{}</cwd>", cwd.to_string_lossy()));
|
||||
}
|
||||
@@ -96,12 +106,22 @@ impl EnvironmentContext {
|
||||
" <network_access>{network_access}</network_access>"
|
||||
));
|
||||
}
|
||||
if let Some(writable_roots) = self.writable_roots {
|
||||
lines.push(" <writable_roots>".to_string());
|
||||
for writable_root in writable_roots {
|
||||
lines.push(format!(
|
||||
" <root>{}</root>",
|
||||
writable_root.to_string_lossy()
|
||||
));
|
||||
}
|
||||
lines.push(" </writable_roots>".to_string());
|
||||
}
|
||||
if let Some(shell) = self.shell
|
||||
&& let Some(shell_name) = shell.name()
|
||||
{
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
}
|
||||
lines.push(ENVIRONMENT_CONTEXT_END.to_string());
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
@@ -117,3 +137,77 @@ impl From<EnvironmentContext> for ResponseItem {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
|
||||
network_access,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_workspace_write_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<cwd>/repo</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
<writable_roots>
|
||||
<root>/repo</root>
|
||||
<root>/tmp</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_read_only_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_full_access_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>on-failure</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
@@ -49,7 +51,7 @@ pub enum CodexErr {
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(Uuid),
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
#[error("session configured event was not the first event in the stream")]
|
||||
SessionConfiguredNotFirstEvent,
|
||||
@@ -127,38 +129,58 @@ pub enum CodexErr {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UsageLimitReachedError {
|
||||
pub plan_type: Option<String>,
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
pub(crate) plan_type: Option<PlanType>,
|
||||
pub(crate) resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UsageLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Base message differs slightly for legacy ChatGPT Plus plan users.
|
||||
if let Some(plan_type) = &self.plan_type
|
||||
&& plan_type == "plus"
|
||||
{
|
||||
write!(
|
||||
f,
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing) or try again"
|
||||
)?;
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " later.")?;
|
||||
let message = match self.plan_type.as_ref() {
|
||||
Some(PlanType::Known(KnownPlan::Plus)) => format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing){}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
|
||||
format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin{}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
write!(f, "You've hit your usage limit.")?;
|
||||
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " Try again in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " Try again later.")?;
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
| Some(PlanType::Known(KnownPlan::Enterprise))
|
||||
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Unknown(_)) | None => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" Try again in {reset_duration}.")
|
||||
} else {
|
||||
" Try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix_after_or(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" or try again in {reset_duration}.")
|
||||
} else {
|
||||
" or try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,7 +259,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_plus_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -246,6 +268,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Free)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_when_none() {
|
||||
let err = UsageLimitReachedError {
|
||||
@@ -258,10 +292,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_team_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Team)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again in 1 hour."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_business_plan_without_reset() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_for_other_plans() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("pro".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -285,7 +343,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_hours_and_minutes() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: Some(3 * 3600 + 32 * 60),
|
||||
};
|
||||
assert_eq!(
|
||||
|
||||
167
codex-rs/core/src/event_mapping.rs
Normal file
167
codex-rs/core/src/event_mapping.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AgentReasoningEvent;
|
||||
use crate::protocol::AgentReasoningRawContentEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::UserMessageEvent;
|
||||
use crate::protocol::WebSearchEndEvent;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
|
||||
/// Convert a `ResponseItem` into zero or more `EventMsg` values that the UI can render.
|
||||
///
|
||||
/// When `show_raw_agent_reasoning` is false, raw reasoning content events are omitted.
|
||||
pub(crate) fn map_response_item_to_event_messages(
|
||||
item: &ResponseItem,
|
||||
show_raw_agent_reasoning: bool,
|
||||
) -> Vec<EventMsg> {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Do not surface system messages as user events.
|
||||
if role == "system" {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut events: Vec<EventMsg> = Vec::new();
|
||||
let mut message_parts: Vec<String> = Vec::new();
|
||||
let mut images: Vec<String> = Vec::new();
|
||||
let mut kind: Option<InputMessageKind> = None;
|
||||
|
||||
for content_item in content.iter() {
|
||||
match content_item {
|
||||
ContentItem::InputText { text } => {
|
||||
if kind.is_none() {
|
||||
let trimmed = text.trim_start();
|
||||
kind = if trimmed.starts_with("<environment_context>") {
|
||||
Some(InputMessageKind::EnvironmentContext)
|
||||
} else if trimmed.starts_with("<user_instructions>") {
|
||||
Some(InputMessageKind::UserInstructions)
|
||||
} else {
|
||||
Some(InputMessageKind::Plain)
|
||||
};
|
||||
}
|
||||
message_parts.push(text.clone());
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
images.push(image_url.clone());
|
||||
}
|
||||
ContentItem::OutputText { text } => {
|
||||
events.push(EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: text.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !message_parts.is_empty() || !images.is_empty() {
|
||||
let message = if message_parts.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
message_parts.join("")
|
||||
};
|
||||
let images = if images.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(images)
|
||||
};
|
||||
|
||||
events.push(EventMsg::UserMessage(UserMessageEvent {
|
||||
message,
|
||||
kind,
|
||||
images,
|
||||
}));
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::Reasoning {
|
||||
summary, content, ..
|
||||
} => {
|
||||
let mut events = Vec::new();
|
||||
for ReasoningItemReasoningSummary::SummaryText { text } in summary {
|
||||
events.push(EventMsg::AgentReasoning(AgentReasoningEvent {
|
||||
text: text.clone(),
|
||||
}));
|
||||
}
|
||||
if let Some(items) = content.as_ref().filter(|_| show_raw_agent_reasoning) {
|
||||
for c in items {
|
||||
let text = match c {
|
||||
ReasoningItemContent::ReasoningText { text }
|
||||
| ReasoningItemContent::Text { text } => text,
|
||||
};
|
||||
events.push(EventMsg::AgentReasoningRawContent(
|
||||
AgentReasoningRawContentEvent { text: text.clone() },
|
||||
));
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
ResponseItem::WebSearchCall { id, action, .. } => match action {
|
||||
WebSearchAction::Search { query } => {
|
||||
let call_id = id.clone().unwrap_or_else(|| "".to_string());
|
||||
vec![EventMsg::WebSearchEnd(WebSearchEndEvent {
|
||||
call_id,
|
||||
query: query.clone(),
|
||||
})]
|
||||
}
|
||||
WebSearchAction::Other => Vec::new(),
|
||||
},
|
||||
|
||||
// Variants that require side effects are handled by higher layers and do not emit events here.
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Other => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn maps_user_message_with_text_and_two_images() {
|
||||
let img1 = "https://example.com/one.png".to_string();
|
||||
let img2 = "https://example.com/two.jpg".to_string();
|
||||
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img1.clone(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img2.clone(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let events = map_response_item_to_event_messages(&item, false);
|
||||
assert_eq!(events.len(), 1, "expected a single user message event");
|
||||
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,6 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use serde_bytes::ByteBuf;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
@@ -369,7 +368,7 @@ async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk: ByteBuf::from(chunk),
|
||||
chunk,
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
|
||||
@@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession {
|
||||
|
||||
/// JoinHandle for the child wait task.
|
||||
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||
|
||||
/// Tracks whether the underlying process has exited.
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
impl ExecCommandSession {
|
||||
@@ -34,6 +37,7 @@ impl ExecCommandSession {
|
||||
reader_handle: JoinHandle<()>,
|
||||
writer_handle: JoinHandle<()>,
|
||||
wait_handle: JoinHandle<()>,
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
writer_tx,
|
||||
@@ -42,6 +46,7 @@ impl ExecCommandSession {
|
||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||
exit_status,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +57,10 @@ impl ExecCommandSession {
|
||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||
self.output_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn has_exited(&self) -> bool {
|
||||
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ExecCommandSession {
|
||||
|
||||
@@ -6,6 +6,7 @@ mod session_manager;
|
||||
|
||||
pub use exec_command_params::ExecCommandParams;
|
||||
pub use exec_command_params::WriteStdinParams;
|
||||
pub(crate) use exec_command_session::ExecCommandSession;
|
||||
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use portable_pty::CommandBuilder;
|
||||
@@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||
use crate::exec_command::session_id::SessionId;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -327,11 +329,14 @@ async fn create_exec_command_session(
|
||||
|
||||
// Keep the child alive until it exits, then signal exit code.
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = exit_status.clone();
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.exit_code() as i32,
|
||||
Err(_) => -1,
|
||||
};
|
||||
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
@@ -343,116 +348,11 @@ async fn create_exec_command_session(
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
);
|
||||
Ok((session, exit_rx))
|
||||
}
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
// No truncation needed
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
// Cannot keep any content; still return a full marker (never truncated).
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
// Helper to truncate a string to a given byte length on a char boundary.
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
||||
// to UTF-8 char boundaries.
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1; // keep the newline so suffix starts on a fresh line
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1; // start after newline
|
||||
}
|
||||
// Fall back to a char boundary at or after start_tail.
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
// Refine marker length and budgets until stable. Marker is never truncated.
|
||||
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
// No room for any content within the cap; return a full, untruncated marker
|
||||
// that reflects the entire truncated content.
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
// Place marker on its own line for symmetry when we keep line boundaries.
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
// Fallback: use last guess to build output.
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -616,50 +516,4 @@ Output:
|
||||
abc"#;
|
||||
assert_eq!(expected, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
// A long string with no newlines that exceeds the cap.
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
let max_bytes = 16; // force truncation
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
// For very small caps, we return the full, untruncated marker,
|
||||
// even if it exceeds the cap.
|
||||
assert_eq!(out, "…16 tokens truncated…");
|
||||
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
||||
assert_eq!(original, Some(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::GitSha;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use futures::future::join_all;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -43,19 +44,6 @@ pub fn get_git_repo_root(base_dir: &Path) -> Option<PathBuf> {
|
||||
/// Timeout for git commands to prevent freezing on large repositories
|
||||
const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5);
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitInfo {
|
||||
/// Current commit hash (SHA)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub commit_hash: Option<String>,
|
||||
/// Current branch name
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub branch: Option<String>,
|
||||
/// Repository URL (if available from remote)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repository_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitDiffToRemote {
|
||||
pub sha: GitSha,
|
||||
@@ -814,7 +802,7 @@ mod tests {
|
||||
async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap().to_path_buf();
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&repo_path),
|
||||
@@ -822,10 +810,7 @@ mod tests {
|
||||
);
|
||||
let nested = repo_path.join("sub/dir");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&nested),
|
||||
Some(expected.clone())
|
||||
);
|
||||
assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
68
codex-rs/core/src/internal_storage.rs
Normal file
68
codex-rs/core/src/internal_storage.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub(crate) const INTERNAL_STORAGE_FILE: &str = "internal_storage.json";
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct InternalStorage {
|
||||
#[serde(skip)]
|
||||
storage_path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub gpt_5_high_model_prompt_seen: bool,
|
||||
}
|
||||
|
||||
// TODO(jif) generalise all the file writers and build proper async channel inserters.
|
||||
impl InternalStorage {
|
||||
pub fn load(codex_home: &Path) -> Self {
|
||||
let storage_path = codex_home.join(INTERNAL_STORAGE_FILE);
|
||||
|
||||
match std::fs::read_to_string(&storage_path) {
|
||||
Ok(serialized) => match serde_json::from_str::<Self>(&serialized) {
|
||||
Ok(mut storage) => {
|
||||
storage.storage_path = storage_path;
|
||||
storage
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to parse internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to read internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn empty(storage_path: PathBuf) -> Self {
|
||||
Self {
|
||||
storage_path,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> anyhow::Result<()> {
|
||||
let serialized = serde_json::to_string_pretty(self)?;
|
||||
|
||||
if let Some(parent) = self.storage_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!(
|
||||
"failed to create internal storage directory at {}",
|
||||
parent.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&self.storage_path, serialized)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to persist internal storage at {}",
|
||||
self.storage_path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
mod apply_patch;
|
||||
pub mod auth;
|
||||
mod bash;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
@@ -16,6 +16,7 @@ mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -27,6 +28,7 @@ mod exec_command;
|
||||
pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod internal_storage;
|
||||
mod is_safe_command;
|
||||
pub mod landlock;
|
||||
mod mcp_connection_manager;
|
||||
@@ -34,12 +36,17 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
@@ -59,9 +66,16 @@ pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tool_apply_patch;
|
||||
pub mod turn_diff_tracker;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::RolloutRecorder;
|
||||
pub use rollout::SESSIONS_SUBDIR;
|
||||
pub use rollout::SessionMeta;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use safety::get_platform_sandbox;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -36,8 +37,8 @@ use crate::config_types::McpServerConfig;
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// Default timeout for initializing MCP server & initially listing tools.
|
||||
const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Map that holds a startup error for every MCP server that could **not** be
|
||||
/// spawned successfully.
|
||||
@@ -81,6 +82,11 @@ struct ToolInfo {
|
||||
tool: Tool,
|
||||
}
|
||||
|
||||
struct ManagedClient {
|
||||
client: Arc<McpClient>,
|
||||
startup_timeout: Duration,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
@@ -88,7 +94,7 @@ pub(crate) struct McpConnectionManager {
|
||||
///
|
||||
/// The server name originates from the keys of the `mcp_servers` map in
|
||||
/// the user configuration.
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
clients: HashMap<String, ManagedClient>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
@@ -126,8 +132,15 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg
|
||||
.startup_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { command, args, env } = cfg;
|
||||
let McpServerConfig {
|
||||
command, args, env, ..
|
||||
} = cfg;
|
||||
let client_res = McpClient::new_stdio_client(
|
||||
command.into(),
|
||||
args.into_iter().map(OsString::from).collect(),
|
||||
@@ -150,16 +163,23 @@ impl McpConnectionManager {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
let initialize_notification_params = None;
|
||||
let timeout = Some(Duration::from_secs(10));
|
||||
match client
|
||||
.initialize(params, initialize_notification_params, timeout)
|
||||
.initialize(
|
||||
params,
|
||||
initialize_notification_params,
|
||||
Some(startup_timeout),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_response) => (server_name, Ok(client)),
|
||||
Ok(_response) => (server_name, Ok((client, startup_timeout))),
|
||||
Err(e) => (server_name, Err(e)),
|
||||
}
|
||||
}
|
||||
@@ -168,15 +188,26 @@ impl McpConnectionManager {
|
||||
});
|
||||
}
|
||||
|
||||
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
|
||||
HashMap::with_capacity(join_set.len());
|
||||
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
let (server_name, client_res) = res?; // JoinError propagation
|
||||
let (server_name, client_res) = match res {
|
||||
Ok((server_name, client_res)) => (server_name, client_res),
|
||||
Err(e) => {
|
||||
warn!("Task panic when starting MCP server: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match client_res {
|
||||
Ok(client) => {
|
||||
clients.insert(server_name, std::sync::Arc::new(client));
|
||||
Ok((client, startup_timeout)) => {
|
||||
clients.insert(
|
||||
server_name,
|
||||
ManagedClient {
|
||||
client: Arc::new(client),
|
||||
startup_timeout,
|
||||
},
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(server_name, e);
|
||||
@@ -184,7 +215,13 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
let all_tools = match list_all_tools(&clients).await {
|
||||
Ok(tools) => tools,
|
||||
Err(e) => {
|
||||
warn!("Failed to list tools from some MCP servers: {e:#}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
|
||||
@@ -212,6 +249,7 @@ impl McpConnectionManager {
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
|
||||
.client
|
||||
.clone();
|
||||
|
||||
client
|
||||
@@ -229,21 +267,18 @@ impl McpConnectionManager {
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
// keeps the overall latency roughly at the slowest server instead of
|
||||
// the cumulative latency.
|
||||
for (server_name, client) in clients {
|
||||
for (server_name, managed_client) in clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = client.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let startup_timeout = managed_client.startup_timeout;
|
||||
join_set.spawn(async move {
|
||||
let res = client_clone
|
||||
.list_tools(None, Some(LIST_TOOLS_TIMEOUT))
|
||||
.await;
|
||||
let res = client_clone.list_tools(None, Some(startup_timeout)).await;
|
||||
(server_name_cloned, res)
|
||||
});
|
||||
}
|
||||
@@ -251,8 +286,19 @@ async fn list_all_tools(
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
let (server_name, list_result) = if let Ok(result) = join_res {
|
||||
result
|
||||
} else {
|
||||
warn!("Task panic when listing tools for MCP server: {join_res:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let list_result = if let Ok(result) = list_result {
|
||||
result
|
||||
} else {
|
||||
warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
//! JSON-Lines tooling. Each record has the following schema:
|
||||
//!
|
||||
//! ````text
|
||||
//! {"session_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! {"conversation_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! ````
|
||||
//!
|
||||
//! To minimise the chance of interleaved writes when multiple processes are
|
||||
@@ -22,14 +22,15 @@ use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config_types::HistoryPersistence;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
#[cfg(unix)]
|
||||
@@ -54,10 +55,14 @@ fn history_filepath(config: &Config) -> PathBuf {
|
||||
path
|
||||
}
|
||||
|
||||
/// Append a `text` entry associated with `session_id` to the history file. Uses
|
||||
/// Append a `text` entry associated with `conversation_id` to the history file. Uses
|
||||
/// advisory file locking to ensure that concurrent writes do not interleave,
|
||||
/// which entails a small amount of blocking I/O internally.
|
||||
pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config) -> Result<()> {
|
||||
pub(crate) async fn append_entry(
|
||||
text: &str,
|
||||
conversation_id: &ConversationId,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
match config.history.persistence {
|
||||
HistoryPersistence::SaveAll => {
|
||||
// Save everything: proceed.
|
||||
@@ -84,7 +89,7 @@ pub(crate) async fn append_entry(text: &str, session_id: &Uuid, config: &Config)
|
||||
|
||||
// Construct the JSON line first so we can write it in a single syscall.
|
||||
let entry = HistoryEntry {
|
||||
session_id: session_id.to_string(),
|
||||
session_id: conversation_id.to_string(),
|
||||
ts,
|
||||
text: text.to_string(),
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
|
||||
/// A model family is a group of models that share certain characteristics.
|
||||
@@ -20,6 +21,9 @@ pub struct ModelFamily {
|
||||
// `summary` is optional).
|
||||
pub supports_reasoning_summaries: bool,
|
||||
|
||||
// Define if we need a special handling of reasoning summary
|
||||
pub reasoning_summary_format: ReasoningSummaryFormat,
|
||||
|
||||
// This should be set to true when the model expects a tool named
|
||||
// "local_shell" to be provided. Its contract must be understood natively by
|
||||
// the model such that its description can be omitted.
|
||||
@@ -41,6 +45,7 @@ macro_rules! model_family {
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
};
|
||||
@@ -61,6 +66,7 @@ macro_rules! simple_model_family {
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
apply_patch_tool_type: None,
|
||||
})
|
||||
@@ -90,13 +96,14 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4.1") {
|
||||
model_family!(
|
||||
slug, "gpt-4.1",
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-oss") {
|
||||
} else if slug.starts_with("gpt-oss") || slug.starts_with("openai/gpt-oss") {
|
||||
model_family!(slug, "gpt-oss", apply_patch_tool_type: Some(ApplyPatchToolType::Function))
|
||||
} else if slug.starts_with("gpt-4o") {
|
||||
simple_model_family!(slug, "gpt-4o")
|
||||
|
||||
@@ -80,7 +80,10 @@ pub struct ModelProviderInfo {
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
@@ -78,13 +78,13 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
max_output_tokens: 4_096,
|
||||
}),
|
||||
|
||||
"gpt-5" => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo {
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
_ if slug.starts_with("codex-") => Some(ModelInfo {
|
||||
context_window: 400_000,
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ use std::collections::HashMap;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
use crate::tool_apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tool_apply_patch::create_apply_patch_json_tool;
|
||||
@@ -58,7 +57,7 @@ pub(crate) enum OpenAiTool {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
ShellWithRequest,
|
||||
LocalShell,
|
||||
StreamableShell,
|
||||
}
|
||||
@@ -70,17 +69,18 @@ pub(crate) struct ToolsConfig {
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) model_family: &'a ModelFamily,
|
||||
pub(crate) approval_policy: AskForApproval,
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) include_plan_tool: bool,
|
||||
pub(crate) include_apply_patch_tool: bool,
|
||||
pub(crate) include_web_search_request: bool,
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
@@ -88,12 +88,12 @@ impl ToolsConfig {
|
||||
let ToolsConfigParams {
|
||||
model_family,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_web_search_request,
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
} = params;
|
||||
let mut shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::StreamableShell
|
||||
@@ -103,9 +103,7 @@ impl ToolsConfig {
|
||||
ConfigShellToolType::DefaultShell
|
||||
};
|
||||
if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
|
||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
}
|
||||
shell_type = ConfigShellToolType::ShellWithRequest;
|
||||
}
|
||||
|
||||
let apply_patch_tool_type = match model_family.apply_patch_tool_type {
|
||||
@@ -126,6 +124,7 @@ impl ToolsConfig {
|
||||
apply_patch_tool_type,
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,7 +199,56 @@ fn create_shell_tool() -> OpenAiTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
fn create_unified_exec_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some(
|
||||
"When no session_id is provided, treat the array as the command and arguments \
|
||||
to launch. When session_id is set, concatenate the strings (in order) and write \
|
||||
them to the session's stdin."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"session_id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier for an existing interactive session. If omitted, a new command \
|
||||
is spawned."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum time in milliseconds to wait for output after writing the input."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "unified_exec".to_string(),
|
||||
description:
|
||||
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["input".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const SHELL_TOOL_DESCRIPTION: &str = r#"Runs a shell command and returns its output"#;
|
||||
|
||||
fn create_shell_tool_for_request() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
@@ -212,82 +260,29 @@ fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
description: Some("Working directory to execute the command in.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
description: Some("Timeout for the command in milliseconds.".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) {
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
description: Some("Request escalated permissions, only for when a command would otherwise be blocked by the sandbox.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
description: Some("Required if and only if with_escalated_permissions == true. One sentence explaining why escalation is needed (e.g., write outside CWD, network fetch, git commit).".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
writable_roots,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
{}{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
writable_roots.iter().map(|wr| format!(" - {}", wr.to_string_lossy())).collect::<Vec<String>>().join("\n"),
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
let description = SHELL_TOOL_DESCRIPTION.to_string();
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
@@ -300,7 +295,6 @@ The shell tool is used to execute shell commands.
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_view_image_tool() -> OpenAiTool {
|
||||
// Support only local filesystem path.
|
||||
let mut properties = BTreeMap::new();
|
||||
@@ -534,23 +528,27 @@ pub(crate) fn get_openai_tools(
|
||||
) -> Vec<OpenAiTool> {
|
||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
if config.experimental_unified_exec_tool {
|
||||
tools.push(create_unified_exec_tool());
|
||||
} else {
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest => {
|
||||
tools.push(create_shell_tool_for_request());
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,10 +575,8 @@ pub(crate) fn get_openai_tools(
|
||||
if config.include_view_image_tool {
|
||||
tools.push(create_view_image_tool());
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
// Ensure deterministic ordering to maximize prompt cache hits.
|
||||
// HashMap iteration order is non-deterministic, so sort by fully-qualified tool name.
|
||||
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
@@ -636,18 +632,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["local_shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -657,18 +653,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -678,12 +674,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
@@ -726,7 +722,7 @@ mod tests {
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -783,12 +779,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -841,11 +837,11 @@ mod tests {
|
||||
]);
|
||||
|
||||
let tools = get_openai_tools(&config, Some(tools_map));
|
||||
// Expect shell first, followed by MCP tools sorted by fully-qualified name.
|
||||
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -860,12 +856,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -893,7 +889,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/search"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/search"],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -922,12 +918,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -953,7 +949,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/paginate"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/paginate"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
@@ -979,12 +975,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1008,7 +1004,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/tags"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1036,12 +1035,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1065,7 +1064,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/value"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1086,13 +1088,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_workspace_write() {
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec!["workspace".into()],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
let tool = super::create_shell_tool_for_sandbox(&sandbox_policy);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1101,29 +1097,13 @@ mod tests {
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
- workspace
|
||||
- Commands that require network access
|
||||
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_readonly() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::ReadOnly);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1132,27 +1112,13 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_danger_full_access() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::DangerFullAccess);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1161,6 +1127,7 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
assert_eq!(description, "Runs a shell command and returns its output.");
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -868,7 +868,7 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
let parts = if contains_connectors(&normalized) {
|
||||
split_on_connectors(&normalized)
|
||||
} else {
|
||||
vec![normalized.clone()]
|
||||
vec![normalized]
|
||||
};
|
||||
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
@@ -1201,10 +1201,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: cmd.clone(),
|
||||
name,
|
||||
}
|
||||
ParsedCommand::Read { cmd, name }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
@@ -1215,10 +1212,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
ParsedCommand::ListFiles { path, cmd, .. } => {
|
||||
if had_connectors {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: cmd.clone(),
|
||||
path,
|
||||
}
|
||||
ParsedCommand::ListFiles { cmd, path }
|
||||
} else {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
@@ -1230,11 +1224,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
query, path, cmd, ..
|
||||
} => {
|
||||
if had_connectors {
|
||||
ParsedCommand::Search {
|
||||
cmd: cmd.clone(),
|
||||
query,
|
||||
path,
|
||||
}
|
||||
ParsedCommand::Search { cmd, query, path }
|
||||
} else {
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
|
||||
@@ -26,7 +26,7 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
|
||||
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
pub async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match read_project_docs(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
@@ -115,7 +115,7 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
// Build chain from cwd upwards and detect git root.
|
||||
let mut chain: Vec<PathBuf> = vec![dir.clone()];
|
||||
let mut git_root: Option<PathBuf> = None;
|
||||
let mut cursor = dir.clone();
|
||||
let mut cursor = dir;
|
||||
while let Some(parent) = cursor.parent() {
|
||||
let git_marker = cursor.join(".git");
|
||||
let git_exists = match std::fs::metadata(&git_marker) {
|
||||
|
||||
@@ -10,6 +10,9 @@ use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
|
||||
/// Returned page of conversation summaries.
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
@@ -34,7 +37,8 @@ pub struct ConversationItem {
|
||||
}
|
||||
|
||||
/// Hard cap to bound worst‑case work per request.
|
||||
const MAX_SCAN_FILES: usize = 50_000;
|
||||
const MAX_SCAN_FILES: usize = 100;
|
||||
const HEAD_RECORD_LIMIT: usize = 10;
|
||||
|
||||
/// Pagination cursor identifying a file by timestamp and UUID.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -166,8 +170,16 @@ async fn traverse_directories_for_paths(
|
||||
if items.len() == page_size {
|
||||
break 'outer;
|
||||
}
|
||||
let head = read_first_jsonl_records(&path, 5).await.unwrap_or_default();
|
||||
items.push(ConversationItem { path, head });
|
||||
// Read head and simultaneously detect message events within the same
|
||||
// first N JSONL records to avoid a second file read.
|
||||
let (head, saw_session_meta, saw_user_event) =
|
||||
read_head_and_flags(&path, HEAD_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or((Vec::new(), false, false));
|
||||
// Apply filters: must have session meta and at least one user message event
|
||||
if saw_session_meta && saw_user_event {
|
||||
items.push(ConversationItem { path, head });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,16 +282,19 @@ fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uui
|
||||
Some((ts, uuid))
|
||||
}
|
||||
|
||||
async fn read_first_jsonl_records(
|
||||
async fn read_head_and_flags(
|
||||
path: &Path,
|
||||
max_records: usize,
|
||||
) -> io::Result<Vec<serde_json::Value>> {
|
||||
) -> io::Result<(Vec<serde_json::Value>, bool, bool)> {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let file = tokio::fs::File::open(path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut head: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_session_meta = false;
|
||||
let mut saw_user_event = false;
|
||||
|
||||
while head.len() < max_records {
|
||||
let line_opt = lines.next_line().await?;
|
||||
let Some(line) = line_opt else { break };
|
||||
@@ -287,9 +302,35 @@ async fn read_first_jsonl_records(
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
head.push(v);
|
||||
|
||||
let parsed: Result<RolloutLine, _> = serde_json::from_str(trimmed);
|
||||
let Ok(rollout_line) = parsed else { continue };
|
||||
|
||||
match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
if let Ok(val) = serde_json::to_value(session_meta_line) {
|
||||
head.push(val);
|
||||
saw_session_meta = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
if let Ok(val) = serde_json::to_value(item) {
|
||||
head.push(val);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::EventMsg(ev) => {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
saw_user_event = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(head)
|
||||
|
||||
Ok((head, saw_session_meta, saw_user_event))
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
//! Rollout module: persistence and discovery of session rollout files.
|
||||
|
||||
pub(crate) const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
|
||||
|
||||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::SessionStateSnapshot;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
/// Whether a rollout `item` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &ResponseItem) -> bool {
|
||||
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
@@ -14,3 +29,44 @@ pub(crate) fn is_persisted_response_item(item: &ResponseItem) -> bool {
|
||||
ResponseItem::WebSearchCall { .. } | ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether an `EventMsg` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
||||
match ev {
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_) => true,
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationPath(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
@@ -17,7 +19,6 @@ use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::list::ConversationsPage;
|
||||
@@ -25,25 +26,15 @@ use super::list::Cursor;
|
||||
use super::list::get_conversations;
|
||||
use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::default_client::ORIGINATOR;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
@@ -55,7 +46,7 @@ pub struct SavedSession {
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
pub session_id: ConversationId,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
@@ -70,16 +61,45 @@ pub struct SavedSession {
|
||||
#[derive(Clone)]
|
||||
pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
pub(crate) rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum RolloutRecorderParams {
|
||||
Create {
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
Resume {
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
AddItems(Vec<RolloutItem>),
|
||||
/// Ensure all prior writes are processed; respond when flushed.
|
||||
Flush {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
Shutdown {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RolloutRecorderParams {
|
||||
pub fn new(conversation_id: ConversationId, instructions: Option<String>) -> Self {
|
||||
Self::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resume(path: PathBuf) -> Self {
|
||||
Self::Resume { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
#[allow(dead_code)]
|
||||
/// List conversations (rollout files) under the provided Codex home directory.
|
||||
pub async fn list_conversations(
|
||||
codex_home: &Path,
|
||||
@@ -92,23 +112,49 @@ impl RolloutRecorder {
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(
|
||||
config: &Config,
|
||||
uuid: Uuid,
|
||||
instructions: Option<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
|
||||
let (file, rollout_path, meta) = match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
} => {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
(
|
||||
tokio::fs::File::from_std(file),
|
||||
path,
|
||||
Some(SessionMeta {
|
||||
id: session_id,
|
||||
timestamp,
|
||||
cwd: config.cwd.clone(),
|
||||
originator: ORIGINATOR.value.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
instructions,
|
||||
}),
|
||||
)
|
||||
}
|
||||
RolloutRecorderParams::Resume { path } => (
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
path,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
@@ -121,21 +167,12 @@ impl RolloutRecorder {
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
|
||||
|
||||
Ok(Self { tx })
|
||||
Ok(Self { tx, rollout_path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -154,55 +191,91 @@ impl RolloutRecorder {
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
/// Flush all queued writes and wait until they are committed by the writer task.
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.send(RolloutCmd::Flush { ack: tx })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
pub(crate) async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
tracing::error!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let _ = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let mut items = Vec::new();
|
||||
if text.trim().is_empty() {
|
||||
return Err(IoError::other("empty session file"));
|
||||
}
|
||||
|
||||
for line in lines {
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
let mut conversation_id: Option<ConversationId> = None;
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => {
|
||||
if is_persisted_response_item(&item) {
|
||||
items.push(item);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
warn!("failed to parse line as JSON: {line:?}, error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the rollout line structure
|
||||
match serde_json::from_value::<RolloutLine>(v.clone()) {
|
||||
Ok(rollout_line) => match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
// Use the FIRST SessionMeta encountered in the file as the canonical
|
||||
// conversation id and main session information. Keep all items intact.
|
||||
if conversation_id.is_none() {
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
}
|
||||
items.push(RolloutItem::SessionMeta(session_meta_line));
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
RolloutItem::Compacted(item) => {
|
||||
items.push(RolloutItem::Compacted(item));
|
||||
}
|
||||
RolloutItem::TurnContext(item) => {
|
||||
items.push(RolloutItem::TurnContext(item));
|
||||
}
|
||||
RolloutItem::EventMsg(_ev) => {
|
||||
items.push(RolloutItem::EventMsg(_ev));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse rollout line: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
tracing::error!(
|
||||
"Resumed rollout with {} items, conversation ID: {:?}",
|
||||
items.len(),
|
||||
conversation_id
|
||||
);
|
||||
let conversation_id = conversation_id
|
||||
.ok_or_else(|| IoError::other("failed to parse conversation ID from rollout file"))?;
|
||||
|
||||
if items.is_empty() {
|
||||
Ok(InitialHistory::New)
|
||||
} else {
|
||||
Ok(InitialHistory::Resumed(items))
|
||||
return Ok(InitialHistory::New);
|
||||
}
|
||||
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id,
|
||||
history: items,
|
||||
rollout_path: path.to_path_buf(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn get_rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
@@ -225,14 +298,20 @@ struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Full path to the rollout file.
|
||||
path: PathBuf,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
session_id: Uuid,
|
||||
conversation_id: ConversationId,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
fn create_log_file(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
@@ -251,7 +330,7 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{session_id}.jsonl");
|
||||
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
@@ -261,7 +340,8 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
session_id,
|
||||
path,
|
||||
conversation_id,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
@@ -277,13 +357,15 @@ async fn rollout_writer(
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file
|
||||
writer.write_line(&session_meta_with_git).await?;
|
||||
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
|
||||
writer
|
||||
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
@@ -292,23 +374,17 @@ async fn rollout_writer(
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_line(&item).await?;
|
||||
writer.write_rollout_item(item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
RolloutCmd::Flush { ack } => {
|
||||
// Ensure underlying file is flushed and then ack.
|
||||
if let Err(e) = writer.file.flush().await {
|
||||
let _ = ack.send(());
|
||||
return Err(e);
|
||||
}
|
||||
writer
|
||||
.write_line(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
})
|
||||
.await?;
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
@@ -324,10 +400,24 @@ struct JsonlWriter {
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> {
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = OffsetDateTime::now_utc()
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let line = RolloutLine {
|
||||
timestamp,
|
||||
item: rollout_item,
|
||||
};
|
||||
self.write_line(&line).await
|
||||
}
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
let _ = self.file.write_all(json.as_bytes()).await;
|
||||
self.file.write_all(json.as_bytes()).await?;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -42,10 +42,30 @@ fn write_session_file(
|
||||
|
||||
let meta = serde_json::json!({
|
||||
"timestamp": ts_str,
|
||||
"id": uuid.to_string()
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": uuid,
|
||||
"timestamp": ts_str,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{meta}")?;
|
||||
|
||||
// Include at least one user message event to satisfy listing filters
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts_str,
|
||||
"type": "event_msg",
|
||||
"payload": {
|
||||
"type": "user_message",
|
||||
"message": "Hello from user",
|
||||
"kind": "plain"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{user_event}")?;
|
||||
|
||||
for i in 0..num_records {
|
||||
let rec = serde_json::json!({
|
||||
"record_type": "response",
|
||||
@@ -93,24 +113,30 @@ async fn test_list_conversations_latest_first() {
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-01-01T12-00-00-{u1}.jsonl"));
|
||||
|
||||
let head_3 = vec![
|
||||
serde_json::json!({"timestamp": "2025-01-03T12-00-00", "id": u3.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
serde_json::json!({"timestamp": "2025-01-02T12-00-00", "id": u2.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_1 = vec![
|
||||
serde_json::json!({"timestamp": "2025-01-01T12-00-00", "id": u1.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-01-03T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-01-02T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-01-01T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
|
||||
let expected_cursor: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-01-01T12-00-00|{u1}\"")).unwrap();
|
||||
@@ -170,14 +196,22 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("04")
|
||||
.join(format!("rollout-2025-03-04T09-00-00-{u4}.jsonl"));
|
||||
let head_5 = vec![
|
||||
serde_json::json!({"timestamp": "2025-03-05T09-00-00", "id": u5.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_4 = vec![
|
||||
serde_json::json!({"timestamp": "2025-03-04T09-00-00", "id": u4.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_5 = vec![serde_json::json!({
|
||||
"id": u5,
|
||||
"timestamp": "2025-03-05T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_4 = vec![serde_json::json!({
|
||||
"id": u4,
|
||||
"timestamp": "2025-03-04T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor1: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-04T09-00-00|{u4}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
@@ -212,14 +246,22 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-03-02T09-00-00-{u2}.jsonl"));
|
||||
let head_3 = vec![
|
||||
serde_json::json!({"timestamp": "2025-03-03T09-00-00", "id": u3.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
serde_json::json!({"timestamp": "2025-03-02T09-00-00", "id": u2.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-03-03T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-03-02T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor2: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-02T09-00-00|{u2}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
@@ -248,10 +290,14 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-03-01T09-00-00-{u1}.jsonl"));
|
||||
let head_1 = vec![
|
||||
serde_json::json!({"timestamp": "2025-03-01T09-00-00", "id": u1.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-03-01T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor3: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-01T09-00-00|{u1}\"")).unwrap();
|
||||
let expected_page3 = ConversationsPage {
|
||||
@@ -259,7 +305,7 @@ async fn test_pagination_cursor() {
|
||||
path: p1,
|
||||
head: head_1,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor3.clone()),
|
||||
next_cursor: Some(expected_cursor3),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02 (anchor), 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
@@ -287,15 +333,18 @@ async fn test_get_conversation_contents() {
|
||||
.join("04")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-04-01T10-30-00-{uuid}.jsonl"));
|
||||
let expected_head = vec![
|
||||
serde_json::json!({"timestamp": ts, "id": uuid.to_string()}),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
];
|
||||
let expected_head = vec![serde_json::json!({
|
||||
"id": uuid,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor: Cursor = serde_json::from_str(&format!("\"{ts}|{uuid}\"")).unwrap();
|
||||
let expected_page = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: expected_path.clone(),
|
||||
path: expected_path,
|
||||
head: expected_head,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor),
|
||||
@@ -305,10 +354,15 @@ async fn test_get_conversation_contents() {
|
||||
assert_eq!(page, expected_page);
|
||||
|
||||
// Entire file contents equality
|
||||
let meta = serde_json::json!({"timestamp": ts, "id": uuid.to_string()});
|
||||
let meta = serde_json::json!({"timestamp": ts, "type": "session_meta", "payload": {"id": uuid, "timestamp": ts, "instructions": null, "cwd": ".", "originator": "test_originator", "cli_version": "test_version"}});
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts,
|
||||
"type": "event_msg",
|
||||
"payload": {"type": "user_message", "message": "Hello from user", "kind": "plain"}
|
||||
});
|
||||
let rec0 = serde_json::json!({"record_type": "response", "index": 0});
|
||||
let rec1 = serde_json::json!({"record_type": "response", "index": 1});
|
||||
let expected_content = format!("{meta}\n{rec0}\n{rec1}\n");
|
||||
let expected_content = format!("{meta}\n{user_event}\n{rec0}\n{rec1}\n");
|
||||
assert_eq!(content, expected_content);
|
||||
}
|
||||
|
||||
@@ -341,7 +395,14 @@ async fn test_stable_ordering_same_second_pagination() {
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u2}.jsonl"));
|
||||
let head = |u: Uuid| -> Vec<serde_json::Value> {
|
||||
vec![serde_json::json!({"timestamp": ts, "id": u.to_string()})]
|
||||
vec![serde_json::json!({
|
||||
"id": u,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})]
|
||||
};
|
||||
let expected_cursor1: Cursor = serde_json::from_str(&format!("\"{ts}|{u2}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
@@ -376,7 +437,7 @@ async fn test_stable_ordering_same_second_pagination() {
|
||||
path: p1,
|
||||
head: head(u1),
|
||||
}],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
next_cursor: Some(expected_cursor2),
|
||||
num_scanned_files: 3, // scanned u3, u2 (anchor), u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
|
||||
@@ -293,7 +293,7 @@ mod tests {
|
||||
// With the parent dir explicitly added as a writable root, the
|
||||
// outside write should be permitted.
|
||||
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![parent.clone()],
|
||||
writable_roots: vec![parent],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -153,7 +153,7 @@ mod tests {
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
writable_roots: vec![root_with_git, root_without_git],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -69,3 +69,8 @@
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
; needed to look up user info, see https://crbug.com/792228
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.system.opendirectoryd.libinfo")
|
||||
)
|
||||
|
||||
@@ -9,6 +9,12 @@ pub struct ZshShell {
|
||||
zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct BashShell {
|
||||
shell_path: String,
|
||||
bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PowerShellConfig {
|
||||
exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
@@ -18,6 +24,7 @@ pub struct PowerShellConfig {
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub enum Shell {
|
||||
Zsh(ZshShell),
|
||||
Bash(BashShell),
|
||||
PowerShell(PowerShellConfig),
|
||||
Unknown,
|
||||
}
|
||||
@@ -26,22 +33,10 @@ impl Shell {
|
||||
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Shell::Zsh(zsh) => {
|
||||
if !std::path::Path::new(&zsh.zshrc_path).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = vec![zsh.shell_path.clone()];
|
||||
result.push("-lc".to_string());
|
||||
|
||||
let joined = strip_bash_lc(&command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
|
||||
|
||||
if let Some(joined) = joined {
|
||||
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
format_shell_invocation_with_rc(&command, &zsh.shell_path, &zsh.zshrc_path)
|
||||
}
|
||||
Shell::Bash(bash) => {
|
||||
format_shell_invocation_with_rc(&command, &bash.shell_path, &bash.bashrc_path)
|
||||
}
|
||||
Shell::PowerShell(ps) => {
|
||||
// If model generated a bash command, prefer a detected bash fallback
|
||||
@@ -97,12 +92,32 @@ impl Shell {
|
||||
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::PowerShell(ps) => Some(ps.exe.clone()),
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_shell_invocation_with_rc(
|
||||
command: &Vec<String>,
|
||||
shell_path: &str,
|
||||
rc_path: &str,
|
||||
) -> Option<Vec<String>> {
|
||||
let joined = strip_bash_lc(command)
|
||||
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok())?;
|
||||
|
||||
let rc_command = if std::path::Path::new(rc_path).exists() {
|
||||
format!("source {rc_path} && ({joined})")
|
||||
} else {
|
||||
joined
|
||||
};
|
||||
|
||||
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
|
||||
}
|
||||
|
||||
fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
match command.as_slice() {
|
||||
// exactly three items
|
||||
@@ -116,44 +131,43 @@ fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
use whoami;
|
||||
#[cfg(unix)]
|
||||
fn detect_default_user_shell() -> Shell {
|
||||
use libc::getpwuid;
|
||||
use libc::getuid;
|
||||
use std::ffi::CStr;
|
||||
|
||||
let user = whoami::username();
|
||||
let home = format!("/Users/{user}");
|
||||
let output = Command::new("dscl")
|
||||
.args([".", "-read", &home, "UserShell"])
|
||||
.output()
|
||||
.await
|
||||
.ok();
|
||||
match output {
|
||||
Some(o) => {
|
||||
if !o.status.success() {
|
||||
return Shell::Unknown;
|
||||
}
|
||||
let stdout = String::from_utf8_lossy(&o.stdout);
|
||||
for line in stdout.lines() {
|
||||
if let Some(shell_path) = line.strip_prefix("UserShell: ")
|
||||
&& shell_path.ends_with("/zsh")
|
||||
{
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc"),
|
||||
});
|
||||
}
|
||||
unsafe {
|
||||
let uid = getuid();
|
||||
let pw = getpwuid(uid);
|
||||
|
||||
if !pw.is_null() {
|
||||
let shell_path = CStr::from_ptr((*pw).pw_shell)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
|
||||
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path,
|
||||
zshrc_path: format!("{home_path}/.zshrc"),
|
||||
});
|
||||
}
|
||||
|
||||
Shell::Unknown
|
||||
if shell_path.ends_with("/bash") {
|
||||
return Shell::Bash(BashShell {
|
||||
shell_path,
|
||||
bashrc_path: format!("{home_path}/.bashrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => Shell::Unknown,
|
||||
}
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "macos"), not(target_os = "windows")))]
|
||||
#[cfg(unix)]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
detect_default_user_shell()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
@@ -196,8 +210,13 @@ pub async fn default_user_shell() -> Shell {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "windows"), not(unix)))]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
#[cfg(unix)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command;
|
||||
@@ -230,9 +249,124 @@ mod tests {
|
||||
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(actual_cmd, None);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/zsh".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bashrc_not_exists() {
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".to_string(),
|
||||
bashrc_path: "/does/not/exist/.bashrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(
|
||||
actual_cmd,
|
||||
Some(vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"myecho".to_string()
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bash_escaping_and_execution() {
|
||||
let shell_path = "/bin/bash";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source BASHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let bashrc_path = temp_home.path().join(".bashrc");
|
||||
std::fs::write(
|
||||
&bashrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Bash(BashShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
bashrc_path: bashrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: actual_cmd.unwrap(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout.text, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
mod macos_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
let shell_path = "/bin/zsh";
|
||||
@@ -298,10 +432,7 @@ mod tests {
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
/// Flat info parsed from the JWT in auth.json.
|
||||
@@ -22,36 +20,6 @@ pub struct TokenData {
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl TokenData {
|
||||
/// Returns true if this is a plan that should use the traditional
|
||||
/// "metered" billing via an API key.
|
||||
pub(crate) fn should_use_api_key(
|
||||
&self,
|
||||
preferred_auth_method: AuthMode,
|
||||
is_openai_email: bool,
|
||||
) -> bool {
|
||||
if preferred_auth_method == AuthMode::ApiKey {
|
||||
return true;
|
||||
}
|
||||
// If the email is an OpenAI email, use AuthMode::ChatGPT unless preferred_auth_method is AuthMode::ApiKey.
|
||||
if is_openai_email {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.id_token
|
||||
.chatgpt_plan_type
|
||||
.as_ref()
|
||||
.is_none_or(|plan| plan.is_plan_that_should_use_api_key())
|
||||
}
|
||||
|
||||
pub fn is_openai_email(&self) -> bool {
|
||||
self.id_token
|
||||
.email
|
||||
.as_deref()
|
||||
.is_some_and(|email| email.trim().to_ascii_lowercase().ends_with("@openai.com"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flat subset of useful claims in id_token from auth.json.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
@@ -79,28 +47,6 @@ pub(crate) enum PlanType {
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl PlanType {
|
||||
fn is_plan_that_should_use_api_key(&self) -> bool {
|
||||
match self {
|
||||
Self::Known(known) => {
|
||||
use KnownPlan::*;
|
||||
!matches!(known, Free | Plus | Pro | Team)
|
||||
}
|
||||
Self::Unknown(_) => {
|
||||
// Unknown plans should use the API key.
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_string(&self) -> String {
|
||||
match self {
|
||||
Self::Known(known) => format!("{known:?}").to_lowercase(),
|
||||
Self::Unknown(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum KnownPlan {
|
||||
|
||||
180
codex-rs/core/src/truncate.rs
Normal file
180
codex-rs/core/src/truncate.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||
//! and suffix on UTF-8 boundaries.
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1;
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1;
|
||||
}
|
||||
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
let mut guess_tokens = est_tokens;
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::truncate_middle;
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||
let max_bytes = 32;
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
assert!(out.starts_with("abc"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("XYZ*"));
|
||||
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
let max_bytes = 64;
|
||||
let (out, tokens) = truncate_middle(&s, max_bytes);
|
||||
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||
assert_eq!(tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_handles_utf8_content() {
|
||||
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||
let max_bytes = 32;
|
||||
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(!out.contains('\u{fffd}'));
|
||||
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -678,7 +678,7 @@ index {left_oid}..{right_oid}
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
src,
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
|
||||
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum UnifiedExecError {
|
||||
#[error("Failed to create unified exec session: {pty_error}")]
|
||||
CreateSession {
|
||||
#[source]
|
||||
pty_error: anyhow::Error,
|
||||
},
|
||||
#[error("Unknown session id {session_id}")]
|
||||
UnknownSessionId { session_id: i32 },
|
||||
#[error("failed to write to stdin")]
|
||||
WriteToStdin,
|
||||
#[error("missing command line for unified exec request")]
|
||||
MissingCommandLine,
|
||||
}
|
||||
|
||||
impl UnifiedExecError {
|
||||
pub(crate) fn create_session(error: anyhow::Error) -> Self {
|
||||
Self::CreateSession { pty_error: error }
|
||||
}
|
||||
}
|
||||
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
@@ -0,0 +1,633 @@
|
||||
use portable_pty::CommandBuilder;
|
||||
use portable_pty::PtySize;
|
||||
use portable_pty::native_pty_system;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI32;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::exec_command::ExecCommandSession;
|
||||
use crate::truncate::truncate_middle;
|
||||
|
||||
mod errors;
|
||||
|
||||
pub(crate) use errors::UnifiedExecError;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
const MAX_TIMEOUT_MS: u64 = 60_000;
|
||||
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecRequest<'a> {
|
||||
pub session_id: Option<i32>,
|
||||
pub input_chunks: &'a [String],
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct UnifiedExecResult {
|
||||
pub session_id: Option<i32>,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct UnifiedExecSessionManager {
|
||||
next_session_id: AtomicI32,
|
||||
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ManagedUnifiedExecSession {
|
||||
session: ExecCommandSession,
|
||||
output_buffer: OutputBuffer,
|
||||
/// Notifies waiters whenever new output has been appended to
|
||||
/// `output_buffer`, allowing clients to poll for fresh data.
|
||||
output_notify: Arc<Notify>,
|
||||
output_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct OutputBufferState {
|
||||
chunks: VecDeque<Vec<u8>>,
|
||||
total_bytes: usize,
|
||||
}
|
||||
|
||||
impl OutputBufferState {
|
||||
fn push_chunk(&mut self, chunk: Vec<u8>) {
|
||||
self.total_bytes = self.total_bytes.saturating_add(chunk.len());
|
||||
self.chunks.push_back(chunk);
|
||||
|
||||
let mut excess = self
|
||||
.total_bytes
|
||||
.saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
|
||||
while excess > 0 {
|
||||
match self.chunks.front_mut() {
|
||||
Some(front) if excess >= front.len() => {
|
||||
excess -= front.len();
|
||||
self.total_bytes = self.total_bytes.saturating_sub(front.len());
|
||||
self.chunks.pop_front();
|
||||
}
|
||||
Some(front) => {
|
||||
front.drain(..excess);
|
||||
self.total_bytes = self.total_bytes.saturating_sub(excess);
|
||||
break;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn drain(&mut self) -> Vec<Vec<u8>> {
|
||||
let drained: Vec<Vec<u8>> = self.chunks.drain(..).collect();
|
||||
self.total_bytes = 0;
|
||||
drained
|
||||
}
|
||||
}
|
||||
|
||||
type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||
|
||||
impl ManagedUnifiedExecSession {
|
||||
fn new(session: ExecCommandSession) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let mut receiver = session.output_receiver();
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
output_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.session.writer_sender()
|
||||
}
|
||||
|
||||
fn output_handles(&self) -> OutputHandles {
|
||||
(
|
||||
Arc::clone(&self.output_buffer),
|
||||
Arc::clone(&self.output_notify),
|
||||
)
|
||||
}
|
||||
|
||||
fn has_exited(&self) -> bool {
|
||||
self.session.has_exited()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ManagedUnifiedExecSession {
|
||||
fn drop(&mut self) {
|
||||
self.output_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedExecSessionManager {
|
||||
pub async fn handle_request(
|
||||
&self,
|
||||
request: UnifiedExecRequest<'_>,
|
||||
) -> Result<UnifiedExecResult, UnifiedExecError> {
|
||||
let (timeout_ms, timeout_warning) = match request.timeout_ms {
|
||||
Some(requested) if requested > MAX_TIMEOUT_MS => (
|
||||
MAX_TIMEOUT_MS,
|
||||
Some(format!(
|
||||
"Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n"
|
||||
)),
|
||||
),
|
||||
Some(requested) => (requested, None),
|
||||
None => (DEFAULT_TIMEOUT_MS, None),
|
||||
};
|
||||
|
||||
let mut new_session: Option<ManagedUnifiedExecSession> = None;
|
||||
let session_id;
|
||||
let writer_tx;
|
||||
let output_buffer;
|
||||
let output_notify;
|
||||
|
||||
if let Some(existing_id) = request.session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
match sessions.get(&existing_id) {
|
||||
Some(session) => {
|
||||
if session.has_exited() {
|
||||
sessions.remove(&existing_id);
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
let (buffer, notify) = session.output_handles();
|
||||
session_id = existing_id;
|
||||
writer_tx = session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
}
|
||||
None => {
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
drop(sessions);
|
||||
} else {
|
||||
let command = request.input_chunks.to_vec();
|
||||
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||
let session = create_unified_exec_session(&command).await?;
|
||||
let managed_session = ManagedUnifiedExecSession::new(session);
|
||||
let (buffer, notify) = managed_session.output_handles();
|
||||
writer_tx = managed_session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
session_id = new_id;
|
||||
new_session = Some(managed_session);
|
||||
};
|
||||
|
||||
if request.session_id.is_some() {
|
||||
let joined_input = request.input_chunks.join(" ");
|
||||
if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err()
|
||||
{
|
||||
return Err(UnifiedExecError::WriteToStdin);
|
||||
}
|
||||
}
|
||||
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let start = Instant::now();
|
||||
let deadline = start + Duration::from_millis(timeout_ms);
|
||||
|
||||
loop {
|
||||
let drained_chunks;
|
||||
let mut wait_for_output = None;
|
||||
{
|
||||
let mut guard = output_buffer.lock().await;
|
||||
drained_chunks = guard.drain();
|
||||
if drained_chunks.is_empty() {
|
||||
wait_for_output = Some(output_notify.notified());
|
||||
}
|
||||
}
|
||||
|
||||
if drained_chunks.is_empty() {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining == Duration::ZERO {
|
||||
break;
|
||||
}
|
||||
|
||||
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||
tokio::pin!(notified);
|
||||
tokio::select! {
|
||||
_ = &mut notified => {}
|
||||
_ = tokio::time::sleep(remaining) => break,
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for chunk in drained_chunks {
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (output, _maybe_tokens) = truncate_middle(
|
||||
&String::from_utf8_lossy(&collected),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
let output = if let Some(warning) = timeout_warning {
|
||||
format!("{warning}{output}")
|
||||
} else {
|
||||
output
|
||||
};
|
||||
|
||||
let should_store_session = if let Some(session) = new_session.as_ref() {
|
||||
!session.has_exited()
|
||||
} else if request.session_id.is_some() {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
if existing.has_exited() {
|
||||
sessions.remove(&session_id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_store_session {
|
||||
if let Some(session) = new_session {
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
}
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: Some(session_id),
|
||||
output,
|
||||
})
|
||||
} else {
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: None,
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_unified_exec_session(
|
||||
command: &[String],
|
||||
) -> Result<ExecCommandSession, UnifiedExecError> {
|
||||
if command.is_empty() {
|
||||
return Err(UnifiedExecError::MissingCommandLine);
|
||||
}
|
||||
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
let pair = pty_system
|
||||
.openpty(PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
|
||||
// Safe thanks to the check at the top of the function.
|
||||
let mut command_builder = CommandBuilder::new(command[0].clone());
|
||||
for arg in &command[1..] {
|
||||
command_builder.arg(arg);
|
||||
}
|
||||
|
||||
let mut child = pair
|
||||
.slave
|
||||
.spawn_command(command_builder)
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let killer = child.clone_killer();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||
|
||||
let mut reader = pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||
}
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
continue;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer = pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let writer = Arc::new(StdMutex::new(writer));
|
||||
let writer_handle = tokio::spawn({
|
||||
let writer = writer.clone();
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let writer = writer.clone();
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Ok(mut guard) = writer.lock() {
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = Arc::clone(&exit_status);
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let _ = child.wait();
|
||||
wait_exit_status.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
Ok(ExecCommandSession::new(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer,
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn push_chunk_trims_only_excess_bytes() {
|
||||
let mut buffer = OutputBufferState::default();
|
||||
buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]);
|
||||
buffer.push_chunk(vec![b'b']);
|
||||
buffer.push_chunk(vec![b'c']);
|
||||
|
||||
assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
assert_eq!(buffer.chunks.len(), 3);
|
||||
assert_eq!(
|
||||
buffer.chunks.front().unwrap().len(),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2
|
||||
);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session_id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_2.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let shell_a = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_a = shell_a.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &[
|
||||
"echo".to_string(),
|
||||
"$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(10),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(7)).await;
|
||||
|
||||
let empty = Vec::new();
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &empty,
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(120_000),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.output.starts_with(
|
||||
"Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n"
|
||||
));
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.session_id.is_none());
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
assert!(manager.sessions.lock().await.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["exit\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let err = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[],
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await
|
||||
.expect_err("expected unknown session error");
|
||||
|
||||
match err {
|
||||
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
|
||||
assert_eq!(err_id, session_id);
|
||||
}
|
||||
other => panic!("expected UnknownSessionId, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!manager.sessions.lock().await.contains_key(&session_id));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
42
codex-rs/core/src/user_instructions.rs
Normal file
42
codex-rs/core/src/user_instructions.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::USER_INSTRUCTIONS_CLOSE_TAG;
|
||||
use codex_protocol::protocol::USER_INSTRUCTIONS_OPEN_TAG;
|
||||
|
||||
/// Wraps user instructions in a tag so the model can classify them easily.
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename = "user_instructions", rename_all = "snake_case")]
|
||||
pub(crate) struct UserInstructions {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl UserInstructions {
|
||||
pub fn new<T: Into<String>>(text: T) -> Self {
|
||||
Self { text: text.into() }
|
||||
}
|
||||
|
||||
/// Serializes the user instructions to an XML-like tagged block that starts
|
||||
/// with <user_instructions> so clients can classify it.
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
format!(
|
||||
"{USER_INSTRUCTIONS_OPEN_TAG}\n\n{}\n\n{USER_INSTRUCTIONS_CLOSE_TAG}",
|
||||
self.text
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserInstructions> for ResponseItem {
|
||||
fn from(ui: UserInstructions) -> Self {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: ui.serialize_to_xml(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,11 @@ use codex_core::ReasoningItemContent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use futures::StreamExt;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -76,7 +76,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
Uuid::new_v4(),
|
||||
ConversationId::new(),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
|
||||
@@ -8,10 +8,10 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use futures::StreamExt;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -69,7 +69,7 @@ async fn run_stream(sse_body: &str) -> Vec<ResponseEvent> {
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
Uuid::new_v4(),
|
||||
ConversationId::new(),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::protocol::GitInfo;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
@@ -77,6 +79,22 @@ async fn chat_mode_stream_cli() {
|
||||
assert_eq!(hi_lines, 1, "Expected exactly one line with 'hi'");
|
||||
|
||||
server.verify().await;
|
||||
|
||||
// Verify a new session rollout was created and is discoverable via list_conversations
|
||||
let page = RolloutRecorder::list_conversations(home.path(), 10, None)
|
||||
.await
|
||||
.expect("list conversations");
|
||||
assert!(
|
||||
!page.items.is_empty(),
|
||||
"expected at least one session to be listed"
|
||||
);
|
||||
// First line of head must be the SessionMeta payload (id/timestamp)
|
||||
let head0 = page.items[0].head.first().expect("missing head record");
|
||||
assert!(head0.get("id").is_some(), "head[0] missing id");
|
||||
assert!(
|
||||
head0.get("timestamp").is_some(),
|
||||
"head[0] missing timestamp"
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that passing `-c experimental_instructions_file=...` to the CLI
|
||||
@@ -297,8 +315,10 @@ async fn integration_creates_and_checks_session_file() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = item.get("content")
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = payload.get("content")
|
||||
&& c.to_string().contains(&marker)
|
||||
{
|
||||
matching_path = Some(path.to_path_buf());
|
||||
@@ -361,9 +381,16 @@ async fn integration_creates_and_checks_session_file() {
|
||||
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||
assert_eq!(
|
||||
meta.get("type").and_then(|v| v.as_str()),
|
||||
Some("session_meta")
|
||||
);
|
||||
let payload = meta
|
||||
.get("payload")
|
||||
.unwrap_or_else(|| panic!("Missing payload in meta line"));
|
||||
assert!(payload.get("id").is_some(), "SessionMeta missing id");
|
||||
assert!(
|
||||
meta.get("timestamp").is_some(),
|
||||
payload.get("timestamp").is_some(),
|
||||
"SessionMeta missing timestamp"
|
||||
);
|
||||
|
||||
@@ -375,8 +402,10 @@ async fn integration_creates_and_checks_session_file() {
|
||||
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = item.get("content")
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = payload.get("content")
|
||||
&& c.to_string().contains(&marker)
|
||||
{
|
||||
found_message = true;
|
||||
@@ -388,8 +417,7 @@ async fn integration_creates_and_checks_session_file() {
|
||||
"No message found in session file containing the marker"
|
||||
);
|
||||
|
||||
// Second run: resume should create a NEW session file that contains both old and new history.
|
||||
let orig_len = content.lines().count();
|
||||
// Second run: resume should update the existing file.
|
||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||
@@ -449,8 +477,8 @@ async fn integration_creates_and_checks_session_file() {
|
||||
}
|
||||
|
||||
let resumed_path = resumed_path.expect("No resumed session file found containing the marker2");
|
||||
// Resume should have written to a new file, not the original one.
|
||||
assert_ne!(
|
||||
// Resume should write to the existing log file.
|
||||
assert_eq!(
|
||||
resumed_path, path,
|
||||
"resume should create a new session file"
|
||||
);
|
||||
@@ -464,14 +492,6 @@ async fn integration_creates_and_checks_session_file() {
|
||||
resumed_content.contains(&marker2),
|
||||
"resumed file missing resumed marker"
|
||||
);
|
||||
|
||||
// Original file should remain unchanged.
|
||||
let content_after = std::fs::read_to_string(&path).unwrap();
|
||||
assert_eq!(
|
||||
content_after.lines().count(),
|
||||
orig_len,
|
||||
"original rollout file should not change on resume"
|
||||
);
|
||||
}
|
||||
|
||||
/// Integration test to verify git info is collected and recorded in session files.
|
||||
@@ -598,7 +618,7 @@ async fn integration_git_info_unit_test() {
|
||||
|
||||
// 5. Test serialization to ensure it works in SessionMeta
|
||||
let serialized = serde_json::to_string(&git_info).unwrap();
|
||||
let deserialized: codex_core::git_info::GitInfo = serde_json::from_str(&serialized).unwrap();
|
||||
let deserialized: GitInfo = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(git_info.commit_hash, deserialized.commit_hash);
|
||||
assert_eq!(git_info.branch, deserialized.branch);
|
||||
|
||||
@@ -4,16 +4,19 @@ use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::project_doc::get_user_instructions;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -96,7 +99,7 @@ fn write_auth_json(
|
||||
"OPENAI_API_KEY": openai_api_key,
|
||||
"tokens": tokens,
|
||||
// RFC3339 datetime; value doesn't matter for these tests
|
||||
"last_refresh": "2025-08-06T20:41:36.232376Z",
|
||||
"last_refresh": chrono::Utc::now(),
|
||||
});
|
||||
|
||||
std::fs::write(
|
||||
@@ -109,7 +112,204 @@ fn write_auth_json(
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a fake rollout session file with prior user + system + assistant messages.
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
let session_path = tmpdir.path().join("resume-session.jsonl");
|
||||
let mut f = std::fs::File::create(&session_path).unwrap();
|
||||
let convo_id = Uuid::new_v4();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:00.000Z",
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": convo_id,
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"instructions": "be nice",
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
}
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: user message (should be delivered)
|
||||
let prior_user = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::InputText {
|
||||
text: "resumed user message".to_string(),
|
||||
}],
|
||||
};
|
||||
let prior_user_json = serde_json::to_value(&prior_user).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:01.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_user_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: system message (excluded from API history)
|
||||
let prior_system = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "system".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: "resumed system instruction".to_string(),
|
||||
}],
|
||||
};
|
||||
let prior_system_json = serde_json::to_value(&prior_system).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:02.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_system_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: assistant message
|
||||
let prior_item = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: "resumed assistant message".to_string(),
|
||||
}],
|
||||
};
|
||||
let prior_item_json = serde_json::to_value(&prior_item).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:03.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_item_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
drop(f);
|
||||
|
||||
// Mock server that will receive the resumed request
|
||||
let server = MockServer::start().await;
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure Codex to resume from our file
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
let cwd = TempDir::new().unwrap();
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.experimental_resume = Some(session_path.clone());
|
||||
// Also configure user instructions to ensure they are NOT delivered on resume.
|
||||
config.user_instructions = Some("be nice".to_string());
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
// 1) Assert initial_messages only includes existing EventMsg entries; response items are not converted
|
||||
let initial_msgs = session_configured
|
||||
.initial_messages
|
||||
.clone()
|
||||
.expect("expected initial messages option for resumed session");
|
||||
let initial_json = serde_json::to_value(&initial_msgs).unwrap();
|
||||
let expected_initial_json = json!([]);
|
||||
assert_eq!(initial_json, expected_initial_json);
|
||||
|
||||
// 2) Submit new input; the request body must include the prior item followed by the new user input.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
// Build expected environment context for this turn.
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg_turn = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "resumed user message" }]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "output_text", "text": "resumed assistant message" }]
|
||||
},
|
||||
expected_env_msg_turn,
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "hello" }]
|
||||
}
|
||||
]);
|
||||
|
||||
assert_eq!(request_body["input"], expected_input);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
@@ -166,12 +366,12 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_conversation_id = request.headers.get("conversation_id").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request_session_id.to_str().unwrap(),
|
||||
request_conversation_id.to_str().unwrap(),
|
||||
conversation_id.to_string()
|
||||
);
|
||||
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||
@@ -238,56 +438,6 @@ async fn includes_base_instructions_override_in_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn originator_config_override_is_used() {
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.responses_originator_header = "my_override".to_owned();
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
assert_eq!(request_originator.to_str().unwrap(), "my_override");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chatgpt_auth_sends_correct_request() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
@@ -344,14 +494,14 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_conversation_id = request.headers.get("conversation_id").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request_session_id.to_str().unwrap(),
|
||||
request_conversation_id.to_str().unwrap(),
|
||||
conversation_id.to_string()
|
||||
);
|
||||
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||
@@ -360,7 +510,6 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
"Bearer Access Token"
|
||||
);
|
||||
assert_eq!(request_chatgpt_account_id.to_str().unwrap(), "account_id");
|
||||
assert!(!request_body["store"].as_bool().unwrap());
|
||||
assert!(request_body["stream"].as_bool().unwrap());
|
||||
assert_eq!(
|
||||
request_body["include"][0].as_str().unwrap(),
|
||||
@@ -368,90 +517,6 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_chatgpt_token_when_config_prefers_chatgpt() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
// Expect ChatGPT base path and correct headers
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header_regex("Authorization", r"Bearer Access-123"))
|
||||
.and(header_regex("chatgpt-account-id", r"acc-123"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
// Write auth.json that contains both API key and ChatGPT tokens for a plan that should prefer ChatGPT.
|
||||
let _jwt = write_auth_json(
|
||||
&codex_home,
|
||||
Some("sk-test-key"),
|
||||
"pro",
|
||||
"Access-123",
|
||||
Some("acc-123"),
|
||||
);
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ChatGPT;
|
||||
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// verify request body flags
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
assert!(
|
||||
!request_body["store"].as_bool().unwrap(),
|
||||
"store should be false for ChatGPT auth"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
@@ -496,13 +561,8 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ApiKey;
|
||||
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
let auth_manager = match CodexAuth::from_codex_home(codex_home.path()) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
@@ -526,14 +586,6 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// verify request body flags
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
assert!(
|
||||
request_body["store"].as_bool().unwrap(),
|
||||
"store should be true for API key auth"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -815,7 +867,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.new_conversation(config.clone())
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
@@ -850,39 +902,49 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
|
||||
|
||||
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
|
||||
let r3_tail_expected = serde_json::json!([
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U1"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U2"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"id": null,
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U3"}]
|
||||
}
|
||||
]);
|
||||
// Build expected environment context dynamically to avoid OS-dependent flakiness.
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
std::env::current_dir().unwrap().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
});
|
||||
// Wrap user instructions in the XML container to match the raw/ingest view
|
||||
let expected_ui_text = format!(
|
||||
"<user_instructions>\n\n{}\n\n</user_instructions>",
|
||||
user_instructions.clone().unwrap()
|
||||
);
|
||||
let expected_ui_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
|
||||
let expected_full = json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U1"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U2"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg,
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U3"}]}]);
|
||||
|
||||
let r3_input_array = requests[2]
|
||||
.body_json::<serde_json::Value>()
|
||||
@@ -891,12 +953,6 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.expect("r3 missing input array");
|
||||
// skipping earlier context and developer messages
|
||||
let tail_len = r3_tail_expected.as_array().unwrap().len();
|
||||
let actual_tail = &r3_input_array[r3_input_array.len() - tail_len..];
|
||||
assert_eq!(
|
||||
serde_json::Value::Array(actual_tail.to_vec()),
|
||||
r3_tail_expected,
|
||||
"request 3 tail mismatch",
|
||||
);
|
||||
|
||||
assert_eq!(json!(r3_input_array), expected_full);
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -142,11 +145,12 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager.new_conversation(config).await.unwrap();
|
||||
let rollout_path = session_configured.rollout_path;
|
||||
|
||||
// 1) Normal user input – should hit server once.
|
||||
codex
|
||||
@@ -248,4 +252,47 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
|
||||
// Shut down Codex to flush rollout entries before inspecting the file.
|
||||
codex.submit(Op::Shutdown).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
// Verify rollout contains APITurn entries for each API call and a Compacted entry.
|
||||
let text = std::fs::read_to_string(&rollout_path).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"failed to read rollout file {}: {e}",
|
||||
rollout_path.display()
|
||||
)
|
||||
});
|
||||
let mut api_turn_count = 0usize;
|
||||
let mut saw_compacted_summary = false;
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(entry): Result<RolloutLine, _> = serde_json::from_str(trimmed) else {
|
||||
continue;
|
||||
};
|
||||
match entry.item {
|
||||
RolloutItem::TurnContext(_) => {
|
||||
api_turn_count += 1;
|
||||
}
|
||||
RolloutItem::Compacted(ci) => {
|
||||
if ci.message == SUMMARY_TEXT {
|
||||
saw_compacted_summary = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
api_turn_count == 3,
|
||||
"expected three APITurn entries in rollout"
|
||||
);
|
||||
assert!(
|
||||
saw_compacted_summary,
|
||||
"expected a Compacted entry containing the summarizer output"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ConversationHistoryResponseEvent;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
@@ -71,84 +75,121 @@ async fn fork_conversation_twice_drops_to_first_message() {
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Request history from the base conversation.
|
||||
codex.submit(Op::GetHistory).await.unwrap();
|
||||
// Request history from the base conversation to obtain rollout path.
|
||||
codex.submit(Op::GetPath).await.unwrap();
|
||||
let base_history =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
|
||||
|
||||
// Capture entries from the base history and compute expected prefixes after each fork.
|
||||
let entries_after_three = match &base_history {
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
|
||||
entries.clone()
|
||||
}
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationPath(_))).await;
|
||||
let base_path = match &base_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event"),
|
||||
};
|
||||
// History layout for this test:
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
// [3] "second" user message,
|
||||
// [4] "third" user message.
|
||||
|
||||
// Fork 1: drops the last user message and everything after.
|
||||
let expected_after_first = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
entries_after_three[3].clone(),
|
||||
];
|
||||
// GetHistory flushes before returning the path; no wait needed.
|
||||
|
||||
// Fork 2: drops the last user message and everything after.
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
let expected_after_second = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
];
|
||||
// Helper: read rollout items (excluding SessionMeta) from a JSONL path.
|
||||
let read_items = |p: &std::path::Path| -> Vec<RolloutItem> {
|
||||
let text = std::fs::read_to_string(p).expect("read rollout file");
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value = serde_json::from_str(line).expect("jsonl line");
|
||||
let rl: RolloutLine = serde_json::from_value(v).expect("rollout line");
|
||||
match rl.item {
|
||||
RolloutItem::SessionMeta(_) => {}
|
||||
other => items.push(other),
|
||||
}
|
||||
}
|
||||
items
|
||||
};
|
||||
|
||||
// Fork once with n=1 → drops the last user message and everything after.
|
||||
// Compute expected prefixes after each fork by truncating base rollout at nth-from-last user input.
|
||||
let base_items = read_items(&base_path);
|
||||
let find_user_input_positions = |items: &[RolloutItem]| -> Vec<usize> {
|
||||
let mut pos = Vec::new();
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = it
|
||||
&& role == "user"
|
||||
{
|
||||
// Consider any user message as an input boundary; recorder stores both EventMsg and ResponseItem.
|
||||
// We specifically look for input items, which are represented as ContentItem::InputText.
|
||||
if content
|
||||
.iter()
|
||||
.any(|c| matches!(c, ContentItem::InputText { .. }))
|
||||
{
|
||||
pos.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
pos
|
||||
};
|
||||
let user_inputs = find_user_input_positions(&base_items);
|
||||
|
||||
// After dropping last user input (n=1), cut strictly before that input if present, else empty.
|
||||
let cut1 = user_inputs
|
||||
.get(user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_first: Vec<RolloutItem> = base_items[..cut1].to_vec();
|
||||
|
||||
// After dropping again (n=1 on fork1), compute expected relative to fork1's rollout.
|
||||
|
||||
// Fork once with n=1 → drops the last user input and everything after.
|
||||
let NewConversation {
|
||||
conversation: codex_fork1,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(entries_after_three.clone(), 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), base_path.clone())
|
||||
.await
|
||||
.expect("fork 1");
|
||||
|
||||
codex_fork1.submit(Op::GetHistory).await.unwrap();
|
||||
codex_fork1.submit(Op::GetPath).await.unwrap();
|
||||
let fork1_history = wait_for_event(&codex_fork1, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
let entries_after_first_fork = match &fork1_history {
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
|
||||
assert!(matches!(
|
||||
fork1_history,
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_first
|
||||
));
|
||||
entries.clone()
|
||||
}
|
||||
let fork1_path = match &fork1_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after first fork"),
|
||||
};
|
||||
|
||||
// GetHistory on fork1 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork1_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_first).unwrap()
|
||||
);
|
||||
|
||||
// Fork again with n=1 → drops the (new) last user message, leaving only the first.
|
||||
let NewConversation {
|
||||
conversation: codex_fork2,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(entries_after_first_fork.clone(), 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), fork1_path.clone())
|
||||
.await
|
||||
.expect("fork 2");
|
||||
|
||||
codex_fork2.submit(Op::GetHistory).await.unwrap();
|
||||
codex_fork2.submit(Op::GetPath).await.unwrap();
|
||||
let fork2_history = wait_for_event(&codex_fork2, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
assert!(matches!(
|
||||
fork2_history,
|
||||
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_second
|
||||
));
|
||||
let fork2_path = match &fork2_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after second fork"),
|
||||
};
|
||||
// GetHistory on fork2 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
let fork1_user_inputs = find_user_input_positions(&fork1_items);
|
||||
let cut_last_on_fork1 = fork1_user_inputs
|
||||
.get(fork1_user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_second: Vec<RolloutItem> = fork1_items[..cut_last_on_fork1].to_vec();
|
||||
let fork2_items = read_items(&fork2_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork2_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_second).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod prompt_caching;
|
||||
mod seatbelt;
|
||||
mod stream_error_allows_next_turn;
|
||||
|
||||
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
const CONFIG_TOML: &str = "config.toml";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_persist_when_config_exists() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
let initial_contents = "model = \"gpt-4o\"\n";
|
||||
tokio::fs::write(&config_path, initial_contents)
|
||||
.await
|
||||
.expect("seed config.toml");
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model = "gpt-4o".to_string();
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
let contents = tokio::fs::read_to_string(&config_path)
|
||||
.await
|
||||
.expect("read config.toml after override");
|
||||
assert_eq!(contents, initial_contents);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_create_config_file() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"test setup should start without config"
|
||||
);
|
||||
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"override should not create config.toml"
|
||||
);
|
||||
}
|
||||
@@ -270,8 +270,13 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let expected_env_text = format!(
|
||||
// Per-turn environment context includes the shell tag.
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
@@ -279,42 +284,53 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
}
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_ui_text =
|
||||
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
|
||||
|
||||
let expected_env_msg = serde_json::json!({
|
||||
let expected_env_msg_turn = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
let expected_ui_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
|
||||
let expected_user_message_1 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 1" } ]
|
||||
});
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body1_input = body1["input"].as_array().unwrap();
|
||||
assert_eq!(
|
||||
body1["input"],
|
||||
serde_json::json!([expected_ui_msg, expected_env_msg, expected_user_message_1])
|
||||
serde_json::json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg_turn,
|
||||
expected_user_message_1
|
||||
])
|
||||
);
|
||||
|
||||
let env_texts: Vec<&str> = body1_input
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
msg.get("content")
|
||||
.and_then(|content| content.as_array())
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|item| item.get("text"))
|
||||
.and_then(|text| text.as_str())
|
||||
})
|
||||
.filter(|text| text.starts_with("<environment_context>"))
|
||||
.collect();
|
||||
assert_eq!(env_texts, vec![expected_env_text_turn.as_str()]);
|
||||
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
@@ -322,7 +338,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_turn, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
@@ -424,21 +440,34 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
// as the prefix of the second request, ensuring cache hit potential.
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
// After overriding the turn context, the environment context should be emitted again
|
||||
// reflecting the new approval policy and sandbox settings. Omit cwd because it did
|
||||
// not change.
|
||||
let expected_env_text_2 = r#"<environment_context>
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
@@ -543,16 +572,168 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
// as the prefix of the second request.
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": serde_json::Value::Null,
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
new_cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_2, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
assert_eq!(body2["input"], expected_body2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tools_stable_across_all_approval_policy_transitions() {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed("resp");
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
// Build all transitions FROM each to each other (exclude self transitions)
|
||||
let policies = vec![
|
||||
AskForApproval::UnlessTrusted,
|
||||
AskForApproval::OnFailure,
|
||||
AskForApproval::OnRequest,
|
||||
AskForApproval::Never,
|
||||
];
|
||||
let mut transitions: Vec<(AskForApproval, AskForApproval)> = Vec::new();
|
||||
for &from in &policies {
|
||||
for &to in &policies {
|
||||
if from != to {
|
||||
transitions.push((from, to));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expect 2 POSTs per transition
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect((transitions.len() * 2) as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
// Keep tools stable and minimal
|
||||
config.include_plan_tool = false;
|
||||
config.include_apply_patch_tool = false;
|
||||
config.tools_web_search_request = false;
|
||||
config.use_experimental_unified_exec_tool = true; // policy-independent tool
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
for (i, (from, to)) in transitions.iter().enumerate() {
|
||||
// Ensure a known starting policy for this pair
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*from),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-a"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Override to the target policy and send next turn
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*to),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-b"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Verify tool arrays are identical across each pair of requests
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
transitions.len() * 2,
|
||||
"expected 2 requests per transition"
|
||||
);
|
||||
|
||||
for i in 0..transitions.len() {
|
||||
let body_a = requests[2 * i].body_json::<serde_json::Value>().unwrap();
|
||||
let body_b = requests[2 * i + 1]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
body_a["tools"], body_b["tools"],
|
||||
"tools changed between requests for transition #{i}: {:?}",
|
||||
transitions[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +159,41 @@ async fn read_only_forbids_all_writes() {
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Verify that user lookups via `pwd.getpwuid(os.getuid())` work under the
|
||||
/// seatbelt sandbox. Prior to allowing the necessary mach‑lookup for
|
||||
/// OpenDirectory libinfo, this would fail with `KeyError: getpwuid(): uid not found`.
|
||||
#[tokio::test]
|
||||
async fn python_getpwuid_works_under_seatbelt() {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ReadOnly is sufficient here since we are only exercising user lookup.
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
|
||||
let mut child = spawn_command_under_seatbelt(
|
||||
vec![
|
||||
"python3".to_string(),
|
||||
"-c".to_string(),
|
||||
// Print the passwd struct; success implies lookup worked.
|
||||
"import pwd, os; print(pwd.getpwuid(os.getuid()))".to_string(),
|
||||
],
|
||||
&policy,
|
||||
std::env::current_dir().expect("should be able to get current dir"),
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
HashMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("should be able to spawn python under seatbelt");
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
.await
|
||||
.expect("should be able to wait for child process");
|
||||
assert!(status.success(), "python exited with {status:?}");
|
||||
}
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
fn create_test_scenario(tmp: &TempDir) -> TestScenario {
|
||||
let repo_parent = tmp.path().to_path_buf();
|
||||
|
||||
@@ -25,7 +25,6 @@ codex-common = { path = "../common", features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-ollama = { path = "../ollama" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
owo-colors = "4.2.0"
|
||||
|
||||
@@ -26,6 +26,7 @@ use codex_core::protocol::TurnAbortReason;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use codex_core::protocol::WebSearchBeginEvent;
|
||||
use codex_core::protocol::WebSearchEndEvent;
|
||||
use codex_protocol::num_format::format_with_separators;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
@@ -189,8 +190,14 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
return CodexStatus::InitiateShutdown;
|
||||
}
|
||||
EventMsg::TokenCount(token_usage) => {
|
||||
ts_println!(self, "tokens used: {}", token_usage.blended_total());
|
||||
EventMsg::TokenCount(ev) => {
|
||||
if let Some(usage_info) = ev.info {
|
||||
ts_println!(
|
||||
self,
|
||||
"tokens used: {}",
|
||||
format_with_separators(usage_info.total_token_usage.blended_total())
|
||||
);
|
||||
}
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if !self.answer_started {
|
||||
@@ -273,7 +280,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
parsed_cmd: _,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
},
|
||||
@@ -375,7 +382,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
@@ -511,17 +518,20 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
initial_messages: _,
|
||||
rollout_path: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
conversation_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
@@ -550,7 +560,8 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
},
|
||||
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
|
||||
EventMsg::ConversationHistory(_) => {}
|
||||
EventMsg::ConversationPath(_) => {}
|
||||
EventMsg::UserMessage(_) => {}
|
||||
}
|
||||
CodexStatus::Running
|
||||
}
|
||||
|
||||
@@ -149,7 +149,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
include_plan_tool: None,
|
||||
include_apply_patch_tool: None,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: oss.then_some(true),
|
||||
show_raw_agent_reasoning: oss.then_some(true),
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -188,11 +187,8 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn run_e2e_exec_test(cwd: &Path, response_streams: Vec<String>)
|
||||
.context("should find binary for codex-exec")
|
||||
.expect("should find binary for codex-exec")
|
||||
.current_dir(cwd.clone())
|
||||
.env("CODEX_HOME", cwd.clone())
|
||||
.env("CODEX_HOME", cwd)
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("OPENAI_BASE_URL", format!("{uri}/v1"))
|
||||
.arg("--skip-git-repo-check")
|
||||
|
||||
@@ -88,7 +88,7 @@ impl ExecvChecker {
|
||||
let mut program = valid_exec.program.to_string();
|
||||
for system_path in valid_exec.system_path {
|
||||
if is_executable_file(&system_path) {
|
||||
program = system_path.to_string();
|
||||
program = system_path;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -196,7 +196,7 @@ system_path=[{fake_cp:?}]
|
||||
let checker = setup(&fake_cp);
|
||||
let exec_call = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![source.clone(), dest.clone()],
|
||||
args: vec![source, dest.clone()],
|
||||
};
|
||||
let valid_exec = match checker.r#match(&exec_call)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -207,7 +207,7 @@ system_path=[{fake_cp:?}]
|
||||
assert_eq!(
|
||||
checker.check(valid_exec.clone(), &cwd, &[], &[]),
|
||||
Err(ReadablePathNotInReadableFolders {
|
||||
file: source_path.clone(),
|
||||
file: source_path,
|
||||
folders: vec![]
|
||||
}),
|
||||
);
|
||||
@@ -229,7 +229,7 @@ system_path=[{fake_cp:?}]
|
||||
// Both readable and writeable folders specified.
|
||||
assert_eq!(
|
||||
checker.check(
|
||||
valid_exec.clone(),
|
||||
valid_exec,
|
||||
&cwd,
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
@@ -241,7 +241,7 @@ system_path=[{fake_cp:?}]
|
||||
// folders.
|
||||
let exec_call_folders_as_args = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![root.clone(), root.clone()],
|
||||
args: vec![root.clone(), root],
|
||||
};
|
||||
let valid_exec_call_folders_as_args = match checker.r#match(&exec_call_folders_as_args)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -254,7 +254,7 @@ system_path=[{fake_cp:?}]
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
),
|
||||
Ok(cp.clone()),
|
||||
Ok(cp),
|
||||
);
|
||||
|
||||
// Specify a parent of a readable folder as input.
|
||||
|
||||
@@ -104,7 +104,7 @@ impl PolicyBuilder {
|
||||
info!("adding program spec: {program_spec:?}");
|
||||
let name = program_spec.program.clone();
|
||||
let mut programs = self.programs.borrow_mut();
|
||||
programs.insert(name.clone(), program_spec);
|
||||
programs.insert(name, program_spec);
|
||||
}
|
||||
|
||||
fn add_forbidden_substrings(&self, substrings: &[String]) {
|
||||
|
||||
@@ -30,3 +30,14 @@ fix *args:
|
||||
install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
# Run `cargo nextest` since it's faster than `cargo test`, though including
|
||||
# --no-fail-fast is important to ensure all tests are run.
|
||||
#
|
||||
# Run `cargo install cargo-nextest` if you don't have it installed.
|
||||
test:
|
||||
cargo nextest run --no-fail-fast
|
||||
|
||||
# Run the MCP server
|
||||
mcp-server-run *args:
|
||||
cargo run -p codex-mcp-server -- "$@"
|
||||
|
||||
@@ -15,9 +15,7 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
landlock = "0.4.1"
|
||||
libc = "0.2.175"
|
||||
|
||||
@@ -17,7 +17,6 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
tempfile = "3"
|
||||
thiserror = "2.0.16"
|
||||
tiny_http = "0.12"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
@@ -31,5 +30,4 @@ urlencoding = "2.1"
|
||||
webbrowser = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::pkce::PkceCodes;
|
||||
use crate::pkce::generate_pkce;
|
||||
@@ -11,6 +16,7 @@ use base64::Engine;
|
||||
use chrono::Utc;
|
||||
use codex_core::auth::AuthDotJson;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::default_client::ORIGINATOR;
|
||||
use codex_core::token_data::TokenData;
|
||||
use codex_core::token_data::parse_id_token;
|
||||
use rand::RngCore;
|
||||
@@ -36,7 +42,7 @@ impl ServerOptions {
|
||||
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
client_id: client_id.to_string(),
|
||||
client_id,
|
||||
issuer: DEFAULT_ISSUER.to_string(),
|
||||
port: DEFAULT_PORT,
|
||||
open_browser: true,
|
||||
@@ -83,7 +89,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let pkce = generate_pkce();
|
||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||
|
||||
let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?;
|
||||
let server = bind_server(opts.port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
@@ -120,7 +126,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
let server = server.clone();
|
||||
let server = server;
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
@@ -136,19 +142,24 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||
match response {
|
||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
||||
let exit_result = match response {
|
||||
HandledRequest::Response(response) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
None
|
||||
}
|
||||
HandledRequest::ResponseAndExit { response, result } => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
Some(result)
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if is_login_complete {
|
||||
break Ok(());
|
||||
if let Some(result) = exit_result {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -172,7 +183,10 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
enum HandledRequest {
|
||||
Response(Response<Cursor<Vec<u8>>>),
|
||||
RedirectWithHeader(Header),
|
||||
ResponseAndExit(Response<Cursor<Vec<u8>>>),
|
||||
ResponseAndExit {
|
||||
response: Response<Cursor<Vec<u8>>>,
|
||||
result: io::Result<()>,
|
||||
},
|
||||
}
|
||||
|
||||
async fn process_request(
|
||||
@@ -267,8 +281,18 @@ async fn process_request(
|
||||
) {
|
||||
resp.add_header(h);
|
||||
}
|
||||
HandledRequest::ResponseAndExit(resp)
|
||||
HandledRequest::ResponseAndExit {
|
||||
response: resp,
|
||||
result: Ok(()),
|
||||
}
|
||||
}
|
||||
"/cancel" => HandledRequest::ResponseAndExit {
|
||||
response: Response::from_string("Login cancelled"),
|
||||
result: Err(io::Error::new(
|
||||
io::ErrorKind::Interrupted,
|
||||
"Login cancelled",
|
||||
)),
|
||||
},
|
||||
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
||||
}
|
||||
}
|
||||
@@ -290,6 +314,7 @@ fn build_authorize_url(
|
||||
("id_token_add_organizations", "true"),
|
||||
("codex_cli_simplified_flow", "true"),
|
||||
("state", state),
|
||||
("originator", ORIGINATOR.value.as_str()),
|
||||
];
|
||||
let qs = query
|
||||
.into_iter()
|
||||
@@ -305,6 +330,68 @@ fn generate_state() -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
fn send_cancel_request(port: u16) -> io::Result<()> {
|
||||
let addr: SocketAddr = format!("127.0.0.1:{port}")
|
||||
.parse()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(2)))?;
|
||||
|
||||
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
|
||||
stream.write_all(format!("Host: 127.0.0.1:{port}\r\n").as_bytes())?;
|
||||
stream.write_all(b"Connection: close\r\n\r\n")?;
|
||||
|
||||
let mut buf = [0u8; 64];
|
||||
let _ = stream.read(&mut buf);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_server(port: u16) -> io::Result<Server> {
|
||||
let bind_address = format!("127.0.0.1:{port}");
|
||||
let mut cancel_attempted = false;
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: u32 = 10;
|
||||
const RETRY_DELAY: Duration = Duration::from_millis(200);
|
||||
|
||||
loop {
|
||||
match Server::http(&bind_address) {
|
||||
Ok(server) => return Ok(server),
|
||||
Err(err) => {
|
||||
attempts += 1;
|
||||
let is_addr_in_use = err
|
||||
.downcast_ref::<io::Error>()
|
||||
.map(|io_err| io_err.kind() == io::ErrorKind::AddrInUse)
|
||||
.unwrap_or(false);
|
||||
|
||||
// If the address is in use, there is probably another instance of the login server
|
||||
// running. Attempt to cancel it and retry.
|
||||
if is_addr_in_use {
|
||||
if !cancel_attempted {
|
||||
cancel_attempted = true;
|
||||
if let Err(cancel_err) = send_cancel_request(port) {
|
||||
eprintln!("Failed to cancel previous login server: {cancel_err}");
|
||||
}
|
||||
}
|
||||
|
||||
thread::sleep(RETRY_DELAY);
|
||||
|
||||
if attempts >= MAX_ATTEMPTS {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
format!("Port {bind_address} is already in use"),
|
||||
));
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExchangedTokens {
|
||||
id_token: String,
|
||||
access_token: String,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpListener;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use codex_login::ServerOptions;
|
||||
@@ -175,3 +177,65 @@ async fn creates_missing_codex_home_dir() {
|
||||
"auth.json should be created even if parent dir was missing"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn cancels_previous_login_server_when_port_is_in_use() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let (issuer_addr, _issuer_handle) = start_mock_issuer();
|
||||
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
|
||||
|
||||
let first_tmp = tempdir().unwrap();
|
||||
let first_codex_home = first_tmp.path().to_path_buf();
|
||||
|
||||
let first_opts = ServerOptions {
|
||||
codex_home: first_codex_home,
|
||||
client_id: codex_login::CLIENT_ID.to_string(),
|
||||
issuer: issuer.clone(),
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state".to_string()),
|
||||
};
|
||||
|
||||
let first_server = run_login_server(first_opts).unwrap();
|
||||
let login_port = first_server.actual_port;
|
||||
let first_server_task = tokio::spawn(async move { first_server.block_until_done().await });
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let second_tmp = tempdir().unwrap();
|
||||
let second_codex_home = second_tmp.path().to_path_buf();
|
||||
|
||||
let second_opts = ServerOptions {
|
||||
codex_home: second_codex_home,
|
||||
client_id: codex_login::CLIENT_ID.to_string(),
|
||||
issuer,
|
||||
port: login_port,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state_2".to_string()),
|
||||
};
|
||||
|
||||
let second_server = run_login_server(second_opts).unwrap();
|
||||
assert_eq!(second_server.actual_port, login_port);
|
||||
|
||||
let cancel_result = first_server_task
|
||||
.await
|
||||
.expect("first login server task panicked")
|
||||
.expect_err("login server should report cancellation");
|
||||
assert_eq!(cancel_result.kind(), io::ErrorKind::Interrupted);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let cancel_url = format!("http://127.0.0.1:{login_port}/cancel");
|
||||
let resp = client.get(cancel_url).send().await.unwrap();
|
||||
assert!(resp.status().is_success());
|
||||
|
||||
second_server
|
||||
.block_until_done()
|
||||
.await
|
||||
.expect_err("second login server should report cancellation");
|
||||
}
|
||||
|
||||
@@ -64,6 +64,9 @@ async fn main() -> Result<()> {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
// This field is used by Codex when it is an MCP server: it should
|
||||
// not be used when Codex is an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -26,7 +26,6 @@ schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
strum_macros = "0.27.2"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -41,8 +40,9 @@ uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
base64 = "0.22"
|
||||
mcp_test_support = { path = "tests/common" }
|
||||
os_info = "3.12.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -1,39 +1,38 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::git_info::git_diff_to_remote;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GitDiffToRemoteResponse;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotification;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::Cursor as RolloutCursor;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::SessionMeta;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::try_read_auth_json;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::config_edit::CONFIG_KEY_EFFORT;
|
||||
use codex_core::config_edit::CONFIG_KEY_MODEL;
|
||||
use codex_core::config_edit::persist_non_null_overrides;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::git_info::git_diff_to_remote;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem as CoreInputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_login::ServerOptions as LoginServerOptions;
|
||||
use codex_login::ShutdownHandle;
|
||||
use codex_login::run_login_server;
|
||||
@@ -42,27 +41,61 @@ use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationResponse;
|
||||
use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::ConversationSummary;
|
||||
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
|
||||
use codex_protocol::mcp_protocol::ExecArbitraryCommandResponse;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::GetConfigTomlResponse;
|
||||
use codex_protocol::mcp_protocol::ExecOneOffCommandParams;
|
||||
use codex_protocol::mcp_protocol::GetUserAgentResponse;
|
||||
use codex_protocol::mcp_protocol::GetUserSavedConfigResponse;
|
||||
use codex_protocol::mcp_protocol::GitDiffToRemoteResponse;
|
||||
use codex_protocol::mcp_protocol::InputItem as WireInputItem;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationResponse;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationResponse;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationSubscriptionResponse;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
||||
use codex_protocol::mcp_protocol::ServerNotification;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InputMessageKind;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::select;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Duration before a ChatGPT login attempt is abandoned.
|
||||
const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
@@ -88,7 +121,7 @@ pub(crate) struct CodexMessageProcessor {
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
pending_interrupts: Arc<Mutex<HashMap<Uuid, Vec<RequestId>>>>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<ConversationId, Vec<RequestId>>>>,
|
||||
}
|
||||
|
||||
impl CodexMessageProcessor {
|
||||
@@ -119,6 +152,15 @@ impl CodexMessageProcessor {
|
||||
// created before processing any subsequent messages.
|
||||
self.process_new_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ListConversations { request_id, params } => {
|
||||
self.handle_list_conversations(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ResumeConversation { request_id, params } => {
|
||||
self.handle_resume_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ArchiveConversation { request_id, params } => {
|
||||
self.archive_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::SendUserMessage { request_id, params } => {
|
||||
self.send_user_message(request_id, params).await;
|
||||
}
|
||||
@@ -137,6 +179,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GitDiffToRemote { request_id, params } => {
|
||||
self.git_diff_to_origin(request_id, params.cwd).await;
|
||||
}
|
||||
ClientRequest::LoginApiKey { request_id, params } => {
|
||||
self.login_api_key(request_id, params).await;
|
||||
}
|
||||
ClientRequest::LoginChatGpt { request_id } => {
|
||||
self.login_chatgpt(request_id).await;
|
||||
}
|
||||
@@ -149,8 +194,53 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GetAuthStatus { request_id, params } => {
|
||||
self.get_auth_status(request_id, params).await;
|
||||
}
|
||||
ClientRequest::GetConfigToml { request_id } => {
|
||||
self.get_config_toml(request_id).await;
|
||||
ClientRequest::GetUserSavedConfig { request_id } => {
|
||||
self.get_user_saved_config(request_id).await;
|
||||
}
|
||||
ClientRequest::SetDefaultModel { request_id, params } => {
|
||||
self.set_default_model(request_id, params).await;
|
||||
}
|
||||
ClientRequest::GetUserAgent { request_id } => {
|
||||
self.get_user_agent(request_id).await;
|
||||
}
|
||||
ClientRequest::UserInfo { request_id } => {
|
||||
self.get_user_info(request_id).await;
|
||||
}
|
||||
ClientRequest::ExecOneOffCommand { request_id, params } => {
|
||||
self.exec_one_off_command(request_id, params).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_api_key(&mut self, request_id: RequestId, params: LoginApiKeyParams) {
|
||||
{
|
||||
let mut guard = self.active_login.lock().await;
|
||||
if let Some(active) = guard.take() {
|
||||
active.drop();
|
||||
}
|
||||
}
|
||||
|
||||
match login_with_api_key(&self.config.codex_home, ¶ms.api_key) {
|
||||
Ok(()) => {
|
||||
self.auth_manager.reload();
|
||||
self.outgoing
|
||||
.send_response(request_id, LoginApiKeyResponse {})
|
||||
.await;
|
||||
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: self.auth_manager.auth().map(|auth| auth.mode),
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to save api key: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -308,7 +398,7 @@ impl CodexMessageProcessor {
|
||||
.await;
|
||||
|
||||
// Send auth status change notification reflecting the current auth mode
|
||||
// after logout (which may fall back to API key via env var).
|
||||
// after logout.
|
||||
let current_auth_method = self.auth_manager.auth().map(|auth| auth.mode);
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: current_auth_method,
|
||||
@@ -323,7 +413,6 @@ impl CodexMessageProcessor {
|
||||
request_id: RequestId,
|
||||
params: codex_protocol::mcp_protocol::GetAuthStatusParams,
|
||||
) {
|
||||
let preferred_auth_method: AuthMode = self.auth_manager.preferred_auth_method();
|
||||
let include_token = params.include_token.unwrap_or(false);
|
||||
let do_refresh = params.refresh_token.unwrap_or(false);
|
||||
|
||||
@@ -331,6 +420,11 @@ impl CodexMessageProcessor {
|
||||
tracing::warn!("failed to refresh token while getting auth status: {err}");
|
||||
}
|
||||
|
||||
// Determine whether auth is required based on the active model provider.
|
||||
// If a custom provider is configured with `requires_openai_auth == false`,
|
||||
// then no auth step is required; otherwise, default to requiring auth.
|
||||
let requires_openai_auth = Some(self.config.model_provider.requires_openai_auth);
|
||||
|
||||
let response = match self.auth_manager.auth() {
|
||||
Some(auth) => {
|
||||
let (reported_auth_method, token_opt) = match auth.get_token().await {
|
||||
@@ -346,21 +440,27 @@ impl CodexMessageProcessor {
|
||||
};
|
||||
codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: reported_auth_method,
|
||||
preferred_auth_method,
|
||||
auth_token: token_opt,
|
||||
requires_openai_auth,
|
||||
}
|
||||
}
|
||||
None => codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: None,
|
||||
preferred_auth_method,
|
||||
auth_token: None,
|
||||
requires_openai_auth,
|
||||
},
|
||||
};
|
||||
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_config_toml(&self, request_id: RequestId) {
|
||||
async fn get_user_agent(&self, request_id: RequestId) {
|
||||
let user_agent = get_codex_user_agent();
|
||||
let response = GetUserAgentResponse { user_agent };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_user_saved_config(&self, request_id: RequestId) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
@@ -387,35 +487,130 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
let profiles: HashMap<String, codex_protocol::config_types::ConfigProfile> = cfg
|
||||
.profiles
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
k,
|
||||
// Define this explicitly here to avoid the need to
|
||||
// implement `From<codex_core::config_profile::ConfigProfile>`
|
||||
// for the `ConfigProfile` type and introduce a dependency on codex_core
|
||||
codex_protocol::config_types::ConfigProfile {
|
||||
model: v.model,
|
||||
approval_policy: v.approval_policy,
|
||||
model_reasoning_effort: v.model_reasoning_effort,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let user_saved_config: UserSavedConfig = cfg.into();
|
||||
|
||||
let response = GetConfigTomlResponse {
|
||||
approval_policy: cfg.approval_policy,
|
||||
sandbox_mode: cfg.sandbox_mode,
|
||||
model_reasoning_effort: cfg.model_reasoning_effort,
|
||||
profile: cfg.profile,
|
||||
profiles: Some(profiles),
|
||||
let response = GetUserSavedConfigResponse {
|
||||
config: user_saved_config,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_user_info(&self, request_id: RequestId) {
|
||||
// Read alleged user email from auth.json (best-effort; not verified).
|
||||
let auth_path = get_auth_file(&self.config.codex_home);
|
||||
let alleged_user_email = match try_read_auth_json(&auth_path) {
|
||||
Ok(auth) => auth.tokens.and_then(|t| t.id_token.email),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let response = UserInfoResponse { alleged_user_email };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) {
|
||||
let SetDefaultModelParams {
|
||||
model,
|
||||
reasoning_effort,
|
||||
} = params;
|
||||
let effort_str = reasoning_effort.map(|effort| effort.to_string());
|
||||
|
||||
let overrides: [(&[&str], Option<&str>); 2] = [
|
||||
(&[CONFIG_KEY_MODEL], model.as_deref()),
|
||||
(&[CONFIG_KEY_EFFORT], effort_str.as_deref()),
|
||||
];
|
||||
|
||||
match persist_non_null_overrides(
|
||||
&self.config.codex_home,
|
||||
self.config.active_profile.as_deref(),
|
||||
&overrides,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
let response = SetDefaultModelResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to persist overrides: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
||||
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
||||
|
||||
if params.command.is_empty() {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "command must not be empty".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let cwd = params.cwd.unwrap_or_else(|| self.config.cwd.clone());
|
||||
let env = create_env(&self.config.shell_environment_policy);
|
||||
let timeout_ms = params.timeout_ms;
|
||||
let exec_params = ExecParams {
|
||||
command: params.command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let effective_policy = params
|
||||
.sandbox_policy
|
||||
.unwrap_or_else(|| self.config.sandbox_policy.clone());
|
||||
|
||||
let sandbox_type = match &effective_policy {
|
||||
codex_core::protocol::SandboxPolicy::DangerFullAccess => {
|
||||
codex_core::exec::SandboxType::None
|
||||
}
|
||||
_ => get_platform_sandbox().unwrap_or(codex_core::exec::SandboxType::None),
|
||||
};
|
||||
tracing::debug!("Sandbox type: {sandbox_type:?}");
|
||||
let codex_linux_sandbox_exe = self.config.codex_linux_sandbox_exe.clone();
|
||||
let outgoing = self.outgoing.clone();
|
||||
let req_id = request_id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
match codex_core::exec::process_exec_tool_call(
|
||||
exec_params,
|
||||
sandbox_type,
|
||||
&effective_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
let response = ExecArbitraryCommandResponse {
|
||||
exit_code: output.exit_code,
|
||||
stdout: output.stdout.text,
|
||||
stderr: output.stderr.text,
|
||||
};
|
||||
outgoing.send_response(req_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("exec failed: {err}"),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(req_id, error).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn process_new_conversation(&self, request_id: RequestId, params: NewConversationParams) {
|
||||
let config = match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()) {
|
||||
Ok(config) => config,
|
||||
@@ -438,8 +633,10 @@ impl CodexMessageProcessor {
|
||||
..
|
||||
} = conversation_id;
|
||||
let response = NewConversationResponse {
|
||||
conversation_id: ConversationId(conversation_id),
|
||||
conversation_id,
|
||||
model: session_configured.model,
|
||||
reasoning_effort: session_configured.reasoning_effort,
|
||||
rollout_path: session_configured.rollout_path,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
@@ -454,6 +651,268 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_conversations(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ListConversationsParams,
|
||||
) {
|
||||
let page_size = params.page_size.unwrap_or(25);
|
||||
// Decode the optional cursor string to a Cursor via serde (Cursor implements Deserialize from string)
|
||||
let cursor_obj: Option<RolloutCursor> = match params.cursor {
|
||||
Some(s) => serde_json::from_str::<RolloutCursor>(&format!("\"{s}\"")).ok(),
|
||||
None => None,
|
||||
};
|
||||
let cursor_ref = cursor_obj.as_ref();
|
||||
|
||||
let page = match RolloutRecorder::list_conversations(
|
||||
&self.config.codex_home,
|
||||
page_size,
|
||||
cursor_ref,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to list conversations: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let items = page
|
||||
.items
|
||||
.into_iter()
|
||||
.filter_map(|it| extract_conversation_summary(it.path, &it.head))
|
||||
.collect();
|
||||
|
||||
// Encode next_cursor as a plain string
|
||||
let next_cursor = match page.next_cursor {
|
||||
Some(c) => match serde_json::to_value(&c) {
|
||||
Ok(serde_json::Value::String(s)) => Some(s),
|
||||
_ => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
let response = ListConversationsResponse { items, next_cursor };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn handle_resume_conversation(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ResumeConversationParams,
|
||||
) {
|
||||
// Derive a Config using the same logic as new conversation, honoring overrides if provided.
|
||||
let config = match params.overrides {
|
||||
Some(overrides) => {
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone())
|
||||
}
|
||||
None => Ok(self.config.as_ref().clone()),
|
||||
};
|
||||
let config = match config {
|
||||
Ok(cfg) => cfg,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match self
|
||||
.conversation_manager
|
||||
.resume_conversation_from_rollout(
|
||||
config,
|
||||
params.path.clone(),
|
||||
self.auth_manager.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(NewConversation {
|
||||
conversation_id,
|
||||
session_configured,
|
||||
..
|
||||
}) => {
|
||||
let event = Event {
|
||||
id: "".to_string(),
|
||||
msg: EventMsg::SessionConfigured(session_configured.clone()),
|
||||
};
|
||||
self.outgoing.send_event_as_notification(&event, None).await;
|
||||
let initial_messages = session_configured.initial_messages.map(|msgs| {
|
||||
msgs.into_iter()
|
||||
.filter(|event| {
|
||||
// Don't send non-plain user messages (like user instructions
|
||||
// or environment context) back so they don't get rendered.
|
||||
if let EventMsg::UserMessage(user_message) = event {
|
||||
return matches!(user_message.kind, Some(InputMessageKind::Plain));
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Reply with conversation id + model and initial messages (when present)
|
||||
let response = codex_protocol::mcp_protocol::ResumeConversationResponse {
|
||||
conversation_id,
|
||||
model: session_configured.model.clone(),
|
||||
initial_messages,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("error resuming conversation: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn archive_conversation(&self, request_id: RequestId, params: ArchiveConversationParams) {
|
||||
let ArchiveConversationParams {
|
||||
conversation_id,
|
||||
rollout_path,
|
||||
} = params;
|
||||
|
||||
// Verify that the rollout path is in the sessions directory or else
|
||||
// a malicious client could specify an arbitrary path.
|
||||
let rollout_folder = self.config.codex_home.join(codex_core::SESSIONS_SUBDIR);
|
||||
let canonical_rollout_path = tokio::fs::canonicalize(&rollout_path).await;
|
||||
let canonical_rollout_path = if let Ok(path) = canonical_rollout_path
|
||||
&& path.starts_with(&rollout_folder)
|
||||
{
|
||||
path
|
||||
} else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` must be in sessions directory",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let required_suffix = format!("{}.jsonl", conversation_id.0);
|
||||
let Some(file_name) = canonical_rollout_path.file_name().map(OsStr::to_owned) else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` missing file name",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
if !file_name
|
||||
.to_string_lossy()
|
||||
.ends_with(required_suffix.as_str())
|
||||
{
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` does not match conversation id {conversation_id}",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let removed_conversation = self
|
||||
.conversation_manager
|
||||
.remove_conversation(&conversation_id)
|
||||
.await;
|
||||
if let Some(conversation) = removed_conversation {
|
||||
info!("conversation {conversation_id} was active; shutting down");
|
||||
let conversation_clone = conversation.clone();
|
||||
let notify = Arc::new(tokio::sync::Notify::new());
|
||||
let notify_clone = notify.clone();
|
||||
|
||||
// Establish the listener for ShutdownComplete before submitting
|
||||
// Shutdown so it is not missed.
|
||||
let is_shutdown = tokio::spawn(async move {
|
||||
loop {
|
||||
select! {
|
||||
_ = notify_clone.notified() => {
|
||||
break;
|
||||
}
|
||||
event = conversation_clone.next_event() => {
|
||||
if let Ok(event) = event && matches!(event.msg, EventMsg::ShutdownComplete) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Request shutdown.
|
||||
match conversation.submit(Op::Shutdown).await {
|
||||
Ok(_) => {
|
||||
// Successfully submitted Shutdown; wait before proceeding.
|
||||
select! {
|
||||
_ = is_shutdown => {
|
||||
// Normal shutdown: proceed with archive.
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(10)) => {
|
||||
warn!("conversation {conversation_id} shutdown timed out; proceeding with archive");
|
||||
notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to submit Shutdown to conversation {conversation_id}: {err}");
|
||||
notify.notify_one();
|
||||
// Perhaps we lost a shutdown race, so let's continue to
|
||||
// clean up the .jsonl file.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move the .jsonl file to the archived sessions subdir.
|
||||
let result: std::io::Result<()> = async {
|
||||
let archive_folder = self
|
||||
.config
|
||||
.codex_home
|
||||
.join(codex_core::ARCHIVED_SESSIONS_SUBDIR);
|
||||
tokio::fs::create_dir_all(&archive_folder).await?;
|
||||
tokio::fs::rename(&canonical_rollout_path, &archive_folder.join(&file_name)).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
let response = ArchiveConversationResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to archive conversation: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) {
|
||||
let SendUserMessageParams {
|
||||
conversation_id,
|
||||
@@ -461,7 +920,7 @@ impl CodexMessageProcessor {
|
||||
} = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -509,7 +968,7 @@ impl CodexMessageProcessor {
|
||||
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -555,7 +1014,7 @@ impl CodexMessageProcessor {
|
||||
let InterruptConversationParams { conversation_id } = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -570,7 +1029,7 @@ impl CodexMessageProcessor {
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(conversation_id.0).or_default().push(request_id);
|
||||
map.entry(conversation_id).or_default().push(request_id);
|
||||
}
|
||||
|
||||
// Submit the interrupt; we'll respond upon TurnAborted.
|
||||
@@ -585,12 +1044,12 @@ impl CodexMessageProcessor {
|
||||
let AddConversationListenerParams { conversation_id } = params;
|
||||
let Ok(conversation) = self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("conversation not found: {}", conversation_id.0),
|
||||
message: format!("conversation not found: {conversation_id}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
@@ -620,18 +1079,18 @@ impl CodexMessageProcessor {
|
||||
};
|
||||
|
||||
// For now, we send a notification for every event,
|
||||
// JSON-serializing the `Event` as-is, but we will move
|
||||
// to creating a special enum for notifications with a
|
||||
// stable wire format.
|
||||
// JSON-serializing the `Event` as-is, but these should
|
||||
// be migrated to be variants of `ServerNotification`
|
||||
// instead.
|
||||
let method = format!("codex/event/{}", event.msg);
|
||||
let mut params = match serde_json::to_value(event.clone()) {
|
||||
Ok(serde_json::Value::Object(map)) => map,
|
||||
Ok(_) => {
|
||||
tracing::error!("event did not serialize to an object");
|
||||
error!("event did not serialize to an object");
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("failed to serialize event: {err}");
|
||||
error!("failed to serialize event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -703,7 +1162,7 @@ async fn apply_bespoke_event_handling(
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<Uuid, Vec<RequestId>>>>,
|
||||
pending_interrupts: Arc<Mutex<HashMap<ConversationId, Vec<RequestId>>>>,
|
||||
) {
|
||||
let Event { id: event_id, msg } = event;
|
||||
match msg {
|
||||
@@ -756,7 +1215,7 @@ async fn apply_bespoke_event_handling(
|
||||
EventMsg::TurnAborted(turn_aborted_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_interrupts.lock().await;
|
||||
map.remove(&conversation_id.0).unwrap_or_default()
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
};
|
||||
if !pending.is_empty() {
|
||||
let response = InterruptConversationResponse {
|
||||
@@ -799,7 +1258,6 @@ fn derive_config_from_params(
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -815,7 +1273,7 @@ fn derive_config_from_params(
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
receiver: oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<CodexConversation>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
@@ -857,14 +1315,14 @@ async fn on_patch_approval_response(
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
receiver: oneshot::Receiver<mcp_types::Result>,
|
||||
conversation: Arc<CodexConversation>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::error!("request failed: {err:?}");
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -890,3 +1348,101 @@ async fn on_exec_approval_response(
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_conversation_summary(
|
||||
path: PathBuf,
|
||||
head: &[serde_json::Value],
|
||||
) -> Option<ConversationSummary> {
|
||||
let session_meta = match head.first() {
|
||||
Some(first_line) => serde_json::from_value::<SessionMeta>(first_line.clone()).ok()?,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let preview = head
|
||||
.iter()
|
||||
.filter_map(|value| serde_json::from_value::<ResponseItem>(value.clone()).ok())
|
||||
.find_map(|item| match item {
|
||||
ResponseItem::Message { content, .. } => {
|
||||
content.into_iter().find_map(|content| match content {
|
||||
ContentItem::InputText { text } => {
|
||||
match InputMessageKind::from(("user", &text)) {
|
||||
InputMessageKind::Plain => Some(text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
})?;
|
||||
|
||||
let preview = match preview.find(USER_MESSAGE_BEGIN) {
|
||||
Some(idx) => preview[idx + USER_MESSAGE_BEGIN.len()..].trim(),
|
||||
None => preview.as_str(),
|
||||
};
|
||||
|
||||
let timestamp = if session_meta.timestamp.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(session_meta.timestamp.clone())
|
||||
};
|
||||
|
||||
Some(ConversationSummary {
|
||||
conversation_id: session_meta.id,
|
||||
timestamp,
|
||||
path,
|
||||
preview: preview.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn extract_conversation_summary_prefers_plain_user_messages() {
|
||||
let conversation_id =
|
||||
ConversationId(Uuid::parse_str("3f941c35-29b3-493b-b0a4-e25800d9aeb0").unwrap());
|
||||
let timestamp = Some("2025-09-05T16:53:11.850Z".to_string());
|
||||
let path = PathBuf::from("rollout.jsonl");
|
||||
|
||||
let head = vec![
|
||||
json!({
|
||||
"id": conversation_id.0,
|
||||
"timestamp": timestamp,
|
||||
"cwd": "/",
|
||||
"originator": "codex",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
}),
|
||||
json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_text",
|
||||
"text": "<user_instructions>\n<AGENTS.md contents>\n</user_instructions>".to_string(),
|
||||
}],
|
||||
}),
|
||||
json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_text",
|
||||
"text": format!("<prior context> {USER_MESSAGE_BEGIN}Count to 5"),
|
||||
}],
|
||||
}),
|
||||
];
|
||||
|
||||
let summary = extract_conversation_summary(path.clone(), &head).expect("summary");
|
||||
|
||||
assert_eq!(summary.conversation_id, conversation_id);
|
||||
assert_eq!(
|
||||
summary.timestamp,
|
||||
Some("2025-09-05T16:53:11.850Z".to_string())
|
||||
);
|
||||
assert_eq!(summary.path, path);
|
||||
assert_eq!(summary.preview, "Count to 5");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,6 @@ impl CodexToolCallParam {
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool: None,
|
||||
include_view_image_tool: None,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
tools_web_search_request: None,
|
||||
};
|
||||
@@ -182,8 +181,8 @@ impl CodexToolCallParam {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CodexToolCallReplyParam {
|
||||
/// The *session id* for this conversation.
|
||||
pub session_id: String,
|
||||
/// The conversation id for this Codex session.
|
||||
pub conversation_id: String,
|
||||
|
||||
/// The *next user prompt* to continue the Codex conversation.
|
||||
pub prompt: String,
|
||||
@@ -214,7 +213,8 @@ pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool {
|
||||
input_schema: tool_input_schema,
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Continue a Codex session by providing the session id and prompt.".to_string(),
|
||||
"Continue a Codex conversation by providing the conversation id and prompt."
|
||||
.to_string(),
|
||||
),
|
||||
annotations: None,
|
||||
}
|
||||
@@ -309,21 +309,21 @@ mod tests {
|
||||
let tool = create_tool_for_codex_tool_call_reply_param();
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"description": "Continue a Codex session by providing the session id and prompt.",
|
||||
"description": "Continue a Codex conversation by providing the conversation id and prompt.",
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"conversationId": {
|
||||
"description": "The conversation id for this Codex session.",
|
||||
"type": "string"
|
||||
},
|
||||
"prompt": {
|
||||
"description": "The *next user prompt* to continue the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"sessionId": {
|
||||
"description": "The *session id* for this conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"conversationId",
|
||||
"prompt",
|
||||
"sessionId",
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
@@ -18,18 +22,13 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
|
||||
pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
|
||||
@@ -43,7 +42,7 @@ pub async fn run_codex_tool_session(
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
conversation_manager: Arc<ConversationManager>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
) {
|
||||
let NewConversation {
|
||||
conversation_id,
|
||||
@@ -119,13 +118,13 @@ pub async fn run_codex_tool_session_reply(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
prompt: String,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
session_id: Uuid,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
conversation_id: ConversationId,
|
||||
) {
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.clone(), session_id);
|
||||
.insert(request_id.clone(), conversation_id);
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text { text: prompt }],
|
||||
@@ -154,7 +153,7 @@ async fn run_codex_tool_session_inner(
|
||||
codex: Arc<CodexConversation>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
@@ -223,7 +222,7 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
Some(msg) => msg,
|
||||
None => "".to_string(),
|
||||
};
|
||||
let result = CallToolResult {
|
||||
@@ -278,7 +277,8 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::ConversationPath(_)
|
||||
| EventMsg::UserMessage(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
|
||||
@@ -9,10 +9,13 @@ use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::default_client::USER_AGENT_SUFFIX;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
@@ -41,7 +44,7 @@ pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
conversation_manager: Arc<ConversationManager>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, ConversationId>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -53,11 +56,7 @@ impl MessageProcessor {
|
||||
config: Arc<Config>,
|
||||
) -> Self {
|
||||
let outgoing = Arc::new(outgoing);
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
);
|
||||
let auth_manager = AuthManager::shared(config.codex_home.clone());
|
||||
let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone()));
|
||||
let codex_message_processor = CodexMessageProcessor::new(
|
||||
auth_manager,
|
||||
@@ -210,6 +209,14 @@ impl MessageProcessor {
|
||||
return;
|
||||
}
|
||||
|
||||
let client_info = params.client_info;
|
||||
let name = client_info.name;
|
||||
let version = client_info.version;
|
||||
let user_agent_suffix = format!("{name}; {version}");
|
||||
if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() {
|
||||
*suffix = Some(user_agent_suffix);
|
||||
}
|
||||
|
||||
self.initialized = true;
|
||||
|
||||
// Build a minimal InitializeResult. Fill with placeholders.
|
||||
@@ -230,6 +237,7 @@ impl MessageProcessor {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
user_agent: Some(get_codex_user_agent()),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -436,7 +444,10 @@ impl MessageProcessor {
|
||||
tracing::info!("tools/call -> params: {:?}", arguments);
|
||||
|
||||
// parse arguments
|
||||
let CodexToolCallReplyParam { session_id, prompt } = match arguments {
|
||||
let CodexToolCallReplyParam {
|
||||
conversation_id,
|
||||
prompt,
|
||||
} = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallReplyParam>(json_val) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
@@ -457,12 +468,12 @@ impl MessageProcessor {
|
||||
},
|
||||
None => {
|
||||
tracing::error!(
|
||||
"Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required."
|
||||
"Missing arguments for codex-reply tool-call; the `conversation_id` and `prompt` fields are required."
|
||||
);
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `conversation_id` and `prompt` fields are required.".to_owned(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -473,14 +484,14 @@ impl MessageProcessor {
|
||||
return;
|
||||
}
|
||||
};
|
||||
let session_id = match Uuid::parse_str(&session_id) {
|
||||
Ok(id) => id,
|
||||
let conversation_id = match Uuid::parse_str(&conversation_id) {
|
||||
Ok(id) => ConversationId::from(id),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse session_id: {e}");
|
||||
tracing::error!("Failed to parse conversation_id: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse session_id: {e}"),
|
||||
text: format!("Failed to parse conversation_id: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -496,14 +507,18 @@ impl MessageProcessor {
|
||||
let outgoing = self.outgoing.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
let codex = match self.conversation_manager.get_conversation(session_id).await {
|
||||
let codex = match self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
tracing::warn!("Session not found for conversation_id: {conversation_id}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Session not found for session_id: {session_id}"),
|
||||
text: format!("Session not found for conversation_id: {conversation_id}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
@@ -516,7 +531,6 @@ impl MessageProcessor {
|
||||
|
||||
// Spawn the long-running reply handler.
|
||||
tokio::spawn({
|
||||
let codex = codex.clone();
|
||||
let outgoing = outgoing.clone();
|
||||
let prompt = prompt.clone();
|
||||
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||
@@ -528,7 +542,7 @@ impl MessageProcessor {
|
||||
request_id,
|
||||
prompt,
|
||||
running_requests_id_to_codex_uuid,
|
||||
session_id,
|
||||
conversation_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -564,24 +578,28 @@ impl MessageProcessor {
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
// Obtain the session_id while holding the first lock, then release.
|
||||
let session_id = {
|
||||
// Obtain the conversation id while holding the first lock, then release.
|
||||
let conversation_id = {
|
||||
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||
match map_guard.get(&request_id) {
|
||||
Some(id) => *id, // Uuid is Copy
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!("session_id: {session_id}");
|
||||
tracing::info!("conversation_id: {conversation_id}");
|
||||
|
||||
// Obtain the Codex conversation from the server.
|
||||
let codex_arc = match self.conversation_manager.get_conversation(session_id).await {
|
||||
let codex_arc = match self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
tracing::warn!("Session not found for conversation_id: {conversation_id}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -97,6 +97,9 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
}
|
||||
|
||||
/// This is used with the MCP server, but not the more general JSON-RPC app
|
||||
/// server. Prefer [`OutgoingMessageSender::send_server_notification`] where
|
||||
/// possible.
|
||||
pub(crate) async fn send_event_as_notification(
|
||||
&self,
|
||||
event: &Event,
|
||||
@@ -123,14 +126,9 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
let method = format!("codex/event/{notification}");
|
||||
let params = match serde_json::to_value(¬ification) {
|
||||
Ok(serde_json::Value::Object(mut map)) => map.remove("data"),
|
||||
_ => None,
|
||||
};
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Notification(OutgoingNotification { method, params });
|
||||
let _ = self.sender.send(outgoing_message);
|
||||
let _ = self
|
||||
.sender
|
||||
.send(OutgoingMessage::AppServerNotification(notification));
|
||||
}
|
||||
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
@@ -148,6 +146,9 @@ impl OutgoingMessageSender {
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(OutgoingRequest),
|
||||
Notification(OutgoingNotification),
|
||||
/// AppServerNotification is specific to the case where this is run as an
|
||||
/// "app server" as opposed to an MCP server.
|
||||
AppServerNotification(ServerNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
@@ -171,6 +172,21 @@ impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
params,
|
||||
})
|
||||
}
|
||||
AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
let params = match notification.to_params() {
|
||||
Ok(params) => Some(params),
|
||||
Err(err) => {
|
||||
warn!("failed to serialize notification params: {err}");
|
||||
None
|
||||
}
|
||||
};
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
@@ -242,8 +258,12 @@ pub(crate) struct OutgoingError {
|
||||
mod tests {
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
@@ -253,13 +273,18 @@ mod tests {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let rollout_file = NamedTempFile::new().unwrap();
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
rollout_path: rollout_file.path().to_path_buf(),
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -276,7 +301,7 @@ mod tests {
|
||||
let Ok(expected_params) = serde_json::to_value(&event) else {
|
||||
panic!("Event must serialize");
|
||||
};
|
||||
assert_eq!(params, Some(expected_params.clone()));
|
||||
assert_eq!(params, Some(expected_params));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -284,11 +309,16 @@ mod tests {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let rollout_file = NamedTempFile::new().unwrap();
|
||||
let session_configured_event = SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
rollout_path: rollout_file.path().to_path_buf(),
|
||||
};
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
@@ -315,11 +345,38 @@ mod tests {
|
||||
"msg": {
|
||||
"session_id": session_configured_event.session_id,
|
||||
"model": session_configured_event.model,
|
||||
"reasoning_effort": session_configured_event.reasoning_effort,
|
||||
"history_log_id": session_configured_event.history_log_id,
|
||||
"history_entry_count": session_configured_event.history_entry_count,
|
||||
"type": "session_configured",
|
||||
"rollout_path": rollout_file.path().to_path_buf(),
|
||||
}
|
||||
});
|
||||
assert_eq!(params.unwrap(), expected_params);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_server_notification_serialization() {
|
||||
let notification =
|
||||
ServerNotification::LoginChatGptComplete(LoginChatGptCompleteNotification {
|
||||
login_id: Uuid::nil(),
|
||||
success: true,
|
||||
error: None,
|
||||
});
|
||||
|
||||
let jsonrpc_notification: JSONRPCMessage =
|
||||
OutgoingMessage::AppServerNotification(notification).into();
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: "2.0".into(),
|
||||
method: "loginChatGptComplete".into(),
|
||||
params: Some(json!({
|
||||
"loginId": Uuid::nil(),
|
||||
"success": true,
|
||||
})),
|
||||
}),
|
||||
jsonrpc_notification,
|
||||
"ensure the strum macros serialize the method field correctly"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,16 +13,14 @@ codex-core = { path = "../../../core" }
|
||||
codex-mcp-server = { path = "../.." }
|
||||
codex-protocol = { path = "../../../protocol" }
|
||||
mcp-types = { path = "../../../mcp-types" }
|
||||
os_info = "3.12.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
serde = { version = "1" }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -13,13 +13,18 @@ use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
@@ -51,6 +56,18 @@ pub struct McpProcess {
|
||||
|
||||
impl McpProcess {
|
||||
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
||||
Self::new_with_env(codex_home, &[]).await
|
||||
}
|
||||
|
||||
/// Creates a new MCP process, allowing tests to override or remove
|
||||
/// specific environment variables for the child process only.
|
||||
///
|
||||
/// Pass a tuple of (key, Some(value)) to set/override, or (key, None) to
|
||||
/// remove a variable from the child's environment.
|
||||
pub async fn new_with_env(
|
||||
codex_home: &Path,
|
||||
env_overrides: &[(&str, Option<&str>)],
|
||||
) -> anyhow::Result<Self> {
|
||||
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
||||
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
||||
.context("should find binary for codex-mcp-server")?;
|
||||
@@ -65,6 +82,17 @@ impl McpProcess {
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
cmd.env("RUST_LOG", "debug");
|
||||
|
||||
for (k, v) in env_overrides {
|
||||
match v {
|
||||
Some(val) => {
|
||||
cmd.env(k, val);
|
||||
}
|
||||
None => {
|
||||
cmd.env_remove(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
@@ -112,6 +140,7 @@ impl McpProcess {
|
||||
name: "elicitation test".into(),
|
||||
title: Some("Elicitation Test".into()),
|
||||
version: "0.0.0".into(),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
||||
};
|
||||
@@ -126,6 +155,14 @@ impl McpProcess {
|
||||
.await?;
|
||||
|
||||
let initialized = self.read_jsonrpc_message().await?;
|
||||
let os_info = os_info::get();
|
||||
let user_agent = format!(
|
||||
"codex_cli_rs/0.0.0 ({} {}; {}) {} (elicitation test; 0.0.0)",
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
codex_core::terminal::user_agent()
|
||||
);
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
@@ -139,7 +176,8 @@ impl McpProcess {
|
||||
"serverInfo": {
|
||||
"name": "codex-mcp-server",
|
||||
"title": "Codex",
|
||||
"version": "0.0.0"
|
||||
"version": "0.0.0",
|
||||
"user_agent": user_agent
|
||||
},
|
||||
"protocolVersion": mcp_types::MCP_SCHEMA_VERSION
|
||||
})
|
||||
@@ -184,6 +222,15 @@ impl McpProcess {
|
||||
self.send_request("newConversation", params).await
|
||||
}
|
||||
|
||||
/// Send an `archiveConversation` JSON-RPC request.
|
||||
pub async fn send_archive_conversation_request(
|
||||
&mut self,
|
||||
params: ArchiveConversationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("archiveConversation", params).await
|
||||
}
|
||||
|
||||
/// Send an `addConversationListener` JSON-RPC request.
|
||||
pub async fn send_add_conversation_listener_request(
|
||||
&mut self,
|
||||
@@ -240,9 +287,55 @@ impl McpProcess {
|
||||
self.send_request("getAuthStatus", params).await
|
||||
}
|
||||
|
||||
/// Send a `getConfigToml` JSON-RPC request.
|
||||
pub async fn send_get_config_toml_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getConfigToml", None).await
|
||||
/// Send a `getUserSavedConfig` JSON-RPC request.
|
||||
pub async fn send_get_user_saved_config_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getUserSavedConfig", None).await
|
||||
}
|
||||
|
||||
/// Send a `getUserAgent` JSON-RPC request.
|
||||
pub async fn send_get_user_agent_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("getUserAgent", None).await
|
||||
}
|
||||
|
||||
/// Send a `userInfo` JSON-RPC request.
|
||||
pub async fn send_user_info_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("userInfo", None).await
|
||||
}
|
||||
|
||||
/// Send a `setDefaultModel` JSON-RPC request.
|
||||
pub async fn send_set_default_model_request(
|
||||
&mut self,
|
||||
params: SetDefaultModelParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("setDefaultModel", params).await
|
||||
}
|
||||
|
||||
/// Send a `listConversations` JSON-RPC request.
|
||||
pub async fn send_list_conversations_request(
|
||||
&mut self,
|
||||
params: ListConversationsParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("listConversations", params).await
|
||||
}
|
||||
|
||||
/// Send a `resumeConversation` JSON-RPC request.
|
||||
pub async fn send_resume_conversation_request(
|
||||
&mut self,
|
||||
params: ResumeConversationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("resumeConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginApiKey` JSON-RPC request.
|
||||
pub async fn send_login_api_key_request(
|
||||
&mut self,
|
||||
params: LoginApiKeyParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("loginApiKey", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginChatGpt` JSON-RPC request.
|
||||
|
||||
105
codex-rs/mcp-server/tests/suite/archive_conversation.rs
Normal file
105
codex-rs/mcp-server/tests/suite/archive_conversation.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::ARCHIVED_SESSIONS_SUBDIR;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn archive_conversation_moves_rollout_into_archived_directory() {
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("initialize timeout")
|
||||
.expect("initialize request");
|
||||
|
||||
let new_request_id = mcp
|
||||
.send_new_conversation_request(NewConversationParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send newConversation");
|
||||
let new_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(new_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("newConversation timeout")
|
||||
.expect("newConversation response");
|
||||
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
rollout_path,
|
||||
..
|
||||
} = to_response::<NewConversationResponse>(new_response)
|
||||
.expect("deserialize newConversation response");
|
||||
|
||||
assert!(
|
||||
rollout_path.exists(),
|
||||
"expected rollout path {} to exist",
|
||||
rollout_path.display()
|
||||
);
|
||||
|
||||
let archive_request_id = mcp
|
||||
.send_archive_conversation_request(ArchiveConversationParams {
|
||||
conversation_id,
|
||||
rollout_path: rollout_path.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("send archiveConversation");
|
||||
let archive_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(archive_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("archiveConversation timeout")
|
||||
.expect("archiveConversation response");
|
||||
|
||||
let _: ArchiveConversationResponse =
|
||||
to_response::<ArchiveConversationResponse>(archive_response)
|
||||
.expect("deserialize archiveConversation response");
|
||||
|
||||
let archived_directory = codex_home.path().join(ARCHIVED_SESSIONS_SUBDIR);
|
||||
let archived_rollout_path =
|
||||
archived_directory.join(rollout_path.file_name().unwrap_or_else(|| {
|
||||
panic!("rollout path {} missing file name", rollout_path.display())
|
||||
}));
|
||||
|
||||
assert!(
|
||||
!rollout_path.exists(),
|
||||
"expected rollout path {} to be moved",
|
||||
rollout_path.display()
|
||||
);
|
||||
assert!(
|
||||
archived_rollout_path.exists(),
|
||||
"expected archived rollout path {} to exist",
|
||||
archived_rollout_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &Path) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(config_toml, config_contents())
|
||||
}
|
||||
|
||||
fn config_contents() -> &'static str {
|
||||
r#"model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
"#
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -36,12 +37,31 @@ stream_max_retries = 0
|
||||
)
|
||||
}
|
||||
|
||||
async fn login_with_api_key_via_request(mcp: &mut McpProcess, api_key: &str) {
|
||||
let request_id = mcp
|
||||
.send_login_api_key_request(LoginApiKeyParams {
|
||||
api_key: api_key.to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("send loginApiKey: {e}"));
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("loginApiKey timeout: {e}"))
|
||||
.unwrap_or_else(|e| panic!("loginApiKey response: {e}"));
|
||||
let _: LoginApiKeyResponse =
|
||||
to_response(resp).unwrap_or_else(|e| panic!("deserialize login response: {e}"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_no_auth() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
let mut mcp = McpProcess::new_with_env(codex_home.path(), &[("OPENAI_API_KEY", None)])
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
@@ -72,8 +92,7 @@ async fn get_auth_status_no_auth() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -83,6 +102,8 @@ async fn get_auth_status_with_api_key() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_auth_status_request(GetAuthStatusParams {
|
||||
include_token: Some(true),
|
||||
@@ -101,14 +122,12 @@ async fn get_auth_status_with_api_key() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert_eq!(status.auth_token, Some("sk-test-key".to_string()));
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -118,6 +137,8 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
// Build params via struct so None field is omitted in wire JSON.
|
||||
let params = GetAuthStatusParams {
|
||||
include_token: None,
|
||||
@@ -138,5 +159,4 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert!(status.auth_token.is_none(), "token must be omitted");
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
@@ -90,6 +90,8 @@ async fn test_codex_jsonrpc_conversation_flow() {
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
rollout_path: _,
|
||||
} = new_conv_resp;
|
||||
assert_eq!(model, "mock-model");
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user