mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
67 Commits
daniel/tes
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f3f2add3e | ||
|
|
7ec9444fb6 | ||
|
|
8ec36fc336 | ||
|
|
d73055c5b1 | ||
|
|
7e3a272b29 | ||
|
|
661663c98a | ||
|
|
721003c552 | ||
|
|
36f1cca1b1 | ||
|
|
d3e1beb26c | ||
|
|
c264ae6021 | ||
|
|
8cd882c4bd | ||
|
|
90fe5e4a7e | ||
|
|
a90a58f7a1 | ||
|
|
b2d81a7cac | ||
|
|
77a8b7fdeb | ||
|
|
7fa5e95c1f | ||
|
|
191d620707 | ||
|
|
53504a38d2 | ||
|
|
5c42419b02 | ||
|
|
aecbe0f333 | ||
|
|
a30a902db5 | ||
|
|
f3b4a26f32 | ||
|
|
dc3c6bf62a | ||
|
|
3203862167 | ||
|
|
06853d94f0 | ||
|
|
cc2f4aafd7 | ||
|
|
356ea6ea34 | ||
|
|
4764fc1ee7 | ||
|
|
90ef94d3b3 | ||
|
|
6c2969d22d | ||
|
|
0ad1b0782b | ||
|
|
d7acd146fb | ||
|
|
c5465aed60 | ||
|
|
a95605a867 | ||
|
|
848058f05b | ||
|
|
a4f1c9d67e | ||
|
|
665341c9b1 | ||
|
|
fae0e6c52c | ||
|
|
1b4a79f03c | ||
|
|
640192ac3d | ||
|
|
205c36e393 | ||
|
|
d13ee79c41 | ||
|
|
bde468ff8d | ||
|
|
e292d1ed21 | ||
|
|
de8d77274a | ||
|
|
a5b7675e42 | ||
|
|
9823de3cc6 | ||
|
|
c32e9cfe86 | ||
|
|
1d17ca1fa3 | ||
|
|
bfe3328129 | ||
|
|
e0b38bd7a2 | ||
|
|
153338c20f | ||
|
|
3495a7dc37 | ||
|
|
042d4d55d9 | ||
|
|
5af08e0719 | ||
|
|
33d3ecbccc | ||
|
|
69cb72f842 | ||
|
|
69ac5153d4 | ||
|
|
16b6951648 | ||
|
|
231c36f8d3 | ||
|
|
1e4541b982 | ||
|
|
7be3b484ad | ||
|
|
9617b69c8a | ||
|
|
1d94b9111c | ||
|
|
2d6cd6951a | ||
|
|
310e3c32e5 | ||
|
|
37786593a0 |
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
@@ -2,7 +2,6 @@ name: 🎁 Feature Request
|
||||
description: Propose a new feature for Codex
|
||||
labels:
|
||||
- enhancement
|
||||
- needs triage
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
@@ -19,11 +18,6 @@ body:
|
||||
label: What feature would you like to see?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: author
|
||||
attributes:
|
||||
label: Are you interested in implementing this feature?
|
||||
description: Please wait for acknowledgement before implementing or opening a PR.
|
||||
- type: textarea
|
||||
id: notes
|
||||
attributes:
|
||||
|
||||
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -60,3 +60,6 @@ jobs:
|
||||
run: ./scripts/asciicheck.py codex-cli/README.md
|
||||
- name: Check codex-cli/README ToC
|
||||
run: python3 scripts/readme_toc.py codex-cli/README.md
|
||||
|
||||
- name: Prettier (run `pnpm run format:fix` to fix)
|
||||
run: pnpm run format
|
||||
|
||||
63
.github/workflows/issue-deduplicator.yml
vendored
63
.github/workflows/issue-deduplicator.yml
vendored
@@ -3,7 +3,7 @@ name: Issue Deduplicator
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
# - opened - disabled while testing
|
||||
- opened
|
||||
- labeled
|
||||
|
||||
jobs:
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -44,10 +44,38 @@ jobs:
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
prompt_file: .github/prompts/issue-deduplicator.txt
|
||||
require_repo_write: false
|
||||
codex_version: 0.43.0-alpha.16
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
model: gpt-5
|
||||
prompt: |
|
||||
You are an assistant that triages new GitHub issues by identifying potential duplicates.
|
||||
|
||||
You will receive the following JSON files located in the current working directory:
|
||||
- `codex-current-issue.json`: JSON object describing the newly created issue (fields: number, title, body).
|
||||
- `codex-existing-issues.json`: JSON array of recent issues (each element includes number, title, body, createdAt).
|
||||
|
||||
Instructions:
|
||||
- Compare the current issue against the existing issues to find up to five that appear to describe the same underlying problem or request.
|
||||
- Focus on the underlying intent and context of each issue—such as reported symptoms, feature requests, reproduction steps, or error messages—rather than relying solely on string similarity or synthetic metrics.
|
||||
- After your analysis, validate your results in 1-2 lines explaining your decision to return the selected matches.
|
||||
- When unsure, prefer returning fewer matches.
|
||||
- Include at most five numbers.
|
||||
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"issues": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"reason": { "type": "string" }
|
||||
},
|
||||
"required": ["issues", "reason"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
||||
comment-on-issue:
|
||||
name: Comment with potential duplicates
|
||||
@@ -65,20 +93,35 @@ jobs:
|
||||
with:
|
||||
github-token: ${{ github.token }}
|
||||
script: |
|
||||
let numbers;
|
||||
const raw = process.env.CODEX_OUTPUT ?? '';
|
||||
let parsed;
|
||||
try {
|
||||
numbers = JSON.parse(process.env.CODEX_OUTPUT);
|
||||
parsed = JSON.parse(raw);
|
||||
} catch (error) {
|
||||
core.info(`Codex output was not valid JSON. Raw output: ${raw}`);
|
||||
core.info(`Parse error: ${error.message}`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (numbers.length === 0) {
|
||||
const issues = Array.isArray(parsed?.issues) ? parsed.issues : [];
|
||||
const currentIssueNumber = String(context.payload.issue.number);
|
||||
|
||||
console.log(`Current issue number: ${currentIssueNumber}`);
|
||||
console.log(issues);
|
||||
|
||||
const filteredIssues = issues.filter((value) => String(value) !== currentIssueNumber);
|
||||
|
||||
if (filteredIssues.length === 0) {
|
||||
core.info('Codex reported no potential duplicates.');
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = ['Potential duplicates detected:', ...numbers.map((value) => `- #${value}`)];
|
||||
const lines = [
|
||||
'Potential duplicates detected. Please review them and close your issue if it is a duplicate.',
|
||||
'',
|
||||
...filteredIssues.map((value) => `- #${String(value)}`),
|
||||
'',
|
||||
'*Powered by [Codex Action](https://github.com/openai/codex-action)*'];
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
|
||||
65
.github/workflows/issue-labeler.yml
vendored
65
.github/workflows/issue-labeler.yml
vendored
@@ -3,7 +3,7 @@ name: Issue Labeler
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
# - opened - disabled while testing
|
||||
- opened
|
||||
- labeled
|
||||
|
||||
jobs:
|
||||
@@ -13,23 +13,60 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
ISSUE_NUMBER: ${{ github.event.issue.number }}
|
||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||
ISSUE_BODY: ${{ github.event.issue.body }}
|
||||
REPO_FULL_NAME: ${{ github.repository }}
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
prompt_file: .github/prompts/issue-labeler.txt
|
||||
require_repo_write: false
|
||||
codex_version: 0.43.0-alpha.16
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
prompt: |
|
||||
You are an assistant that reviews GitHub issues for the repository.
|
||||
|
||||
Your job is to choose the most appropriate existing labels for the issue described later in this prompt.
|
||||
Follow these rules:
|
||||
- Only pick labels out of the list below.
|
||||
- Prefer a small set of precise labels over many broad ones.
|
||||
|
||||
Labels to apply:
|
||||
1. bug — Reproducible defects in Codex products (CLI, VS Code extension, web, auth).
|
||||
2. enhancement — Feature requests or usability improvements that ask for new capabilities, better ergonomics, or quality-of-life tweaks.
|
||||
3. extension — VS Code (or other IDE) extension-specific issues.
|
||||
4. windows-os — Bugs or friction specific to Windows environments (always when PowerShell is mentioned, path handling, copy/paste, OS-specific auth or tooling failures).
|
||||
5. mcp — Topics involving Model Context Protocol servers/clients.
|
||||
6. codex-web — Issues targeting the Codex web UI/Cloud experience.
|
||||
8. azure — Problems or requests tied to Azure OpenAI deployments.
|
||||
9. documentation — Updates or corrections needed in docs/README/config references (broken links, missing examples, outdated keys, clarification requests).
|
||||
10. model-behavior — Undesirable LLM behavior: forgetting goals, refusing work, hallucinating environment details, quota misreports, or other reasoning/performance anomalies.
|
||||
|
||||
Issue number: ${{ github.event.issue.number }}
|
||||
|
||||
Issue title:
|
||||
${{ github.event.issue.title }}
|
||||
|
||||
Issue body:
|
||||
${{ github.event.issue.body }}
|
||||
|
||||
Repository full name:
|
||||
${{ github.repository }}
|
||||
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"labels": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["labels"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
||||
apply-labels:
|
||||
name: Apply labels from Codex output
|
||||
@@ -53,12 +90,12 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if ! printf '%s' "$json" | jq -e 'type == "array"' >/dev/null 2>&1; then
|
||||
echo "Codex output was not a JSON array. Raw output: $json"
|
||||
if ! printf '%s' "$json" | jq -e 'type == "object" and (.labels | type == "array")' >/dev/null 2>&1; then
|
||||
echo "Codex output did not include a labels array. Raw output: $json"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
labels=$(printf '%s' "$json" | jq -r '.[] | tostring')
|
||||
labels=$(printf '%s' "$json" | jq -r '.labels[] | tostring')
|
||||
if [ -z "$labels" ]; then
|
||||
echo "Codex returned an empty array. Nothing to do."
|
||||
exit 0
|
||||
|
||||
10
AGENTS.md
10
AGENTS.md
@@ -8,11 +8,16 @@ In the codex-rs folder where the rust code lives:
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` or `CODEX_SANDBOX_ENV_VAR`.
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
- Always collapse if statements per https://rust-lang.github.io/rust-clippy/master/index.html#collapsible_if
|
||||
- Always inline format! args when possible per https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args
|
||||
- Use method references over closures when possible per https://rust-lang.github.io/rust-clippy/master/index.html#redundant_closure_for_method_calls
|
||||
- When writing tests, prefer comparing the equality of entire objects over fields one by one.
|
||||
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
|
||||
1. Run the test for the specific project that was changed. For example, if changes were made in `codex-rs/tui`, run `cargo test -p codex-tui`.
|
||||
2. Once those pass, if any changes were made in common, core, or protocol, run the complete test suite with `cargo test --all-features`.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
When running interactively, ask the user before running `just fix` to finalize. `just fmt` does not require approval. project-specific or individual tests can be run without asking the user, but do ask the user before running the complete test suite.
|
||||
|
||||
## TUI style conventions
|
||||
|
||||
@@ -28,6 +33,7 @@ See `codex-rs/tui/styles.md`.
|
||||
- Desired: vec![" └ ".into(), "M".red(), " ".dim(), "tui/src/app.rs".dim()]
|
||||
|
||||
### TUI Styling (ratatui)
|
||||
|
||||
- Prefer Stylize helpers: use "text".dim(), .bold(), .cyan(), .italic(), .underlined() instead of manual Style where possible.
|
||||
- Prefer simple conversions: use "text".into() for spans and vec![…].into() for lines; when inference is ambiguous (e.g., Paragraph::new/Cell::from), use Line::from(spans) or Span::from(text).
|
||||
- Computed styles: if the Style is computed at runtime, using `Span::styled` is OK (`Span::from(text).set_style(style)` is also acceptable).
|
||||
@@ -39,6 +45,7 @@ See `codex-rs/tui/styles.md`.
|
||||
- Compactness: prefer the form that stays on one line after rustfmt; if only one of Line::from(vec![…]) or vec![…].into() avoids wrapping, choose that. If both wrap, pick the one with fewer wrapped lines.
|
||||
|
||||
### Text wrapping
|
||||
|
||||
- Always use textwrap::wrap to wrap plain strings.
|
||||
- If you have a ratatui Line and you want to wrap it, use the helpers in tui/src/wrapping.rs, e.g. word_wrap_lines / word_wrap_line.
|
||||
- If you need to indent wrapped lines, use the initial_indent / subsequent_indent options from RtOptions if you can, rather than writing custom logic.
|
||||
@@ -60,6 +67,7 @@ This repo uses snapshot tests (via `insta`), especially in `codex-rs/tui`, to va
|
||||
- `cargo insta accept -p codex-tui`
|
||||
|
||||
If you don’t have the tool:
|
||||
|
||||
- `cargo install cargo-insta`
|
||||
|
||||
### Test assertions
|
||||
|
||||
10
README.md
10
README.md
@@ -1,4 +1,3 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
@@ -62,8 +61,7 @@ You can also use Codex with an API key, but this requires [additional setup](./d
|
||||
|
||||
### Model Context Protocol (MCP)
|
||||
|
||||
Codex CLI supports [MCP servers](./docs/advanced.md#model-context-protocol-mcp). Enable by adding an `mcp_servers` section to your `~/.codex/config.toml`.
|
||||
|
||||
Codex can access MCP servers. To configure them, refer to the [config docs](./docs/config.md#mcp_servers).
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -83,9 +81,11 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
- [Auth methods](./docs/authentication.md#forcing-a-specific-auth-method-advanced)
|
||||
- [Login on a "Headless" machine](./docs/authentication.md#connecting-on-a-headless-machine)
|
||||
- [**Non-interactive mode**](./docs/exec.md)
|
||||
- **Automating Codex**
|
||||
- [GitHub Action](https://github.com/openai/codex-action)
|
||||
- [TypeScript SDK](./sdk/typescript/README.md)
|
||||
- [Non-interactive mode (`codex exec`)](./docs/exec.md)
|
||||
- [**Advanced**](./docs/advanced.md)
|
||||
- [Non-interactive / CI mode](./docs/advanced.md#non-interactive--ci-mode)
|
||||
- [Tracing / verbose logging](./docs/advanced.md#tracing--verbose-logging)
|
||||
- [Model Context Protocol (MCP)](./docs/advanced.md#model-context-protocol-mcp)
|
||||
- [**Zero data retention (ZDR)**](./docs/zdr.md)
|
||||
|
||||
1488
codex-rs/Cargo.lock
generated
1488
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -32,6 +32,7 @@ members = [
|
||||
"git-apply",
|
||||
"utils/json-to-toml",
|
||||
"utils/readiness",
|
||||
"utils/string",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -71,6 +72,7 @@ codex-rmcp-client = { path = "rmcp-client" }
|
||||
codex-tui = { path = "tui" }
|
||||
codex-utils-json-to-toml = { path = "utils/json-to-toml" }
|
||||
codex-utils-readiness = { path = "utils/readiness" }
|
||||
codex-utils-string = { path = "utils/string" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
mcp-types = { path = "mcp-types" }
|
||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||
@@ -81,10 +83,12 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_matches = "1.5.0"
|
||||
assert_cmd = "2"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.89"
|
||||
axum = { version = "0.8", default-features = false }
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
chrono = "0.4.42"
|
||||
@@ -102,7 +106,7 @@ env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
escargot = "0.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
ignore = "0.4.23"
|
||||
@@ -110,6 +114,7 @@ image = { version = "^0.25.8", default-features = false }
|
||||
indexmap = "2.6.0"
|
||||
insta = "1.43.2"
|
||||
itertools = "0.14.0"
|
||||
keyring = "3.6"
|
||||
landlock = "0.4.1"
|
||||
lazy_static = "1"
|
||||
libc = "0.2.175"
|
||||
@@ -138,11 +143,13 @@ rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = "0.12"
|
||||
rmcp = { version = "0.8.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
serial_test = "3.2.0"
|
||||
sha1 = "0.10.6"
|
||||
sha2 = "0.10"
|
||||
shlex = "1.3.0"
|
||||
@@ -237,5 +244,9 @@ strip = "symbols"
|
||||
codegen-units = 1
|
||||
|
||||
[patch.crates-io]
|
||||
# Uncomment to debug local changes.
|
||||
# ratatui = { path = "../../ratatui" }
|
||||
ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" }
|
||||
|
||||
# Uncomment to debug local changes.
|
||||
# rmcp = { path = "../../rust-sdk/crates/rmcp" }
|
||||
|
||||
@@ -23,9 +23,15 @@ Codex supports a rich set of configuration options. Note that the Rust CLI uses
|
||||
|
||||
### Model Context Protocol Support
|
||||
|
||||
Codex CLI functions as an MCP client that can connect to MCP servers on startup. See the [`mcp_servers`](../docs/config.md#mcp_servers) section in the configuration documentation for details.
|
||||
#### MCP client
|
||||
|
||||
It is still experimental, but you can also launch Codex as an MCP _server_ by running `codex mcp-server`. Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
Codex CLI functions as an MCP client that allows the Codex CLI and IDE extension to connect to MCP servers on startup. See the [`configuration documentation`](../docs/config.md#mcp_servers) for details.
|
||||
|
||||
#### MCP server (experimental)
|
||||
|
||||
Codex can be launched as an MCP _server_ by running `codex mcp-server`. This allows _other_ MCP clients to use Codex as a tool for another agent.
|
||||
|
||||
Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
|
||||
```shell
|
||||
npx @modelcontextprotocol/inspector codex mcp-server
|
||||
@@ -71,9 +77,13 @@ To test to see what happens when a command is run under the sandbox provided by
|
||||
|
||||
```
|
||||
# macOS
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex sandbox macos [--full-auto] [COMMAND]...
|
||||
|
||||
# Linux
|
||||
codex sandbox linux [--full-auto] [COMMAND]...
|
||||
|
||||
# Legacy aliases
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex debug landlock [--full-auto] [COMMAND]...
|
||||
```
|
||||
|
||||
|
||||
@@ -725,6 +725,7 @@ pub struct FuzzyFileSearchParams {
|
||||
pub struct FuzzyFileSearchResult {
|
||||
pub root: String,
|
||||
pub path: String,
|
||||
pub file_name: String,
|
||||
pub score: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub indices: Option<Vec<u32>>,
|
||||
|
||||
15
codex-rs/app-server/README.md
Normal file
15
codex-rs/app-server/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# codex-app-server
|
||||
|
||||
`codex app-server` is the harness Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt). The message schema is currently unstable, but those who wish to build experimental UIs on top of Codex may find it valuable.
|
||||
|
||||
## Protocol
|
||||
|
||||
Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication, streaming JSONL over stdio. The protocol is JSON-RPC 2.0, though the `"jsonrpc":"2.0"` header is omitted.
|
||||
|
||||
## Message Schema
|
||||
|
||||
Currently, you can dump a TypeScript version of the schema using `codex generate-ts`. It is specific to the version of Codex you used to run `generate-ts`, so the two are guaranteed to be compatible.
|
||||
|
||||
```
|
||||
codex generate-ts --out DIR
|
||||
```
|
||||
@@ -500,7 +500,7 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn get_user_saved_config(&self, request_id: RequestId) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home) {
|
||||
let toml_value = match load_config_as_toml(&self.config.codex_home).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
@@ -653,18 +653,19 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn process_new_conversation(&self, request_id: RequestId, params: NewConversationParams) {
|
||||
let config = match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()) {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let config =
|
||||
match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()).await {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match self.conversation_manager.new_conversation(config).await {
|
||||
Ok(conversation_id) => {
|
||||
@@ -752,7 +753,7 @@ impl CodexMessageProcessor {
|
||||
// Derive a Config using the same logic as new conversation, honoring overrides if provided.
|
||||
let config = match params.overrides {
|
||||
Some(overrides) => {
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone())
|
||||
derive_config_from_params(overrides, self.codex_linux_sandbox_exe.clone()).await
|
||||
}
|
||||
None => Ok(self.config.as_ref().clone()),
|
||||
};
|
||||
@@ -1320,7 +1321,7 @@ async fn apply_bespoke_event_handling(
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_config_from_params(
|
||||
async fn derive_config_from_params(
|
||||
params: NewConversationParams,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> std::io::Result<Config> {
|
||||
@@ -1358,7 +1359,7 @@ fn derive_config_from_params(
|
||||
.map(|(k, v)| (k, json_to_toml(v)))
|
||||
.collect();
|
||||
|
||||
Config::load_with_cli_overrides(cli_overrides, overrides)
|
||||
Config::load_with_cli_overrides(cli_overrides, overrides).await
|
||||
}
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::num::NonZero;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@@ -56,9 +57,16 @@ pub(crate) async fn run_fuzzy_file_search(
|
||||
match res {
|
||||
Ok(Ok((root, res))) => {
|
||||
for m in res.matches {
|
||||
let path = m.path;
|
||||
//TODO(shijie): Move file name generation to file_search lib.
|
||||
let file_name = Path::new(&path)
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| path.clone());
|
||||
let result = FuzzyFileSearchResult {
|
||||
root: root.clone(),
|
||||
path: m.path,
|
||||
path,
|
||||
file_name,
|
||||
score: m.score,
|
||||
indices: m.indices,
|
||||
};
|
||||
|
||||
@@ -81,6 +81,7 @@ pub async fn run_main(
|
||||
)
|
||||
})?;
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}"))
|
||||
})?;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
@@ -47,6 +48,7 @@ pub struct McpProcess {
|
||||
process: Child,
|
||||
stdin: ChildStdin,
|
||||
stdout: BufReader<ChildStdout>,
|
||||
pending_user_messages: VecDeque<JSONRPCNotification>,
|
||||
}
|
||||
|
||||
impl McpProcess {
|
||||
@@ -117,6 +119,7 @@ impl McpProcess {
|
||||
process,
|
||||
stdin,
|
||||
stdout,
|
||||
pending_user_messages: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -375,8 +378,9 @@ impl McpProcess {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(jsonrpc_request) => {
|
||||
return jsonrpc_request.try_into().with_context(
|
||||
@@ -402,8 +406,9 @@ impl McpProcess {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
@@ -427,8 +432,9 @@ impl McpProcess {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
@@ -451,6 +457,10 @@ impl McpProcess {
|
||||
) -> anyhow::Result<JSONRPCNotification> {
|
||||
eprintln!("in read_stream_until_notification_message({method})");
|
||||
|
||||
if let Some(notification) = self.take_pending_notification_by_method(method) {
|
||||
return Ok(notification);
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
@@ -458,6 +468,7 @@ impl McpProcess {
|
||||
if notification.method == method {
|
||||
return Ok(notification);
|
||||
}
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
@@ -471,4 +482,21 @@ impl McpProcess {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_pending_notification_by_method(&mut self, method: &str) -> Option<JSONRPCNotification> {
|
||||
if let Some(pos) = self
|
||||
.pending_user_messages
|
||||
.iter()
|
||||
.position(|notification| notification.method == method)
|
||||
{
|
||||
return self.pending_user_messages.remove(pos);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn enqueue_user_message(&mut self, notification: JSONRPCNotification) {
|
||||
if notification.method == "codex/event/user_message" {
|
||||
self.pending_user_messages.push_back(notification);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use app_test_support::to_response;
|
||||
use codex_app_server_protocol::AddConversationListenerParams;
|
||||
use codex_app_server_protocol::AddConversationSubscriptionResponse;
|
||||
use codex_app_server_protocol::ExecCommandApprovalParams;
|
||||
use codex_app_server_protocol::InputItem;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::NewConversationParams;
|
||||
@@ -25,6 +26,10 @@ use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_core::protocol_config_types::ReasoningSummary;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::InputMessageKind;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::env;
|
||||
use tempfile::TempDir;
|
||||
@@ -367,6 +372,234 @@ async fn test_send_user_turn_changes_approval_policy_behavior() {
|
||||
}
|
||||
|
||||
// Helper: minimal config.toml pointing at mock provider.
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let tmp = TempDir::new().expect("tmp dir");
|
||||
let codex_home = tmp.path().join("codex_home");
|
||||
std::fs::create_dir(&codex_home).expect("create codex home dir");
|
||||
let workspace_root = tmp.path().join("workspace");
|
||||
std::fs::create_dir(&workspace_root).expect("create workspace root");
|
||||
let first_cwd = workspace_root.join("turn1");
|
||||
let second_cwd = workspace_root.join("turn2");
|
||||
std::fs::create_dir(&first_cwd).expect("create first cwd");
|
||||
std::fs::create_dir(&second_cwd).expect("create second cwd");
|
||||
|
||||
let responses = vec![
|
||||
create_shell_sse_response(
|
||||
vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo first turn".to_string(),
|
||||
],
|
||||
None,
|
||||
Some(5000),
|
||||
"call-first",
|
||||
)
|
||||
.expect("create first shell response"),
|
||||
create_final_assistant_message_sse_response("done first")
|
||||
.expect("create first final assistant message"),
|
||||
create_shell_sse_response(
|
||||
vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo second turn".to_string(),
|
||||
],
|
||||
None,
|
||||
Some(5000),
|
||||
"call-second",
|
||||
)
|
||||
.expect("create second shell response"),
|
||||
create_final_assistant_message_sse_response("done second")
|
||||
.expect("create second final assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri()).expect("write config");
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home)
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
let new_conv_id = mcp
|
||||
.send_new_conversation_request(NewConversationParams {
|
||||
cwd: Some(first_cwd.to_string_lossy().into_owned()),
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox: Some(SandboxMode::WorkspaceWrite),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send newConversation");
|
||||
let new_conv_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(new_conv_id)),
|
||||
)
|
||||
.await
|
||||
.expect("newConversation timeout")
|
||||
.expect("newConversation resp");
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
..
|
||||
} = to_response::<NewConversationResponse>(new_conv_resp)
|
||||
.expect("deserialize newConversation response");
|
||||
|
||||
let add_listener_id = mcp
|
||||
.send_add_conversation_listener_request(AddConversationListenerParams { conversation_id })
|
||||
.await
|
||||
.expect("send addConversationListener");
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)),
|
||||
)
|
||||
.await
|
||||
.expect("addConversationListener timeout")
|
||||
.expect("addConversationListener resp");
|
||||
|
||||
let first_turn_id = mcp
|
||||
.send_send_user_turn_request(SendUserTurnParams {
|
||||
conversation_id,
|
||||
items: vec![InputItem::Text {
|
||||
text: "first turn".to_string(),
|
||||
}],
|
||||
cwd: first_cwd.clone(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![first_cwd.clone()],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
},
|
||||
model: model.clone(),
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await
|
||||
.expect("send first sendUserTurn");
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(first_turn_id)),
|
||||
)
|
||||
.await
|
||||
.expect("sendUserTurn 1 timeout")
|
||||
.expect("sendUserTurn 1 resp");
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/task_complete"),
|
||||
)
|
||||
.await
|
||||
.expect("task_complete 1 timeout")
|
||||
.expect("task_complete 1 notification");
|
||||
|
||||
let second_turn_id = mcp
|
||||
.send_send_user_turn_request(SendUserTurnParams {
|
||||
conversation_id,
|
||||
items: vec![InputItem::Text {
|
||||
text: "second turn".to_string(),
|
||||
}],
|
||||
cwd: second_cwd.clone(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: model.clone(),
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await
|
||||
.expect("send second sendUserTurn");
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(second_turn_id)),
|
||||
)
|
||||
.await
|
||||
.expect("sendUserTurn 2 timeout")
|
||||
.expect("sendUserTurn 2 resp");
|
||||
|
||||
let mut env_message: Option<String> = None;
|
||||
let second_cwd_str = second_cwd.to_string_lossy().into_owned();
|
||||
for _ in 0..10 {
|
||||
let notification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/user_message"),
|
||||
)
|
||||
.await
|
||||
.expect("user_message timeout")
|
||||
.expect("user_message notification");
|
||||
let params = notification
|
||||
.params
|
||||
.clone()
|
||||
.expect("user_message should include params");
|
||||
let event: Event = serde_json::from_value(params).expect("deserialize user_message event");
|
||||
if let EventMsg::UserMessage(user) = event.msg
|
||||
&& matches!(user.kind, Some(InputMessageKind::EnvironmentContext))
|
||||
&& user.message.contains(&second_cwd_str)
|
||||
{
|
||||
env_message = Some(user.message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let env_message = env_message.expect("expected environment context update");
|
||||
assert!(
|
||||
env_message.contains("<sandbox_mode>danger-full-access</sandbox_mode>"),
|
||||
"env context should reflect new sandbox mode: {env_message}"
|
||||
);
|
||||
assert!(
|
||||
env_message.contains("<network_access>enabled</network_access>"),
|
||||
"env context should enable network access for danger-full-access policy: {env_message}"
|
||||
);
|
||||
assert!(
|
||||
env_message.contains(&second_cwd_str),
|
||||
"env context should include updated cwd: {env_message}"
|
||||
);
|
||||
|
||||
let exec_begin_notification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/exec_command_begin"),
|
||||
)
|
||||
.await
|
||||
.expect("exec_command_begin timeout")
|
||||
.expect("exec_command_begin notification");
|
||||
let params = exec_begin_notification
|
||||
.params
|
||||
.clone()
|
||||
.expect("exec_command_begin params");
|
||||
let event: Event = serde_json::from_value(params).expect("deserialize exec begin event");
|
||||
let exec_begin = match event.msg {
|
||||
EventMsg::ExecCommandBegin(exec_begin) => exec_begin,
|
||||
other => panic!("expected ExecCommandBegin event, got {other:?}"),
|
||||
};
|
||||
assert_eq!(
|
||||
exec_begin.cwd, second_cwd,
|
||||
"exec turn should run from updated cwd"
|
||||
);
|
||||
assert_eq!(
|
||||
exec_begin.command,
|
||||
vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo second turn".to_string()
|
||||
],
|
||||
"exec turn should run expected command"
|
||||
);
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/task_complete"),
|
||||
)
|
||||
.await
|
||||
.expect("task_complete 2 timeout")
|
||||
.expect("task_complete 2 notification");
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
@@ -9,30 +11,41 @@ use tokio::time::timeout;
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_fuzzy_file_search_sorts_and_includes_indices() {
|
||||
async fn test_fuzzy_file_search_sorts_and_includes_indices() -> Result<()> {
|
||||
// Prepare a temporary Codex home and a separate root with test files.
|
||||
let codex_home = TempDir::new().expect("create temp codex home");
|
||||
let root = TempDir::new().expect("create temp search root");
|
||||
let codex_home = TempDir::new().context("create temp codex home")?;
|
||||
let root = TempDir::new().context("create temp search root")?;
|
||||
|
||||
// Create files designed to have deterministic ordering for query "abc".
|
||||
std::fs::write(root.path().join("abc"), "x").expect("write file abc");
|
||||
std::fs::write(root.path().join("abcde"), "x").expect("write file abcx");
|
||||
std::fs::write(root.path().join("abexy"), "x").expect("write file abcx");
|
||||
std::fs::write(root.path().join("zzz.txt"), "x").expect("write file zzz");
|
||||
// Create files designed to have deterministic ordering for query "abe".
|
||||
std::fs::write(root.path().join("abc"), "x").context("write file abc")?;
|
||||
std::fs::write(root.path().join("abcde"), "x").context("write file abcde")?;
|
||||
std::fs::write(root.path().join("abexy"), "x").context("write file abexy")?;
|
||||
std::fs::write(root.path().join("zzz.txt"), "x").context("write file zzz")?;
|
||||
let sub_dir = root.path().join("sub");
|
||||
std::fs::create_dir_all(&sub_dir).context("create sub dir")?;
|
||||
let sub_abce_path = sub_dir.join("abce");
|
||||
std::fs::write(&sub_abce_path, "x").context("write file sub/abce")?;
|
||||
let sub_abce_rel = sub_abce_path
|
||||
.strip_prefix(root.path())
|
||||
.context("strip root prefix from sub/abce")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
// Start MCP server and initialize.
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await.expect("spawn mcp");
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.context("spawn mcp")?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
.context("init timeout")?
|
||||
.context("init failed")?;
|
||||
|
||||
let root_path = root.path().to_string_lossy().to_string();
|
||||
// Send fuzzyFileSearch request.
|
||||
let request_id = mcp
|
||||
.send_fuzzy_file_search_request("abe", vec![root_path.clone()], None)
|
||||
.await
|
||||
.expect("send fuzzyFileSearch");
|
||||
.context("send fuzzyFileSearch")?;
|
||||
|
||||
// Read response and verify shape and ordering.
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
@@ -40,39 +53,65 @@ async fn test_fuzzy_file_search_sorts_and_includes_indices() {
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("fuzzyFileSearch timeout")
|
||||
.expect("fuzzyFileSearch resp");
|
||||
.context("fuzzyFileSearch timeout")?
|
||||
.context("fuzzyFileSearch resp")?;
|
||||
|
||||
let value = resp.result;
|
||||
// The path separator on Windows affects the score.
|
||||
let expected_score = if cfg!(windows) { 69 } else { 72 };
|
||||
|
||||
assert_eq!(
|
||||
value,
|
||||
json!({
|
||||
"files": [
|
||||
{ "root": root_path.clone(), "path": "abexy", "score": 88, "indices": [0, 1, 2] },
|
||||
{ "root": root_path.clone(), "path": "abcde", "score": 74, "indices": [0, 1, 4] },
|
||||
{
|
||||
"root": root_path.clone(),
|
||||
"path": "abexy",
|
||||
"file_name": "abexy",
|
||||
"score": 88,
|
||||
"indices": [0, 1, 2],
|
||||
},
|
||||
{
|
||||
"root": root_path.clone(),
|
||||
"path": "abcde",
|
||||
"file_name": "abcde",
|
||||
"score": 74,
|
||||
"indices": [0, 1, 4],
|
||||
},
|
||||
{
|
||||
"root": root_path.clone(),
|
||||
"path": sub_abce_rel,
|
||||
"file_name": "abce",
|
||||
"score": expected_score,
|
||||
"indices": [4, 5, 7],
|
||||
},
|
||||
]
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_fuzzy_file_search_accepts_cancellation_token() {
|
||||
let codex_home = TempDir::new().expect("create temp codex home");
|
||||
let root = TempDir::new().expect("create temp search root");
|
||||
async fn test_fuzzy_file_search_accepts_cancellation_token() -> Result<()> {
|
||||
let codex_home = TempDir::new().context("create temp codex home")?;
|
||||
let root = TempDir::new().context("create temp search root")?;
|
||||
|
||||
std::fs::write(root.path().join("alpha.txt"), "contents").expect("write alpha");
|
||||
std::fs::write(root.path().join("alpha.txt"), "contents").context("write alpha")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await.expect("spawn mcp");
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.context("spawn mcp")?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
.context("init timeout")?
|
||||
.context("init failed")?;
|
||||
|
||||
let root_path = root.path().to_string_lossy().to_string();
|
||||
let request_id = mcp
|
||||
.send_fuzzy_file_search_request("alp", vec![root_path.clone()], None)
|
||||
.await
|
||||
.expect("send fuzzyFileSearch");
|
||||
.context("send fuzzyFileSearch")?;
|
||||
|
||||
let request_id_2 = mcp
|
||||
.send_fuzzy_file_search_request(
|
||||
@@ -81,24 +120,27 @@ async fn test_fuzzy_file_search_accepts_cancellation_token() {
|
||||
Some(request_id.to_string()),
|
||||
)
|
||||
.await
|
||||
.expect("send fuzzyFileSearch");
|
||||
.context("send fuzzyFileSearch")?;
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id_2)),
|
||||
)
|
||||
.await
|
||||
.expect("fuzzyFileSearch timeout")
|
||||
.expect("fuzzyFileSearch resp");
|
||||
.context("fuzzyFileSearch timeout")?
|
||||
.context("fuzzyFileSearch resp")?;
|
||||
|
||||
let files = resp
|
||||
.result
|
||||
.get("files")
|
||||
.and_then(|value| value.as_array())
|
||||
.cloned()
|
||||
.expect("files array");
|
||||
.context("files key missing")?
|
||||
.as_array()
|
||||
.context("files not array")?
|
||||
.clone();
|
||||
|
||||
assert_eq!(files.len(), 1);
|
||||
assert_eq!(files[0]["root"], root_path);
|
||||
assert_eq!(files[0]["path"], "alpha.txt");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -23,5 +23,6 @@ tree-sitter-bash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -843,6 +843,7 @@ pub fn print_summary(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::string::ToString;
|
||||
@@ -894,10 +895,10 @@ mod tests {
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch(&args),
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -905,10 +906,10 @@ mod tests {
|
||||
let patch = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch".to_string();
|
||||
let args = vec![patch];
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -916,10 +917,10 @@ mod tests {
|
||||
let script = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch";
|
||||
let args = args_bash(script);
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -29,7 +29,8 @@ pub async fn run_apply_command(
|
||||
.parse_overrides()
|
||||
.map_err(anyhow::Error::msg)?,
|
||||
ConfigOverrides::default(),
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ codex-app-server-protocol = { workspace = true }
|
||||
codex-protocol-ts = { workspace = true }
|
||||
codex-responses-api-proxy = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-cloud-tasks = { path = "../cloud-tasks" }
|
||||
ctor = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
@@ -46,6 +47,7 @@ tokio = { workspace = true, features = [
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -73,7 +73,8 @@ async fn run_command_under_sandbox(
|
||||
codex_linux_sandbox_exe,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
// In practice, this should be `std::env::current_dir()` because this CLI
|
||||
// does not support `--cwd`, but let's use the config value for consistency.
|
||||
|
||||
@@ -9,6 +9,8 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::run_device_code_login;
|
||||
use codex_login::run_login_server;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
@@ -24,7 +26,7 @@ pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
}
|
||||
|
||||
pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match login_with_chatgpt(config.codex_home).await {
|
||||
Ok(_) => {
|
||||
@@ -42,7 +44,7 @@ pub async fn run_login_with_api_key(
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
api_key: String,
|
||||
) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match login_with_api_key(&config.codex_home, &api_key) {
|
||||
Ok(_) => {
|
||||
@@ -56,13 +58,40 @@ pub async fn run_login_with_api_key(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_api_key_from_stdin() -> String {
|
||||
let mut stdin = std::io::stdin();
|
||||
|
||||
if stdin.is_terminal() {
|
||||
eprintln!(
|
||||
"--with-api-key expects the API key on stdin. Try piping it, e.g. `printenv OPENAI_API_KEY | codex login --with-api-key`."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
eprintln!("Reading API key from stdin...");
|
||||
|
||||
let mut buffer = String::new();
|
||||
if let Err(err) = stdin.read_to_string(&mut buffer) {
|
||||
eprintln!("Failed to read API key from stdin: {err}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let api_key = buffer.trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
eprintln!("No API key provided via stdin.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
api_key
|
||||
}
|
||||
|
||||
/// Login using the OAuth device code flow.
|
||||
pub async fn run_login_with_device_code(
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
issuer_base_url: Option<String>,
|
||||
client_id: Option<String>,
|
||||
) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
let mut opts = ServerOptions::new(
|
||||
config.codex_home,
|
||||
client_id.unwrap_or(CLIENT_ID.to_string()),
|
||||
@@ -83,7 +112,7 @@ pub async fn run_login_with_device_code(
|
||||
}
|
||||
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
@@ -114,7 +143,7 @@ pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
}
|
||||
|
||||
pub async fn run_logout(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
let config = load_config_or_exit(cli_config_overrides).await;
|
||||
|
||||
match logout(&config.codex_home) {
|
||||
Ok(true) => {
|
||||
@@ -132,7 +161,7 @@ pub async fn run_logout(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
async fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
let cli_overrides = match cli_config_overrides.parse_overrides() {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
@@ -142,7 +171,7 @@ fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config {
|
||||
};
|
||||
|
||||
let config_overrides = ConfigOverrides::default();
|
||||
match Config::load_with_cli_overrides(cli_overrides, config_overrides) {
|
||||
match Config::load_with_cli_overrides(cli_overrides, config_overrides).await {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
eprintln!("Error loading configuration: {e}");
|
||||
|
||||
@@ -7,6 +7,7 @@ use codex_chatgpt::apply_command::ApplyCommand;
|
||||
use codex_chatgpt::apply_command::run_apply_command;
|
||||
use codex_cli::LandlockCommand;
|
||||
use codex_cli::SeatbeltCommand;
|
||||
use codex_cli::login::read_api_key_from_stdin;
|
||||
use codex_cli::login::run_login_status;
|
||||
use codex_cli::login::run_login_with_api_key;
|
||||
use codex_cli::login::run_login_with_chatgpt;
|
||||
@@ -75,8 +76,9 @@ enum Subcommand {
|
||||
/// Generate shell completion scripts.
|
||||
Completion(CompletionCommand),
|
||||
|
||||
/// Internal debugging commands.
|
||||
Debug(DebugArgs),
|
||||
/// Run commands within a Codex-provided sandbox.
|
||||
#[clap(visible_alias = "debug")]
|
||||
Sandbox(SandboxArgs),
|
||||
|
||||
/// Apply the latest diff produced by Codex agent as a `git apply` to your local working tree.
|
||||
#[clap(visible_alias = "a")]
|
||||
@@ -120,18 +122,20 @@ struct ResumeCommand {
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct DebugArgs {
|
||||
struct SandboxArgs {
|
||||
#[command(subcommand)]
|
||||
cmd: DebugCommand,
|
||||
cmd: SandboxCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
enum DebugCommand {
|
||||
enum SandboxCommand {
|
||||
/// Run a command under Seatbelt (macOS only).
|
||||
Seatbelt(SeatbeltCommand),
|
||||
#[clap(visible_alias = "seatbelt")]
|
||||
Macos(SeatbeltCommand),
|
||||
|
||||
/// Run a command under Landlock+seccomp (Linux only).
|
||||
Landlock(LandlockCommand),
|
||||
#[clap(visible_alias = "landlock")]
|
||||
Linux(LandlockCommand),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -139,7 +143,18 @@ struct LoginCommand {
|
||||
#[clap(skip)]
|
||||
config_overrides: CliConfigOverrides,
|
||||
|
||||
#[arg(long = "api-key", value_name = "API_KEY")]
|
||||
#[arg(
|
||||
long = "with-api-key",
|
||||
help = "Read the API key from stdin (e.g. `printenv OPENAI_API_KEY | codex login --with-api-key`)"
|
||||
)]
|
||||
with_api_key: bool,
|
||||
|
||||
#[arg(
|
||||
long = "api-key",
|
||||
value_name = "API_KEY",
|
||||
help = "(deprecated) Previously accepted the API key directly; now exits with guidance to use --with-api-key",
|
||||
hide = true
|
||||
)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// EXPERIMENTAL: Use device code flow (not yet supported)
|
||||
@@ -298,7 +313,13 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
login_cli.client_id,
|
||||
)
|
||||
.await;
|
||||
} else if let Some(api_key) = login_cli.api_key {
|
||||
} else if login_cli.api_key.is_some() {
|
||||
eprintln!(
|
||||
"The --api-key flag is no longer supported. Pipe the key instead, e.g. `printenv OPENAI_API_KEY | codex login --with-api-key`."
|
||||
);
|
||||
std::process::exit(1);
|
||||
} else if login_cli.with_api_key {
|
||||
let api_key = read_api_key_from_stdin();
|
||||
run_login_with_api_key(login_cli.config_overrides, api_key).await;
|
||||
} else {
|
||||
run_login_with_chatgpt(login_cli.config_overrides).await;
|
||||
@@ -323,8 +344,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
);
|
||||
codex_cloud_tasks::run_main(cloud_cli, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Debug(debug_args)) => match debug_args.cmd {
|
||||
DebugCommand::Seatbelt(mut seatbelt_cli) => {
|
||||
Some(Subcommand::Sandbox(sandbox_args)) => match sandbox_args.cmd {
|
||||
SandboxCommand::Macos(mut seatbelt_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut seatbelt_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -335,7 +356,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
DebugCommand::Landlock(mut landlock_cli) => {
|
||||
SandboxCommand::Linux(mut landlock_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut landlock_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -454,6 +475,7 @@ fn print_completion(cmd: CompletionCommand) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
@@ -586,14 +608,14 @@ mod tests {
|
||||
assert_eq!(interactive.model.as_deref(), Some("gpt-5-test"));
|
||||
assert!(interactive.oss);
|
||||
assert_eq!(interactive.config_profile.as_deref(), Some("my-profile"));
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
interactive.sandbox_mode,
|
||||
Some(codex_common::SandboxModeCliArg::WorkspaceWrite)
|
||||
));
|
||||
assert!(matches!(
|
||||
);
|
||||
assert_matches!(
|
||||
interactive.approval_policy,
|
||||
Some(codex_common::ApprovalModeCliArg::OnRequest)
|
||||
));
|
||||
);
|
||||
assert!(interactive.full_auto);
|
||||
assert_eq!(
|
||||
interactive.cwd.as_deref(),
|
||||
|
||||
@@ -12,6 +12,8 @@ use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use codex_rmcp_client::delete_oauth_tokens;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
|
||||
/// [experimental] Launch Codex as an MCP server or manage configured MCP servers.
|
||||
///
|
||||
@@ -43,6 +45,14 @@ pub enum McpSubcommand {
|
||||
|
||||
/// [experimental] Remove a global MCP server entry.
|
||||
Remove(RemoveArgs),
|
||||
|
||||
/// [experimental] Authenticate with a configured MCP server via OAuth.
|
||||
/// Requires experimental_use_rmcp_client = true in config.toml.
|
||||
Login(LoginArgs),
|
||||
|
||||
/// [experimental] Remove stored OAuth credentials for a server.
|
||||
/// Requires experimental_use_rmcp_client = true in config.toml.
|
||||
Logout(LogoutArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -82,6 +92,18 @@ pub struct RemoveArgs {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct LoginArgs {
|
||||
/// Name of the MCP server to authenticate with oauth.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct LogoutArgs {
|
||||
/// Name of the MCP server to deauthenticate.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl McpCli {
|
||||
pub async fn run(self) -> Result<()> {
|
||||
let McpCli {
|
||||
@@ -91,16 +113,22 @@ impl McpCli {
|
||||
|
||||
match subcommand {
|
||||
McpSubcommand::List(args) => {
|
||||
run_list(&config_overrides, args)?;
|
||||
run_list(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Get(args) => {
|
||||
run_get(&config_overrides, args)?;
|
||||
run_get(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Add(args) => {
|
||||
run_add(&config_overrides, args)?;
|
||||
run_add(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Remove(args) => {
|
||||
run_remove(&config_overrides, args)?;
|
||||
run_remove(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Login(args) => {
|
||||
run_login(&config_overrides, args).await?;
|
||||
}
|
||||
McpSubcommand::Logout(args) => {
|
||||
run_logout(&config_overrides, args).await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,7 +136,7 @@ impl McpCli {
|
||||
}
|
||||
}
|
||||
|
||||
fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
@@ -134,6 +162,7 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<(
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
@@ -156,7 +185,7 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) -> Result<()> {
|
||||
async fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) -> Result<()> {
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let RemoveArgs { name } = remove_args;
|
||||
@@ -165,6 +194,7 @@ fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) ->
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let removed = servers.remove(&name).is_some();
|
||||
@@ -183,9 +213,65 @@ fn run_remove(config_overrides: &CliConfigOverrides, remove_args: RemoveArgs) ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
if !config.use_experimental_use_rmcp_client {
|
||||
bail!(
|
||||
"OAuth login is only supported when experimental_use_rmcp_client is true in config.toml."
|
||||
);
|
||||
}
|
||||
|
||||
let LoginArgs { name } = login_args;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&name) else {
|
||||
bail!("No MCP server named '{name}' found.");
|
||||
};
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url).await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let LogoutArgs { name } = logout_args;
|
||||
|
||||
let server = config
|
||||
.mcp_servers
|
||||
.get(&name)
|
||||
.ok_or_else(|| anyhow!("No MCP server named '{name}' found in configuration."))?;
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let mut entries: Vec<_> = config.mcp_servers.iter().collect();
|
||||
@@ -343,9 +429,10 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<()> {
|
||||
async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<()> {
|
||||
let overrides = config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
let config = Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load configuration")?;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&get_args.name) else {
|
||||
|
||||
@@ -13,8 +13,8 @@ fn codex_command(codex_home: &Path) -> Result<assert_cmd::Command> {
|
||||
Ok(cmd)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
@@ -24,7 +24,7 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("Added global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert_eq!(servers.len(), 1);
|
||||
let docs = servers.get("docs").expect("server should exist");
|
||||
match &docs.transport {
|
||||
@@ -43,7 +43,7 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("Removed global MCP server 'docs'."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
let mut remove_again_cmd = codex_command(codex_home.path())?;
|
||||
@@ -53,14 +53,14 @@ fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
.success()
|
||||
.stdout(contains("No MCP server named 'docs' found."));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
@@ -80,7 +80,7 @@ fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let envy = servers.get("envy").expect("server should exist");
|
||||
let env = match &envy.transport {
|
||||
McpServerTransportConfig::Stdio { env: Some(env), .. } => env,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-cloud-tasks"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "codex_cloud_tasks"
|
||||
@@ -12,25 +12,27 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = ["mock", "online"] }
|
||||
ratatui = { version = "0.29.0" }
|
||||
crossterm = { version = "0.28.1", features = ["event-stream"] }
|
||||
tokio-stream = "0.1.17"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-login = { path = "../login" }
|
||||
codex-core = { path = "../core" }
|
||||
throbber-widgets-tui = "0.8.0"
|
||||
base64 = "0.22"
|
||||
serde_json = "1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = [
|
||||
"mock",
|
||||
"online",
|
||||
] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-tui = { path = "../tui" }
|
||||
crossterm = { version = "0.28.1", features = ["event-stream"] }
|
||||
ratatui = { version = "0.29.0" }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
tokio-stream = "0.1.17"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
unicode-width = "0.1"
|
||||
codex-tui = { path = "../tui" }
|
||||
|
||||
[dev-dependencies]
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
// Environment filter data models for the TUI
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -42,15 +43,13 @@ use crate::scrollable_diff::ScrollableDiff;
|
||||
use codex_cloud_tasks_client::CloudBackend;
|
||||
use codex_cloud_tasks_client::TaskId;
|
||||
use codex_cloud_tasks_client::TaskSummary;
|
||||
use throbber_widgets_tui::ThrobberState;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct App {
|
||||
pub tasks: Vec<TaskSummary>,
|
||||
pub selected: usize,
|
||||
pub status: String,
|
||||
pub diff_overlay: Option<DiffOverlay>,
|
||||
pub throbber: ThrobberState,
|
||||
pub spinner_start: Option<Instant>,
|
||||
pub refresh_inflight: bool,
|
||||
pub details_inflight: bool,
|
||||
// Environment filter state
|
||||
@@ -82,7 +81,7 @@ impl App {
|
||||
selected: 0,
|
||||
status: "Press r to refresh".to_string(),
|
||||
diff_overlay: None,
|
||||
throbber: ThrobberState::default(),
|
||||
spinner_start: None,
|
||||
refresh_inflight: false,
|
||||
details_inflight: false,
|
||||
env_filter: None,
|
||||
|
||||
@@ -400,16 +400,20 @@ pub async fn run_main(_cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> a
|
||||
let _ = frame_tx.send(Instant::now() + codex_tui::ComposerInput::recommended_flush_delay());
|
||||
}
|
||||
}
|
||||
// Advance throbber only while loading.
|
||||
// Keep spinner pulsing only while loading.
|
||||
if app.refresh_inflight
|
||||
|| app.details_inflight
|
||||
|| app.env_loading
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
app.throbber.calc_next();
|
||||
if app.spinner_start.is_none() {
|
||||
app.spinner_start = Some(Instant::now());
|
||||
}
|
||||
needs_redraw = true;
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(100));
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(600));
|
||||
} else {
|
||||
app.spinner_start = None;
|
||||
}
|
||||
render_if_needed(&mut terminal, &mut app, &mut needs_redraw)?;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use ratatui::widgets::ListState;
|
||||
use ratatui::widgets::Padding;
|
||||
use ratatui::widgets::Paragraph;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::app::AttemptView;
|
||||
@@ -229,7 +230,7 @@ fn draw_list(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|
||||
// In-box spinner during initial/refresh loads
|
||||
if app.refresh_inflight {
|
||||
draw_centered_spinner(frame, inner, &mut app.throbber, "Loading tasks…");
|
||||
draw_centered_spinner(frame, inner, &mut app.spinner_start, "Loading tasks…");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +292,7 @@ fn draw_footer(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
draw_inline_spinner(frame, top[1], &mut app.throbber, "Loading…");
|
||||
draw_inline_spinner(frame, top[1], &mut app.spinner_start, "Loading…");
|
||||
} else {
|
||||
frame.render_widget(Clear, top[1]);
|
||||
}
|
||||
@@ -449,7 +450,12 @@ fn draw_diff_overlay(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
.map(|o| o.sd.wrapped_lines().is_empty())
|
||||
.unwrap_or(true);
|
||||
if app.details_inflight && raw_empty {
|
||||
draw_centered_spinner(frame, content_area, &mut app.throbber, "Loading details…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content_area,
|
||||
&mut app.spinner_start,
|
||||
"Loading details…",
|
||||
);
|
||||
} else {
|
||||
let scroll = app
|
||||
.diff_overlay
|
||||
@@ -494,11 +500,11 @@ pub fn draw_apply_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
frame.render_widget(header, rows[0]);
|
||||
// Body: spinner while preflight/apply runs; otherwise show result message and path lists
|
||||
if app.apply_preflight_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Checking…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Checking…");
|
||||
} else if app.apply_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Applying…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Applying…");
|
||||
} else if m.result_message.is_none() {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Loading…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Loading…");
|
||||
} else if let Some(msg) = &m.result_message {
|
||||
let mut body_lines: Vec<Line> = Vec::new();
|
||||
let first = match m.result_level {
|
||||
@@ -859,29 +865,29 @@ fn format_relative_time(ts: chrono::DateTime<Utc>) -> String {
|
||||
fn draw_inline_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
use ratatui::style::Style;
|
||||
use throbber_widgets_tui::BRAILLE_EIGHT;
|
||||
use throbber_widgets_tui::Throbber;
|
||||
use throbber_widgets_tui::WhichUse;
|
||||
let w = Throbber::default()
|
||||
.label(label)
|
||||
.style(Style::default().cyan())
|
||||
.throbber_style(Style::default().magenta().bold())
|
||||
.throbber_set(BRAILLE_EIGHT)
|
||||
.use_type(WhichUse::Spin);
|
||||
frame.render_stateful_widget(w, area, state);
|
||||
use ratatui::widgets::Paragraph;
|
||||
let start = spinner_start.get_or_insert_with(Instant::now);
|
||||
let blink_on = (start.elapsed().as_millis() / 600).is_multiple_of(2);
|
||||
let dot = if blink_on {
|
||||
"• ".into()
|
||||
} else {
|
||||
"◦ ".dim()
|
||||
};
|
||||
let label = label.cyan();
|
||||
let line = Line::from(vec![dot, label]);
|
||||
frame.render_widget(Paragraph::new(line), area);
|
||||
}
|
||||
|
||||
fn draw_centered_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
// Center a 1xN throbber within the given rect
|
||||
// Center a 1xN spinner within the given rect
|
||||
let rows = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
@@ -898,7 +904,7 @@ fn draw_centered_spinner(
|
||||
Constraint::Percentage(50),
|
||||
])
|
||||
.split(rows[1]);
|
||||
draw_inline_spinner(frame, cols[1], state, label);
|
||||
draw_inline_spinner(frame, cols[1], spinner_start, label);
|
||||
}
|
||||
|
||||
// Styling helpers for diff rendering live inline where used.
|
||||
@@ -918,7 +924,12 @@ pub fn draw_env_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
let content = overlay_content(inner);
|
||||
|
||||
if app.env_loading {
|
||||
draw_centered_spinner(frame, content, &mut app.throbber, "Loading environments…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content,
|
||||
&mut app.spinner_start,
|
||||
"Loading environments…",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,14 @@ async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-mcp-client = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-otel = { workspace = true, features = ["otel"] }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-utils-string = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
dunce = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
@@ -75,6 +76,24 @@ wildmatch = { workspace = true }
|
||||
landlock = { workspace = true }
|
||||
seccompiler = { workspace = true }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
core-foundation = "0.9"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows = { version = "0.58", features = [
|
||||
"Win32_Foundation",
|
||||
"Win32_Security_Isolation",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Authorization",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_System_Memory",
|
||||
"Win32_System_Threading",
|
||||
] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
windows_appcontainer_command_ext = []
|
||||
|
||||
# Build OpenSSL from source for musl builds.
|
||||
[target.x86_64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
@@ -85,16 +104,18 @@ openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
maplit = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
tracing-test = { workspace = true, features = ["no-env-filter"] }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
tracing-test = { workspace = true, features = ["no-env-filter"] }
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
@@ -12,7 +12,7 @@ Expects `/usr/bin/sandbox-exec` to be present.
|
||||
|
||||
### Linux
|
||||
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex debug landlock` when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex sandbox linux` (legacy alias: `codex debug landlock`) when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
|
||||
### All Platforms
|
||||
|
||||
|
||||
@@ -10,12 +10,14 @@ You are Codex, based on GPT-5. You are running as a coding agent in the Codex CL
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ pub(crate) enum InternalApplyPatchInvocation {
|
||||
DelegateToExec(ApplyPatchExec),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ApplyPatchExec {
|
||||
pub(crate) action: ApplyPatchAction,
|
||||
pub(crate) user_explicitly_approved_this_action: bool,
|
||||
@@ -109,3 +110,28 @@ pub(crate) fn convert_apply_patch_to_protocol(
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn convert_apply_patch_maps_add_variant() {
|
||||
let tmp = tempdir().expect("tmp");
|
||||
let p = tmp.path().join("a.txt");
|
||||
// Create an action with a single Add change
|
||||
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
|
||||
|
||||
let got = convert_apply_patch_to_protocol(&action);
|
||||
|
||||
assert_eq!(
|
||||
got.get(&p),
|
||||
Some(&FileChange::Add {
|
||||
content: "hello".to_string()
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ struct ErrorResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
|
||||
@@ -228,7 +227,7 @@ impl ModelClient {
|
||||
input: &input_with_instructions,
|
||||
tools: &tools_json,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: prompt.parallel_tool_calls,
|
||||
reasoning,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
@@ -794,9 +793,13 @@ async fn process_sse<S>(
|
||||
if let Some(error) = error {
|
||||
match serde_json::from_value::<Error>(error.clone()) {
|
||||
Ok(error) => {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(CodexErr::ContextWindowExceeded);
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ErrorResponse: {e}");
|
||||
@@ -922,9 +925,14 @@ fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
@@ -1179,6 +1187,74 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_with_newline_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ────────────────────────────
|
||||
// Table-driven test from `main`
|
||||
// ────────────────────────────
|
||||
@@ -1316,10 +1392,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro)));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
@@ -1334,7 +1407,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip");
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
@@ -9,9 +9,11 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
@@ -29,7 +31,10 @@ pub struct Prompt {
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<OpenAiTool>,
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
@@ -49,8 +54,8 @@ impl Prompt {
|
||||
// AND
|
||||
// - there is no apply_patch tool present
|
||||
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
OpenAiTool::Function(f) => f.name == "apply_patch",
|
||||
OpenAiTool::Freeform(f) => f.name == "apply_patch",
|
||||
ToolSpec::Function(f) => f.name == "apply_patch",
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if self.base_instructions_override.is_none()
|
||||
@@ -64,10 +69,125 @@ impl Prompt {
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
@@ -160,6 +280,65 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub(crate) mod tools {
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum ToolSpec {
|
||||
#[serde(rename = "function")]
|
||||
Function(ResponsesApiTool),
|
||||
#[serde(rename = "local_shell")]
|
||||
LocalShell {},
|
||||
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) format: FreeformToolFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformToolFormat {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) syntax: String,
|
||||
pub(crate) definition: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
pub struct ResponsesApiTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
/// TODO: Validation. When strict is set to true, the JSON schema,
|
||||
/// `required` and `additional_properties` must be present. All fields in
|
||||
/// `properties` must be present in `required`.
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_reasoning_param_for_request(
|
||||
model_family: &ModelFamily,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
@@ -279,7 +458,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -320,7 +499,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -356,7 +535,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -103,6 +103,18 @@ async fn run_compact_task_inner(
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
use crate::config_loader::LoadedConfigLayers;
|
||||
pub use crate::config_loader::load_config_as_toml;
|
||||
use crate::config_loader::load_config_layers_with_overrides;
|
||||
use crate::config_loader::merge_toml_values;
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use crate::config_types::DEFAULT_OTEL_ENVIRONMENT;
|
||||
use crate::config_types::History;
|
||||
@@ -22,6 +26,7 @@ use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::safety::set_windows_sandbox_enabled;
|
||||
use anyhow::Context;
|
||||
use codex_app_server_protocol::Tools;
|
||||
use codex_app_server_protocol::UserSavedConfig;
|
||||
@@ -42,7 +47,10 @@ use toml_edit::DocumentMut;
|
||||
use toml_edit::Item as TomlItem;
|
||||
use toml_edit::Table as TomlTable;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
#[cfg(target_os = "windows")]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
const OPENAI_DEFAULT_REVIEW_MODEL: &str = "gpt-5-codex";
|
||||
pub const GPT_5_CODEX_MEDIUM_MODEL: &str = "gpt-5-codex";
|
||||
|
||||
@@ -163,6 +171,9 @@ pub struct Config {
|
||||
/// When this program is invoked, arg0 will be set to `codex-linux-sandbox`.
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
|
||||
/// Enable the experimental Windows sandbox implementation.
|
||||
pub experimental_windows_sandbox: bool,
|
||||
|
||||
/// Value to use for `reasoning.effort` when making a request using the
|
||||
/// Responses API.
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
@@ -202,6 +213,9 @@ pub struct Config {
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: bool,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -212,50 +226,38 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration with *generic* CLI overrides (`-c key=value`) applied
|
||||
/// **in between** the values parsed from `config.toml` and the
|
||||
/// strongly-typed overrides specified via [`ConfigOverrides`].
|
||||
///
|
||||
/// The precedence order is therefore: `config.toml` < `-c` overrides <
|
||||
/// `ConfigOverrides`.
|
||||
pub fn load_with_cli_overrides(
|
||||
pub async fn load_with_cli_overrides(
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
overrides: ConfigOverrides,
|
||||
) -> std::io::Result<Self> {
|
||||
// Resolve the directory that stores Codex state (e.g. ~/.codex or the
|
||||
// value of $CODEX_HOME) so we can embed it into the resulting
|
||||
// `Config` instance.
|
||||
let codex_home = find_codex_home()?;
|
||||
|
||||
// Step 1: parse `config.toml` into a generic JSON value.
|
||||
let mut root_value = load_config_as_toml(&codex_home)?;
|
||||
let root_value = load_resolved_config(
|
||||
&codex_home,
|
||||
cli_overrides,
|
||||
crate::config_loader::LoaderOverrides::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Step 2: apply the `-c` overrides.
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut root_value, &path, value);
|
||||
}
|
||||
|
||||
// Step 3: deserialize into `ConfigToml` so that Serde can enforce the
|
||||
// correct types.
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
|
||||
// Step 4: merge with the strongly-typed overrides.
|
||||
Self::load_from_base_config_with_overrides(cfg, overrides, codex_home)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_config_as_toml_with_cli_overrides(
|
||||
pub async fn load_config_as_toml_with_cli_overrides(
|
||||
codex_home: &Path,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
) -> std::io::Result<ConfigToml> {
|
||||
let mut root_value = load_config_as_toml(codex_home)?;
|
||||
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut root_value, &path, value);
|
||||
}
|
||||
let root_value = load_resolved_config(
|
||||
codex_home,
|
||||
cli_overrides,
|
||||
crate::config_loader::LoaderOverrides::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
@@ -265,33 +267,40 @@ pub fn load_config_as_toml_with_cli_overrides(
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
/// Read `CODEX_HOME/config.toml` and return it as a generic TOML value. Returns
|
||||
/// an empty TOML table when the file does not exist.
|
||||
pub fn load_config_as_toml(codex_home: &Path) -> std::io::Result<TomlValue> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
match std::fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(val) => Ok(val),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse config.toml: {e}");
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
},
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
tracing::info!("config.toml not found, using defaults");
|
||||
Ok(TomlValue::Table(Default::default()))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read config.toml: {e}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
async fn load_resolved_config(
|
||||
codex_home: &Path,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
overrides: crate::config_loader::LoaderOverrides,
|
||||
) -> std::io::Result<TomlValue> {
|
||||
let layers = load_config_layers_with_overrides(codex_home, overrides).await?;
|
||||
Ok(apply_overlays(layers, cli_overrides))
|
||||
}
|
||||
|
||||
pub fn load_global_mcp_servers(
|
||||
fn apply_overlays(
|
||||
layers: LoadedConfigLayers,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
) -> TomlValue {
|
||||
let LoadedConfigLayers {
|
||||
mut base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = layers;
|
||||
|
||||
for (path, value) in cli_overrides.into_iter() {
|
||||
apply_toml_override(&mut base, &path, value);
|
||||
}
|
||||
|
||||
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
|
||||
merge_toml_values(&mut base, &overlay);
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
pub async fn load_global_mcp_servers(
|
||||
codex_home: &Path,
|
||||
) -> std::io::Result<BTreeMap<String, McpServerConfig>> {
|
||||
let root_value = load_config_as_toml(codex_home)?;
|
||||
let root_value = load_config_as_toml(codex_home).await?;
|
||||
let Some(servers_value) = root_value.get("mcp_servers") else {
|
||||
return Ok(BTreeMap::new());
|
||||
};
|
||||
@@ -469,6 +478,29 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist the acknowledgement flag for the Windows onboarding screen.
|
||||
pub fn set_windows_wsl_setup_acknowledged(
|
||||
codex_home: &Path,
|
||||
acknowledged: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
doc["windows_wsl_setup_acknowledged"] = toml_edit::value(acknowledged);
|
||||
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
std::fs::write(tmp_file.path(), doc.to_string())?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
@@ -722,6 +754,8 @@ pub struct ConfigToml {
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
pub experimental_windows_sandbox: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
@@ -735,6 +769,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// OTEL configuration.
|
||||
pub otel: Option<crate::config_types::OtelConfigToml>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ConfigToml> for UserSavedConfig {
|
||||
@@ -976,6 +1013,8 @@ impl Config {
|
||||
.or(cfg.tools.as_ref().and_then(|t| t.view_image))
|
||||
.unwrap_or(true);
|
||||
|
||||
let experimental_windows_sandbox = cfg.experimental_windows_sandbox.unwrap_or(false);
|
||||
|
||||
let model = model
|
||||
.or(config_profile.model)
|
||||
.or(cfg.model)
|
||||
@@ -1061,6 +1100,7 @@ impl Config {
|
||||
history,
|
||||
file_opener: cfg.file_opener.unwrap_or(UriBasedFileOpener::VsCode),
|
||||
codex_linux_sandbox_exe,
|
||||
experimental_windows_sandbox,
|
||||
|
||||
hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false),
|
||||
show_raw_agent_reasoning: cfg
|
||||
@@ -1080,7 +1120,9 @@ impl Config {
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool
|
||||
.or(cfg.experimental_use_freeform_apply_patch)
|
||||
.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
@@ -1091,6 +1133,7 @@ impl Config {
|
||||
use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
windows_wsl_setup_acknowledged: cfg.windows_wsl_setup_acknowledged.unwrap_or(false),
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
.tui
|
||||
@@ -1111,6 +1154,7 @@ impl Config {
|
||||
}
|
||||
},
|
||||
};
|
||||
set_windows_sandbox_enabled(config.experimental_windows_sandbox);
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -1329,18 +1373,18 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_round_trips_entries() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_round_trips_entries() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::new();
|
||||
@@ -1359,7 +1403,7 @@ exclude_slash_tmp = true
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert_eq!(loaded.len(), 1);
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
@@ -1375,14 +1419,47 @@ exclude_slash_tmp = true
|
||||
|
||||
let empty = BTreeMap::new();
|
||||
write_global_mcp_servers(codex_home.path(), &empty)?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(loaded.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn managed_config_wins_over_cli_overrides() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
codex_home.path().join(CONFIG_TOML_FILE),
|
||||
"model = \"base\"\n",
|
||||
)?;
|
||||
std::fs::write(&managed_path, "model = \"managed_config\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(
|
||||
codex_home.path(),
|
||||
vec![("model".to_string(), TomlValue::String("cli".to_string()))],
|
||||
overrides,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
|
||||
assert_eq!(cfg.model.as_deref(), Some("managed_config"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
@@ -1396,15 +1473,15 @@ startup_timeout_ms = 2500
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path())?;
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = servers.get("docs").expect("docs entry");
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_millis(2500)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
@@ -1439,7 +1516,7 @@ ZIG_VAR = "3"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
@@ -1457,8 +1534,8 @@ ZIG_VAR = "3"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::from([(
|
||||
@@ -1486,7 +1563,7 @@ startup_timeout_sec = 2.0
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
@@ -1518,7 +1595,7 @@ url = "https://example.com/mcp"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
@@ -1835,6 +1912,7 @@ model_verbosity = "high"
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
codex_linux_sandbox_exe: None,
|
||||
experimental_windows_sandbox: false,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
@@ -1850,6 +1928,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1896,6 +1975,7 @@ model_verbosity = "high"
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
codex_linux_sandbox_exe: None,
|
||||
experimental_windows_sandbox: false,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: None,
|
||||
@@ -1911,6 +1991,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1972,6 +2053,7 @@ model_verbosity = "high"
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
codex_linux_sandbox_exe: None,
|
||||
experimental_windows_sandbox: false,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: None,
|
||||
@@ -1987,6 +2069,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2034,6 +2117,7 @@ model_verbosity = "high"
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
codex_linux_sandbox_exe: None,
|
||||
experimental_windows_sandbox: false,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
@@ -2049,6 +2133,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2159,6 +2244,7 @@ trust_level = "trusted"
|
||||
#[cfg(test)]
|
||||
mod notifications_tests {
|
||||
use crate::config_types::Notifications;
|
||||
use assert_matches::assert_matches;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
@@ -2178,10 +2264,7 @@ mod notifications_tests {
|
||||
notifications = true
|
||||
"#;
|
||||
let parsed: RootTomlTest = toml::from_str(toml).expect("deserialize notifications=true");
|
||||
assert!(matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Enabled(true)
|
||||
));
|
||||
assert_matches!(parsed.tui.notifications, Notifications::Enabled(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2192,9 +2275,9 @@ mod notifications_tests {
|
||||
"#;
|
||||
let parsed: RootTomlTest =
|
||||
toml::from_str(toml).expect("deserialize notifications=[\"foo\"]");
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Custom(ref v) if v == &vec!["foo".to_string()]
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
118
codex-rs/core/src/config_loader/macos.rs
Normal file
118
codex-rs/core/src/config_loader/macos.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use std::io;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod native {
|
||||
use super::*;
|
||||
use base64::Engine;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use core_foundation::base::TCFType;
|
||||
use core_foundation::string::CFString;
|
||||
use core_foundation::string::CFStringRef;
|
||||
use std::ffi::c_void;
|
||||
use tokio::task;
|
||||
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
if let Some(encoded) = override_base64 {
|
||||
let trimmed = encoded.trim();
|
||||
return if trimmed.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
};
|
||||
}
|
||||
|
||||
const LOAD_ERROR: &str = "Failed to load managed preferences configuration";
|
||||
|
||||
match task::spawn_blocking(load_managed_admin_config).await {
|
||||
Ok(result) => result,
|
||||
Err(join_err) => {
|
||||
if join_err.is_cancelled() {
|
||||
tracing::error!("Managed preferences load task was cancelled");
|
||||
} else {
|
||||
tracing::error!("Managed preferences load task failed: {join_err}");
|
||||
}
|
||||
Err(io::Error::other(LOAD_ERROR))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn load_managed_admin_config() -> io::Result<Option<TomlValue>> {
|
||||
#[link(name = "CoreFoundation", kind = "framework")]
|
||||
unsafe extern "C" {
|
||||
fn CFPreferencesCopyAppValue(
|
||||
key: CFStringRef,
|
||||
application_id: CFStringRef,
|
||||
) -> *mut c_void;
|
||||
}
|
||||
|
||||
const MANAGED_PREFERENCES_APPLICATION_ID: &str = "com.openai.codex";
|
||||
const MANAGED_PREFERENCES_CONFIG_KEY: &str = "config_toml_base64";
|
||||
|
||||
let application_id = CFString::new(MANAGED_PREFERENCES_APPLICATION_ID);
|
||||
let key = CFString::new(MANAGED_PREFERENCES_CONFIG_KEY);
|
||||
|
||||
let value_ref = unsafe {
|
||||
CFPreferencesCopyAppValue(
|
||||
key.as_concrete_TypeRef(),
|
||||
application_id.as_concrete_TypeRef(),
|
||||
)
|
||||
};
|
||||
|
||||
if value_ref.is_null() {
|
||||
tracing::debug!(
|
||||
"Managed preferences for {} key {} not found",
|
||||
MANAGED_PREFERENCES_APPLICATION_ID,
|
||||
MANAGED_PREFERENCES_CONFIG_KEY
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = unsafe { CFString::wrap_under_create_rule(value_ref as _) };
|
||||
let contents = value.to_string();
|
||||
let trimmed = contents.trim();
|
||||
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
}
|
||||
|
||||
pub(super) fn parse_managed_preferences_base64(encoded: &str) -> io::Result<TomlValue> {
|
||||
let decoded = BASE64_STANDARD.decode(encoded.as_bytes()).map_err(|err| {
|
||||
tracing::error!("Failed to decode managed preferences as base64: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
let decoded_str = String::from_utf8(decoded).map_err(|err| {
|
||||
tracing::error!("Managed preferences base64 contents were not valid UTF-8: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
match toml::from_str::<TomlValue>(&decoded_str) {
|
||||
Ok(TomlValue::Table(parsed)) => Ok(TomlValue::Table(parsed)),
|
||||
Ok(other) => {
|
||||
tracing::error!(
|
||||
"Managed preferences TOML must have a table at the root, found {other:?}",
|
||||
);
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"managed preferences root must be a table",
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse managed preferences TOML: {err}");
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) use native::load_managed_admin_config_layer;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
_override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
Ok(None)
|
||||
}
|
||||
311
codex-rs/core/src/config_loader/mod.rs
Normal file
311
codex-rs/core/src/config_loader/mod.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
mod macos;
|
||||
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use macos::load_managed_admin_config_layer;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(unix)]
|
||||
const CODEX_MANAGED_CONFIG_SYSTEM_PATH: &str = "/etc/codex/managed_config.toml";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LoadedConfigLayers {
|
||||
pub base: TomlValue,
|
||||
pub managed_config: Option<TomlValue>,
|
||||
pub managed_preferences: Option<TomlValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct LoaderOverrides {
|
||||
pub managed_config_path: Option<PathBuf>,
|
||||
#[cfg(target_os = "macos")]
|
||||
pub managed_preferences_base64: Option<String>,
|
||||
}
|
||||
|
||||
// Configuration layering pipeline (top overrides bottom):
|
||||
//
|
||||
// +-------------------------+
|
||||
// | Managed preferences (*) |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | managed_config.toml |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | config.toml (base) |
|
||||
// +-------------------------+
|
||||
//
|
||||
// (*) Only available on macOS via managed device profiles.
|
||||
|
||||
pub async fn load_config_as_toml(codex_home: &Path) -> io::Result<TomlValue> {
|
||||
load_config_as_toml_with_overrides(codex_home, LoaderOverrides::default()).await
|
||||
}
|
||||
|
||||
fn default_empty_table() -> TomlValue {
|
||||
TomlValue::Table(Default::default())
|
||||
}
|
||||
|
||||
pub(crate) async fn load_config_layers_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
load_config_layers_internal(codex_home, overrides).await
|
||||
}
|
||||
|
||||
async fn load_config_as_toml_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<TomlValue> {
|
||||
let layers = load_config_layers_internal(codex_home, overrides).await?;
|
||||
Ok(apply_managed_layers(layers))
|
||||
}
|
||||
|
||||
async fn load_config_layers_internal(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
#[cfg(target_os = "macos")]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
managed_preferences_base64,
|
||||
} = overrides;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
} = overrides;
|
||||
|
||||
let managed_config_path =
|
||||
managed_config_path.unwrap_or_else(|| managed_config_default_path(codex_home));
|
||||
|
||||
let user_config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let user_config = read_config_from_path(&user_config_path, true).await?;
|
||||
let managed_config = read_config_from_path(&managed_config_path, false).await?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
let managed_preferences =
|
||||
load_managed_admin_config_layer(managed_preferences_base64.as_deref()).await?;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let managed_preferences = load_managed_admin_config_layer(None).await?;
|
||||
|
||||
Ok(LoadedConfigLayers {
|
||||
base: user_config.unwrap_or_else(default_empty_table),
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_config_from_path(
|
||||
path: &Path,
|
||||
log_missing_as_info: bool,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
match fs::read_to_string(path).await {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse {}: {err}", path.display());
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
},
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => {
|
||||
if log_missing_as_info {
|
||||
tracing::info!("{} not found, using defaults", path.display());
|
||||
} else {
|
||||
tracing::debug!("{} not found", path.display());
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to read {}: {err}", path.display());
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge config `overlay` into `base`, giving `overlay` precedence.
|
||||
pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
if let TomlValue::Table(overlay_table) = overlay
|
||||
&& let TomlValue::Table(base_table) = base
|
||||
{
|
||||
for (key, value) in overlay_table {
|
||||
if let Some(existing) = base_table.get_mut(key) {
|
||||
merge_toml_values(existing, value);
|
||||
} else {
|
||||
base_table.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*base = overlay.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn managed_config_default_path(codex_home: &Path) -> PathBuf {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let _ = codex_home;
|
||||
PathBuf::from(CODEX_MANAGED_CONFIG_SYSTEM_PATH)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
codex_home.join("managed_config.toml")
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_managed_layers(layers: LoadedConfigLayers) -> TomlValue {
|
||||
let LoadedConfigLayers {
|
||||
mut base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = layers;
|
||||
|
||||
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
|
||||
merge_toml_values(&mut base, &overlay);
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn merges_managed_config_layer_on_top() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"foo = 1
|
||||
|
||||
[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"foo = 2
|
||||
|
||||
[nested]
|
||||
value = "managed_config"
|
||||
extra = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
|
||||
assert_eq!(table.get("foo"), Some(&TomlValue::Integer(2)));
|
||||
let nested = table
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed_config".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("extra"), Some(&TomlValue::Boolean(true)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_empty_when_all_layers_missing() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let layers = load_config_layers_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load layers");
|
||||
let base_table = layers.base.as_table().expect("base table expected");
|
||||
assert!(
|
||||
base_table.is_empty(),
|
||||
"expected empty base layer when configs missing"
|
||||
);
|
||||
assert!(
|
||||
layers.managed_config.is_none(),
|
||||
"managed config layer should be absent when file missing"
|
||||
);
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let loaded = load_config_as_toml(tmp.path()).await.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
assert!(
|
||||
table.is_empty(),
|
||||
"expected empty table when configs missing"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn managed_preferences_take_highest_precedence() {
|
||||
use base64::Engine;
|
||||
|
||||
let managed_payload = r#"
|
||||
[nested]
|
||||
value = "managed"
|
||||
flag = false
|
||||
"#;
|
||||
let encoded = base64::prelude::BASE64_STANDARD.encode(managed_payload.as_bytes());
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"[nested]
|
||||
value = "managed_config"
|
||||
flag = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
managed_preferences_base64: Some(encoded),
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let nested = loaded
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested table");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("flag"), Some(&TomlValue::Boolean(false)));
|
||||
}
|
||||
}
|
||||
@@ -210,6 +210,7 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -236,7 +237,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn drops_from_last_user_only() {
|
||||
let items = vec![
|
||||
let items = [
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
@@ -283,7 +284,7 @@ mod tests {
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
assert_matches!(truncated2, InitialHistory::New);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -55,6 +55,11 @@ pub enum CodexErr {
|
||||
#[error("stream disconnected before completion: {0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error(
|
||||
"Codex ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying."
|
||||
)]
|
||||
ContextWindowExceeded,
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
@@ -108,6 +113,9 @@ pub enum CodexErr {
|
||||
#[error("unsupported operation: {0}")]
|
||||
UnsupportedOperation(String),
|
||||
|
||||
#[error("Fatal error: {0}")]
|
||||
Fatal(String),
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Automatic conversions for common external error types
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
@@ -127,6 +127,7 @@ mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -158,7 +159,7 @@ mod tests {
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_matches!(user.kind, Some(InputMessageKind::Plain));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
|
||||
@@ -27,6 +27,8 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
#[cfg(windows)]
|
||||
use crate::windows_appcontainer::spawn_command_under_windows_appcontainer;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
@@ -70,6 +72,9 @@ pub enum SandboxType {
|
||||
|
||||
/// Only available on Linux.
|
||||
LinuxSeccomp,
|
||||
|
||||
/// Only available on Windows.
|
||||
WindowsAppContainer,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -94,6 +99,31 @@ pub async fn process_exec_tool_call(
|
||||
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||
{
|
||||
SandboxType::None => exec(params, sandbox_policy, stdout_stream.clone()).await,
|
||||
SandboxType::WindowsAppContainer => {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd: command_cwd,
|
||||
env,
|
||||
..
|
||||
} = params;
|
||||
let child = spawn_command_under_windows_appcontainer(
|
||||
command,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
sandbox_cwd,
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, timeout_duration, stdout_stream.clone()).await
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
panic!("windows sandboxing is not available on this platform");
|
||||
}
|
||||
}
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let ExecParams {
|
||||
command,
|
||||
@@ -198,7 +228,10 @@ pub async fn process_exec_tool_call(
|
||||
/// For now, we conservatively check for 'command not found' (exit code 127),
|
||||
/// and can add additional cases as necessary.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exit_code: i32) -> bool {
|
||||
if sandbox_type == SandboxType::None {
|
||||
if matches!(
|
||||
sandbox_type,
|
||||
SandboxType::None | SandboxType::WindowsAppContainer
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
|
||||
pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
|
||||
pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
|
||||
|
||||
101
codex-rs/core/src/executor/backends.rs
Normal file
101
codex-rs/core/src/executor/backends.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::CODEX_APPLY_PATCH_ARG1;
|
||||
use crate::apply_patch::ApplyPatchExec;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
|
||||
pub(crate) enum ExecutionMode {
|
||||
Shell,
|
||||
ApplyPatch(ApplyPatchExec),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
/// Backend-specific hooks that prepare and post-process execution requests for a
|
||||
/// given [`ExecutionMode`].
|
||||
pub(crate) trait ExecutionBackend: Send + Sync {
|
||||
fn prepare(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
// Required for downcasting the apply_patch.
|
||||
mode: &ExecutionMode,
|
||||
) -> Result<ExecParams, FunctionCallError>;
|
||||
|
||||
fn stream_stdout(&self, _mode: &ExecutionMode) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
static SHELL_BACKEND: ShellBackend = ShellBackend;
|
||||
static APPLY_PATCH_BACKEND: ApplyPatchBackend = ApplyPatchBackend;
|
||||
|
||||
pub(crate) fn backend_for_mode(mode: &ExecutionMode) -> &'static dyn ExecutionBackend {
|
||||
match mode {
|
||||
ExecutionMode::Shell => &SHELL_BACKEND,
|
||||
ExecutionMode::ApplyPatch(_) => &APPLY_PATCH_BACKEND,
|
||||
}
|
||||
}
|
||||
|
||||
struct ShellBackend;
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutionBackend for ShellBackend {
|
||||
fn prepare(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
mode: &ExecutionMode,
|
||||
) -> Result<ExecParams, FunctionCallError> {
|
||||
match mode {
|
||||
ExecutionMode::Shell => Ok(params),
|
||||
_ => Err(FunctionCallError::RespondToModel(
|
||||
"shell backend invoked with non-shell mode".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ApplyPatchBackend;
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutionBackend for ApplyPatchBackend {
|
||||
fn prepare(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
mode: &ExecutionMode,
|
||||
) -> Result<ExecParams, FunctionCallError> {
|
||||
match mode {
|
||||
ExecutionMode::ApplyPatch(exec) => {
|
||||
let path_to_codex = env::current_exe()
|
||||
.ok()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"failed to determine path to codex executable".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let patch = exec.action.patch.clone();
|
||||
Ok(ExecParams {
|
||||
command: vec![path_to_codex, CODEX_APPLY_PATCH_ARG1.to_string(), patch],
|
||||
cwd: exec.action.cwd.clone(),
|
||||
timeout_ms: params.timeout_ms,
|
||||
// Run apply_patch with a minimal environment for determinism and to
|
||||
// avoid leaking host environment variables into the patch process.
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification,
|
||||
})
|
||||
}
|
||||
ExecutionMode::Shell => Err(FunctionCallError::RespondToModel(
|
||||
"apply_patch backend invoked without patch context".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_stdout(&self, _mode: &ExecutionMode) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
51
codex-rs/core/src/executor/cache.rs
Normal file
51
codex-rs/core/src/executor/cache.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// Thread-safe store of user approvals so repeated commands can reuse
|
||||
/// previously granted trust.
|
||||
pub(crate) struct ApprovalCache {
|
||||
inner: Arc<Mutex<HashSet<Vec<String>>>>,
|
||||
}
|
||||
|
||||
impl ApprovalCache {
|
||||
pub(crate) fn insert(&self, command: Vec<String>) {
|
||||
if command.is_empty() {
|
||||
return;
|
||||
}
|
||||
if let Ok(mut guard) = self.inner.lock() {
|
||||
guard.insert(command);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn snapshot(&self) -> HashSet<Vec<String>> {
|
||||
self.inner.lock().map(|g| g.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn insert_ignores_empty_and_dedupes() {
|
||||
let cache = ApprovalCache::default();
|
||||
|
||||
// Empty should be ignored
|
||||
cache.insert(vec![]);
|
||||
assert!(cache.snapshot().is_empty());
|
||||
|
||||
// Insert a command and verify snapshot contains it
|
||||
let cmd = vec!["foo".to_string(), "bar".to_string()];
|
||||
cache.insert(cmd.clone());
|
||||
let snap1 = cache.snapshot();
|
||||
assert!(snap1.contains(&cmd));
|
||||
|
||||
// Reinserting should not create duplicates
|
||||
cache.insert(cmd);
|
||||
let snap2 = cache.snapshot();
|
||||
assert_eq!(snap1, snap2);
|
||||
}
|
||||
}
|
||||
64
codex-rs/core/src/executor/mod.rs
Normal file
64
codex-rs/core/src/executor/mod.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
mod backends;
|
||||
mod cache;
|
||||
mod runner;
|
||||
mod sandbox;
|
||||
|
||||
pub(crate) use backends::ExecutionMode;
|
||||
pub(crate) use runner::ExecutionRequest;
|
||||
pub(crate) use runner::Executor;
|
||||
pub(crate) use runner::ExecutorConfig;
|
||||
pub(crate) use runner::normalize_exec_result;
|
||||
|
||||
pub(crate) mod linkers {
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::executor::backends::ExecutionMode;
|
||||
use crate::executor::runner::ExecutionRequest;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
|
||||
pub struct PreparedExec {
|
||||
pub(crate) context: ExecCommandContext,
|
||||
pub(crate) request: ExecutionRequest,
|
||||
}
|
||||
|
||||
impl PreparedExec {
|
||||
pub fn new(
|
||||
context: ExecCommandContext,
|
||||
params: ExecParams,
|
||||
approval_command: Vec<String>,
|
||||
mode: ExecutionMode,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
use_shell_profile: bool,
|
||||
) -> Self {
|
||||
let request = ExecutionRequest {
|
||||
params,
|
||||
approval_command,
|
||||
mode,
|
||||
stdout_stream,
|
||||
use_shell_profile,
|
||||
};
|
||||
|
||||
Self { context, request }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod errors {
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecError {
|
||||
#[error(transparent)]
|
||||
Function(#[from] FunctionCallError),
|
||||
#[error(transparent)]
|
||||
Codex(#[from] CodexErr),
|
||||
}
|
||||
|
||||
impl ExecError {
|
||||
pub(crate) fn rejection(msg: impl Into<String>) -> Self {
|
||||
FunctionCallError::RespondToModel(msg.into()).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
409
codex-rs/core/src/executor/runner.rs
Normal file
409
codex-rs/core/src/executor/runner.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::backends::ExecutionMode;
|
||||
use super::backends::backend_for_mode;
|
||||
use super::cache::ApprovalCache;
|
||||
use crate::codex::Session;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::error::get_error_message_ui;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::executor::errors::ExecError;
|
||||
use crate::executor::sandbox::select_sandbox;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::shell;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
use codex_otel::otel_event_manager::ToolDecisionSource;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ExecutorConfig {
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) sandbox_cwd: PathBuf,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ExecutorConfig {
|
||||
pub(crate) fn new(
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sandbox_cwd: PathBuf,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sandbox_policy,
|
||||
sandbox_cwd,
|
||||
codex_linux_sandbox_exe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Coordinates sandbox selection, backend-specific preparation, and command
|
||||
/// execution for tool calls requested by the model.
|
||||
pub(crate) struct Executor {
|
||||
approval_cache: ApprovalCache,
|
||||
config: Arc<RwLock<ExecutorConfig>>,
|
||||
}
|
||||
|
||||
impl Executor {
|
||||
pub(crate) fn new(config: ExecutorConfig) -> Self {
|
||||
Self {
|
||||
approval_cache: ApprovalCache::default(),
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the sandbox policy and working directory used for future
|
||||
/// executions without recreating the executor.
|
||||
pub(crate) fn update_environment(&self, sandbox_policy: SandboxPolicy, sandbox_cwd: PathBuf) {
|
||||
if let Ok(mut cfg) = self.config.write() {
|
||||
cfg.sandbox_policy = sandbox_policy;
|
||||
cfg.sandbox_cwd = sandbox_cwd;
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a prepared execution request end-to-end: prepares parameters, decides on
|
||||
/// sandbox placement (prompting the user when necessary), launches the command,
|
||||
/// and lets the backend post-process the final output.
|
||||
pub(crate) async fn run(
|
||||
&self,
|
||||
mut request: ExecutionRequest,
|
||||
session: &Session,
|
||||
approval_policy: AskForApproval,
|
||||
context: &ExecCommandContext,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
if matches!(request.mode, ExecutionMode::Shell) {
|
||||
request.params =
|
||||
maybe_translate_shell_command(request.params, session, request.use_shell_profile);
|
||||
}
|
||||
|
||||
// Step 1: Normalise parameters via the selected backend.
|
||||
let backend = backend_for_mode(&request.mode);
|
||||
let stdout_stream = if backend.stream_stdout(&request.mode) {
|
||||
request.stdout_stream.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
request.params = backend
|
||||
.prepare(request.params, &request.mode)
|
||||
.map_err(ExecError::from)?;
|
||||
|
||||
// Step 2: Snapshot sandbox configuration so it stays stable for this run.
|
||||
let config = self
|
||||
.config
|
||||
.read()
|
||||
.map_err(|_| ExecError::rejection("executor config poisoned"))?
|
||||
.clone();
|
||||
|
||||
// Step 3: Decide sandbox placement, prompting for approval when needed.
|
||||
let sandbox_decision = select_sandbox(
|
||||
&request,
|
||||
approval_policy,
|
||||
self.approval_cache.snapshot(),
|
||||
&config,
|
||||
session,
|
||||
&context.sub_id,
|
||||
&context.call_id,
|
||||
&context.otel_event_manager,
|
||||
)
|
||||
.await?;
|
||||
if sandbox_decision.record_session_approval {
|
||||
self.approval_cache.insert(request.approval_command.clone());
|
||||
}
|
||||
|
||||
// Step 4: Launch the command within the chosen sandbox.
|
||||
let first_attempt = self
|
||||
.spawn(
|
||||
request.params.clone(),
|
||||
sandbox_decision.initial_sandbox,
|
||||
&config,
|
||||
stdout_stream.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Step 5: Handle sandbox outcomes, optionally escalating to an unsandboxed retry.
|
||||
match first_attempt {
|
||||
Ok(output) => Ok(output),
|
||||
Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => {
|
||||
Err(CodexErr::Sandbox(SandboxErr::Timeout { output }).into())
|
||||
}
|
||||
Err(CodexErr::Sandbox(error)) => {
|
||||
if sandbox_decision.escalate_on_failure {
|
||||
self.retry_without_sandbox(
|
||||
&request,
|
||||
&config,
|
||||
session,
|
||||
context,
|
||||
stdout_stream,
|
||||
error,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let message = sandbox_failure_message(error);
|
||||
Err(ExecError::rejection(message))
|
||||
}
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback path invoked when a sandboxed run is denied so the user can
|
||||
/// approve rerunning without isolation.
|
||||
async fn retry_without_sandbox(
|
||||
&self,
|
||||
request: &ExecutionRequest,
|
||||
config: &ExecutorConfig,
|
||||
session: &Session,
|
||||
context: &ExecCommandContext,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
sandbox_error: SandboxErr,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
session
|
||||
.notify_background_event(
|
||||
&context.sub_id,
|
||||
format!("Execution failed: {sandbox_error}"),
|
||||
)
|
||||
.await;
|
||||
let decision = session
|
||||
.request_command_approval(
|
||||
context.sub_id.to_string(),
|
||||
context.call_id.to_string(),
|
||||
request.approval_command.clone(),
|
||||
request.params.cwd.clone(),
|
||||
Some("command failed; retry without sandbox?".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
context.otel_event_manager.tool_decision(
|
||||
&context.tool_name,
|
||||
&context.call_id,
|
||||
decision,
|
||||
ToolDecisionSource::User,
|
||||
);
|
||||
match decision {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
|
||||
if matches!(decision, ReviewDecision::ApprovedForSession) {
|
||||
self.approval_cache.insert(request.approval_command.clone());
|
||||
}
|
||||
session
|
||||
.notify_background_event(&context.sub_id, "retrying command without sandbox")
|
||||
.await;
|
||||
|
||||
let retry_output = self
|
||||
.spawn(
|
||||
request.params.clone(),
|
||||
SandboxType::None,
|
||||
config,
|
||||
stdout_stream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(retry_output)
|
||||
}
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
sandbox: SandboxType,
|
||||
config: &ExecutorConfig,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput, CodexErr> {
|
||||
process_exec_tool_call(
|
||||
params,
|
||||
sandbox,
|
||||
&config.sandbox_policy,
|
||||
&config.sandbox_cwd,
|
||||
&config.codex_linux_sandbox_exe,
|
||||
stdout_stream,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_translate_shell_command(
|
||||
params: ExecParams,
|
||||
session: &Session,
|
||||
use_shell_profile: bool,
|
||||
) -> ExecParams {
|
||||
let should_translate =
|
||||
matches!(session.user_shell(), shell::Shell::PowerShell(_)) || use_shell_profile;
|
||||
|
||||
if should_translate
|
||||
&& let Some(command) = session
|
||||
.user_shell()
|
||||
.format_default_shell_invocation(params.command.clone())
|
||||
{
|
||||
return ExecParams { command, ..params };
|
||||
}
|
||||
|
||||
params
|
||||
}
|
||||
|
||||
fn sandbox_failure_message(error: SandboxErr) -> String {
|
||||
let codex_error = CodexErr::Sandbox(error);
|
||||
let friendly = get_error_message_ui(&codex_error);
|
||||
format!("failed in sandbox: {friendly}")
|
||||
}
|
||||
|
||||
pub(crate) struct ExecutionRequest {
|
||||
pub params: ExecParams,
|
||||
pub approval_command: Vec<String>,
|
||||
pub mode: ExecutionMode,
|
||||
pub stdout_stream: Option<StdoutStream>,
|
||||
pub use_shell_profile: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct NormalizedExecOutput<'a> {
|
||||
borrowed: Option<&'a ExecToolCallOutput>,
|
||||
synthetic: Option<ExecToolCallOutput>,
|
||||
}
|
||||
|
||||
impl<'a> NormalizedExecOutput<'a> {
|
||||
pub(crate) fn event_output(&'a self) -> &'a ExecToolCallOutput {
|
||||
match (self.borrowed, self.synthetic.as_ref()) {
|
||||
(Some(output), _) => output,
|
||||
(None, Some(output)) => output,
|
||||
(None, None) => unreachable!("normalized exec output missing data"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a raw execution result into a uniform view that always exposes an
|
||||
/// [`ExecToolCallOutput`], synthesizing error output when the command fails
|
||||
/// before producing a response.
|
||||
pub(crate) fn normalize_exec_result(
|
||||
result: &Result<ExecToolCallOutput, ExecError>,
|
||||
) -> NormalizedExecOutput<'_> {
|
||||
match result {
|
||||
Ok(output) => NormalizedExecOutput {
|
||||
borrowed: Some(output),
|
||||
synthetic: None,
|
||||
},
|
||||
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => {
|
||||
NormalizedExecOutput {
|
||||
borrowed: Some(output.as_ref()),
|
||||
synthetic: None,
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let message = match err {
|
||||
ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(),
|
||||
ExecError::Codex(e) => get_error_message_ui(e),
|
||||
err => err.to_string(),
|
||||
};
|
||||
let synthetic = ExecToolCallOutput {
|
||||
exit_code: -1,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(message.clone()),
|
||||
aggregated_output: StreamOutput::new(message),
|
||||
duration: Duration::default(),
|
||||
timed_out: false,
|
||||
};
|
||||
NormalizedExecOutput {
|
||||
borrowed: None,
|
||||
synthetic: Some(synthetic),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::EnvVarError;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::exec::StreamOutput;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn make_output(text: &str) -> ExecToolCallOutput {
|
||||
ExecToolCallOutput {
|
||||
exit_code: 1,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(text.to_string()),
|
||||
duration: Duration::from_millis(123),
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_success_borrows() {
|
||||
let out = make_output("ok");
|
||||
let result: Result<ExecToolCallOutput, ExecError> = Ok(out);
|
||||
let normalized = normalize_exec_result(&result);
|
||||
assert_eq!(normalized.event_output().aggregated_output.text, "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_timeout_borrows_embedded_output() {
|
||||
let out = make_output("timed out payload");
|
||||
let err = CodexErr::Sandbox(SandboxErr::Timeout {
|
||||
output: Box::new(out),
|
||||
});
|
||||
let result: Result<ExecToolCallOutput, ExecError> = Err(ExecError::Codex(err));
|
||||
let normalized = normalize_exec_result(&result);
|
||||
assert_eq!(
|
||||
normalized.event_output().aggregated_output.text,
|
||||
"timed out payload"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_failure_message_uses_denied_stderr() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 101,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new("sandbox stderr".to_string()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
};
|
||||
let message = sandbox_failure_message(err);
|
||||
assert_eq!(message, "failed in sandbox: sandbox stderr");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_function_error_synthesizes_payload() {
|
||||
let err = FunctionCallError::RespondToModel("boom".to_string());
|
||||
let result: Result<ExecToolCallOutput, ExecError> = Err(ExecError::Function(err));
|
||||
let normalized = normalize_exec_result(&result);
|
||||
assert_eq!(normalized.event_output().aggregated_output.text, "boom");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_codex_error_synthesizes_user_message() {
|
||||
// Use a simple EnvVar error which formats to a clear message
|
||||
let e = CodexErr::EnvVar(EnvVarError {
|
||||
var: "FOO".to_string(),
|
||||
instructions: Some("set it".to_string()),
|
||||
});
|
||||
let result: Result<ExecToolCallOutput, ExecError> = Err(ExecError::Codex(e));
|
||||
let normalized = normalize_exec_result(&result);
|
||||
assert!(
|
||||
normalized
|
||||
.event_output()
|
||||
.aggregated_output
|
||||
.text
|
||||
.contains("Missing environment variable: `FOO`"),
|
||||
"expected synthesized user-friendly message"
|
||||
);
|
||||
}
|
||||
}
|
||||
405
codex-rs/core/src/executor/sandbox.rs
Normal file
405
codex-rs/core/src/executor/sandbox.rs
Normal file
@@ -0,0 +1,405 @@
|
||||
use crate::apply_patch::ApplyPatchExec;
|
||||
use crate::codex::Session;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::executor::ExecutionMode;
|
||||
use crate::executor::ExecutionRequest;
|
||||
use crate::executor::ExecutorConfig;
|
||||
use crate::executor::errors::ExecError;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_otel::otel_event_manager::ToolDecisionSource;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Sandbox placement options selected for an execution run, including whether
|
||||
/// to escalate after failures and whether approvals should persist.
|
||||
pub(crate) struct SandboxDecision {
|
||||
pub(crate) initial_sandbox: SandboxType,
|
||||
pub(crate) escalate_on_failure: bool,
|
||||
pub(crate) record_session_approval: bool,
|
||||
}
|
||||
|
||||
impl SandboxDecision {
|
||||
fn auto(sandbox: SandboxType, escalate_on_failure: bool) -> Self {
|
||||
Self {
|
||||
initial_sandbox: sandbox,
|
||||
escalate_on_failure,
|
||||
record_session_approval: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn user_override(record_session_approval: bool) -> Self {
|
||||
Self {
|
||||
initial_sandbox: SandboxType::None,
|
||||
escalate_on_failure: false,
|
||||
record_session_approval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_escalate_on_failure(approval: AskForApproval, sandbox: SandboxType) -> bool {
|
||||
matches!(
|
||||
(approval, sandbox),
|
||||
(
|
||||
AskForApproval::UnlessTrusted | AskForApproval::OnFailure,
|
||||
SandboxType::MacosSeatbelt | SandboxType::LinuxSeccomp
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/// Determines how a command should be sandboxed, prompting the user when
|
||||
/// policy requires explicit approval.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn select_sandbox(
|
||||
request: &ExecutionRequest,
|
||||
approval_policy: AskForApproval,
|
||||
approval_cache: HashSet<Vec<String>>,
|
||||
config: &ExecutorConfig,
|
||||
session: &Session,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
) -> Result<SandboxDecision, ExecError> {
|
||||
match &request.mode {
|
||||
ExecutionMode::Shell => {
|
||||
select_shell_sandbox(
|
||||
request,
|
||||
approval_policy,
|
||||
approval_cache,
|
||||
config,
|
||||
session,
|
||||
sub_id,
|
||||
call_id,
|
||||
otel_event_manager,
|
||||
)
|
||||
.await
|
||||
}
|
||||
ExecutionMode::ApplyPatch(exec) => {
|
||||
select_apply_patch_sandbox(exec, approval_policy, config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn select_shell_sandbox(
|
||||
request: &ExecutionRequest,
|
||||
approval_policy: AskForApproval,
|
||||
approved_snapshot: HashSet<Vec<String>>,
|
||||
config: &ExecutorConfig,
|
||||
session: &Session,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
) -> Result<SandboxDecision, ExecError> {
|
||||
let command_for_safety = if request.approval_command.is_empty() {
|
||||
request.params.command.clone()
|
||||
} else {
|
||||
request.approval_command.clone()
|
||||
};
|
||||
|
||||
let safety = assess_command_safety(
|
||||
&command_for_safety,
|
||||
approval_policy,
|
||||
&config.sandbox_policy,
|
||||
&approved_snapshot,
|
||||
request.params.with_escalated_permissions.unwrap_or(false),
|
||||
);
|
||||
|
||||
match safety {
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type,
|
||||
user_explicitly_approved,
|
||||
} => {
|
||||
let mut decision = SandboxDecision::auto(
|
||||
sandbox_type,
|
||||
should_escalate_on_failure(approval_policy, sandbox_type),
|
||||
);
|
||||
if user_explicitly_approved {
|
||||
decision.record_session_approval = true;
|
||||
}
|
||||
let (decision_for_event, source) = if user_explicitly_approved {
|
||||
(ReviewDecision::ApprovedForSession, ToolDecisionSource::User)
|
||||
} else {
|
||||
(ReviewDecision::Approved, ToolDecisionSource::Config)
|
||||
};
|
||||
otel_event_manager.tool_decision("local_shell", call_id, decision_for_event, source);
|
||||
Ok(decision)
|
||||
}
|
||||
SafetyCheck::AskUser => {
|
||||
let decision = session
|
||||
.request_command_approval(
|
||||
sub_id.to_string(),
|
||||
call_id.to_string(),
|
||||
request.approval_command.clone(),
|
||||
request.params.cwd.clone(),
|
||||
request.params.justification.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
otel_event_manager.tool_decision(
|
||||
"local_shell",
|
||||
call_id,
|
||||
decision,
|
||||
ToolDecisionSource::User,
|
||||
);
|
||||
match decision {
|
||||
ReviewDecision::Approved => Ok(SandboxDecision::user_override(false)),
|
||||
ReviewDecision::ApprovedForSession => Ok(SandboxDecision::user_override(true)),
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
SafetyCheck::Reject { reason } => Err(ExecError::rejection(format!(
|
||||
"exec command rejected: {reason}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn select_apply_patch_sandbox(
|
||||
exec: &ApplyPatchExec,
|
||||
approval_policy: AskForApproval,
|
||||
config: &ExecutorConfig,
|
||||
) -> Result<SandboxDecision, ExecError> {
|
||||
if exec.user_explicitly_approved_this_action {
|
||||
return Ok(SandboxDecision::user_override(false));
|
||||
}
|
||||
|
||||
match assess_patch_safety(
|
||||
&exec.action,
|
||||
approval_policy,
|
||||
&config.sandbox_policy,
|
||||
&config.sandbox_cwd,
|
||||
) {
|
||||
SafetyCheck::AutoApprove { sandbox_type, .. } => Ok(SandboxDecision::auto(
|
||||
sandbox_type,
|
||||
should_escalate_on_failure(approval_policy, sandbox_type),
|
||||
)),
|
||||
SafetyCheck::AskUser => Err(ExecError::rejection(
|
||||
"patch requires approval but none was recorded",
|
||||
)),
|
||||
SafetyCheck::Reject { reason } => {
|
||||
Err(ExecError::rejection(format!("patch rejected: {reason}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn select_apply_patch_user_override_when_explicit() {
|
||||
let (session, ctx) = make_session_and_context();
|
||||
let tmp = tempfile::tempdir().expect("tmp");
|
||||
let p = tmp.path().join("a.txt");
|
||||
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
|
||||
let exec = ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action: true,
|
||||
};
|
||||
let cfg = ExecutorConfig::new(SandboxPolicy::ReadOnly, std::env::temp_dir(), None);
|
||||
let request = ExecutionRequest {
|
||||
params: ExecParams {
|
||||
command: vec!["apply_patch".into()],
|
||||
cwd: std::env::temp_dir(),
|
||||
timeout_ms: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
approval_command: vec!["apply_patch".into()],
|
||||
mode: ExecutionMode::ApplyPatch(exec),
|
||||
stdout_stream: None,
|
||||
use_shell_profile: false,
|
||||
};
|
||||
let otel_event_manager = ctx.client.get_otel_event_manager();
|
||||
let decision = select_sandbox(
|
||||
&request,
|
||||
AskForApproval::OnRequest,
|
||||
Default::default(),
|
||||
&cfg,
|
||||
&session,
|
||||
"sub",
|
||||
"call",
|
||||
&otel_event_manager,
|
||||
)
|
||||
.await
|
||||
.expect("ok");
|
||||
// Explicit user override runs without sandbox
|
||||
assert_eq!(decision.initial_sandbox, SandboxType::None);
|
||||
assert_eq!(decision.escalate_on_failure, false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn select_apply_patch_autoapprove_in_danger() {
|
||||
let (session, ctx) = make_session_and_context();
|
||||
let tmp = tempfile::tempdir().expect("tmp");
|
||||
let p = tmp.path().join("a.txt");
|
||||
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
|
||||
let exec = ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action: false,
|
||||
};
|
||||
let cfg = ExecutorConfig::new(SandboxPolicy::DangerFullAccess, std::env::temp_dir(), None);
|
||||
let request = ExecutionRequest {
|
||||
params: ExecParams {
|
||||
command: vec!["apply_patch".into()],
|
||||
cwd: std::env::temp_dir(),
|
||||
timeout_ms: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
approval_command: vec!["apply_patch".into()],
|
||||
mode: ExecutionMode::ApplyPatch(exec),
|
||||
stdout_stream: None,
|
||||
use_shell_profile: false,
|
||||
};
|
||||
let otel_event_manager = ctx.client.get_otel_event_manager();
|
||||
let decision = select_sandbox(
|
||||
&request,
|
||||
AskForApproval::OnRequest,
|
||||
Default::default(),
|
||||
&cfg,
|
||||
&session,
|
||||
"sub",
|
||||
"call",
|
||||
&otel_event_manager,
|
||||
)
|
||||
.await
|
||||
.expect("ok");
|
||||
// On platforms with a sandbox, DangerFullAccess still prefers it
|
||||
let expected = crate::safety::get_platform_sandbox().unwrap_or(SandboxType::None);
|
||||
assert_eq!(decision.initial_sandbox, expected);
|
||||
assert_eq!(decision.escalate_on_failure, false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn select_apply_patch_requires_approval_on_unless_trusted() {
|
||||
let (session, ctx) = make_session_and_context();
|
||||
let tempdir = tempfile::tempdir().expect("tmpdir");
|
||||
let p = tempdir.path().join("a.txt");
|
||||
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
|
||||
let exec = ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action: false,
|
||||
};
|
||||
let cfg = ExecutorConfig::new(SandboxPolicy::ReadOnly, std::env::temp_dir(), None);
|
||||
let request = ExecutionRequest {
|
||||
params: ExecParams {
|
||||
command: vec!["apply_patch".into()],
|
||||
cwd: std::env::temp_dir(),
|
||||
timeout_ms: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
approval_command: vec!["apply_patch".into()],
|
||||
mode: ExecutionMode::ApplyPatch(exec),
|
||||
stdout_stream: None,
|
||||
use_shell_profile: false,
|
||||
};
|
||||
let otel_event_manager = ctx.client.get_otel_event_manager();
|
||||
let result = select_sandbox(
|
||||
&request,
|
||||
AskForApproval::UnlessTrusted,
|
||||
Default::default(),
|
||||
&cfg,
|
||||
&session,
|
||||
"sub",
|
||||
"call",
|
||||
&otel_event_manager,
|
||||
)
|
||||
.await;
|
||||
match result {
|
||||
Ok(_) => panic!("expected error"),
|
||||
Err(ExecError::Function(FunctionCallError::RespondToModel(msg))) => {
|
||||
assert!(msg.contains("requires approval"))
|
||||
}
|
||||
Err(other) => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn select_shell_autoapprove_in_danger_mode() {
|
||||
let (session, ctx) = make_session_and_context();
|
||||
let cfg = ExecutorConfig::new(SandboxPolicy::DangerFullAccess, std::env::temp_dir(), None);
|
||||
let request = ExecutionRequest {
|
||||
params: ExecParams {
|
||||
command: vec!["some-unknown".into()],
|
||||
cwd: std::env::temp_dir(),
|
||||
timeout_ms: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
approval_command: vec!["some-unknown".into()],
|
||||
mode: ExecutionMode::Shell,
|
||||
stdout_stream: None,
|
||||
use_shell_profile: false,
|
||||
};
|
||||
let otel_event_manager = ctx.client.get_otel_event_manager();
|
||||
let decision = select_sandbox(
|
||||
&request,
|
||||
AskForApproval::OnRequest,
|
||||
Default::default(),
|
||||
&cfg,
|
||||
&session,
|
||||
"sub",
|
||||
"call",
|
||||
&otel_event_manager,
|
||||
)
|
||||
.await
|
||||
.expect("ok");
|
||||
assert_eq!(decision.initial_sandbox, SandboxType::None);
|
||||
assert_eq!(decision.escalate_on_failure, false);
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "linux"))]
|
||||
#[tokio::test]
|
||||
async fn select_shell_escalates_on_failure_with_platform_sandbox() {
|
||||
let (session, ctx) = make_session_and_context();
|
||||
let cfg = ExecutorConfig::new(SandboxPolicy::ReadOnly, std::env::temp_dir(), None);
|
||||
let request = ExecutionRequest {
|
||||
params: ExecParams {
|
||||
// Unknown command => untrusted but not flagged dangerous
|
||||
command: vec!["some-unknown".into()],
|
||||
cwd: std::env::temp_dir(),
|
||||
timeout_ms: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
approval_command: vec!["some-unknown".into()],
|
||||
mode: ExecutionMode::Shell,
|
||||
stdout_stream: None,
|
||||
use_shell_profile: false,
|
||||
};
|
||||
let otel_event_manager = ctx.client.get_otel_event_manager();
|
||||
let decision = select_sandbox(
|
||||
&request,
|
||||
AskForApproval::OnFailure,
|
||||
Default::default(),
|
||||
&cfg,
|
||||
&session,
|
||||
"sub",
|
||||
"call",
|
||||
&otel_event_manager,
|
||||
)
|
||||
.await
|
||||
.expect("ok");
|
||||
// On macOS/Linux we should have a platform sandbox and escalate on failure
|
||||
assert_ne!(decision.initial_sandbox, SandboxType::None);
|
||||
assert_eq!(decision.escalate_on_failure, true);
|
||||
}
|
||||
}
|
||||
@@ -4,4 +4,8 @@ use thiserror::Error;
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
#[error("LocalShellCall without call_id or id")]
|
||||
MissingLocalShellCallId,
|
||||
#[error("Fatal error: {0}")]
|
||||
Fatal(String),
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ pub use codex_conversation::CodexConversation;
|
||||
mod command_safety;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_loader;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -27,6 +28,7 @@ pub mod error;
|
||||
pub mod exec;
|
||||
mod exec_command;
|
||||
pub mod exec_env;
|
||||
pub mod executor;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod landlock;
|
||||
@@ -56,7 +58,6 @@ pub mod default_client;
|
||||
pub mod model_family;
|
||||
mod openai_model_info;
|
||||
mod openai_tools;
|
||||
pub mod plan_tool;
|
||||
pub mod project_doc;
|
||||
mod rollout;
|
||||
pub(crate) mod safety;
|
||||
@@ -64,7 +65,7 @@ pub mod seatbelt;
|
||||
pub mod shell;
|
||||
pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tool_apply_patch;
|
||||
mod tools;
|
||||
pub mod turn_diff_tracker;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
||||
@@ -81,6 +82,9 @@ mod tasks;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub mod windows_appcontainer;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use command_safety::is_safe_command;
|
||||
pub use safety::get_platform_sandbox;
|
||||
|
||||
@@ -108,9 +108,6 @@ impl McpClientAdapter {
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
info!(
|
||||
"new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}"
|
||||
);
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
@@ -123,12 +120,15 @@ impl McpClientAdapter {
|
||||
}
|
||||
|
||||
async fn new_streamable_http_client(
|
||||
server_name: String,
|
||||
url: String,
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?);
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
}
|
||||
@@ -202,22 +202,9 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp { .. }
|
||||
) && !use_rmcp_client
|
||||
{
|
||||
info!(
|
||||
"skipping MCP server `{}` configured with url because rmcp client is disabled",
|
||||
server_name
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let use_rmcp_client_flag = use_rmcp_client;
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
@@ -246,17 +233,18 @@ impl McpConnectionManager {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
McpClientAdapter::new_stdio_client(
|
||||
use_rmcp_client_flag,
|
||||
use_rmcp_client,
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
params.clone(),
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
bearer_token,
|
||||
params,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
|
||||
|
||||
/// The `instructions` field in the payload sent to a model should always start
|
||||
/// with this content.
|
||||
@@ -35,12 +35,19 @@ pub struct ModelFamily {
|
||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||
pub uses_local_shell_tool: bool,
|
||||
|
||||
/// Whether this model supports parallel tool calls when using the
|
||||
/// Responses API.
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
|
||||
// Instructions to use for querying the model
|
||||
pub base_instructions: String,
|
||||
|
||||
/// Names of beta tools that should be exposed to this model family.
|
||||
pub experimental_supported_tools: Vec<String>,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
@@ -55,8 +62,10 @@ macro_rules! model_family {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
@@ -68,7 +77,11 @@ macro_rules! model_family {
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
pub fn find_family_for_model(mut slug: &str) -> Option<ModelFamily> {
|
||||
// TODO(jif) clean once we have proper feature flags
|
||||
if matches!(std::env::var("CODEX_EXPERIMENTAL").as_deref(), Ok("1")) {
|
||||
slug = "codex-experimental";
|
||||
}
|
||||
if slug.starts_with("o3") {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
@@ -99,12 +112,39 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||
} else if slug.starts_with("test-gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec![
|
||||
"read_file".to_string(),
|
||||
"test_sync_tool".to_string()
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Internal models.
|
||||
} else if slug.starts_with("codex-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
experimental_supported_tools: vec!["read_file".to_string()],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Production models.
|
||||
} else if slug.starts_with("gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
@@ -125,7 +165,9 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,8 @@ use std::collections::HashSet;
|
||||
use std::path::Component;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
@@ -13,6 +15,12 @@ use crate::command_safety::is_safe_command::is_known_safe_command;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
static WINDOWS_SANDBOX_ENABLED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub(crate) fn set_windows_sandbox_enabled(enabled: bool) {
|
||||
WINDOWS_SANDBOX_ENABLED.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum SafetyCheck {
|
||||
AutoApprove {
|
||||
@@ -125,9 +133,10 @@ pub fn assess_command_safety(
|
||||
// the session _because_ they know it needs to run outside a sandbox.
|
||||
|
||||
if is_known_safe_command(command) || approved.contains(command) {
|
||||
let user_explicitly_approved = approved.contains(command);
|
||||
return SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
user_explicitly_approved: false,
|
||||
user_explicitly_approved,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -205,6 +214,12 @@ pub fn get_platform_sandbox() -> Option<SandboxType> {
|
||||
Some(SandboxType::MacosSeatbelt)
|
||||
} else if cfg!(target_os = "linux") {
|
||||
Some(SandboxType::LinuxSeccomp)
|
||||
} else if cfg!(target_os = "windows") {
|
||||
if WINDOWS_SANDBOX_ENABLED.load(Ordering::Relaxed) {
|
||||
Some(SandboxType::WindowsAppContainer)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -380,7 +395,7 @@ mod tests {
|
||||
safety_check,
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
user_explicitly_approved: false,
|
||||
user_explicitly_approved: true,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -435,4 +450,19 @@ mod tests {
|
||||
};
|
||||
assert_eq!(safety_check, expected);
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[test]
|
||||
fn windows_sandbox_toggle_controls_platform_sandbox() {
|
||||
set_windows_sandbox_enabled(false);
|
||||
assert_eq!(get_platform_sandbox(), None);
|
||||
|
||||
set_windows_sandbox_enabled(true);
|
||||
assert_eq!(
|
||||
get_platform_sandbox(),
|
||||
Some(SandboxType::WindowsAppContainer)
|
||||
);
|
||||
|
||||
set_windows_sandbox_enabled(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use crate::RolloutRecorder;
|
||||
use crate::exec_command::ExecSessionManager;
|
||||
use crate::executor::Executor;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub(crate) struct SessionServices {
|
||||
@@ -12,7 +12,7 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
|
||||
pub(crate) notifier: UserNotifier,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) user_shell: crate::shell::Shell,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
pub(crate) executor: Executor,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
//! Session-wide mutable state.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
@@ -12,7 +10,6 @@ use crate::protocol::TokenUsageInfo;
|
||||
/// Persistent, session-scoped state previously stored directly on `Session`.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) approved_commands: HashSet<Vec<String>>,
|
||||
pub(crate) history: ConversationHistory,
|
||||
pub(crate) token_info: Option<TokenUsageInfo>,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
@@ -44,15 +41,6 @@ impl SessionState {
|
||||
self.history.replace(items);
|
||||
}
|
||||
|
||||
// Approved command helpers
|
||||
pub(crate) fn add_approved_command(&mut self, cmd: Vec<String>) {
|
||||
self.approved_commands.insert(cmd);
|
||||
}
|
||||
|
||||
pub(crate) fn approved_commands_ref(&self) -> &HashSet<Vec<String>> {
|
||||
&self.approved_commands
|
||||
}
|
||||
|
||||
// Token/rate limit helpers
|
||||
pub(crate) fn update_token_info_from_usage(
|
||||
&mut self,
|
||||
@@ -76,5 +64,14 @@ impl SessionState {
|
||||
(self.token_info.clone(), self.latest_rate_limits.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: u64) {
|
||||
match &mut self.token_info {
|
||||
Some(info) => info.fill_to_context_window(context_window),
|
||||
None => {
|
||||
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pending input/approval moved to TurnState.
|
||||
}
|
||||
|
||||
249
codex-rs/core/src/tools/context.rs
Normal file
249
codex-rs/core/src/tools/context.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::tools::TELEMETRY_PREVIEW_MAX_BYTES;
|
||||
use crate::tools::TELEMETRY_PREVIEW_MAX_LINES;
|
||||
use crate::tools::TELEMETRY_PREVIEW_TRUNCATION_NOTICE;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use codex_protocol::protocol::FileChange;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use mcp_types::CallToolResult;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ToolPayload {
|
||||
Function {
|
||||
arguments: String,
|
||||
},
|
||||
Custom {
|
||||
input: String,
|
||||
},
|
||||
LocalShell {
|
||||
params: ShellToolCallParams,
|
||||
},
|
||||
UnifiedExec {
|
||||
arguments: String,
|
||||
},
|
||||
Mcp {
|
||||
server: String,
|
||||
tool: String,
|
||||
raw_arguments: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolPayload {
|
||||
pub fn log_payload(&self) -> Cow<'_, str> {
|
||||
match self {
|
||||
ToolPayload::Function { arguments } => Cow::Borrowed(arguments),
|
||||
ToolPayload::Custom { input } => Cow::Borrowed(input),
|
||||
ToolPayload::LocalShell { params } => Cow::Owned(params.command.join(" ")),
|
||||
ToolPayload::UnifiedExec { arguments } => Cow::Borrowed(arguments),
|
||||
ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ToolOutput {
|
||||
Function {
|
||||
content: String,
|
||||
success: Option<bool>,
|
||||
},
|
||||
Mcp {
|
||||
result: Result<CallToolResult, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolOutput {
|
||||
pub fn log_preview(&self) -> String {
|
||||
match self {
|
||||
ToolOutput::Function { content, .. } => telemetry_preview(content),
|
||||
ToolOutput::Mcp { result } => format!("{result:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success_for_logging(&self) -> bool {
|
||||
match self {
|
||||
ToolOutput::Function { success, .. } => success.unwrap_or(true),
|
||||
ToolOutput::Mcp { result } => result.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
match self {
|
||||
ToolOutput::Function { content, success } => {
|
||||
if matches!(payload, ToolPayload::Custom { .. }) {
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: content,
|
||||
}
|
||||
} else {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: FunctionCallOutputPayload { content, success },
|
||||
}
|
||||
}
|
||||
}
|
||||
ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
result,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_preview(content: &str) -> String {
|
||||
let truncated_slice = take_bytes_at_char_boundary(content, TELEMETRY_PREVIEW_MAX_BYTES);
|
||||
let truncated_by_bytes = truncated_slice.len() < content.len();
|
||||
|
||||
let mut preview = String::new();
|
||||
let mut lines_iter = truncated_slice.lines();
|
||||
for idx in 0..TELEMETRY_PREVIEW_MAX_LINES {
|
||||
match lines_iter.next() {
|
||||
Some(line) => {
|
||||
if idx > 0 {
|
||||
preview.push('\n');
|
||||
}
|
||||
preview.push_str(line);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
let truncated_by_lines = lines_iter.next().is_some();
|
||||
|
||||
if !truncated_by_bytes && !truncated_by_lines {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
if preview.len() < truncated_slice.len()
|
||||
&& truncated_slice
|
||||
.as_bytes()
|
||||
.get(preview.len())
|
||||
.is_some_and(|byte| *byte == b'\n')
|
||||
{
|
||||
preview.push('\n');
|
||||
}
|
||||
|
||||
if !preview.is_empty() && !preview.ends_with('\n') {
|
||||
preview.push('\n');
|
||||
}
|
||||
preview.push_str(TELEMETRY_PREVIEW_TRUNCATION_NOTICE);
|
||||
|
||||
preview
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn custom_tool_calls_should_roundtrip_as_custom_outputs() {
|
||||
let payload = ToolPayload::Custom {
|
||||
input: "patch".to_string(),
|
||||
};
|
||||
let response = ToolOutput::Function {
|
||||
content: "patched".to_string(),
|
||||
success: Some(true),
|
||||
}
|
||||
.into_response("call-42", &payload);
|
||||
|
||||
match response {
|
||||
ResponseInputItem::CustomToolCallOutput { call_id, output } => {
|
||||
assert_eq!(call_id, "call-42");
|
||||
assert_eq!(output, "patched");
|
||||
}
|
||||
other => panic!("expected CustomToolCallOutput, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_payloads_remain_function_outputs() {
|
||||
let payload = ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
};
|
||||
let response = ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
success: Some(true),
|
||||
}
|
||||
.into_response("fn-1", &payload);
|
||||
|
||||
match response {
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
assert_eq!(call_id, "fn-1");
|
||||
assert_eq!(output.content, "ok");
|
||||
assert_eq!(output.success, Some(true));
|
||||
}
|
||||
other => panic!("expected FunctionCallOutput, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_returns_original_within_limits() {
|
||||
let content = "short output";
|
||||
assert_eq!(telemetry_preview(content), content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_truncates_by_bytes() {
|
||||
let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8);
|
||||
let preview = telemetry_preview(&content);
|
||||
|
||||
assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||
assert!(
|
||||
preview.len()
|
||||
<= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_truncates_by_lines() {
|
||||
let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5))
|
||||
.map(|idx| format!("line {idx}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let preview = telemetry_preview(&content);
|
||||
let lines: Vec<&str> = preview.lines().collect();
|
||||
|
||||
assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1);
|
||||
assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ExecCommandContext {
|
||||
pub(crate) sub_id: String,
|
||||
pub(crate) call_id: String,
|
||||
pub(crate) command_for_display: Vec<String>,
|
||||
pub(crate) cwd: PathBuf,
|
||||
pub(crate) apply_patch: Option<ApplyPatchCommandContext>,
|
||||
pub(crate) tool_name: String,
|
||||
pub(crate) otel_event_manager: OtelEventManager,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ApplyPatchCommandContext {
|
||||
pub(crate) user_explicitly_approved_this_action: bool,
|
||||
pub(crate) changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
@@ -1,15 +1,97 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handle_container_exec_with_params;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::spec::ApplyPatchToolArgs;
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::openai_tools::FreeformTool;
|
||||
use crate::openai_tools::FreeformToolFormat;
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
pub struct ApplyPatchHandler;
|
||||
|
||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ApplyPatchHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::Function { .. } | ToolPayload::Custom { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
} = invocation;
|
||||
|
||||
let patch_input = match payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
args.input
|
||||
}
|
||||
ToolPayload::Custom { input } => input,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"apply_patch handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let exec_params = ExecParams {
|
||||
command: vec!["apply_patch".to_string(), patch_input.clone()],
|
||||
cwd: turn.cwd.clone(),
|
||||
timeout_ms: None,
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ApplyPatchToolType {
|
||||
@@ -19,10 +101,10 @@ pub enum ApplyPatchToolType {
|
||||
|
||||
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
|
||||
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
|
||||
pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
||||
OpenAiTool::Freeform(FreeformTool {
|
||||
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
|
||||
ToolSpec::Freeform(FreeformTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files. This is a FREEFORM tool, so do not wrap the patch in JSON.".to_string(),
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
@@ -32,7 +114,7 @@ pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
||||
}
|
||||
|
||||
/// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models
|
||||
pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool {
|
||||
pub(crate) fn create_apply_patch_json_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
@@ -41,7 +123,7 @@ pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool {
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: r#"Use the `apply_patch` tool to edit files.
|
||||
Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
||||
@@ -111,7 +193,7 @@ It is important to remember:
|
||||
- You must prefix new lines with `+` even when creating a new file
|
||||
- File references can only be relative, NEVER ABSOLUTE.
|
||||
"#
|
||||
.to_string(),
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
68
codex-rs/core/src/tools/handlers/exec_stream.rs
Normal file
68
codex-rs/core/src/tools/handlers/exec_stream.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::exec_command::EXEC_COMMAND_TOOL_NAME;
|
||||
use crate::exec_command::ExecCommandParams;
|
||||
use crate::exec_command::WRITE_STDIN_TOOL_NAME;
|
||||
use crate::exec_command::WriteStdinParams;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ExecStreamHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ExecStreamHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"exec_stream handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let content = match tool_name.as_str() {
|
||||
EXEC_COMMAND_TOOL_NAME => {
|
||||
let params: ExecCommandParams = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
session.handle_exec_command_tool(params).await?
|
||||
}
|
||||
WRITE_STDIN_TOOL_NAME => {
|
||||
let params: WriteStdinParams = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
session.handle_write_stdin_tool(params).await?
|
||||
}
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"exec_stream handler does not support tool {tool_name}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
67
codex-rs/core/src/tools/handlers/mcp.rs
Normal file
67
codex-rs/core/src/tools/handlers/mcp.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct McpHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for McpHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Mcp
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let payload = match payload {
|
||||
ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
raw_arguments,
|
||||
} => (server, tool, raw_arguments),
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"mcp handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let (server, tool, raw_arguments) = payload;
|
||||
let arguments_str = raw_arguments;
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session.as_ref(),
|
||||
&sub_id,
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
arguments_str,
|
||||
)
|
||||
.await;
|
||||
|
||||
match response {
|
||||
codex_protocol::models::ResponseInputItem::McpToolCallOutput { result, .. } => {
|
||||
Ok(ToolOutput::Mcp { result })
|
||||
}
|
||||
codex_protocol::models::ResponseInputItem::FunctionCallOutput { output, .. } => {
|
||||
let codex_protocol::models::FunctionCallOutputPayload { content, success } = output;
|
||||
Ok(ToolOutput::Function { content, success })
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(
|
||||
"mcp handler received unexpected response variant".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
21
codex-rs/core/src/tools/handlers/mod.rs
Normal file
21
codex-rs/core/src/tools/handlers/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
pub mod apply_patch;
|
||||
mod exec_stream;
|
||||
mod mcp;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
mod view_image;
|
||||
|
||||
pub use plan::PLAN_TOOL;
|
||||
|
||||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use exec_stream::ExecStreamHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
pub use unified_exec::UnifiedExecHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
@@ -1,23 +1,23 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
// Use the canonical plan tool types from the protocol crate to ensure
|
||||
// type-identity matches events transported via `codex_protocol`.
|
||||
pub use codex_protocol::plan_tool::PlanItemArg;
|
||||
pub use codex_protocol::plan_tool::StepStatus;
|
||||
pub use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||
pub struct PlanHandler;
|
||||
|
||||
// Types for the TODO tool arguments matching codex-vscode/todo-mcp/src/main.rs
|
||||
|
||||
pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
||||
pub static PLAN_TOOL: LazyLock<ToolSpec> = LazyLock::new(|| {
|
||||
let mut plan_item_props = BTreeMap::new();
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
||||
plan_item_props.insert(
|
||||
@@ -43,7 +43,7 @@ pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
||||
);
|
||||
properties.insert("plan".to_string(), plan_items_schema);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "update_plan".to_string(),
|
||||
description: r#"Updates the task plan.
|
||||
Provide an optional explanation and a list of plan items, each with a step and status.
|
||||
@@ -59,6 +59,40 @@ At most one step can be in_progress at a time.
|
||||
})
|
||||
});
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for PlanHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"update_plan handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render.
|
||||
/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other
|
||||
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
||||
252
codex-rs/core/src/tools/handlers/read_file.rs
Normal file
252
codex-rs/core/src/tools/handlers/read_file.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ReadFileHandler;
|
||||
|
||||
const MAX_LINE_LENGTH: usize = 500;
|
||||
|
||||
fn default_offset() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
2000
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ReadFileArgs {
|
||||
file_path: String,
|
||||
#[serde(default = "default_offset")]
|
||||
offset: usize,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ReadFileHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"read_file handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ReadFileArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let ReadFileArgs {
|
||||
file_path,
|
||||
offset,
|
||||
limit,
|
||||
} = args;
|
||||
|
||||
if offset == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset must be a 1-indexed line number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(&file_path);
|
||||
if !path.is_absolute() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"file_path must be an absolute path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let collected = read_file_slice(&path, offset, limit).await?;
|
||||
Ok(ToolOutput::Function {
|
||||
content: collected.join("\n"),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_file_slice(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let file = File::open(path)
|
||||
.await
|
||||
.map_err(|err| FunctionCallError::RespondToModel(format!("failed to read file: {err}")))?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
let mut collected = Vec::new();
|
||||
let mut seen = 0usize;
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||
})?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if buffer.last() == Some(&b'\n') {
|
||||
buffer.pop();
|
||||
if buffer.last() == Some(&b'\r') {
|
||||
buffer.pop();
|
||||
}
|
||||
}
|
||||
|
||||
seen += 1;
|
||||
|
||||
if seen < offset {
|
||||
continue;
|
||||
}
|
||||
|
||||
if collected.len() == limit {
|
||||
break;
|
||||
}
|
||||
|
||||
let formatted = format_line(&buffer);
|
||||
collected.push(format!("L{seen}: {formatted}"));
|
||||
|
||||
if collected.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if seen < offset {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset exceeds file length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
fn format_line(bytes: &[u8]) -> String {
|
||||
let decoded = String::from_utf8_lossy(bytes);
|
||||
if decoded.len() > MAX_LINE_LENGTH {
|
||||
take_bytes_at_char_boundary(&decoded, MAX_LINE_LENGTH).to_string()
|
||||
} else {
|
||||
decoded.into_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn reads_requested_range() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
writeln!(temp, "alpha").unwrap();
|
||||
writeln!(temp, "beta").unwrap();
|
||||
writeln!(temp, "gamma").unwrap();
|
||||
|
||||
let lines = read_file_slice(temp.path(), 2, 2)
|
||||
.await
|
||||
.expect("read slice");
|
||||
assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_offset_exceeds_length() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
writeln!(temp, "only").unwrap();
|
||||
|
||||
let err = read_file_slice(temp.path(), 3, 1)
|
||||
.await
|
||||
.expect_err("offset exceeds length");
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("offset exceeds file length".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reads_non_utf8_lines() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
temp.as_file_mut().write_all(b"\xff\xfe\nplain\n").unwrap();
|
||||
|
||||
let lines = read_file_slice(temp.path(), 1, 2)
|
||||
.await
|
||||
.expect("read slice");
|
||||
let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}');
|
||||
assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn trims_crlf_endings() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
write!(temp, "one\r\ntwo\r\n").unwrap();
|
||||
|
||||
let lines = read_file_slice(temp.path(), 1, 2)
|
||||
.await
|
||||
.expect("read slice");
|
||||
assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_limit_even_with_more_lines() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
writeln!(temp, "first").unwrap();
|
||||
writeln!(temp, "second").unwrap();
|
||||
writeln!(temp, "third").unwrap();
|
||||
|
||||
let lines = read_file_slice(temp.path(), 1, 2)
|
||||
.await
|
||||
.expect("read slice");
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec!["L1: first".to_string(), "L2: second".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn truncates_lines_longer_than_max_length() {
|
||||
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||
use std::io::Write as _;
|
||||
let long_line = "x".repeat(MAX_LINE_LENGTH + 50);
|
||||
writeln!(temp, "{long_line}").unwrap();
|
||||
|
||||
let lines = read_file_slice(temp.path(), 1, 1)
|
||||
.await
|
||||
.expect("read slice");
|
||||
let expected = "x".repeat(MAX_LINE_LENGTH);
|
||||
assert_eq!(lines, vec![format!("L1: {expected}")]);
|
||||
}
|
||||
}
|
||||
101
codex-rs/core/src/tools/handlers/shell.rs
Normal file
101
codex-rs/core/src/tools/handlers/shell.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handle_container_exec_with_params;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ShellHandler;
|
||||
|
||||
impl ShellHandler {
|
||||
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
|
||||
ExecParams {
|
||||
command: params.command,
|
||||
cwd: turn_context.resolve_path(params.workdir.clone()),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: create_env(&turn_context.shell_environment_policy),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ShellHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::Function { .. } | ToolPayload::LocalShell { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
} = invocation;
|
||||
|
||||
match payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
let params: ShellToolCallParams =
|
||||
serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
ToolPayload::LocalShell { params } => {
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported payload for shell handler: {tool_name}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct TestSyncHandler;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
|
||||
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||
|
||||
struct BarrierState {
|
||||
barrier: Arc<Barrier>,
|
||||
participants: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarrierArgs {
|
||||
id: String,
|
||||
participants: usize,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestSyncArgs {
|
||||
#[serde(default)]
|
||||
sleep_before_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
sleep_after_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
barrier: Option<BarrierArgs>,
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
DEFAULT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for TestSyncHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"test_sync_tool handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(delay) = args.sleep_before_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
if let Some(barrier) = args.barrier {
|
||||
wait_on_barrier(barrier).await?;
|
||||
}
|
||||
|
||||
if let Some(delay) = args.sleep_after_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||
if args.participants == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier participants must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.timeout_ms == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier timeout must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let barrier_id = args.id.clone();
|
||||
let barrier = {
|
||||
let mut map = barrier_map().lock().await;
|
||||
match map.entry(barrier_id.clone()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let state = entry.get();
|
||||
if state.participants != args.participants {
|
||||
let existing = state.participants;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"barrier {barrier_id} already registered with {existing} participants"
|
||||
)));
|
||||
}
|
||||
state.barrier.clone()
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let barrier = Arc::new(Barrier::new(args.participants));
|
||||
entry.insert(BarrierState {
|
||||
barrier: barrier.clone(),
|
||||
participants: args.participants,
|
||||
});
|
||||
barrier
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(args.timeout_ms);
|
||||
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||
})?;
|
||||
|
||||
if wait_result.is_leader() {
|
||||
let mut map = barrier_map().lock().await;
|
||||
if let Some(state) = map.get(&barrier_id)
|
||||
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||
{
|
||||
map.remove(&barrier_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
109
codex-rs/core/src/tools/handlers/unified_exec.rs
Normal file
109
codex-rs/core/src/tools/handlers/unified_exec.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::unified_exec::UnifiedExecRequest;
|
||||
|
||||
pub struct UnifiedExecHandler;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UnifiedExecArgs {
|
||||
input: Vec<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
#[serde(default)]
|
||||
timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for UnifiedExecHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::UnifiedExec
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::UnifiedExec { .. } | ToolPayload::Function { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session, payload, ..
|
||||
} = invocation;
|
||||
|
||||
let args = match payload {
|
||||
ToolPayload::UnifiedExec { arguments } | ToolPayload::Function { arguments } => {
|
||||
serde_json::from_str::<UnifiedExecArgs>(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"unified_exec handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let UnifiedExecArgs {
|
||||
input,
|
||||
session_id,
|
||||
timeout_ms,
|
||||
} = args;
|
||||
|
||||
let parsed_session_id = if let Some(session_id) = session_id {
|
||||
match session_id.parse::<i32>() {
|
||||
Ok(parsed) => Some(parsed),
|
||||
Err(output) => {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"invalid session_id: {session_id} due to error {output:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = UnifiedExecRequest {
|
||||
session_id: parsed_session_id,
|
||||
input_chunks: &input,
|
||||
timeout_ms,
|
||||
};
|
||||
|
||||
let value = session
|
||||
.run_unified_exec_request(request)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}"))
|
||||
})?;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct SerializedUnifiedExecResult {
|
||||
session_id: Option<String>,
|
||||
output: String,
|
||||
}
|
||||
|
||||
let content = serde_json::to_string(&SerializedUnifiedExecResult {
|
||||
session_id: value.session_id.map(|id| id.to_string()),
|
||||
output: value.output,
|
||||
})
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize unified exec output: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
93
codex-rs/core/src/tools/handlers/view_image.rs
Normal file
93
codex-rs/core/src/tools/handlers/view_image.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::ViewImageToolCallEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ViewImageHandler;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ViewImageArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ViewImageHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
sub_id,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"view_image handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ViewImageArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}"))
|
||||
})?;
|
||||
|
||||
let abs_path = turn.resolve_path(Some(args.path));
|
||||
|
||||
let metadata = fs::metadata(&abs_path).await.map_err(|error| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"unable to locate image at `{}`: {error}",
|
||||
abs_path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
if !metadata.is_file() {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"image path `{}` is not a file",
|
||||
abs_path.display()
|
||||
)));
|
||||
}
|
||||
let event_path = abs_path.clone();
|
||||
|
||||
session
|
||||
.inject_input(vec![InputItem::LocalImage { path: abs_path }])
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"unable to attach image (no active task)".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||
call_id,
|
||||
path: event_path,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "attached local image path".to_string(),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
387
codex-rs/core/src/tools/mod.rs
Normal file
387
codex-rs/core/src/tools/mod.rs
Normal file
@@ -0,0 +1,387 @@
|
||||
pub mod context;
|
||||
pub(crate) mod handlers;
|
||||
pub mod parallel;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
pub mod spec;
|
||||
|
||||
use crate::apply_patch;
|
||||
use crate::apply_patch::ApplyPatchExec;
|
||||
use crate::apply_patch::InternalApplyPatchInvocation;
|
||||
use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::executor::ExecutionMode;
|
||||
use crate::executor::errors::ExecError;
|
||||
use crate::executor::linkers::PreparedExec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ApplyPatchCommandContext;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
pub use router::ToolRouter;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB
|
||||
pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines
|
||||
pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2;
|
||||
pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128
|
||||
pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2;
|
||||
|
||||
// Telemetry preview limits: keep log events smaller than model budgets.
|
||||
pub(crate) const TELEMETRY_PREVIEW_MAX_BYTES: usize = 2 * 1024; // 2 KiB
|
||||
pub(crate) const TELEMETRY_PREVIEW_MAX_LINES: usize = 64; // lines
|
||||
pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||
"[... telemetry preview truncated ...]";
|
||||
|
||||
// TODO(jif) break this down
|
||||
pub(crate) async fn handle_container_exec_with_params(
|
||||
tool_name: &str,
|
||||
params: ExecParams,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let otel_event_manager = turn_context.client.get_otel_event_manager();
|
||||
|
||||
if params.with_escalated_permissions.unwrap_or(false)
|
||||
&& !matches!(turn_context.approval_policy, AskForApproval::OnRequest)
|
||||
{
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
|
||||
policy = turn_context.approval_policy
|
||||
)));
|
||||
}
|
||||
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&sub_id,
|
||||
&call_id,
|
||||
changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => return item,
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||
Some(apply_patch_exec)
|
||||
}
|
||||
}
|
||||
}
|
||||
MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
|
||||
// It looks like an invocation of `apply_patch`, but we
|
||||
// could not resolve it into a patch that would apply
|
||||
// cleanly. Return to model for resample.
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"apply_patch verification failed: {parse_error}"
|
||||
)));
|
||||
}
|
||||
MaybeApplyPatchVerified::ShellParseError(error) => {
|
||||
trace!("Failed to parse shell command, {error:?}");
|
||||
None
|
||||
}
|
||||
MaybeApplyPatchVerified::NotApplyPatch => None,
|
||||
};
|
||||
|
||||
let command_for_display = if let Some(exec) = apply_patch_exec.as_ref() {
|
||||
vec!["apply_patch".to_string(), exec.action.patch.clone()]
|
||||
} else {
|
||||
params.command.clone()
|
||||
};
|
||||
|
||||
let exec_command_context = ExecCommandContext {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
command_for_display: command_for_display.clone(),
|
||||
cwd: params.cwd.clone(),
|
||||
apply_patch: apply_patch_exec.as_ref().map(
|
||||
|ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action,
|
||||
}| ApplyPatchCommandContext {
|
||||
user_explicitly_approved_this_action: *user_explicitly_approved_this_action,
|
||||
changes: convert_apply_patch_to_protocol(action),
|
||||
},
|
||||
),
|
||||
tool_name: tool_name.to_string(),
|
||||
otel_event_manager,
|
||||
};
|
||||
|
||||
let mode = match apply_patch_exec {
|
||||
Some(exec) => ExecutionMode::ApplyPatch(exec),
|
||||
None => ExecutionMode::Shell,
|
||||
};
|
||||
|
||||
sess.services.executor.update_environment(
|
||||
turn_context.sandbox_policy.clone(),
|
||||
turn_context.cwd.clone(),
|
||||
);
|
||||
|
||||
let prepared_exec = PreparedExec::new(
|
||||
exec_command_context,
|
||||
params,
|
||||
command_for_display,
|
||||
mode,
|
||||
Some(StdoutStream {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: sess.get_tx_event(),
|
||||
}),
|
||||
turn_context.shell_environment_policy.use_profile,
|
||||
);
|
||||
|
||||
let output_result = sess
|
||||
.run_exec_with_events(
|
||||
turn_diff_tracker.clone(),
|
||||
prepared_exec,
|
||||
turn_context.approval_policy,
|
||||
)
|
||||
.await;
|
||||
|
||||
// always make sure to truncate the output if its length isn't controlled.
|
||||
match output_result {
|
||||
Ok(output) => {
|
||||
let ExecToolCallOutput { exit_code, .. } = &output;
|
||||
let content = format_exec_output_apply_patch(&output);
|
||||
if *exit_code == 0 {
|
||||
Ok(content)
|
||||
} else {
|
||||
Err(FunctionCallError::RespondToModel(content))
|
||||
}
|
||||
}
|
||||
Err(ExecError::Function(err)) => Err(truncate_function_error(err)),
|
||||
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err(
|
||||
FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)),
|
||||
),
|
||||
Err(ExecError::Codex(err)) => {
|
||||
let message = format!("execution error: {err:?}");
|
||||
Err(FunctionCallError::RespondToModel(format_exec_output(
|
||||
&message,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_exec_output_apply_patch(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
exit_code,
|
||||
duration,
|
||||
..
|
||||
} = exec_output;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ExecMetadata {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ExecOutput<'a> {
|
||||
output: &'a str,
|
||||
metadata: ExecMetadata,
|
||||
}
|
||||
|
||||
// round to 1 decimal place
|
||||
let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0;
|
||||
|
||||
let formatted_output = format_exec_output_str(exec_output);
|
||||
|
||||
let payload = ExecOutput {
|
||||
output: &formatted_output,
|
||||
metadata: ExecMetadata {
|
||||
exit_code: *exit_code,
|
||||
duration_seconds,
|
||||
},
|
||||
};
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
serde_json::to_string(&payload).expect("serialize ExecOutput")
|
||||
}
|
||||
|
||||
pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
aggregated_output, ..
|
||||
} = exec_output;
|
||||
|
||||
let content = aggregated_output.text.as_str();
|
||||
|
||||
if exec_output.timed_out {
|
||||
let prefixed = format!(
|
||||
"command timed out after {} milliseconds\n{content}",
|
||||
exec_output.duration.as_millis()
|
||||
);
|
||||
return format_exec_output(&prefixed);
|
||||
}
|
||||
|
||||
format_exec_output(content)
|
||||
}
|
||||
|
||||
fn truncate_function_error(err: FunctionCallError) -> FunctionCallError {
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(msg) => {
|
||||
FunctionCallError::RespondToModel(format_exec_output(&msg))
|
||||
}
|
||||
FunctionCallError::Fatal(msg) => FunctionCallError::Fatal(format_exec_output(&msg)),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_exec_output(content: &str) -> String {
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
let total_lines = content.lines().count();
|
||||
if content.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return content.to_string();
|
||||
}
|
||||
let output = truncate_formatted_exec_output(content, total_lines);
|
||||
format!("Total output lines: {total_lines}\n\n{output}")
|
||||
}
|
||||
|
||||
fn truncate_formatted_exec_output(content: &str, total_lines: usize) -> String {
|
||||
let segments: Vec<&str> = content.split_inclusive('\n').collect();
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
|
||||
let omitted = segments.len().saturating_sub(head_take + tail_take);
|
||||
|
||||
let head_slice_end: usize = segments
|
||||
.iter()
|
||||
.take(head_take)
|
||||
.map(|segment| segment.len())
|
||||
.sum();
|
||||
let tail_slice_start: usize = if tail_take == 0 {
|
||||
content.len()
|
||||
} else {
|
||||
content.len()
|
||||
- segments
|
||||
.iter()
|
||||
.rev()
|
||||
.take(tail_take)
|
||||
.map(|segment| segment.len())
|
||||
.sum::<usize>()
|
||||
};
|
||||
let marker = format!("\n[... omitted {omitted} of {total_lines} lines ...]\n\n");
|
||||
|
||||
// Byte budgets for head/tail around the marker
|
||||
let mut head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES);
|
||||
let tail_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(head_budget + marker.len());
|
||||
if tail_budget == 0 && marker.len() >= MODEL_FORMAT_MAX_BYTES {
|
||||
// Degenerate case: marker alone exceeds budget; return a clipped marker
|
||||
return take_bytes_at_char_boundary(&marker, MODEL_FORMAT_MAX_BYTES).to_string();
|
||||
}
|
||||
if tail_budget == 0 {
|
||||
// Make room for the marker by shrinking head
|
||||
head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len());
|
||||
}
|
||||
|
||||
let head_slice = &content[..head_slice_end];
|
||||
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(content.len()));
|
||||
|
||||
result.push_str(head_part);
|
||||
result.push_str(&marker);
|
||||
|
||||
let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len());
|
||||
if remaining == 0 {
|
||||
return result;
|
||||
}
|
||||
|
||||
let tail_slice = &content[tail_slice_start..];
|
||||
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
|
||||
result.push_str(tail_part);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use regex_lite::Regex;
|
||||
|
||||
fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) {
|
||||
let pattern = truncated_message_pattern(line, total_lines);
|
||||
let regex = Regex::new(&pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern}: {err}");
|
||||
});
|
||||
let captures = regex
|
||||
.captures(message)
|
||||
.unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {message}"));
|
||||
let body = captures
|
||||
.name("body")
|
||||
.expect("missing body capture")
|
||||
.as_str();
|
||||
assert!(
|
||||
body.len() <= MODEL_FORMAT_MAX_BYTES,
|
||||
"body exceeds byte limit: {} bytes",
|
||||
body.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn truncated_message_pattern(line: &str, total_lines: usize) -> String {
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(total_lines);
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(total_lines.saturating_sub(head_take));
|
||||
let omitted = total_lines.saturating_sub(head_take + tail_take);
|
||||
let escaped_line = regex_lite::escape(line);
|
||||
format!(
|
||||
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_formatted_exec_output_truncates_large_error() {
|
||||
let line = "very long execution error line that should trigger truncation\n";
|
||||
let large_error = line.repeat(2_500); // way beyond both byte and line limits
|
||||
|
||||
let truncated = format_exec_output(&large_error);
|
||||
|
||||
let total_lines = large_error.lines().count();
|
||||
assert_truncated_message_matches(&truncated, line, total_lines);
|
||||
assert_ne!(truncated, large_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_respond_to_model() {
|
||||
let line = "respond-to-model error that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::RespondToModel(huge));
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_fatal() {
|
||||
let line = "fatal error output that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::Fatal(huge));
|
||||
match err {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
137
codex-rs/core/src/tools/parallel.rs
Normal file
137
codex-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
use crate::codex::ProcessedResponseItem;
|
||||
|
||||
struct PendingToolCall {
|
||||
index: usize,
|
||||
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
pending_calls: Vec<PendingToolCall>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
pending_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_tool_call(
|
||||
&mut self,
|
||||
call: ToolCall,
|
||||
output_index: usize,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
if supports_parallel {
|
||||
self.spawn_parallel(call, output_index);
|
||||
} else {
|
||||
self.resolve_pending(output).await?;
|
||||
let response = self.dispatch_serial(call).await?;
|
||||
let slot = output.get_mut(output_index).ok_or_else(|| {
|
||||
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
|
||||
})?;
|
||||
slot.response = Some(response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn abort_all(&mut self) {
|
||||
while let Some(pending) = self.pending_calls.pop() {
|
||||
pending.handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn resolve_pending(
|
||||
&mut self,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => {
|
||||
if let Some(slot) = output.get_mut(index) {
|
||||
slot.response = Some(response);
|
||||
}
|
||||
}
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
Ok(Err(other)) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(other.to_string()));
|
||||
}
|
||||
Err(join_err) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to join: {join_err}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
});
|
||||
self.pending_calls.push(PendingToolCall { index, handle });
|
||||
}
|
||||
|
||||
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
|
||||
match self
|
||||
.router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&self.session),
|
||||
Arc::clone(&self.turn_context),
|
||||
Arc::clone(&self.tracker),
|
||||
self.sub_id.clone(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => Err(CodexErr::Fatal(other.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
220
codex-rs/core/src/tools/registry.rs
Normal file
220
codex-rs/core/src/tools/registry.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum ToolKind {
|
||||
Function,
|
||||
UnifiedExec,
|
||||
Mcp,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ToolHandler: Send + Sync {
|
||||
fn kind(&self) -> ToolKind;
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
(self.kind(), payload),
|
||||
(ToolKind::Function, ToolPayload::Function { .. })
|
||||
| (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. })
|
||||
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||||
Self { handlers }
|
||||
}
|
||||
|
||||
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||||
self.handlers.get(name).map(Arc::clone)
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
// let name = name.into();
|
||||
// if self.handlers.insert(name.clone(), handler).is_some() {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.client.get_otel_event_manager();
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
let log_payload = payload_for_response.log_payload();
|
||||
|
||||
let handler = match self.handler(tool_name.as_ref()) {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
let message =
|
||||
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::RespondToModel(message));
|
||||
}
|
||||
};
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::Fatal(message));
|
||||
}
|
||||
|
||||
let output_cell = tokio::sync::Mutex::new(None);
|
||||
|
||||
let result = otel
|
||||
.log_tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
|| {
|
||||
let handler = handler.clone();
|
||||
let output_cell = &output_cell;
|
||||
let invocation = invocation;
|
||||
async move {
|
||||
match handler.handle(invocation).await {
|
||||
Ok(output) => {
|
||||
let preview = output.log_preview();
|
||||
let success = output.success_for_logging();
|
||||
let mut guard = output_cell.lock().await;
|
||||
*guard = Some(output);
|
||||
Ok((preview, success))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
let output = guard.take().ok_or_else(|| {
|
||||
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||
})?;
|
||||
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfiguredToolSpec {
|
||||
pub spec: ToolSpec,
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ConfiguredToolSpec {
|
||||
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
specs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.push_spec_with_parallel_support(spec, false);
|
||||
}
|
||||
|
||||
pub fn push_spec_with_parallel_support(
|
||||
&mut self,
|
||||
spec: ToolSpec,
|
||||
supports_parallel_tool_calls: bool,
|
||||
) {
|
||||
self.specs
|
||||
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
let name = name.into();
|
||||
if self
|
||||
.handlers
|
||||
.insert(name.clone(), handler.clone())
|
||||
.is_some()
|
||||
{
|
||||
warn!("overwriting handler for tool {name}");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
||||
// where
|
||||
// I: IntoIterator,
|
||||
// I::Item: Into<String>,
|
||||
// {
|
||||
// for name in names {
|
||||
// let name = name.into();
|
||||
// if self
|
||||
// .handlers
|
||||
// .insert(name.clone(), handler.clone())
|
||||
// .is_some()
|
||||
// {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
}
|
||||
|
||||
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
||||
match payload {
|
||||
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||||
_ => format!("unsupported call: {tool_name}"),
|
||||
}
|
||||
}
|
||||
190
codex-rs/core/src/tools/router.rs
Normal file
190
codex-rs/core/src/tools/router.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCall {
|
||||
pub tool_name: String,
|
||||
pub call_id: String,
|
||||
pub payload: ToolPayload,
|
||||
}
|
||||
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRouter {
|
||||
pub fn from_config(
|
||||
config: &ToolsConfig,
|
||||
mcp_tools: Option<HashMap<String, mcp_types::Tool>>,
|
||||
) -> Self {
|
||||
let builder = build_specs(config, mcp_tools);
|
||||
let (specs, registry) = builder.build();
|
||||
|
||||
Self { registry, specs }
|
||||
}
|
||||
|
||||
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||
self.specs
|
||||
.iter()
|
||||
.map(|config| config.spec.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
session: &Session,
|
||||
item: ResponseItem,
|
||||
) -> Result<Option<ToolCall>, FunctionCallError> {
|
||||
match item {
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
if let Some((server, tool)) = session.parse_mcp_tool_name(&name) {
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
raw_arguments: arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
let payload = if name == "unified_exec" {
|
||||
ToolPayload::UnifiedExec { arguments }
|
||||
} else {
|
||||
ToolPayload::Function { arguments }
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
name,
|
||||
input,
|
||||
call_id,
|
||||
..
|
||||
} => Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Custom { input },
|
||||
})),
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id,
|
||||
action,
|
||||
..
|
||||
} => {
|
||||
let call_id = call_id
|
||||
.or(id)
|
||||
.ok_or(FunctionCallError::MissingLocalShellCallId)?;
|
||||
|
||||
match action {
|
||||
LocalShellAction::Exec(exec) => {
|
||||
let params = ShellToolCallParams {
|
||||
command: exec.command,
|
||||
workdir: exec.working_directory,
|
||||
timeout_ms: exec.timeout_ms,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: "local_shell".to_string(),
|
||||
call_id,
|
||||
payload: ToolPayload::LocalShell { params },
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
tool_name,
|
||||
call_id,
|
||||
payload,
|
||||
} = call;
|
||||
let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. });
|
||||
let failure_call_id = call_id.clone();
|
||||
|
||||
let invocation = ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
};
|
||||
|
||||
match self.registry.dispatch(invocation).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)),
|
||||
Err(err) => Ok(Self::failure_response(
|
||||
failure_call_id,
|
||||
payload_outputs_custom,
|
||||
err,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn failure_response(
|
||||
call_id: String,
|
||||
payload_outputs_custom: bool,
|
||||
err: FunctionCallError,
|
||||
) -> ResponseInputItem {
|
||||
let message = err.to_string();
|
||||
if payload_outputs_custom {
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: message,
|
||||
}
|
||||
} else {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1403
codex-rs/core/src/tools/spec.rs
Normal file
1403
codex-rs/core/src/tools/spec.rs
Normal file
File diff suppressed because it is too large
Load Diff
467
codex-rs/core/src/windows_appcontainer.rs
Normal file
467
codex-rs/core/src/windows_appcontainer.rs
Normal file
@@ -0,0 +1,467 @@
|
||||
#![cfg(windows)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tokio::process::Child;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::spawn::StdioPolicy;
|
||||
|
||||
#[cfg(feature = "windows_appcontainer_command_ext")]
|
||||
mod imp {
|
||||
use super::*;
|
||||
|
||||
use std::ffi::OsStr;
|
||||
use std::ffi::c_void;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::os::windows::process::CommandExt;
|
||||
use std::ptr::null_mut;
|
||||
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
|
||||
use windows::Win32::Foundation::ERROR_ALREADY_EXISTS;
|
||||
use windows::Win32::Foundation::ERROR_SUCCESS;
|
||||
use windows::Win32::Foundation::GetLastError;
|
||||
use windows::Win32::Foundation::HANDLE;
|
||||
use windows::Win32::Foundation::HLOCAL;
|
||||
use windows::Win32::Foundation::LocalFree;
|
||||
use windows::Win32::Foundation::WIN32_ERROR;
|
||||
use windows::Win32::Security::ACL;
|
||||
use windows::Win32::Security::Authorization::ConvertStringSidToSidW;
|
||||
use windows::Win32::Security::Authorization::EXPLICIT_ACCESS_W;
|
||||
use windows::Win32::Security::Authorization::GetNamedSecurityInfoW;
|
||||
use windows::Win32::Security::Authorization::SE_FILE_OBJECT;
|
||||
use windows::Win32::Security::Authorization::SET_ACCESS;
|
||||
use windows::Win32::Security::Authorization::SetEntriesInAclW;
|
||||
use windows::Win32::Security::Authorization::SetNamedSecurityInfoW;
|
||||
use windows::Win32::Security::Authorization::TRUSTEE_IS_SID;
|
||||
use windows::Win32::Security::Authorization::TRUSTEE_IS_UNKNOWN;
|
||||
use windows::Win32::Security::Authorization::TRUSTEE_W;
|
||||
use windows::Win32::Security::DACL_SECURITY_INFORMATION;
|
||||
use windows::Win32::Security::FreeSid;
|
||||
use windows::Win32::Security::Isolation::CreateAppContainerProfile;
|
||||
use windows::Win32::Security::Isolation::DeriveAppContainerSidFromAppContainerName;
|
||||
use windows::Win32::Security::OBJECT_INHERIT_ACE;
|
||||
use windows::Win32::Security::PSECURITY_DESCRIPTOR;
|
||||
use windows::Win32::Security::PSID;
|
||||
use windows::Win32::Security::SECURITY_CAPABILITIES;
|
||||
use windows::Win32::Security::SID_AND_ATTRIBUTES;
|
||||
use windows::Win32::Security::SUB_CONTAINERS_AND_OBJECTS_INHERIT;
|
||||
use windows::Win32::Storage::FileSystem::FILE_GENERIC_EXECUTE;
|
||||
use windows::Win32::Storage::FileSystem::FILE_GENERIC_READ;
|
||||
use windows::Win32::Storage::FileSystem::FILE_GENERIC_WRITE;
|
||||
use windows::Win32::System::Memory::GetProcessHeap;
|
||||
use windows::Win32::System::Memory::HEAP_FLAGS;
|
||||
use windows::Win32::System::Memory::HEAP_ZERO_MEMORY;
|
||||
use windows::Win32::System::Memory::HeapAlloc;
|
||||
use windows::Win32::System::Memory::HeapFree;
|
||||
use windows::Win32::System::Threading::DeleteProcThreadAttributeList;
|
||||
use windows::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT;
|
||||
use windows::Win32::System::Threading::InitializeProcThreadAttributeList;
|
||||
use windows::Win32::System::Threading::LPPROC_THREAD_ATTRIBUTE_LIST;
|
||||
use windows::Win32::System::Threading::PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES;
|
||||
use windows::Win32::System::Threading::UpdateProcThreadAttribute;
|
||||
use windows::core::PCWSTR;
|
||||
use windows::core::PWSTR;
|
||||
|
||||
const WINDOWS_APPCONTAINER_PROFILE_NAME: &str = "codex_appcontainer";
|
||||
const WINDOWS_APPCONTAINER_PROFILE_DESC: &str = "Codex Windows AppContainer profile";
|
||||
const WINDOWS_APPCONTAINER_SANDBOX_VALUE: &str = "windows_appcontainer";
|
||||
const INTERNET_CLIENT_SID: &str = "S-1-15-3-1";
|
||||
const PRIVATE_NETWORK_CLIENT_SID: &str = "S-1-15-3-3";
|
||||
|
||||
pub async fn spawn_command_under_windows_appcontainer(
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
stdio_policy: StdioPolicy,
|
||||
mut env: HashMap<String, String>,
|
||||
) -> io::Result<Child> {
|
||||
trace!("windows appcontainer sandbox command = {:?}", command);
|
||||
|
||||
let (program, rest) = command
|
||||
.split_first()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "command args are empty"))?;
|
||||
|
||||
ensure_appcontainer_profile()?;
|
||||
let mut sid = derive_appcontainer_sid()?;
|
||||
let mut capability_sids = build_capabilities(sandbox_policy)?;
|
||||
let mut attribute_list = AttributeList::new(&mut sid, &mut capability_sids)?;
|
||||
|
||||
configure_writable_roots(sandbox_policy, sandbox_policy_cwd, sid.sid())?;
|
||||
configure_writable_roots_for_command_cwd(&command_cwd, sid.sid())?;
|
||||
|
||||
if !sandbox_policy.has_full_network_access() {
|
||||
env.insert(
|
||||
CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR.to_string(),
|
||||
"1".to_string(),
|
||||
);
|
||||
}
|
||||
env.insert(
|
||||
CODEX_SANDBOX_ENV_VAR.to_string(),
|
||||
WINDOWS_APPCONTAINER_SANDBOX_VALUE.to_string(),
|
||||
);
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(rest);
|
||||
cmd.current_dir(command_cwd);
|
||||
cmd.env_clear();
|
||||
cmd.envs(env);
|
||||
apply_stdio_policy(&mut cmd, stdio_policy);
|
||||
cmd.kill_on_drop(true);
|
||||
|
||||
unsafe {
|
||||
let std_cmd = cmd.as_std_mut();
|
||||
std_cmd.creation_flags(EXTENDED_STARTUPINFO_PRESENT.0);
|
||||
std_cmd.raw_attribute_list(attribute_list.as_mut_ptr().0);
|
||||
}
|
||||
|
||||
let child = cmd.spawn();
|
||||
drop(attribute_list);
|
||||
child
|
||||
}
|
||||
|
||||
fn apply_stdio_policy(cmd: &mut Command, policy: StdioPolicy) {
|
||||
match policy {
|
||||
StdioPolicy::RedirectForShellTool => {
|
||||
cmd.stdin(std::process::Stdio::null());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
}
|
||||
StdioPolicy::Inherit => {
|
||||
cmd.stdin(std::process::Stdio::inherit());
|
||||
cmd.stdout(std::process::Stdio::inherit());
|
||||
cmd.stderr(std::process::Stdio::inherit());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_wide<S: AsRef<OsStr>>(s: S) -> Vec<u16> {
|
||||
s.as_ref().encode_wide().chain(std::iter::once(0)).collect()
|
||||
}
|
||||
|
||||
fn ensure_appcontainer_profile() -> io::Result<()> {
|
||||
unsafe {
|
||||
let name = to_wide(WINDOWS_APPCONTAINER_PROFILE_NAME);
|
||||
let desc = to_wide(WINDOWS_APPCONTAINER_PROFILE_DESC);
|
||||
match CreateAppContainerProfile(
|
||||
PCWSTR(name.as_ptr()),
|
||||
PCWSTR(name.as_ptr()),
|
||||
PCWSTR(desc.as_ptr()),
|
||||
None,
|
||||
) {
|
||||
Ok(profile_sid) => {
|
||||
if !profile_sid.is_invalid() {
|
||||
FreeSid(profile_sid);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let already_exists = WIN32_ERROR::from(ERROR_ALREADY_EXISTS);
|
||||
if GetLastError() != already_exists {
|
||||
return Err(io::Error::from_raw_os_error(error.code().0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct SidHandle {
|
||||
ptr: PSID,
|
||||
}
|
||||
|
||||
impl SidHandle {
|
||||
fn sid(&self) -> PSID {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SidHandle {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
if !self.ptr.is_invalid() {
|
||||
FreeSid(self.ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_appcontainer_sid() -> io::Result<SidHandle> {
|
||||
unsafe {
|
||||
let name = to_wide(WINDOWS_APPCONTAINER_PROFILE_NAME);
|
||||
let sid = DeriveAppContainerSidFromAppContainerName(PCWSTR(name.as_ptr()))
|
||||
.map_err(|e| io::Error::from_raw_os_error(e.code().0))?;
|
||||
Ok(SidHandle { ptr: sid })
|
||||
}
|
||||
}
|
||||
|
||||
struct CapabilitySid {
|
||||
sid: PSID,
|
||||
}
|
||||
|
||||
impl CapabilitySid {
|
||||
fn new_from_string(value: &str) -> io::Result<Self> {
|
||||
unsafe {
|
||||
let mut sid_ptr = PSID::default();
|
||||
let wide = to_wide(value);
|
||||
ConvertStringSidToSidW(PCWSTR(wide.as_ptr()), &mut sid_ptr)
|
||||
.map_err(|e| io::Error::from_raw_os_error(e.code().0))?;
|
||||
Ok(Self { sid: sid_ptr })
|
||||
}
|
||||
}
|
||||
|
||||
fn sid_and_attributes(&self) -> SID_AND_ATTRIBUTES {
|
||||
SID_AND_ATTRIBUTES {
|
||||
Sid: self.sid,
|
||||
Attributes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CapabilitySid {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
if !self.sid.is_invalid() {
|
||||
let _ = LocalFree(HLOCAL(self.sid.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_capabilities(policy: &SandboxPolicy) -> io::Result<Vec<CapabilitySid>> {
|
||||
if policy.has_full_network_access() {
|
||||
Ok(vec![
|
||||
CapabilitySid::new_from_string(INTERNET_CLIENT_SID)?,
|
||||
CapabilitySid::new_from_string(PRIVATE_NETWORK_CLIENT_SID)?,
|
||||
])
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
struct AttributeList<'a> {
|
||||
heap: HANDLE,
|
||||
buffer: *mut c_void,
|
||||
list: LPPROC_THREAD_ATTRIBUTE_LIST,
|
||||
sec_caps: SECURITY_CAPABILITIES,
|
||||
sid_and_attributes: Vec<SID_AND_ATTRIBUTES>,
|
||||
#[allow(dead_code)]
|
||||
sid: &'a mut SidHandle,
|
||||
#[allow(dead_code)]
|
||||
capabilities: &'a mut Vec<CapabilitySid>,
|
||||
}
|
||||
|
||||
impl<'a> AttributeList<'a> {
|
||||
fn new(sid: &'a mut SidHandle, caps: &'a mut Vec<CapabilitySid>) -> io::Result<Self> {
|
||||
unsafe {
|
||||
let mut list_size = 0usize;
|
||||
let _ = InitializeProcThreadAttributeList(
|
||||
LPPROC_THREAD_ATTRIBUTE_LIST::default(),
|
||||
1,
|
||||
0,
|
||||
&mut list_size,
|
||||
);
|
||||
let heap =
|
||||
GetProcessHeap().map_err(|e| io::Error::from_raw_os_error(e.code().0))?;
|
||||
let buffer = HeapAlloc(heap, HEAP_ZERO_MEMORY, list_size);
|
||||
if buffer.is_null() {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
let list = LPPROC_THREAD_ATTRIBUTE_LIST(buffer);
|
||||
InitializeProcThreadAttributeList(list, 1, 0, &mut list_size)
|
||||
.map_err(|e| io::Error::from_raw_os_error(e.code().0))?;
|
||||
|
||||
let mut sid_and_attributes: Vec<SID_AND_ATTRIBUTES> =
|
||||
caps.iter().map(CapabilitySid::sid_and_attributes).collect();
|
||||
|
||||
let mut sec_caps = SECURITY_CAPABILITIES {
|
||||
AppContainerSid: sid.sid(),
|
||||
Capabilities: if sid_and_attributes.is_empty() {
|
||||
null_mut()
|
||||
} else {
|
||||
sid_and_attributes.as_mut_ptr()
|
||||
},
|
||||
CapabilityCount: sid_and_attributes.len() as u32,
|
||||
Reserved: 0,
|
||||
};
|
||||
|
||||
UpdateProcThreadAttribute(
|
||||
list,
|
||||
0,
|
||||
PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES as usize,
|
||||
Some(&mut sec_caps as *mut _ as *const std::ffi::c_void),
|
||||
std::mem::size_of::<SECURITY_CAPABILITIES>(),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.map_err(|e| io::Error::from_raw_os_error(e.code().0))?;
|
||||
|
||||
Ok(Self {
|
||||
heap,
|
||||
buffer,
|
||||
list,
|
||||
sec_caps,
|
||||
sid_and_attributes,
|
||||
sid,
|
||||
capabilities: caps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn as_mut_ptr(&mut self) -> LPPROC_THREAD_ATTRIBUTE_LIST {
|
||||
self.list
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AttributeList<'_> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
if !self.list.is_invalid() {
|
||||
DeleteProcThreadAttributeList(self.list);
|
||||
}
|
||||
if !self.heap.is_invalid() && !self.buffer.is_null() {
|
||||
let _ = HeapFree(self.heap, HEAP_FLAGS(0), Some(self.buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_writable_roots(
|
||||
policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
sid: PSID,
|
||||
) -> io::Result<()> {
|
||||
match policy {
|
||||
SandboxPolicy::DangerFullAccess => Ok(()),
|
||||
SandboxPolicy::ReadOnly => grant_path_with_flags(sandbox_policy_cwd, sid, false),
|
||||
SandboxPolicy::WorkspaceWrite { .. } => {
|
||||
let roots = policy.get_writable_roots_with_cwd(sandbox_policy_cwd);
|
||||
for writable in roots {
|
||||
grant_path_with_flags(&writable.root, sid, true)?;
|
||||
for ro in writable.read_only_subpaths {
|
||||
grant_path_with_flags(&ro, sid, false)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_writable_roots_for_command_cwd(command_cwd: &Path, sid: PSID) -> io::Result<()> {
|
||||
grant_path_with_flags(command_cwd, sid, true)
|
||||
}
|
||||
|
||||
fn grant_path_with_flags(path: &Path, sid: PSID, write: bool) -> io::Result<()> {
|
||||
if !path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let wide = to_wide(path.as_os_str());
|
||||
unsafe {
|
||||
let mut existing_dacl: *mut ACL = null_mut();
|
||||
let mut security_descriptor = PSECURITY_DESCRIPTOR::default();
|
||||
let status = GetNamedSecurityInfoW(
|
||||
PCWSTR(wide.as_ptr()),
|
||||
SE_FILE_OBJECT,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
None,
|
||||
None,
|
||||
Some(&mut existing_dacl),
|
||||
None,
|
||||
&mut security_descriptor,
|
||||
);
|
||||
if status != WIN32_ERROR::from(ERROR_SUCCESS) {
|
||||
if !security_descriptor.is_invalid() {
|
||||
let _ = LocalFree(HLOCAL(security_descriptor.0));
|
||||
}
|
||||
return Err(io::Error::from_raw_os_error(status.0 as i32));
|
||||
}
|
||||
|
||||
let permissions = if write {
|
||||
(FILE_GENERIC_READ | FILE_GENERIC_WRITE | FILE_GENERIC_EXECUTE).0
|
||||
} else {
|
||||
(FILE_GENERIC_READ | FILE_GENERIC_EXECUTE).0
|
||||
};
|
||||
let explicit = EXPLICIT_ACCESS_W {
|
||||
grfAccessPermissions: permissions,
|
||||
grfAccessMode: SET_ACCESS,
|
||||
grfInheritance: (SUB_CONTAINERS_AND_OBJECTS_INHERIT | OBJECT_INHERIT_ACE).0,
|
||||
Trustee: TRUSTEE_W {
|
||||
TrusteeForm: TRUSTEE_IS_SID,
|
||||
TrusteeType: TRUSTEE_IS_UNKNOWN,
|
||||
ptstrName: PWSTR(sid.0.cast()),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
let explicit_entries = [explicit];
|
||||
let mut new_dacl: *mut ACL = null_mut();
|
||||
let add_result =
|
||||
SetEntriesInAclW(Some(&explicit_entries), Some(existing_dacl), &mut new_dacl);
|
||||
if add_result != WIN32_ERROR::from(ERROR_SUCCESS) {
|
||||
if !new_dacl.is_null() {
|
||||
let _ = LocalFree(HLOCAL(new_dacl.cast()));
|
||||
}
|
||||
if !security_descriptor.is_invalid() {
|
||||
let _ = LocalFree(HLOCAL(security_descriptor.0));
|
||||
}
|
||||
return Err(io::Error::from_raw_os_error(add_result.0 as i32));
|
||||
}
|
||||
|
||||
let set_result = SetNamedSecurityInfoW(
|
||||
PCWSTR(wide.as_ptr()),
|
||||
SE_FILE_OBJECT,
|
||||
DACL_SECURITY_INFORMATION,
|
||||
None,
|
||||
None,
|
||||
Some(new_dacl),
|
||||
None,
|
||||
);
|
||||
if set_result != WIN32_ERROR::from(ERROR_SUCCESS) {
|
||||
if !new_dacl.is_null() {
|
||||
let _ = LocalFree(HLOCAL(new_dacl.cast()));
|
||||
}
|
||||
if !security_descriptor.is_invalid() {
|
||||
let _ = LocalFree(HLOCAL(security_descriptor.0));
|
||||
}
|
||||
return Err(io::Error::from_raw_os_error(set_result.0 as i32));
|
||||
}
|
||||
|
||||
if !new_dacl.is_null() {
|
||||
let _ = LocalFree(HLOCAL(new_dacl.cast()));
|
||||
}
|
||||
if !security_descriptor.is_invalid() {
|
||||
let _ = LocalFree(HLOCAL(security_descriptor.0));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "windows_appcontainer_command_ext")]
|
||||
pub use imp::spawn_command_under_windows_appcontainer;
|
||||
|
||||
#[cfg(not(feature = "windows_appcontainer_command_ext"))]
|
||||
pub async fn spawn_command_under_windows_appcontainer(
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
_sandbox_policy: &SandboxPolicy,
|
||||
_sandbox_policy_cwd: &Path,
|
||||
_stdio_policy: StdioPolicy,
|
||||
_env: HashMap<String, String>,
|
||||
) -> io::Result<Child> {
|
||||
let _ = (command, command_cwd);
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"AppContainer sandboxing requires the `windows_appcontainer_command_ext` feature",
|
||||
))
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use assert_matches::assert_matches;
|
||||
use std::sync::Arc;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
@@ -178,7 +179,7 @@ async fn streams_text_without_reasoning() {
|
||||
other => panic!("expected terminal message, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -219,7 +220,7 @@ async fn streams_reasoning_from_string_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[4], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[4], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -266,7 +267,7 @@ async fn streams_reasoning_from_object_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[5], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[5], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -293,7 +294,7 @@ async fn streams_reasoning_from_final_message() {
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -337,7 +338,7 @@ async fn streams_reasoning_before_tool_call() {
|
||||
other => panic!("expected function call, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[3], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[3], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -10,6 +10,7 @@ path = "lib.rs"
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
|
||||
@@ -6,23 +6,50 @@ use codex_core::CodexConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use regex_lite::Regex;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
|
||||
pub mod responses;
|
||||
pub mod test_codex;
|
||||
pub mod test_codex_exec;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> {
|
||||
let regex = Regex::new(pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern:?}: {err}");
|
||||
});
|
||||
regex
|
||||
.captures(actual)
|
||||
.unwrap_or_else(|| panic!("regex {pattern:?} did not match {actual:?}"))
|
||||
}
|
||||
|
||||
/// Returns a default `Config` whose on-disk state is confined to the provided
|
||||
/// temporary directory. Using a per-test directory keeps tests hermetic and
|
||||
/// avoids clobbering a developer’s real `~/.codex`.
|
||||
pub fn load_default_config_for_test(codex_home: &TempDir) -> Config {
|
||||
Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
default_test_overrides(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.expect("defaults for test should always succeed")
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
fn default_test_overrides() -> ConfigOverrides {
|
||||
ConfigOverrides {
|
||||
codex_linux_sandbox_exe: Some(cargo_bin("codex-linux-sandbox")),
|
||||
..ConfigOverrides::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
fn default_test_overrides() -> ConfigOverrides {
|
||||
ConfigOverrides::default()
|
||||
}
|
||||
|
||||
/// Builds an SSE stream body from a JSON fixture.
|
||||
///
|
||||
/// The fixture must contain an array of objects where each object represents a
|
||||
|
||||
@@ -34,6 +34,16 @@ pub fn ev_completed(id: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a created response with a specific id.
|
||||
pub fn ev_response_created(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": id,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ev_completed_with_tokens(id: &str, total_tokens: u64) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
@@ -135,6 +145,16 @@ pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sse_failed(id: &str, code: &str, message: &str) -> String {
|
||||
sse(vec![serde_json::json!({
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"error": {"code": code, "message": message}
|
||||
}
|
||||
})])
|
||||
}
|
||||
|
||||
pub fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
|
||||
@@ -13,7 +13,7 @@ use tempfile::TempDir;
|
||||
|
||||
use crate::load_default_config_for_test;
|
||||
|
||||
type ConfigMutator = dyn FnOnce(&mut Config);
|
||||
type ConfigMutator = dyn FnOnce(&mut Config) + Send;
|
||||
|
||||
pub struct TestCodexBuilder {
|
||||
config_mutators: Vec<Box<ConfigMutator>>,
|
||||
@@ -22,7 +22,7 @@ pub struct TestCodexBuilder {
|
||||
impl TestCodexBuilder {
|
||||
pub fn with_config<T>(mut self, mutator: T) -> Self
|
||||
where
|
||||
T: FnOnce(&mut Config) + 'static,
|
||||
T: FnOnce(&mut Config) + Send + 'static,
|
||||
{
|
||||
self.config_mutators.push(Box::new(mutator));
|
||||
self
|
||||
|
||||
@@ -3,14 +3,14 @@ use std::time::Duration;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
|
||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||
/// function call, then interrupt the session and expect TurnAborted.
|
||||
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
||||
let body = sse(vec![
|
||||
ev_function_call("call_sleep", "shell", &args),
|
||||
ev_completed("done"),
|
||||
]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await;
|
||||
mount_sse_once(&server, body).await;
|
||||
|
||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -26,8 +28,10 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
@@ -37,6 +41,7 @@ use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
@@ -996,6 +1001,100 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
let server = MockServer::start().await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("trigger context window"),
|
||||
responses::sse_failed(
|
||||
"resp_context_window",
|
||||
"context_length_exceeded",
|
||||
"Your input exceeds the context window of this model. Please adjust your input and try again.",
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("seed turn"),
|
||||
sse_completed("resp_seed"),
|
||||
)
|
||||
.await;
|
||||
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
.with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("known gpt-5 model family");
|
||||
config.model_context_window = Some(272_000);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "seed turn".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "trigger context window".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
let token_event = wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| {
|
||||
matches!(
|
||||
event,
|
||||
EventMsg::TokenCount(payload)
|
||||
if payload.info.as_ref().is_some_and(|info| {
|
||||
info.model_context_window == Some(info.total_token_usage.total_tokens)
|
||||
&& info.total_token_usage.total_tokens > 0
|
||||
})
|
||||
)
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::TokenCount(token_payload) = token_event else {
|
||||
unreachable!("wait_for_event_with_timeout returned unexpected event");
|
||||
};
|
||||
|
||||
let info = token_payload
|
||||
.info
|
||||
.expect("token usage info present when context window is exceeded");
|
||||
|
||||
assert_eq!(info.model_context_window, Some(272_000));
|
||||
assert_eq!(info.total_token_usage.total_tokens, 272_000);
|
||||
|
||||
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
|
||||
let expected_context_window_message = CodexErr::ContextWindowExceeded.to_string();
|
||||
assert!(
|
||||
matches!(
|
||||
error_event,
|
||||
EventMsg::Error(ref err) if err.message == expected_context_window_message
|
||||
),
|
||||
"expected context window error; got {error_event:?}"
|
||||
);
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
skip_if_no_network!();
|
||||
|
||||
@@ -13,12 +13,6 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
@@ -26,14 +20,10 @@ use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
pub(super) const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
@@ -295,12 +285,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -308,23 +293,13 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -455,12 +430,7 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -468,23 +438,13 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -582,35 +542,20 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1.clone()).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2.clone()).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
!body.contains("You have exceeded the maximum number of tokens")
|
||||
&& body.contains(SUMMARY_TEXT)
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3.clone()).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -708,49 +653,7 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
ev_completed_with_tokens("r6", 120),
|
||||
]);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SeqResponder {
|
||||
bodies: Arc<Vec<String>>,
|
||||
calls: Arc<AtomicUsize>,
|
||||
requests: Arc<Mutex<Vec<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl SeqResponder {
|
||||
fn new(bodies: Vec<String>) -> Self {
|
||||
Self {
|
||||
bodies: Arc::new(bodies),
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn recorded_requests(&self) -> Vec<Vec<u8>> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let idx = self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.requests.lock().unwrap().push(req.body.clone());
|
||||
let body = self
|
||||
.bodies
|
||||
.get(idx)
|
||||
.unwrap_or_else(|| panic!("unexpected request index {idx}"))
|
||||
.clone();
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
let responder = SeqResponder::new(vec![sse1, sse2, sse3, sse4, sse5, sse6]);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder.clone())
|
||||
.expect(6)
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_sequence(&server, vec![sse1, sse2, sse3, sse4, sse5, sse6]).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -801,10 +704,12 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
"auto compact should not emit task lifecycle events"
|
||||
);
|
||||
|
||||
let request_bodies: Vec<String> = responder
|
||||
.recorded_requests()
|
||||
let request_bodies: Vec<String> = server
|
||||
.received_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|body| String::from_utf8(body).unwrap_or_default())
|
||||
.map(|request| String::from_utf8(request.body).unwrap_or_default())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
request_bodies.len(),
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -131,9 +132,10 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let expected_model = OPENAI_DEFAULT_MODEL;
|
||||
let user_turn_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -182,7 +184,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let compact_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -251,7 +253,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let user_turn_2_after_compact = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -316,7 +318,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let usert_turn_3_after_resume = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -401,7 +403,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let user_turn_3_after_fork = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
|
||||
@@ -12,12 +12,20 @@ mod fork_conversation;
|
||||
mod json_result;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod model_tools;
|
||||
mod otel;
|
||||
mod prompt_caching;
|
||||
mod read_file;
|
||||
mod review;
|
||||
mod rmcp_client;
|
||||
mod rollout_list_find;
|
||||
mod seatbelt;
|
||||
mod shell_serialization;
|
||||
mod stream_error_allows_next_turn;
|
||||
mod stream_no_completed;
|
||||
mod tool_harness;
|
||||
mod tool_parallelism;
|
||||
mod tools;
|
||||
mod unified_exec;
|
||||
mod user_notification;
|
||||
mod view_image;
|
||||
|
||||
131
codex-rs/core/tests/suite/model_tools.rs
Normal file
131
codex-rs/core/tests/suite/model_tools.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
fn tool_identifiers(body: &serde_json::Value) -> Vec<String> {
|
||||
body["tools"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
tool.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| tool.get("type").and_then(|v| v.as_str()))
|
||||
.map(std::string::ToString::to_string)
|
||||
.expect("tool should have either name or type")
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed(model);
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.model = model.to_string();
|
||||
config.model_family =
|
||||
find_family_for_model(model).unwrap_or_else(|| panic!("unknown model family for {model}"));
|
||||
config.include_plan_tool = false;
|
||||
config.include_apply_patch_tool = false;
|
||||
config.include_view_image_tool = false;
|
||||
config.tools_web_search_request = false;
|
||||
config.use_experimental_streamable_shell_tool = false;
|
||||
config.use_experimental_unified_exec_tool = false;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello tools".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for model {model}"
|
||||
);
|
||||
let body = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
tool_identifiers(&body)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn model_selects_expected_tools() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await;
|
||||
assert_eq!(
|
||||
codex_tools,
|
||||
vec!["local_shell".to_string()],
|
||||
"codex-mini-latest should expose the local shell tool",
|
||||
);
|
||||
|
||||
let o3_tools = collect_tool_identifiers_for_model("o3").await;
|
||||
assert_eq!(
|
||||
o3_tools,
|
||||
vec!["shell".to_string()],
|
||||
"o3 should expose the generic shell tool",
|
||||
);
|
||||
|
||||
let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await;
|
||||
assert_eq!(
|
||||
gpt5_codex_tools,
|
||||
vec!["shell".to_string(), "apply_patch".to_string(),],
|
||||
"gpt-5-codex should expose the apply_patch tool",
|
||||
);
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -18,6 +19,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use std::collections::HashMap;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
@@ -178,16 +180,16 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
config.include_apply_patch_tool = true;
|
||||
config.include_plan_tool = true;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let expected_instructions = config.model_family.base_instructions.clone();
|
||||
let base_instructions = config.model_family.base_instructions.clone();
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
@@ -219,8 +221,29 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
// our internal implementation is responsible for keeping tools in sync
|
||||
// with the OpenAI schema, so we just verify the tool presence here
|
||||
let expected_tools_names: &[&str] = &["shell", "update_plan", "apply_patch", "view_image"];
|
||||
let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([
|
||||
("gpt-5", vec!["shell", "update_plan", "view_image"]),
|
||||
(
|
||||
"gpt-5-codex",
|
||||
vec!["shell", "update_plan", "apply_patch", "view_image"],
|
||||
),
|
||||
]);
|
||||
let expected_tools_names = tools_by_model
|
||||
.get(OPENAI_DEFAULT_MODEL)
|
||||
.unwrap_or_else(|| panic!("expected tools to be defined for model {OPENAI_DEFAULT_MODEL}"))
|
||||
.as_slice();
|
||||
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
let expected_instructions = if expected_tools_names.contains(&"apply_patch") {
|
||||
base_instructions
|
||||
} else {
|
||||
[
|
||||
base_instructions.clone(),
|
||||
include_str!("../../../apply-patch/apply_patch_tool_instructions.md").to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
body0["instructions"],
|
||||
serde_json::json!(expected_instructions),
|
||||
|
||||
123
codex-rs/core/tests/suite/read_file.rs
Normal file
123
codex-rs/core/tests/suite/read_file.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable read_file tool"]
|
||||
async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let file_path = cwd.path().join("sample.txt");
|
||||
std::fs::write(&file_path, "first\nsecond\nthird\nfourth\n")?;
|
||||
let file_path = file_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "read-file-call";
|
||||
let arguments = serde_json::json!({
|
||||
"file_path": file_path,
|
||||
"offset": 2,
|
||||
"limit": 2,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "read_file", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please inspect sample.txt".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(output_text, "L2: second\nL3: third");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -24,6 +24,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id_from_str;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -260,25 +261,28 @@ async fn review_does_not_emit_agent_message_on_structured_output() {
|
||||
.unwrap();
|
||||
|
||||
// Drain events until TaskComplete; ensure none are AgentMessage.
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
let mut saw_entered = false;
|
||||
let mut saw_exited = false;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(5), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
match ev.msg {
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| match event {
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
EventMsg::AgentMessage(_) => {
|
||||
panic!("unexpected AgentMessage during review with structured output")
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => saw_entered = true,
|
||||
EventMsg::ExitedReviewMode(_) => saw_exited = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => {
|
||||
saw_entered = true;
|
||||
false
|
||||
}
|
||||
EventMsg::ExitedReviewMode(_) => {
|
||||
saw_exited = true;
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
tokio::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
assert!(saw_entered && saw_exited, "missing review lifecycle events");
|
||||
|
||||
server.verify().await;
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
@@ -19,6 +24,8 @@ use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use escargot::CargoBuild;
|
||||
use serde_json::Value;
|
||||
use serial_test::serial;
|
||||
use tempfile::tempdir;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
@@ -40,10 +47,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -177,10 +181,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -328,6 +329,186 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This test writes to a fallback credentials file in CODEX_HOME.
|
||||
/// Ideally, we wouldn't need to serialize the test but it's much more cumbersome to wire CODEX_HOME through the code.
|
||||
#[serial(codex_home)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-789";
|
||||
let server_name = "rmcp_http_oauth";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message(
|
||||
"msg-1",
|
||||
"rmcp streamable http oauth echo tool completed successfully.",
|
||||
),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env-http-oauth";
|
||||
let expected_token = "initial-access-token";
|
||||
let client_id = "test-client-id";
|
||||
let refresh_token = "initial-refresh-token";
|
||||
let rmcp_http_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_streamable_http_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let server_url = format!("http://{bind_addr}/mcp");
|
||||
|
||||
let mut http_server_child = Command::new(&rmcp_http_server_bin)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_EXPECT_BEARER", expected_token)
|
||||
.env("MCP_TEST_VALUE", expected_env_value)
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
|
||||
.await?;
|
||||
|
||||
let temp_home = tempdir()?;
|
||||
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
|
||||
write_fallback_oauth_tokens(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
client_id,
|
||||
expected_token,
|
||||
refresh_token,
|
||||
)?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp streamable http oauth echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
match http_server_child.try_wait() {
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("failed to check streamable http oauth server status: {error}");
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
}
|
||||
if let Err(error) = http_server_child.wait().await {
|
||||
eprintln!("failed to await streamable http oauth server shutdown: {error}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
@@ -369,3 +550,60 @@ async fn wait_for_streamable_http_server(
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn write_fallback_oauth_tokens(
|
||||
home: &Path,
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
client_id: &str,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let expires_at = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(3600))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
|
||||
let store = serde_json::json!({
|
||||
"stub": {
|
||||
"server_name": server_name,
|
||||
"server_url": server_url,
|
||||
"client_id": client_id,
|
||||
"access_token": access_token,
|
||||
"expires_at": expires_at,
|
||||
"refresh_token": refresh_token,
|
||||
"scopes": ["profile"],
|
||||
}
|
||||
});
|
||||
|
||||
let file_path = home.join(".credentials.json");
|
||||
fs::write(&file_path, serde_json::to_vec(&store)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: &std::ffi::OsStr) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
match &self.original {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,6 +169,12 @@ async fn python_getpwuid_works_under_seatbelt() {
|
||||
return;
|
||||
}
|
||||
|
||||
// For local dev.
|
||||
if which::which("python3").is_err() {
|
||||
eprintln!("python3 not found in PATH, skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ReadOnly is sufficient here since we are only exercising user lookup.
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
let command_cwd = std::env::current_dir().expect("getcwd");
|
||||
|
||||
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str, sandbox_policy: SandboxPolicy) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[wiremock::Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn find_function_call_output<'a>(bodies: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
{
|
||||
return Some(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_stays_json_without_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = false;
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-json";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "shell json"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the json shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item = find_function_call_output(&bodies, call_id).expect("shell output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("shell output string");
|
||||
|
||||
let parsed: Value = serde_json::from_str(output)?;
|
||||
assert_eq!(
|
||||
parsed
|
||||
.get("metadata")
|
||||
.and_then(|metadata| metadata.get("exit_code"))
|
||||
.and_then(Value::as_i64),
|
||||
Some(0),
|
||||
"expected zero exit code in unformatted JSON output",
|
||||
);
|
||||
let stdout = parsed
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert_regex_match(r"(?s)^shell json\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_is_structured_with_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-structured";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "freeform shell"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the structured shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("structured output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("structured output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected structured shell output to be plain text",
|
||||
);
|
||||
let expected_pattern = r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
freeform shell
|
||||
?$";
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_reserializes_truncated_content() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("gpt-5-codex").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-truncated";
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "seq 1 400"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the truncation shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("truncated output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("truncated output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected truncated shell output to be plain text",
|
||||
);
|
||||
let truncated_pattern = r#"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: 400
|
||||
Output:
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
.*
|
||||
\[\.{3} omitted \d+ of 400 lines \.{3}\]
|
||||
|
||||
.*
|
||||
396
|
||||
397
|
||||
398
|
||||
399
|
||||
400
|
||||
$"#;
|
||||
assert_regex_match(truncated_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -13,7 +13,7 @@ use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use tokio::time::timeout;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
@@ -102,13 +102,10 @@ async fn retries_on_early_close() {
|
||||
.unwrap();
|
||||
|
||||
// Wait until TaskComplete (should succeed after retry).
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(10), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| matches!(event, EventMsg::TaskComplete(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
543
codex-rs/core/tests/suite/tool_harness.rs
Normal file
543
codex-rs/core/tests/suite/tool_harness.rs
Normal file
@@ -0,0 +1,543 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::plan_tool::StepStatus;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
item.get("output").and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||
requests
|
||||
.iter()
|
||||
.find(|body| function_call_output(body).is_some())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-tool-call";
|
||||
let command = vec!["/bin/echo", "tool harness"];
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_local_shell_call(call_id, "completed", command),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "all done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please run the shell command".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let exec_output: Value = serde_json::from_str(output_text)?;
|
||||
assert_eq!(exec_output["metadata"]["exit_code"], 0);
|
||||
let stdout = exec_output["output"].as_str().expect("stdout field");
|
||||
assert_regex_match(r"(?s)^tool harness\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_plan_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "plan-tool-call";
|
||||
let plan_args = json!({
|
||||
"explanation": "Tool harness check",
|
||||
"plan": [
|
||||
{"step": "Inspect workspace", "status": "in_progress"},
|
||||
{"step": "Report results", "status": "pending"},
|
||||
],
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &plan_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "plan acknowledged"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please update the plan".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(update) => {
|
||||
saw_plan_update = true;
|
||||
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||
assert_eq!(update.plan.len(), 2);
|
||||
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||
assert_matches!(update.plan[0].status, StepStatus::InProgress);
|
||||
assert_eq!(update.plan[1].step, "Report results");
|
||||
assert_matches!(update.plan[1].status, StepStatus::Pending);
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_plan_update, "expected PlanUpdate event");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
assert_eq!(output_text, "Plan updated");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_plan_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "plan-tool-invalid";
|
||||
let invalid_args = json!({
|
||||
"explanation": "Missing plan data"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &invalid_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "malformed plan payload"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please update the plan".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(_) => {
|
||||
saw_plan_update = true;
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
!saw_plan_update,
|
||||
"did not expect PlanUpdate event for malformed payload"
|
||||
);
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
assert!(
|
||||
output_text.contains("failed to parse function arguments"),
|
||||
"expected parse error message in output text, got {output_text:?}"
|
||||
);
|
||||
if let Some(success_flag) = output_item
|
||||
.get("output")
|
||||
.and_then(|value| value.as_object())
|
||||
.and_then(|obj| obj.get("success"))
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
{
|
||||
assert!(
|
||||
!success_flag,
|
||||
"expected tool output to mark success=false for malformed payload"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-call";
|
||||
let patch_content = r#"*** Begin Patch
|
||||
*** Add File: notes.txt
|
||||
+Tool harness apply patch
|
||||
*** End Patch"#;
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "patch complete"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please apply a patch".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut saw_patch_begin = false;
|
||||
let mut patch_end_success = None;
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PatchApplyBegin(begin) => {
|
||||
saw_patch_begin = true;
|
||||
assert_eq!(begin.call_id, call_id);
|
||||
false
|
||||
}
|
||||
EventMsg::PatchApplyEnd(end) => {
|
||||
assert_eq!(end.call_id, call_id);
|
||||
patch_end_success = Some(end.success);
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_patch_begin, "expected PatchApplyBegin event");
|
||||
let patch_end_success =
|
||||
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
|
||||
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
|
||||
let exit_code = exec_output["metadata"]["exit_code"]
|
||||
.as_i64()
|
||||
.expect("exit_code present");
|
||||
let summary = exec_output["output"].as_str().expect("output field");
|
||||
assert_eq!(
|
||||
exit_code, 0,
|
||||
"expected apply_patch exit_code=0, got {exit_code}, summary: {summary:?}"
|
||||
);
|
||||
assert!(
|
||||
patch_end_success,
|
||||
"expected PatchApplyEnd success flag, summary: {summary:?}"
|
||||
);
|
||||
assert!(
|
||||
summary.contains("Success."),
|
||||
"expected apply_patch summary to note success, got {summary:?}"
|
||||
);
|
||||
|
||||
let patched_path = cwd.path().join("notes.txt");
|
||||
let contents = std::fs::read_to_string(&patched_path)
|
||||
.unwrap_or_else(|e| panic!("failed reading {}: {e}", patched_path.display()));
|
||||
assert_eq!(contents, "Tool harness apply patch\n");
|
||||
} else {
|
||||
assert!(
|
||||
output_text.contains("codex-run-as-apply-patch"),
|
||||
"expected apply_patch failure message to mention codex-run-as-apply-patch, got {output_text:?}"
|
||||
);
|
||||
assert!(
|
||||
!patch_end_success,
|
||||
"expected PatchApplyEnd to report success=false when apply_patch invocation fails"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "apply-patch-parse-error";
|
||||
let patch_content = r"*** Begin Patch
|
||||
*** Update File: broken.txt
|
||||
*** End Patch";
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "failed"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please apply a patch".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
|
||||
assert!(
|
||||
output_text.contains("apply_patch verification failed"),
|
||||
"expected apply_patch verification failure message, got {output_text:?}"
|
||||
);
|
||||
assert!(
|
||||
output_text.contains("invalid hunk"),
|
||||
"expected parse diagnostics in output text, got {output_text:?}"
|
||||
);
|
||||
|
||||
if let Some(success_flag) = output_item
|
||||
.get("output")
|
||||
.and_then(|value| value.as_object())
|
||||
.and_then(|obj| obj.get("success"))
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
{
|
||||
assert!(
|
||||
!success_flag,
|
||||
"expected tool output to mark success=false for parse failures"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
178
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
178
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::json;
|
||||
|
||||
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
|
||||
let start = Instant::now();
|
||||
run_turn(test, prompt).await?;
|
||||
Ok(start.elapsed())
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "test-gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
fn assert_parallel_duration(actual: Duration) {
|
||||
assert!(
|
||||
actual < Duration::from_millis(500),
|
||||
"expected parallel execution to finish quickly, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_serial_duration(actual: Duration) {
|
||||
assert!(
|
||||
actual >= Duration::from_millis(500),
|
||||
"expected serial execution to take longer, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let parallel_args = json!({
|
||||
"sleep_after_ms": 300,
|
||||
"barrier": {
|
||||
"id": "parallel-test-sync",
|
||||
"participants": 2,
|
||||
"timeout_ms": 1_000,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", ¶llel_args),
|
||||
ev_function_call("call-2", "test_sync_tool", ¶llel_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
|
||||
assert_parallel_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = test_codex().build(&server).await?;
|
||||
|
||||
let shell_args = json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let args_one = serde_json::to_string(&shell_args)?;
|
||||
let args_two = serde_json::to_string(&shell_args)?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "shell", &args_one),
|
||||
ev_function_call("call-2", "shell", &args_two),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "run shell twice").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let sync_args = json!({
|
||||
"sleep_after_ms": 300
|
||||
})
|
||||
.to_string();
|
||||
let shell_args = serde_json::to_string(&json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
}))?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", &sync_args),
|
||||
ev_function_call("call-2", "shell", &shell_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "mix tools").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
614
codex-rs/core/tests/suite/tools.rs
Normal file
614
codex-rs/core/tests/suite/tools.rs
Normal file
@@ -0,0 +1,614 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::Request;
|
||||
|
||||
async fn submit_turn(
|
||||
test: &TestCodex,
|
||||
prompt: &str,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> {
|
||||
let mut out = Vec::new();
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some(ty) {
|
||||
out.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn tool_names(body: &Value) -> Vec<String> {
|
||||
body.get("tools")
|
||||
.and_then(Value::as_array)
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
tool.get("name")
|
||||
.or_else(|| tool.get("type"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "custom-unsupported";
|
||||
let tool_name = "unsupported_tool";
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_custom_tool_call(call_id, tool_name, "\"payload\""),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"invoke custom tool",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let custom_items = collect_output_items(&bodies, "custom_tool_call_output");
|
||||
assert_eq!(custom_items.len(), 1, "expected single custom tool output");
|
||||
let item = custom_items[0];
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id));
|
||||
|
||||
let output = item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
let expected = format!("unsupported custom tool call: {tool_name}");
|
||||
assert_eq!(output, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let command = ["/bin/echo", "shell ok"];
|
||||
let call_id_blocked = "shell-blocked";
|
||||
let call_id_success = "shell-success";
|
||||
|
||||
let first_args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 1_000,
|
||||
"with_escalated_permissions": true,
|
||||
});
|
||||
let second_args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
call_id_blocked,
|
||||
"shell",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
call_id_success,
|
||||
"shell",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the shell command",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
for item in &function_outputs {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
call_id == call_id_blocked || call_id == call_id_success,
|
||||
"unexpected call id {call_id}"
|
||||
);
|
||||
}
|
||||
|
||||
let policy = AskForApproval::Never;
|
||||
let expected_message = format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}"
|
||||
);
|
||||
|
||||
let blocked_outputs: Vec<&Value> = function_outputs
|
||||
.iter()
|
||||
.filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked))
|
||||
.copied()
|
||||
.collect();
|
||||
assert!(
|
||||
!blocked_outputs.is_empty(),
|
||||
"expected at least one rejection output for {call_id_blocked}"
|
||||
);
|
||||
for item in blocked_outputs {
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
Some(expected_message.as_str()),
|
||||
"unexpected rejection message"
|
||||
);
|
||||
}
|
||||
|
||||
let success_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success))
|
||||
.expect("success output present");
|
||||
let output_json: Value = serde_json::from_str(
|
||||
success_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("success output string"),
|
||||
)?;
|
||||
assert_eq!(
|
||||
output_json["metadata"]["exit_code"].as_i64(),
|
||||
Some(0),
|
||||
"expected exit code 0 after rerunning without escalation",
|
||||
);
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
let stdout_pattern = r"(?s)^shell ok\n?$";
|
||||
assert_regex_match(stdout_pattern, stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let local_shell_event = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "local_shell_call",
|
||||
"status": "completed",
|
||||
"action": {
|
||||
"type": "exec",
|
||||
"command": ["/bin/echo", "hi"],
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
local_shell_event,
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"check shell output",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
assert_eq!(
|
||||
function_outputs.len(),
|
||||
1,
|
||||
"expected a single function output"
|
||||
);
|
||||
let item = function_outputs[0];
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
Some("LocalShellCall without call_id or id"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let responses = vec![sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.use_experimental_unified_exec_tool = use_unified_exec;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"list tools",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for tools collection"
|
||||
);
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let first_body = bodies.first().expect("request body present");
|
||||
Ok(tool_names(first_body))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_spec_toggle_end_to_end() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let tools_disabled = collect_tools(false).await?;
|
||||
assert!(
|
||||
!tools_disabled.iter().any(|name| name == "unified_exec"),
|
||||
"tools list should not include unified_exec when disabled: {tools_disabled:?}"
|
||||
);
|
||||
|
||||
let tools_enabled = collect_tools(true).await?;
|
||||
assert!(
|
||||
tools_enabled.iter().any(|name| name == "unified_exec"),
|
||||
"tools list should include unified_exec when enabled: {tools_enabled:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-timeout";
|
||||
let timeout_ms = 50u64;
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "yes line | head -n 400; sleep 1"],
|
||||
"timeout_ms": timeout_ms,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run a long command",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
let timeout_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||
.expect("timeout output present");
|
||||
|
||||
let output_str = timeout_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("timeout output string");
|
||||
|
||||
// The exec path can report a timeout in two ways depending on timing:
|
||||
// 1) Structured JSON with exit_code 124 and a timeout prefix (preferred), or
|
||||
// 2) A plain error string if the child is observed as killed by a signal first.
|
||||
if let Ok(output_json) = serde_json::from_str::<Value>(output_str) {
|
||||
assert_eq!(
|
||||
output_json["metadata"]["exit_code"].as_i64(),
|
||||
Some(124),
|
||||
"expected timeout exit code 124",
|
||||
);
|
||||
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
let timeout_pattern = r"(?s)^Total output lines: \d+
|
||||
|
||||
command timed out after (?P<ms>\d+) milliseconds
|
||||
line
|
||||
.*$";
|
||||
let captures = assert_regex_match(timeout_pattern, stdout);
|
||||
let duration_ms = captures
|
||||
.name("ms")
|
||||
.and_then(|m| m.as_str().parse::<u64>().ok())
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
duration_ms >= timeout_ms,
|
||||
"expected duration >= configured timeout, got {duration_ms} (timeout {timeout_ms})"
|
||||
);
|
||||
} else {
|
||||
// Fallback: accept the signal classification path to deflake the test.
|
||||
let signal_pattern = r"(?is)^execution error:.*signal.*$";
|
||||
assert_regex_match(signal_pattern, output_str);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_sandbox_denied_truncates_error_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-denied";
|
||||
let long_line = "this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
let script = format!(
|
||||
"for i in $(seq 1 500); do >&2 echo '{long_line}'; done; cat <<'EOF' > denied.txt\ncontent\nEOF",
|
||||
);
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", script],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"attempt to write in read-only sandbox",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
let denied_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||
.expect("denied output present");
|
||||
|
||||
let output = denied_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("denied output string");
|
||||
|
||||
let sandbox_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
failed in sandbox: .*?(?:Operation not permitted|Permission denied|Read-only file system).*?
|
||||
\[\.{3} omitted \d+ of \d+ lines \.{3}\]
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
\n?$"#;
|
||||
let sandbox_regex = Regex::new(sandbox_pattern)?;
|
||||
if !sandbox_regex.is_match(output) {
|
||||
let fallback_pattern = r#"(?s)^Total output lines: \d+
|
||||
|
||||
failed in sandbox: this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
.*(?:Operation not permitted|Permission denied|Read-only file system).*$"#;
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|cfg| {
|
||||
cfg.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-spawn-failure";
|
||||
let bogus_component = "missing-bin-".repeat(700);
|
||||
let bogus_exe = test
|
||||
.cwd
|
||||
.path()
|
||||
.join(bogus_component)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let args = json!({
|
||||
"command": [bogus_exe],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"spawn a missing binary",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
let failure_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||
.expect("spawn failure output present");
|
||||
|
||||
let output = failure_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("spawn failure output string");
|
||||
|
||||
let spawn_error_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
execution error: .*$"#;
|
||||
let spawn_truncated_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
execution error: .*$"#;
|
||||
let spawn_error_regex = Regex::new(spawn_error_pattern)?;
|
||||
let spawn_truncated_regex = Regex::new(spawn_truncated_pattern)?;
|
||||
if !spawn_error_regex.is_match(output) && !spawn_truncated_regex.is_match(output) {
|
||||
let fallback_pattern = r"(?s)^execution error: .*$";
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
assert!(output.len() <= 10 * 1024);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
277
codex-rs/core/tests/suite/unified_exec.rs
Normal file
277
codex-rs/core/tests/suite/unified_exec.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::skip_if_sandbox;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
item.get("output").and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, Value>> {
|
||||
let mut outputs = HashMap::new();
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) != Some("function_call_output") {
|
||||
continue;
|
||||
}
|
||||
if let Some(call_id) = item.get("call_id").and_then(Value::as_str) {
|
||||
let content = extract_output_text(item)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing tool output content"))?;
|
||||
let parsed: Value = serde_json::from_str(content)?;
|
||||
outputs.insert(call_id.to_string(), parsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let first_call_id = "uexec-start";
|
||||
let first_args = serde_json::json!({
|
||||
"input": ["/bin/cat"],
|
||||
"timeout_ms": 200,
|
||||
});
|
||||
|
||||
let second_call_id = "uexec-stdin";
|
||||
let second_args = serde_json::json!({
|
||||
"input": ["hello unified exec\n"],
|
||||
"session_id": "0",
|
||||
"timeout_ms": 500,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "all done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "run unified exec".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let start_output = outputs
|
||||
.get(first_call_id)
|
||||
.expect("missing first unified_exec output");
|
||||
let session_id = start_output["session_id"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
!session_id.is_empty(),
|
||||
"expected session id in first unified_exec response"
|
||||
);
|
||||
assert!(
|
||||
start_output["output"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.is_empty()
|
||||
);
|
||||
|
||||
let reuse_output = outputs
|
||||
.get(second_call_id)
|
||||
.expect("missing reused unified_exec output");
|
||||
assert_eq!(
|
||||
reuse_output["session_id"].as_str().unwrap_or_default(),
|
||||
session_id
|
||||
);
|
||||
let echoed = reuse_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
echoed.contains("hello unified exec"),
|
||||
"expected echoed output, got {echoed:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let first_call_id = "uexec-timeout";
|
||||
let first_args = serde_json::json!({
|
||||
"input": ["/bin/sh", "-c", "sleep 0.1; echo ready"],
|
||||
"timeout_ms": 10,
|
||||
});
|
||||
|
||||
let second_call_id = "uexec-poll";
|
||||
let second_args = serde_json::json!({
|
||||
"input": Vec::<String>::new(),
|
||||
"session_id": "0",
|
||||
"timeout_ms": 800,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "check timeout".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
|
||||
let first_output = outputs.get(first_call_id).expect("missing timeout output");
|
||||
assert_eq!(first_output["session_id"], "0");
|
||||
assert!(
|
||||
first_output["output"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.is_empty()
|
||||
);
|
||||
|
||||
let poll_output = outputs.get(second_call_id).expect("missing poll output");
|
||||
let output_text = poll_output["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
output_text.contains("ready"),
|
||||
"expected ready output, got {output_text:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
338
codex-rs/core/tests/suite/view_image.rs
Normal file
338
codex-rs/core/tests/suite/view_image.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn find_image_message(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("message")
|
||||
&& item
|
||||
.get("content")
|
||||
.and_then(Value::as_array)
|
||||
.map(|content| {
|
||||
content.iter().any(|span| {
|
||||
span.get("type").and_then(Value::as_str) == Some("input_image")
|
||||
})
|
||||
})
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
item.get("output").and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||
requests
|
||||
.iter()
|
||||
.find(|body| function_call_output(body).is_some())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let rel_path = "assets/example.png";
|
||||
let abs_path = cwd.path().join(rel_path);
|
||||
if let Some(parent) = abs_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let image_bytes = b"fake_png_bytes".to_vec();
|
||||
std::fs::write(&abs_path, &image_bytes)?;
|
||||
|
||||
let call_id = "view-image-call";
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please add the screenshot".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut tool_event = None;
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::ViewImageToolCall(_) => {
|
||||
tool_event = Some(event.clone());
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
let tool_event = match tool_event.expect("view image tool event emitted") {
|
||||
EventMsg::ViewImageToolCall(event) => event,
|
||||
_ => unreachable!("stored event must be ViewImageToolCall"),
|
||||
};
|
||||
assert_eq!(tool_event.call_id, call_id);
|
||||
assert_eq!(tool_event.path, abs_path);
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
assert_eq!(output_text, "attached local image path");
|
||||
|
||||
let image_message = find_image_message(body_with_tool_output)
|
||||
.expect("pending input image message not included in request");
|
||||
let image_url = image_message
|
||||
.get("content")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|content| {
|
||||
content.iter().find_map(|span| {
|
||||
if span.get("type").and_then(Value::as_str) == Some("input_image") {
|
||||
span.get("image_url").and_then(Value::as_str)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.expect("image_url present");
|
||||
|
||||
let expected_image_url = format!(
|
||||
"data:image/png;base64,{}",
|
||||
BASE64_STANDARD.encode(&image_bytes)
|
||||
);
|
||||
assert_eq!(image_url, expected_image_url);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let rel_path = "assets";
|
||||
let abs_path = cwd.path().join(rel_path);
|
||||
std::fs::create_dir_all(&abs_path)?;
|
||||
|
||||
let call_id = "view-image-directory";
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please attach the folder".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let expected_message = format!("image path `{}` is not a file", abs_path.display());
|
||||
assert_eq!(output_text, expected_message);
|
||||
|
||||
assert!(
|
||||
find_image_message(body_with_tool_output).is_none(),
|
||||
"directory path should not produce an input_image message"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let rel_path = "missing/example.png";
|
||||
let abs_path = cwd.path().join(rel_path);
|
||||
|
||||
let call_id = "view-image-missing";
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "please attach the missing image".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display());
|
||||
assert!(
|
||||
output_text.starts_with(&expected_prefix),
|
||||
"expected error to start with `{expected_prefix}` but got `{output_text}`"
|
||||
);
|
||||
|
||||
assert!(
|
||||
find_image_message(body_with_tool_output).is_none(),
|
||||
"missing file should not produce an input_image message"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user