mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
62 Commits
add-admin-
...
pr4925
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c986aeb0c1 | ||
|
|
a43ae86b6c | ||
|
|
496cb801e1 | ||
|
|
abd517091f | ||
|
|
b8b04514bc | ||
|
|
0e5d72cc57 | ||
|
|
60f9e85c16 | ||
|
|
b016a3e7d8 | ||
|
|
a0d56541cf | ||
|
|
226215f36d | ||
|
|
338c2c873c | ||
|
|
4b0f5eb6a8 | ||
|
|
75176dae70 | ||
|
|
12fd2b4160 | ||
|
|
f2555422b9 | ||
|
|
27f169bb91 | ||
|
|
b16c985ed2 | ||
|
|
35a770e871 | ||
|
|
b09f62a1c3 | ||
|
|
5833508a17 | ||
|
|
d73055c5b1 | ||
|
|
7e3a272b29 | ||
|
|
661663c98a | ||
|
|
721003c552 | ||
|
|
36f1cca1b1 | ||
|
|
d3e1beb26c | ||
|
|
c264ae6021 | ||
|
|
8cd882c4bd | ||
|
|
90fe5e4a7e | ||
|
|
a90a58f7a1 | ||
|
|
b2d81a7cac | ||
|
|
77a8b7fdeb | ||
|
|
7fa5e95c1f | ||
|
|
191d620707 | ||
|
|
53504a38d2 | ||
|
|
5c42419b02 | ||
|
|
aecbe0f333 | ||
|
|
a30a902db5 | ||
|
|
f3b4a26f32 | ||
|
|
dc3c6bf62a | ||
|
|
3203862167 | ||
|
|
06853d94f0 | ||
|
|
cc2f4aafd7 | ||
|
|
356ea6ea34 | ||
|
|
4764fc1ee7 | ||
|
|
90ef94d3b3 | ||
|
|
6c2969d22d | ||
|
|
0ad1b0782b | ||
|
|
d7acd146fb | ||
|
|
c5465aed60 | ||
|
|
a95605a867 | ||
|
|
848058f05b | ||
|
|
a4f1c9d67e | ||
|
|
665341c9b1 | ||
|
|
fae0e6c52c | ||
|
|
1b4a79f03c | ||
|
|
640192ac3d | ||
|
|
205c36e393 | ||
|
|
d13ee79c41 | ||
|
|
bde468ff8d | ||
|
|
e292d1ed21 | ||
|
|
de8d77274a |
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
6
.github/ISSUE_TEMPLATE/4-feature-request.yml
vendored
@@ -2,7 +2,6 @@ name: 🎁 Feature Request
|
||||
description: Propose a new feature for Codex
|
||||
labels:
|
||||
- enhancement
|
||||
- needs triage
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
@@ -19,11 +18,6 @@ body:
|
||||
label: What feature would you like to see?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: author
|
||||
attributes:
|
||||
label: Are you interested in implementing this feature?
|
||||
description: Please wait for acknowledgement before implementing or opening a PR.
|
||||
- type: textarea
|
||||
id: notes
|
||||
attributes:
|
||||
|
||||
30
.github/workflows/issue-deduplicator.yml
vendored
30
.github/workflows/issue-deduplicator.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -44,8 +44,8 @@ jobs:
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
require_repo_write: false
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
model: gpt-5
|
||||
prompt: |
|
||||
You are an assistant that triages new GitHub issues by identifying potential duplicates.
|
||||
@@ -55,12 +55,13 @@ jobs:
|
||||
- `codex-existing-issues.json`: JSON array of recent issues (each element includes number, title, body, createdAt).
|
||||
|
||||
Instructions:
|
||||
- Load both files as JSON and review their contents carefully. The codex-existing-issues.json file is large, ensure you explore all of it.
|
||||
- Compare the current issue against the existing issues to find up to five that appear to describe the same underlying problem or request.
|
||||
- Focus on the underlying intent and context of each issue—such as reported symptoms, feature requests, reproduction steps, or error messages—rather than relying solely on string similarity or synthetic metrics.
|
||||
- After your analysis, validate your results in 1-2 lines explaining your decision to return the selected matches.
|
||||
- When unsure, prefer returning fewer matches.
|
||||
- Include at most five numbers.
|
||||
|
||||
output_schema: |
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -69,9 +70,10 @@ jobs:
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"reason": { "type": "string" }
|
||||
},
|
||||
"required": ["issues"],
|
||||
"required": ["issues", "reason"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
||||
@@ -102,14 +104,22 @@ jobs:
|
||||
}
|
||||
|
||||
const issues = Array.isArray(parsed?.issues) ? parsed.issues : [];
|
||||
if (issues.length === 0) {
|
||||
const currentIssueNumber = String(context.payload.issue.number);
|
||||
|
||||
console.log(`Current issue number: ${currentIssueNumber}`);
|
||||
console.log(issues);
|
||||
|
||||
const filteredIssues = issues.filter((value) => String(value) !== currentIssueNumber);
|
||||
|
||||
if (filteredIssues.length === 0) {
|
||||
core.info('Codex reported no potential duplicates.');
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = [
|
||||
'Potential duplicates detected:',
|
||||
...issues.map((value) => `- #${String(value)}`),
|
||||
'Potential duplicates detected. Please review them and close your issue if it is a duplicate.',
|
||||
'',
|
||||
...filteredIssues.map((value) => `- #${String(value)}`),
|
||||
'',
|
||||
'*Powered by [Codex Action](https://github.com/openai/codex-action)*'];
|
||||
|
||||
|
||||
8
.github/workflows/issue-labeler.yml
vendored
8
.github/workflows/issue-labeler.yml
vendored
@@ -14,15 +14,15 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
codex_output: ${{ steps.codex.outputs.final_message }}
|
||||
codex_output: ${{ steps.codex.outputs.final-message }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: codex
|
||||
uses: openai/codex-action@main
|
||||
with:
|
||||
openai_api_key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
require_repo_write: false
|
||||
openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }}
|
||||
allow-users: "*"
|
||||
prompt: |
|
||||
You are an assistant that reviews GitHub issues for the repository.
|
||||
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
Repository full name:
|
||||
${{ github.repository }}
|
||||
|
||||
output_schema: |
|
||||
output-schema: |
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -73,3 +73,28 @@ If you don’t have the tool:
|
||||
### Test assertions
|
||||
|
||||
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
|
||||
|
||||
### Integration tests (core)
|
||||
|
||||
- Prefer the utilities in `core_test_support::responses` when writing end-to-end Codex tests.
|
||||
|
||||
- All `mount_sse*` helpers return a `ResponseMock`; hold onto it so you can assert against outbound `/responses` POST bodies.
|
||||
- Use `ResponseMock::single_request()` when a test should only issue one POST, or `ResponseMock::requests()` to inspect every captured `ResponsesRequest`.
|
||||
- `ResponsesRequest` exposes helpers (`body_json`, `input`, `function_call_output`, `custom_tool_call_output`, `call_output`, `header`, `path`, `query_param`) so assertions can target structured payloads instead of manual JSON digging.
|
||||
- Build SSE payloads with the provided `ev_*` constructors and the `sse(...)`.
|
||||
|
||||
- Typical pattern:
|
||||
|
||||
```rust
|
||||
let mock = responses::mount_sse_once(&server, responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
responses::ev_completed("resp-1"),
|
||||
])).await;
|
||||
|
||||
codex.submit(Op::UserTurn { ... }).await?;
|
||||
|
||||
// Assert request body if needed.
|
||||
let request = mock.single_request();
|
||||
// assert using request.function_call_output(call_id) or request.json_body() or other helpers.
|
||||
```
|
||||
|
||||
@@ -61,7 +61,7 @@ You can also use Codex with an API key, but this requires [additional setup](./d
|
||||
|
||||
### Model Context Protocol (MCP)
|
||||
|
||||
Codex CLI supports [MCP servers](./docs/advanced.md#model-context-protocol-mcp). Enable by adding an `mcp_servers` section to your `~/.codex/config.toml`.
|
||||
Codex can access MCP servers. To configure them, refer to the [config docs](./docs/config.md#mcp_servers).
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -81,9 +81,11 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
- [**Authentication**](./docs/authentication.md)
|
||||
- [Auth methods](./docs/authentication.md#forcing-a-specific-auth-method-advanced)
|
||||
- [Login on a "Headless" machine](./docs/authentication.md#connecting-on-a-headless-machine)
|
||||
- [**Non-interactive mode**](./docs/exec.md)
|
||||
- **Automating Codex**
|
||||
- [GitHub Action](https://github.com/openai/codex-action)
|
||||
- [TypeScript SDK](./sdk/typescript/README.md)
|
||||
- [Non-interactive mode (`codex exec`)](./docs/exec.md)
|
||||
- [**Advanced**](./docs/advanced.md)
|
||||
- [Non-interactive / CI mode](./docs/advanced.md#non-interactive--ci-mode)
|
||||
- [Tracing / verbose logging](./docs/advanced.md#tracing--verbose-logging)
|
||||
- [Model Context Protocol (MCP)](./docs/advanced.md#model-context-protocol-mcp)
|
||||
- [**Zero data retention (ZDR)**](./docs/zdr.md)
|
||||
|
||||
54
codex-rs/Cargo.lock
generated
54
codex-rs/Cargo.lock
generated
@@ -300,6 +300,12 @@ dependencies = [
|
||||
"wait-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assert_matches"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
|
||||
|
||||
[[package]]
|
||||
name = "async-broadcast"
|
||||
version = "0.7.2"
|
||||
@@ -871,6 +877,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"assert_matches",
|
||||
"pretty_assertions",
|
||||
"similar",
|
||||
"tempfile",
|
||||
@@ -933,6 +940,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"assert_matches",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"codex-app-server",
|
||||
@@ -980,12 +988,11 @@ dependencies = [
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"throbber-widgets-tui",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"unicode-width 0.1.14",
|
||||
"unicode-width 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1022,6 +1029,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"askama",
|
||||
"assert_cmd",
|
||||
"assert_matches",
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"base64",
|
||||
@@ -1162,6 +1170,7 @@ dependencies = [
|
||||
name = "codex-git-tooling"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"assert_matches",
|
||||
"pretty_assertions",
|
||||
"tempfile",
|
||||
"thiserror 2.0.16",
|
||||
@@ -1249,6 +1258,7 @@ dependencies = [
|
||||
name = "codex-ollama"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"assert_matches",
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"codex-core",
|
||||
@@ -1367,6 +1377,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arboard",
|
||||
"assert_matches",
|
||||
"async-stream",
|
||||
"base64",
|
||||
"chrono",
|
||||
@@ -1413,6 +1424,8 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter-bash",
|
||||
"tree-sitter-highlight",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.2.1",
|
||||
"url",
|
||||
@@ -1432,6 +1445,7 @@ dependencies = [
|
||||
name = "codex-utils-readiness"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"assert_matches",
|
||||
"async-trait",
|
||||
"thiserror 2.0.16",
|
||||
"time",
|
||||
@@ -1562,6 +1576,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"regex-lite",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
@@ -4757,9 +4772,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.7.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af"
|
||||
checksum = "6f35acda8f89fca5fd8c96cae3c6d5b4c38ea0072df4c8030915f3b5ff469c1c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
@@ -4791,9 +4806,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.7.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ba777eb0e5f53a757e36f0e287441da0ab766564ba7201600eeb92a4753022e"
|
||||
checksum = "c9f1d5220aaa23b79c3d02e18f7a554403b3ccea544bbb6c69d6bcb3e854a274"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
@@ -5829,16 +5844,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "throbber-widgets-tui"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d36b5738d666a2b4c91b7c24998a8588db724b3107258343ebf8824bf55b06d"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
"ratatui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.10.3"
|
||||
@@ -6016,6 +6021,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
@@ -6257,9 +6263,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.9"
|
||||
version = "0.25.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccd2a058a86cfece0bf96f7cce1021efef9c8ed0e892ab74639173e5ed7a34fa"
|
||||
checksum = "78f873475d258561b06f1c595d93308a7ed124d9977cb26b148c2084a4a3cc87"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
@@ -6279,6 +6285,18 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-highlight"
|
||||
version = "0.25.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adc5f880ad8d8f94e88cb81c3557024cf1a8b75e3b504c50481ed4f5a6006ff3"
|
||||
dependencies = [
|
||||
"regex",
|
||||
"streaming-iterator",
|
||||
"thiserror 2.0.16",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.5"
|
||||
|
||||
@@ -83,6 +83,7 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_matches = "1.5.0"
|
||||
assert_cmd = "2"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
@@ -142,7 +143,7 @@ rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = "0.12"
|
||||
rmcp = { version = "0.7.0", default-features = false }
|
||||
rmcp = { version = "0.8.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
serde = "1"
|
||||
@@ -174,8 +175,9 @@ tracing = "0.1.41"
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = "0.3.20"
|
||||
tracing-test = "0.2.5"
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
tree-sitter = "0.25.10"
|
||||
tree-sitter-bash = "0.25"
|
||||
tree-sitter-highlight = "0.25.10"
|
||||
ts-rs = "11"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.2"
|
||||
@@ -243,5 +245,9 @@ strip = "symbols"
|
||||
codegen-units = 1
|
||||
|
||||
[patch.crates-io]
|
||||
# Uncomment to debug local changes.
|
||||
# ratatui = { path = "../../ratatui" }
|
||||
ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" }
|
||||
|
||||
# Uncomment to debug local changes.
|
||||
# rmcp = { path = "../../rust-sdk/crates/rmcp" }
|
||||
|
||||
@@ -23,9 +23,15 @@ Codex supports a rich set of configuration options. Note that the Rust CLI uses
|
||||
|
||||
### Model Context Protocol Support
|
||||
|
||||
Codex CLI functions as an MCP client that can connect to MCP servers on startup. See the [`mcp_servers`](../docs/config.md#mcp_servers) section in the configuration documentation for details.
|
||||
#### MCP client
|
||||
|
||||
It is still experimental, but you can also launch Codex as an MCP _server_ by running `codex mcp-server`. Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
Codex CLI functions as an MCP client that allows the Codex CLI and IDE extension to connect to MCP servers on startup. See the [`configuration documentation`](../docs/config.md#mcp_servers) for details.
|
||||
|
||||
#### MCP server (experimental)
|
||||
|
||||
Codex can be launched as an MCP _server_ by running `codex mcp-server`. This allows _other_ MCP clients to use Codex as a tool for another agent.
|
||||
|
||||
Use the [`@modelcontextprotocol/inspector`](https://github.com/modelcontextprotocol/inspector) to try it out:
|
||||
|
||||
```shell
|
||||
npx @modelcontextprotocol/inspector codex mcp-server
|
||||
@@ -71,9 +77,13 @@ To test to see what happens when a command is run under the sandbox provided by
|
||||
|
||||
```
|
||||
# macOS
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex sandbox macos [--full-auto] [COMMAND]...
|
||||
|
||||
# Linux
|
||||
codex sandbox linux [--full-auto] [COMMAND]...
|
||||
|
||||
# Legacy aliases
|
||||
codex debug seatbelt [--full-auto] [COMMAND]...
|
||||
codex debug landlock [--full-auto] [COMMAND]...
|
||||
```
|
||||
|
||||
|
||||
@@ -23,5 +23,6 @@ tree-sitter-bash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -843,6 +843,7 @@ pub fn print_summary(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::string::ToString;
|
||||
@@ -894,10 +895,10 @@ mod tests {
|
||||
|
||||
fn assert_not_match(script: &str) {
|
||||
let args = args_bash(script);
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch(&args),
|
||||
MaybeApplyPatch::NotApplyPatch
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -905,10 +906,10 @@ mod tests {
|
||||
let patch = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch".to_string();
|
||||
let args = vec![patch];
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -916,10 +917,10 @@ mod tests {
|
||||
let script = "*** Begin Patch\n*** Add File: foo\n+hi\n*** End Patch";
|
||||
let args = args_bash(script);
|
||||
let dir = tempdir().unwrap();
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
maybe_parse_apply_patch_verified(&args, dir.path()),
|
||||
MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation)
|
||||
));
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -47,6 +47,7 @@ tokio = { workspace = true, features = [
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -76,8 +76,9 @@ enum Subcommand {
|
||||
/// Generate shell completion scripts.
|
||||
Completion(CompletionCommand),
|
||||
|
||||
/// Internal debugging commands.
|
||||
Debug(DebugArgs),
|
||||
/// Run commands within a Codex-provided sandbox.
|
||||
#[clap(visible_alias = "debug")]
|
||||
Sandbox(SandboxArgs),
|
||||
|
||||
/// Apply the latest diff produced by Codex agent as a `git apply` to your local working tree.
|
||||
#[clap(visible_alias = "a")]
|
||||
@@ -121,18 +122,20 @@ struct ResumeCommand {
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct DebugArgs {
|
||||
struct SandboxArgs {
|
||||
#[command(subcommand)]
|
||||
cmd: DebugCommand,
|
||||
cmd: SandboxCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
enum DebugCommand {
|
||||
enum SandboxCommand {
|
||||
/// Run a command under Seatbelt (macOS only).
|
||||
Seatbelt(SeatbeltCommand),
|
||||
#[clap(visible_alias = "seatbelt")]
|
||||
Macos(SeatbeltCommand),
|
||||
|
||||
/// Run a command under Landlock+seccomp (Linux only).
|
||||
Landlock(LandlockCommand),
|
||||
#[clap(visible_alias = "landlock")]
|
||||
Linux(LandlockCommand),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -154,9 +157,7 @@ struct LoginCommand {
|
||||
)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// EXPERIMENTAL: Use device code flow (not yet supported)
|
||||
/// This feature is experimental and may changed in future releases.
|
||||
#[arg(long = "experimental_use-device-code", hide = true)]
|
||||
#[arg(long = "use-device-code")]
|
||||
use_device_code: bool,
|
||||
|
||||
/// EXPERIMENTAL: Use custom OAuth issuer base URL (advanced)
|
||||
@@ -291,7 +292,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
last,
|
||||
config_overrides,
|
||||
);
|
||||
codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
}
|
||||
Some(Subcommand::Login(mut login_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -341,8 +343,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
);
|
||||
codex_cloud_tasks::run_main(cloud_cli, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Debug(debug_args)) => match debug_args.cmd {
|
||||
DebugCommand::Seatbelt(mut seatbelt_cli) => {
|
||||
Some(Subcommand::Sandbox(sandbox_args)) => match sandbox_args.cmd {
|
||||
SandboxCommand::Macos(mut seatbelt_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut seatbelt_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -353,7 +355,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
DebugCommand::Landlock(mut landlock_cli) => {
|
||||
SandboxCommand::Linux(mut landlock_cli) => {
|
||||
prepend_config_flags(
|
||||
&mut landlock_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
@@ -472,6 +474,7 @@ fn print_completion(cmd: CompletionCommand) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
@@ -604,14 +607,14 @@ mod tests {
|
||||
assert_eq!(interactive.model.as_deref(), Some("gpt-5-test"));
|
||||
assert!(interactive.oss);
|
||||
assert_eq!(interactive.config_profile.as_deref(), Some("my-profile"));
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
interactive.sandbox_mode,
|
||||
Some(codex_common::SandboxModeCliArg::WorkspaceWrite)
|
||||
));
|
||||
assert!(matches!(
|
||||
);
|
||||
assert_matches!(
|
||||
interactive.approval_policy,
|
||||
Some(codex_common::ApprovalModeCliArg::OnRequest)
|
||||
));
|
||||
);
|
||||
assert!(interactive.full_auto);
|
||||
assert_eq!(
|
||||
interactive.cwd.as_deref(),
|
||||
|
||||
@@ -4,6 +4,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use clap::ArgGroup;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
@@ -77,13 +78,61 @@ pub struct AddArgs {
|
||||
/// Name for the MCP server configuration.
|
||||
pub name: String,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
#[arg(long, value_parser = parse_env_pair, value_name = "KEY=VALUE")]
|
||||
pub env: Vec<(String, String)>,
|
||||
#[command(flatten)]
|
||||
pub transport_args: AddMcpTransportArgs,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
#[command(
|
||||
group(
|
||||
ArgGroup::new("transport")
|
||||
.args(["command", "url"])
|
||||
.required(true)
|
||||
.multiple(false)
|
||||
)
|
||||
)]
|
||||
pub struct AddMcpTransportArgs {
|
||||
#[command(flatten)]
|
||||
pub stdio: Option<AddMcpStdioArgs>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub streamable_http: Option<AddMcpStreamableHttpArgs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStdioArgs {
|
||||
/// Command to launch the MCP server.
|
||||
#[arg(trailing_var_arg = true, num_args = 1..)]
|
||||
/// Use --url for a streamable HTTP server.
|
||||
#[arg(
|
||||
trailing_var_arg = true,
|
||||
num_args = 0..,
|
||||
)]
|
||||
pub command: Vec<String>,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
/// Only valid with stdio servers.
|
||||
#[arg(
|
||||
long,
|
||||
value_parser = parse_env_pair,
|
||||
value_name = "KEY=VALUE",
|
||||
)]
|
||||
pub env: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStreamableHttpArgs {
|
||||
/// URL for a streamable HTTP MCP server.
|
||||
#[arg(long)]
|
||||
pub url: String,
|
||||
|
||||
/// Optional environment variable to read for a bearer token.
|
||||
/// Only valid with streamable HTTP servers.
|
||||
#[arg(
|
||||
long = "bearer-token-env-var",
|
||||
value_name = "ENV_VAR",
|
||||
requires = "url"
|
||||
)]
|
||||
pub bearer_token_env_var: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -140,37 +189,51 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let AddArgs { name, env, command } = add_args;
|
||||
let AddArgs {
|
||||
name,
|
||||
transport_args,
|
||||
} = add_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let mut command_parts = command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in env {
|
||||
map.insert(key, value);
|
||||
}
|
||||
Some(map)
|
||||
};
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
let transport = match transport_args {
|
||||
AddMcpTransportArgs {
|
||||
stdio: Some(stdio), ..
|
||||
} => {
|
||||
let mut command_parts = stdio.command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if stdio.env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stdio.env.into_iter().collect::<HashMap<_, _>>())
|
||||
};
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
}
|
||||
}
|
||||
AddMcpTransportArgs {
|
||||
streamable_http: Some(streamable_http),
|
||||
..
|
||||
} => McpServerTransportConfig::StreamableHttp {
|
||||
url: streamable_http.url,
|
||||
bearer_token_env_var: streamable_http.bearer_token_env_var,
|
||||
},
|
||||
AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"),
|
||||
};
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
};
|
||||
@@ -236,7 +299,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url).await?;
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
@@ -259,7 +322,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url) {
|
||||
match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
@@ -288,11 +351,14 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
})
|
||||
}
|
||||
};
|
||||
@@ -345,13 +411,15 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
};
|
||||
stdio_rows.push([name.clone(), command.clone(), args_display, env_display]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
let has_bearer = if bearer_token.is_some() {
|
||||
"True"
|
||||
} else {
|
||||
"False"
|
||||
};
|
||||
http_rows.push([name.clone(), url.clone(), has_bearer.into()]);
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
http_rows.push([
|
||||
name.clone(),
|
||||
url.clone(),
|
||||
bearer_token_env_var.clone().unwrap_or("-".to_string()),
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -396,7 +464,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
if !http_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Url".len(), "Has Bearer Token".len()];
|
||||
let mut widths = ["Name".len(), "Url".len(), "Bearer Token Env Var".len()];
|
||||
for row in &http_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -407,7 +475,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
"Name",
|
||||
"Url",
|
||||
"Has Bearer Token",
|
||||
"Bearer Token Env Var",
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
@@ -447,10 +515,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
}),
|
||||
};
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
@@ -493,11 +564,14 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let bearer = bearer_token.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token: {bearer}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
}
|
||||
}
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
|
||||
@@ -93,3 +93,116 @@ async fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_without_manual_token() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args(["mcp", "add", "github", "--url", "https://example.com/mcp"])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let github = servers.get("github").expect("github server should exist");
|
||||
match &github.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!codex_home.path().join(".credentials.json").exists());
|
||||
assert!(!codex_home.path().join(".env").exists());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_with_custom_env_var() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"issues",
|
||||
"--url",
|
||||
"https://example.com/issues",
|
||||
"--bearer-token-env-var",
|
||||
"GITHUB_TOKEN",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let issues = servers.get("issues").expect("issues server should exist");
|
||||
match &issues.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/issues");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_rejects_removed_flag() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--with-bearer-token",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("--with-bearer-token"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_cant_add_command_and_url() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--command",
|
||||
"--",
|
||||
"echo",
|
||||
"hello",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("unexpected argument '--command' found"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-cloud-tasks"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "codex_cloud_tasks"
|
||||
@@ -11,26 +11,28 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
anyhow = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = [
|
||||
"mock",
|
||||
"online",
|
||||
] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
codex-cloud-tasks-client = { path = "../cloud-tasks-client", features = ["mock", "online"] }
|
||||
ratatui = { version = "0.29.0" }
|
||||
crossterm = { version = "0.28.1", features = ["event-stream"] }
|
||||
tokio-stream = "0.1.17"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
codex-login = { path = "../login" }
|
||||
codex-core = { path = "../core" }
|
||||
throbber-widgets-tui = "0.8.0"
|
||||
base64 = "0.22"
|
||||
serde_json = "1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
unicode-width = "0.1"
|
||||
codex-login = { path = "../login" }
|
||||
codex-tui = { path = "../tui" }
|
||||
crossterm = { workspace = true, features = ["event-stream"] }
|
||||
ratatui = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
tokio-stream = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
unicode-width = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
async-trait = "0.1"
|
||||
async-trait = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
// Environment filter data models for the TUI
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -42,15 +43,13 @@ use crate::scrollable_diff::ScrollableDiff;
|
||||
use codex_cloud_tasks_client::CloudBackend;
|
||||
use codex_cloud_tasks_client::TaskId;
|
||||
use codex_cloud_tasks_client::TaskSummary;
|
||||
use throbber_widgets_tui::ThrobberState;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct App {
|
||||
pub tasks: Vec<TaskSummary>,
|
||||
pub selected: usize,
|
||||
pub status: String,
|
||||
pub diff_overlay: Option<DiffOverlay>,
|
||||
pub throbber: ThrobberState,
|
||||
pub spinner_start: Option<Instant>,
|
||||
pub refresh_inflight: bool,
|
||||
pub details_inflight: bool,
|
||||
// Environment filter state
|
||||
@@ -82,7 +81,7 @@ impl App {
|
||||
selected: 0,
|
||||
status: "Press r to refresh".to_string(),
|
||||
diff_overlay: None,
|
||||
throbber: ThrobberState::default(),
|
||||
spinner_start: None,
|
||||
refresh_inflight: false,
|
||||
details_inflight: false,
|
||||
env_filter: None,
|
||||
|
||||
@@ -400,16 +400,20 @@ pub async fn run_main(_cli: Cli, _codex_linux_sandbox_exe: Option<PathBuf>) -> a
|
||||
let _ = frame_tx.send(Instant::now() + codex_tui::ComposerInput::recommended_flush_delay());
|
||||
}
|
||||
}
|
||||
// Advance throbber only while loading.
|
||||
// Keep spinner pulsing only while loading.
|
||||
if app.refresh_inflight
|
||||
|| app.details_inflight
|
||||
|| app.env_loading
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
app.throbber.calc_next();
|
||||
if app.spinner_start.is_none() {
|
||||
app.spinner_start = Some(Instant::now());
|
||||
}
|
||||
needs_redraw = true;
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(100));
|
||||
let _ = frame_tx.send(Instant::now() + Duration::from_millis(600));
|
||||
} else {
|
||||
app.spinner_start = None;
|
||||
}
|
||||
render_if_needed(&mut terminal, &mut app, &mut needs_redraw)?;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use ratatui::widgets::ListState;
|
||||
use ratatui::widgets::Padding;
|
||||
use ratatui::widgets::Paragraph;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::app::AttemptView;
|
||||
@@ -229,7 +230,7 @@ fn draw_list(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|
||||
// In-box spinner during initial/refresh loads
|
||||
if app.refresh_inflight {
|
||||
draw_centered_spinner(frame, inner, &mut app.throbber, "Loading tasks…");
|
||||
draw_centered_spinner(frame, inner, &mut app.spinner_start, "Loading tasks…");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +292,7 @@ fn draw_footer(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
|| app.apply_preflight_inflight
|
||||
|| app.apply_inflight
|
||||
{
|
||||
draw_inline_spinner(frame, top[1], &mut app.throbber, "Loading…");
|
||||
draw_inline_spinner(frame, top[1], &mut app.spinner_start, "Loading…");
|
||||
} else {
|
||||
frame.render_widget(Clear, top[1]);
|
||||
}
|
||||
@@ -449,7 +450,12 @@ fn draw_diff_overlay(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
.map(|o| o.sd.wrapped_lines().is_empty())
|
||||
.unwrap_or(true);
|
||||
if app.details_inflight && raw_empty {
|
||||
draw_centered_spinner(frame, content_area, &mut app.throbber, "Loading details…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content_area,
|
||||
&mut app.spinner_start,
|
||||
"Loading details…",
|
||||
);
|
||||
} else {
|
||||
let scroll = app
|
||||
.diff_overlay
|
||||
@@ -494,11 +500,11 @@ pub fn draw_apply_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
frame.render_widget(header, rows[0]);
|
||||
// Body: spinner while preflight/apply runs; otherwise show result message and path lists
|
||||
if app.apply_preflight_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Checking…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Checking…");
|
||||
} else if app.apply_inflight {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Applying…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Applying…");
|
||||
} else if m.result_message.is_none() {
|
||||
draw_centered_spinner(frame, rows[1], &mut app.throbber, "Loading…");
|
||||
draw_centered_spinner(frame, rows[1], &mut app.spinner_start, "Loading…");
|
||||
} else if let Some(msg) = &m.result_message {
|
||||
let mut body_lines: Vec<Line> = Vec::new();
|
||||
let first = match m.result_level {
|
||||
@@ -859,29 +865,29 @@ fn format_relative_time(ts: chrono::DateTime<Utc>) -> String {
|
||||
fn draw_inline_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
use ratatui::style::Style;
|
||||
use throbber_widgets_tui::BRAILLE_EIGHT;
|
||||
use throbber_widgets_tui::Throbber;
|
||||
use throbber_widgets_tui::WhichUse;
|
||||
let w = Throbber::default()
|
||||
.label(label)
|
||||
.style(Style::default().cyan())
|
||||
.throbber_style(Style::default().magenta().bold())
|
||||
.throbber_set(BRAILLE_EIGHT)
|
||||
.use_type(WhichUse::Spin);
|
||||
frame.render_stateful_widget(w, area, state);
|
||||
use ratatui::widgets::Paragraph;
|
||||
let start = spinner_start.get_or_insert_with(Instant::now);
|
||||
let blink_on = (start.elapsed().as_millis() / 600).is_multiple_of(2);
|
||||
let dot = if blink_on {
|
||||
"• ".into()
|
||||
} else {
|
||||
"◦ ".dim()
|
||||
};
|
||||
let label = label.cyan();
|
||||
let line = Line::from(vec![dot, label]);
|
||||
frame.render_widget(Paragraph::new(line), area);
|
||||
}
|
||||
|
||||
fn draw_centered_spinner(
|
||||
frame: &mut Frame,
|
||||
area: Rect,
|
||||
state: &mut throbber_widgets_tui::ThrobberState,
|
||||
spinner_start: &mut Option<Instant>,
|
||||
label: &str,
|
||||
) {
|
||||
// Center a 1xN throbber within the given rect
|
||||
// Center a 1xN spinner within the given rect
|
||||
let rows = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
@@ -898,7 +904,7 @@ fn draw_centered_spinner(
|
||||
Constraint::Percentage(50),
|
||||
])
|
||||
.split(rows[1]);
|
||||
draw_inline_spinner(frame, cols[1], state, label);
|
||||
draw_inline_spinner(frame, cols[1], spinner_start, label);
|
||||
}
|
||||
|
||||
// Styling helpers for diff rendering live inline where used.
|
||||
@@ -918,7 +924,12 @@ pub fn draw_env_modal(frame: &mut Frame, area: Rect, app: &mut App) {
|
||||
let content = overlay_content(inner);
|
||||
|
||||
if app.env_loading {
|
||||
draw_centered_spinner(frame, content, &mut app.throbber, "Loading environments…");
|
||||
draw_centered_spinner(
|
||||
frame,
|
||||
content,
|
||||
&mut app.spinner_start,
|
||||
"Loading environments…",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
toml = { workspace = true }
|
||||
toml_edit = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
@@ -89,6 +89,7 @@ openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
maplit = { workspace = true }
|
||||
|
||||
@@ -12,7 +12,7 @@ Expects `/usr/bin/sandbox-exec` to be present.
|
||||
|
||||
### Linux
|
||||
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex debug landlock` when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex sandbox linux` (legacy alias: `codex debug landlock`) when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
|
||||
### All Platforms
|
||||
|
||||
|
||||
@@ -10,12 +10,14 @@ You are Codex, based on GPT-5. You are running as a coding agent in the Codex CL
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -63,7 +64,6 @@ struct ErrorResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
|
||||
@@ -228,7 +228,7 @@ impl ModelClient {
|
||||
input: &input_with_instructions,
|
||||
tools: &tools_json,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: prompt.parallel_tool_calls,
|
||||
reasoning,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
@@ -656,7 +656,7 @@ async fn process_sse<S>(
|
||||
{
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
error!("SSE Error: {e:#}");
|
||||
let event = CodexErr::Stream(e.to_string(), None);
|
||||
let _ = tx_event.send(Err(event)).await;
|
||||
return;
|
||||
@@ -717,7 +717,7 @@ async fn process_sse<S>(
|
||||
let event: SseEvent = match serde_json::from_str(&sse.data) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
error!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -744,7 +744,7 @@ async fn process_sse<S>(
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
error!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -794,14 +794,16 @@ async fn process_sse<S>(
|
||||
if let Some(error) = error {
|
||||
match serde_json::from_value::<Error>(error.clone()) {
|
||||
Ok(error) => {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(CodexErr::ContextWindowExceeded);
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(CodexErr::Stream(message, delay));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ErrorResponse: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(CodexErr::Stream(error, None))
|
||||
error!("failed to parse ErrorResponse: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -815,9 +817,7 @@ async fn process_sse<S>(
|
||||
response_completed = Some(r);
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(CodexErr::Stream(error, None));
|
||||
error!("failed to parse ResponseCompleted: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -922,9 +922,14 @@ fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
@@ -1179,6 +1184,74 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_with_newline_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: Some("https://test.com".to_string()),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let otel_event_manager = otel_event_manager();
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(err @ CodexErr::ContextWindowExceeded) => {
|
||||
assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string());
|
||||
}
|
||||
other => panic!("unexpected context window event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ────────────────────────────
|
||||
// Table-driven test from `main`
|
||||
// ────────────────────────────
|
||||
@@ -1316,10 +1389,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro)));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
@@ -1334,7 +1404,7 @@ mod tests {
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip");
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
|
||||
@@ -9,9 +9,11 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
@@ -31,6 +33,9 @@ pub struct Prompt {
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
|
||||
@@ -64,10 +69,125 @@ impl Prompt {
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
self.input.clone()
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
@@ -182,6 +302,17 @@ pub(crate) mod tools {
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
@@ -327,7 +458,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -368,7 +499,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -404,7 +535,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
|
||||
@@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
@@ -100,7 +102,9 @@ use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::format_exec_output_str;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -360,6 +364,7 @@ impl Session {
|
||||
let mcp_fut = McpConnectionManager::new(
|
||||
config.mcp_servers.clone(),
|
||||
config.use_experimental_use_rmcp_client,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
@@ -782,6 +787,17 @@ impl Session {
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) {
|
||||
let context_window = turn_context.client.get_model_context_window();
|
||||
if let Some(context_window) = context_window {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_token_usage_full(context_window);
|
||||
}
|
||||
self.send_token_count_event(sub_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a user input item to conversation history and also persist a
|
||||
/// corresponding UserMessage EventMsg to rollout.
|
||||
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
|
||||
@@ -807,7 +823,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
let ExecCommandContext {
|
||||
@@ -823,7 +839,10 @@ impl Session {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => {
|
||||
turn_diff_tracker.on_patch_begin(&changes);
|
||||
{
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.on_patch_begin(&changes);
|
||||
}
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
@@ -850,7 +869,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
output: &ExecToolCallOutput,
|
||||
@@ -898,7 +917,10 @@ impl Session {
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -915,7 +937,7 @@ impl Session {
|
||||
/// Returns the output of the exec tool call.
|
||||
pub(crate) async fn run_exec_with_events(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prepared: PreparedExec,
|
||||
approval_policy: AskForApproval,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
@@ -924,7 +946,7 @@ impl Session {
|
||||
let sub_id = context.sub_id.clone();
|
||||
let call_id = context.call_id.clone();
|
||||
|
||||
self.on_exec_command_begin(turn_diff_tracker, context.clone())
|
||||
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||
.await;
|
||||
|
||||
let result = self
|
||||
@@ -1565,7 +1587,7 @@ async fn spawn_review_thread(
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
let input: Vec<InputItem> = vec![InputItem::Text {
|
||||
text: format!("{base_instructions}\n\n---\n\nNow, here's your task: {review_prompt}"),
|
||||
text: review_prompt,
|
||||
}];
|
||||
let tc = Arc::new(review_turn_context);
|
||||
|
||||
@@ -1633,7 +1655,7 @@ pub(crate) async fn run_task(
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let mut auto_compact_recently_attempted = false;
|
||||
|
||||
loop {
|
||||
@@ -1681,9 +1703,9 @@ pub(crate) async fn run_task(
|
||||
})
|
||||
.collect();
|
||||
match run_turn(
|
||||
&sess,
|
||||
turn_context.as_ref(),
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
)
|
||||
@@ -1906,18 +1928,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(mcp_tools),
|
||||
));
|
||||
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
tools: router.specs().to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
@@ -1925,10 +1956,10 @@ async fn run_turn(
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(
|
||||
&router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&sub_id,
|
||||
&prompt,
|
||||
)
|
||||
@@ -1938,6 +1969,10 @@ async fn run_turn(
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
@@ -1984,9 +2019,9 @@ async fn run_turn(
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
pub(crate) struct ProcessedResponseItem {
|
||||
pub(crate) item: ResponseItem,
|
||||
pub(crate) response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -1996,10 +2031,10 @@ struct TurnRunResult {
|
||||
}
|
||||
|
||||
async fn try_run_turn(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
@@ -2069,44 +2104,102 @@ async fn try_run_turn(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.to_string(),
|
||||
);
|
||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||
FuturesOrdered::new();
|
||||
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
// Treat as a disconnected stream so the caller can retry.
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
None => {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
}
|
||||
let add_completed = &mut |response_item: ProcessedResponseItem| {
|
||||
output.push_back(future::ready(Ok(response_item)).boxed());
|
||||
};
|
||||
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(
|
||||
router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
|
||||
let response = tool_runtime.handle_tool_call(call);
|
||||
|
||||
output.push_back(
|
||||
async move {
|
||||
Ok(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response.await?),
|
||||
})
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
Ok(None) => {
|
||||
let response = handle_non_tool_response_item(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
add_completed(ProcessedResponseItem { item, response });
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => {
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
let _ = sess
|
||||
@@ -2126,10 +2219,15 @@ async fn try_run_turn(
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -2140,7 +2238,7 @@ async fn try_run_turn(
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
processed_items: output,
|
||||
processed_items,
|
||||
total_token_usage: token_usage.clone(),
|
||||
};
|
||||
|
||||
@@ -2188,88 +2286,40 @@ async fn try_run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
async fn handle_non_tool_response_item(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
|
||||
match ToolRouter::build_tool_call(sess, item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
match router
|
||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(Some(response)),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
&item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg,
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
@@ -2505,13 +2555,19 @@ mod tests {
|
||||
|
||||
let out = format_exec_output_str(&exec);
|
||||
|
||||
// Strip truncation header if present for subsequent assertions
|
||||
let body = out
|
||||
.strip_prefix("Total output lines: ")
|
||||
.and_then(|rest| rest.split_once("\n\n").map(|x| x.1))
|
||||
.unwrap_or(out.as_str());
|
||||
|
||||
// Expect elision marker with correct counts
|
||||
let omitted = 400 - MODEL_FORMAT_MAX_LINES; // 144
|
||||
let marker = format!("\n[... omitted {omitted} of 400 lines ...]\n\n");
|
||||
assert!(out.contains(&marker), "missing marker: {out}");
|
||||
|
||||
// Validate head and tail
|
||||
let parts: Vec<&str> = out.split(&marker).collect();
|
||||
let parts: Vec<&str> = body.split(&marker).collect();
|
||||
assert_eq!(parts.len(), 2, "expected one marker split");
|
||||
let head = parts[0];
|
||||
let tail = parts[1];
|
||||
@@ -2547,14 +2603,19 @@ mod tests {
|
||||
};
|
||||
|
||||
let out = format_exec_output_str(&exec);
|
||||
assert!(out.len() <= MODEL_FORMAT_MAX_BYTES, "exceeds byte budget");
|
||||
// Keep strict budget on the truncated body (excluding header)
|
||||
let body = out
|
||||
.strip_prefix("Total output lines: ")
|
||||
.and_then(|rest| rest.split_once("\n\n").map(|x| x.1))
|
||||
.unwrap_or(out.as_str());
|
||||
assert!(body.len() <= MODEL_FORMAT_MAX_BYTES, "exceeds byte budget");
|
||||
assert!(out.contains("omitted"), "should contain elision marker");
|
||||
|
||||
// Ensure head and tail are drawn from the original
|
||||
assert!(full.starts_with(out.chars().take(8).collect::<String>().as_str()));
|
||||
assert!(full.starts_with(body.chars().take(8).collect::<String>().as_str()));
|
||||
assert!(
|
||||
full.ends_with(
|
||||
out.chars()
|
||||
body.chars()
|
||||
.rev()
|
||||
.take(8)
|
||||
.collect::<String>()
|
||||
@@ -2901,13 +2962,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn fatal_tool_error_stops_turn_and_reports_error() {
|
||||
let (session, turn_context, _rx) = make_session_and_context_with_rx();
|
||||
let session_ref = session.as_ref();
|
||||
let turn_context_ref = turn_context.as_ref();
|
||||
let router = ToolRouter::from_config(
|
||||
&turn_context_ref.tools_config,
|
||||
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
|
||||
&turn_context.tools_config,
|
||||
Some(session.services.mcp_connection_manager.list_all_tools()),
|
||||
);
|
||||
let mut tracker = TurnDiffTracker::new();
|
||||
let item = ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
@@ -2916,22 +2974,26 @@ mod tests {
|
||||
input: "{}".to_string(),
|
||||
};
|
||||
|
||||
let err = handle_response_item(
|
||||
&router,
|
||||
session_ref,
|
||||
turn_context_ref,
|
||||
&mut tracker,
|
||||
"sub-id",
|
||||
item,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
|
||||
.expect("build tool call")
|
||||
.expect("tool call present");
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let err = router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
tracker,
|
||||
"sub-id".to_string(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
|
||||
match err {
|
||||
CodexErr::Fatal(message) => {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_eq!(message, "tool shell invoked with incompatible payload");
|
||||
}
|
||||
other => panic!("expected CodexErr::Fatal, got {other:?}"),
|
||||
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3045,9 +3107,11 @@ mod tests {
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let (session, mut turn_context) = make_session_and_context();
|
||||
let (session, mut turn_context_raw) = make_session_and_context();
|
||||
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
||||
turn_context.approval_policy = AskForApproval::OnFailure;
|
||||
turn_context_raw.approval_policy = AskForApproval::OnFailure;
|
||||
let session = Arc::new(session);
|
||||
let mut turn_context = Arc::new(turn_context_raw);
|
||||
|
||||
let params = ExecParams {
|
||||
command: if cfg!(windows) {
|
||||
@@ -3075,7 +3139,7 @@ mod tests {
|
||||
..params.clone()
|
||||
};
|
||||
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let tool_name = "shell";
|
||||
let sub_id = "test-sub".to_string();
|
||||
@@ -3084,9 +3148,9 @@ mod tests {
|
||||
let resp = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id,
|
||||
call_id,
|
||||
)
|
||||
@@ -3105,14 +3169,16 @@ mod tests {
|
||||
|
||||
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
||||
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
||||
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
Arc::get_mut(&mut turn_context)
|
||||
.expect("unique turn context Arc")
|
||||
.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
let resp2 = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params2,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
"test-sub".to_string(),
|
||||
"test-call-2".to_string(),
|
||||
)
|
||||
|
||||
@@ -103,6 +103,18 @@ async fn run_compact_task_inner(
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
|
||||
@@ -33,12 +33,15 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tempfile::NamedTempFile;
|
||||
use toml::Value as TomlValue;
|
||||
use toml_edit::Array as TomlArray;
|
||||
@@ -46,7 +49,10 @@ use toml_edit::DocumentMut;
|
||||
use toml_edit::Item as TomlItem;
|
||||
use toml_edit::Table as TomlTable;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
#[cfg(target_os = "windows")]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5-codex";
|
||||
const OPENAI_DEFAULT_REVIEW_MODEL: &str = "gpt-5-codex";
|
||||
pub const GPT_5_CODEX_MEDIUM_MODEL: &str = "gpt-5-codex";
|
||||
|
||||
@@ -139,6 +145,15 @@ pub struct Config {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred store for MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
/// auto (default): keyring if available, otherwise file.
|
||||
pub mcp_oauth_credentials_store_mode: OAuthCredentialsStoreMode,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
|
||||
@@ -206,6 +221,9 @@ pub struct Config {
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: bool,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -295,12 +313,35 @@ pub async fn load_global_mcp_servers(
|
||||
return Ok(BTreeMap::new());
|
||||
};
|
||||
|
||||
ensure_no_inline_bearer_tokens(servers_value)?;
|
||||
|
||||
servers_value
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
|
||||
/// We briefly allowed plain text bearer_token fields in MCP server configs.
|
||||
/// We want to warn people who recently added these fields but can remove this after a few months.
|
||||
fn ensure_no_inline_bearer_tokens(value: &TomlValue) -> std::io::Result<()> {
|
||||
let Some(servers_table) = value.as_table() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for (server_name, server_value) in servers_table {
|
||||
if let Some(server_table) = server_value.as_table()
|
||||
&& server_table.contains_key("bearer_token")
|
||||
{
|
||||
let message = format!(
|
||||
"mcp_servers.{server_name} uses unsupported `bearer_token`; set `bearer_token_env_var`."
|
||||
);
|
||||
return Err(std::io::Error::new(ErrorKind::InvalidData, message));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_global_mcp_servers(
|
||||
codex_home: &Path,
|
||||
servers: &BTreeMap<String, McpServerConfig>,
|
||||
@@ -349,10 +390,13 @@ pub fn write_global_mcp_servers(
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
entry["url"] = toml_edit::value(url.clone());
|
||||
if let Some(token) = bearer_token {
|
||||
entry["bearer_token"] = toml_edit::value(token.clone());
|
||||
if let Some(env_var) = bearer_token_env_var {
|
||||
entry["bearer_token_env_var"] = toml_edit::value(env_var.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -468,6 +512,29 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist the acknowledgement flag for the Windows onboarding screen.
|
||||
pub fn set_windows_wsl_setup_acknowledged(
|
||||
codex_home: &Path,
|
||||
acknowledged: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
doc["windows_wsl_setup_acknowledged"] = toml_edit::value(acknowledged);
|
||||
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
std::fs::write(tmp_file.path(), doc.to_string())?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
@@ -665,6 +732,14 @@ pub struct ConfigToml {
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred backend for storing MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: Use a file in the Codex home directory.
|
||||
/// auto (default): Use the OS-specific keyring service if available, otherwise use a file.
|
||||
#[serde(default)]
|
||||
pub mcp_oauth_credentials_store: Option<OAuthCredentialsStoreMode>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
@@ -721,6 +796,7 @@ pub struct ConfigToml {
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
@@ -734,6 +810,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// OTEL configuration.
|
||||
pub otel: Option<crate::config_types::OtelConfigToml>,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ConfigToml> for UserSavedConfig {
|
||||
@@ -1041,6 +1120,9 @@ impl Config {
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
// The config.toml omits "_mode" because it's a config file. However, "_mode"
|
||||
// is important in code to differentiate the mode from the store implementation.
|
||||
mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(),
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
project_doc_fallback_filenames: cfg
|
||||
@@ -1079,7 +1161,9 @@ impl Config {
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool
|
||||
.or(cfg.experimental_use_freeform_apply_patch)
|
||||
.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
@@ -1090,6 +1174,7 @@ impl Config {
|
||||
use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
windows_wsl_setup_acknowledged: cfg.windows_wsl_setup_acknowledged.unwrap_or(false),
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
.tui
|
||||
@@ -1328,6 +1413,85 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_defaults_to_auto_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml::default();
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_honors_explicit_file_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml {
|
||||
mcp_oauth_credentials_store: Some(OAuthCredentialsStoreMode::File),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn managed_config_overrides_oauth_store_mode() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(&config_path, "mcp_oauth_credentials_store = \"file\"\n")?;
|
||||
std::fs::write(&managed_path, "mcp_oauth_credentials_store = \"keyring\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path.clone()),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(codex_home.path(), Vec::new(), overrides).await?;
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
assert_eq!(
|
||||
cfg.mcp_oauth_credentials_store,
|
||||
Some(OAuthCredentialsStoreMode::Keyring),
|
||||
);
|
||||
|
||||
let final_config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
assert_eq!(
|
||||
final_config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Keyring,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1435,6 +1599,31 @@ startup_timeout_ms = 2500
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_rejects_inline_bearer_token() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let err = load_global_mcp_servers(codex_home.path())
|
||||
.await
|
||||
.expect_err("bearer_token entries should be rejected");
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(err.to_string().contains("bearer_token"));
|
||||
assert!(err.to_string().contains("bearer_token_env_var"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1498,7 +1687,7 @@ ZIG_VAR = "3"
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret-token".to_string()),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
@@ -1513,7 +1702,7 @@ ZIG_VAR = "3"
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret-token"
|
||||
bearer_token_env_var = "MCP_TOKEN"
|
||||
startup_timeout_sec = 2.0
|
||||
"#
|
||||
);
|
||||
@@ -1521,9 +1710,12 @@ startup_timeout_sec = 2.0
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert_eq!(bearer_token.as_deref(), Some("secret-token"));
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1534,7 +1726,7 @@ startup_timeout_sec = 2.0
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
@@ -1553,9 +1745,12 @@ url = "https://example.com/mcp"
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token.is_none());
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1860,6 +2055,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1882,6 +2078,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1921,6 +2118,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1943,6 +2141,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -1997,6 +2196,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2019,6 +2219,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2059,6 +2260,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2081,6 +2283,7 @@ model_verbosity = "high"
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2191,6 +2394,7 @@ trust_level = "trusted"
|
||||
#[cfg(test)]
|
||||
mod notifications_tests {
|
||||
use crate::config_types::Notifications;
|
||||
use assert_matches::assert_matches;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
@@ -2210,10 +2414,7 @@ mod notifications_tests {
|
||||
notifications = true
|
||||
"#;
|
||||
let parsed: RootTomlTest = toml::from_str(toml).expect("deserialize notifications=true");
|
||||
assert!(matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Enabled(true)
|
||||
));
|
||||
assert_matches!(parsed.tui.notifications, Notifications::Enabled(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2224,9 +2425,9 @@ mod notifications_tests {
|
||||
"#;
|
||||
let parsed: RootTomlTest =
|
||||
toml::from_str(toml).expect("deserialize notifications=[\"foo\"]");
|
||||
assert!(matches!(
|
||||
assert_matches!(
|
||||
parsed.tui.notifications,
|
||||
Notifications::Custom(ref v) if v == &vec!["foo".to_string()]
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
@@ -86,11 +87,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
args,
|
||||
env,
|
||||
url,
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?;
|
||||
throw_if_set(
|
||||
"stdio",
|
||||
"bearer_token_env_var",
|
||||
bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
@@ -100,6 +105,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
RawMcpServerConfig {
|
||||
url: Some(url),
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
@@ -108,7 +114,11 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token }
|
||||
throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
}
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
};
|
||||
@@ -135,11 +145,11 @@ pub enum McpServerTransportConfig {
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// A plain text bearer token to use for authentication.
|
||||
/// This bearer token will be included in the HTTP request header as an `Authorization: Bearer <token>` header.
|
||||
/// This should be used with caution because it lives on disk in clear text.
|
||||
/// Name of the environment variable to read for an HTTP bearer token.
|
||||
/// When set, requests will include the token via `Authorization: Bearer <token>`.
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -506,17 +516,17 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None
|
||||
bearer_token_env_var: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_bearer_token() {
|
||||
fn deserialize_streamable_http_server_config_with_env_var() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
bearer_token_env_var = "GITHUB_TOKEN"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
@@ -525,7 +535,7 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret".to_string())
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string())
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -553,13 +563,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_bearer_token_for_stdio_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
url = "https://example.com"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer token for stdio transport");
|
||||
.expect_err("should reject bearer_token field");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("bearer_token is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +210,7 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -236,7 +237,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn drops_from_last_user_only() {
|
||||
let items = vec![
|
||||
let items = [
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
@@ -283,7 +284,7 @@ mod tests {
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
assert_matches!(truncated2, InitialHistory::New);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::sync::OnceLock;
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
@@ -35,10 +35,11 @@ pub enum SetOriginatorError {
|
||||
AlreadyInitialized,
|
||||
}
|
||||
|
||||
fn init_originator_from_env() -> Originator {
|
||||
let default = "codex_cli_rs";
|
||||
fn get_originator_value(provided: Option<String>) -> Originator {
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.unwrap_or_else(|_| default.to_string());
|
||||
.ok()
|
||||
.or(provided)
|
||||
.unwrap_or(DEFAULT_ORIGINATOR.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
@@ -48,31 +49,22 @@ fn init_originator_from_env() -> Originator {
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: default.to_string(),
|
||||
header_value: HeaderValue::from_static(default),
|
||||
value: DEFAULT_ORIGINATOR.to_string(),
|
||||
header_value: HeaderValue::from_static(DEFAULT_ORIGINATOR),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_originator(value: String) -> Result<Originator, SetOriginatorError> {
|
||||
let header_value =
|
||||
HeaderValue::from_str(&value).map_err(|_| SetOriginatorError::InvalidHeaderValue)?;
|
||||
Ok(Originator {
|
||||
value,
|
||||
header_value,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_default_originator(value: &str) -> Result<(), SetOriginatorError> {
|
||||
let originator = build_originator(value.to_string())?;
|
||||
pub fn set_default_originator(value: String) -> Result<(), SetOriginatorError> {
|
||||
let originator = get_originator_value(Some(value));
|
||||
ORIGINATOR
|
||||
.set(originator)
|
||||
.map_err(|_| SetOriginatorError::AlreadyInitialized)
|
||||
}
|
||||
|
||||
pub fn originator() -> &'static Originator {
|
||||
ORIGINATOR.get_or_init(init_originator_from_env)
|
||||
ORIGINATOR.get_or_init(|| get_originator_value(None))
|
||||
}
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
|
||||
@@ -55,6 +55,11 @@ pub enum CodexErr {
|
||||
#[error("stream disconnected before completion: {0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error(
|
||||
"Codex ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying."
|
||||
)]
|
||||
ContextWindowExceeded,
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -158,7 +159,7 @@ mod tests {
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_matches!(user.kind, Some(InputMessageKind::Plain));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -16,6 +17,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -108,9 +110,6 @@ impl McpClientAdapter {
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
info!(
|
||||
"new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}"
|
||||
);
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
@@ -128,9 +127,11 @@ impl McpClientAdapter {
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
@@ -185,6 +186,7 @@ impl McpConnectionManager {
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
@@ -205,20 +207,17 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp { .. }
|
||||
) && !use_rmcp_client
|
||||
{
|
||||
info!(
|
||||
"skipping MCP server `{server_name}` because the legacy MCP client only supports stdio servers",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let resolved_bearer_token = match &cfg.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
|
||||
_ => Ok(None),
|
||||
};
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
@@ -256,13 +255,14 @@ impl McpConnectionManager {
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
bearer_token,
|
||||
resolved_bearer_token.unwrap_or_default(),
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -350,6 +350,33 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bearer_token(
|
||||
server_name: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let Some(env_var) = bearer_token_env_var else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match env::var(env_var) {
|
||||
Ok(value) => {
|
||||
if value.is_empty() {
|
||||
Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is empty"
|
||||
))
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
Err(env::VarError::NotPresent) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is not set"
|
||||
)),
|
||||
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
|
||||
@@ -35,6 +35,10 @@ pub struct ModelFamily {
|
||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||
pub uses_local_shell_tool: bool,
|
||||
|
||||
/// Whether this model supports parallel tool calls when using the
|
||||
/// Responses API.
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
@@ -58,6 +62,7 @@ macro_rules! model_family {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
@@ -72,7 +77,11 @@ macro_rules! model_family {
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
pub fn find_family_for_model(mut slug: &str) -> Option<ModelFamily> {
|
||||
// TODO(jif) clean once we have proper feature flags
|
||||
if matches!(std::env::var("CODEX_EXPERIMENTAL").as_deref(), Ok("1")) {
|
||||
slug = "codex-experimental";
|
||||
}
|
||||
if slug.starts_with("o3") {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
@@ -103,13 +112,40 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||
} else if slug.starts_with("test-gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec!["read_file".to_string()],
|
||||
experimental_supported_tools: vec![
|
||||
"read_file".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"test_sync_tool".to_string()
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Internal models.
|
||||
} else if slug.starts_with("codex-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
experimental_supported_tools: vec!["read_file".to_string(), "list_dir".to_string()],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
|
||||
// Production models.
|
||||
} else if slug.starts_with("gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
@@ -130,6 +166,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
|
||||
@@ -64,5 +64,14 @@ impl SessionState {
|
||||
(self.token_info.clone(), self.latest_rate_limits.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: u64) {
|
||||
match &mut self.token_info {
|
||||
Some(info) => info.fill_to_context_window(context_window),
|
||||
None => {
|
||||
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pending input/approval moved to TurnState.
|
||||
}
|
||||
|
||||
@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct ToolInvocation<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub tracker: &'a mut TurnDiffTracker,
|
||||
pub sub_id: &'a str,
|
||||
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -106,7 +104,7 @@ pub enum ApplyPatchToolType {
|
||||
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
|
||||
ToolSpec::Freeform(FreeformTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files. This is a FREEFORM tool, so do not wrap the patch in JSON.".to_string(),
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
|
||||
@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
tool_name,
|
||||
|
||||
476
codex-rs/core/src/tools/handlers/list_dir.rs
Normal file
476
codex-rs/core/src/tools/handlers/list_dir.rs
Normal file
@@ -0,0 +1,476 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::FileType;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ListDirHandler;
|
||||
|
||||
const MAX_ENTRY_LENGTH: usize = 500;
|
||||
const INDENTATION_SPACES: usize = 2;
|
||||
|
||||
fn default_offset() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
25
|
||||
}
|
||||
|
||||
fn default_depth() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListDirArgs {
|
||||
dir_path: String,
|
||||
#[serde(default = "default_offset")]
|
||||
offset: usize,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
#[serde(default = "default_depth")]
|
||||
depth: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ListDirHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"list_dir handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ListDirArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let ListDirArgs {
|
||||
dir_path,
|
||||
offset,
|
||||
limit,
|
||||
depth,
|
||||
} = args;
|
||||
|
||||
if offset == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset must be a 1-indexed entry number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if depth == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"depth must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(&dir_path);
|
||||
if !path.is_absolute() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"dir_path must be an absolute path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(&path, offset, limit, depth).await?;
|
||||
let mut output = Vec::with_capacity(entries.len() + 1);
|
||||
output.push(format!("Absolute path: {}", path.display()));
|
||||
output.extend(entries);
|
||||
Ok(ToolOutput::Function {
|
||||
content: output.join("\n"),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_dir_slice(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
depth: usize,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut entries = Vec::new();
|
||||
collect_entries(path, Path::new(""), depth, &mut entries).await?;
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let start_index = offset - 1;
|
||||
if start_index >= entries.len() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset exceeds directory entry count".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let remaining_entries = entries.len() - start_index;
|
||||
let capped_limit = limit.min(remaining_entries);
|
||||
let end_index = start_index + capped_limit;
|
||||
let mut selected_entries = entries[start_index..end_index].to_vec();
|
||||
selected_entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
|
||||
let mut formatted = Vec::with_capacity(selected_entries.len());
|
||||
|
||||
for entry in &selected_entries {
|
||||
formatted.push(format_entry_line(entry));
|
||||
}
|
||||
|
||||
if end_index < entries.len() {
|
||||
formatted.push(format!("More than {capped_limit} entries found"));
|
||||
}
|
||||
|
||||
Ok(formatted)
|
||||
}
|
||||
|
||||
async fn collect_entries(
|
||||
dir_path: &Path,
|
||||
relative_prefix: &Path,
|
||||
depth: usize,
|
||||
entries: &mut Vec<DirEntry>,
|
||||
) -> Result<(), FunctionCallError> {
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back((dir_path.to_path_buf(), relative_prefix.to_path_buf(), depth));
|
||||
|
||||
while let Some((current_dir, prefix, remaining_depth)) = queue.pop_front() {
|
||||
let mut read_dir = fs::read_dir(¤t_dir).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})?;
|
||||
|
||||
let mut dir_entries = Vec::new();
|
||||
|
||||
while let Some(entry) = read_dir.next_entry().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})? {
|
||||
let file_type = entry.file_type().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to inspect entry: {err}"))
|
||||
})?;
|
||||
|
||||
let file_name = entry.file_name();
|
||||
let relative_path = if prefix.as_os_str().is_empty() {
|
||||
PathBuf::from(&file_name)
|
||||
} else {
|
||||
prefix.join(&file_name)
|
||||
};
|
||||
|
||||
let display_name = format_entry_component(&file_name);
|
||||
let display_depth = prefix.components().count();
|
||||
let sort_key = format_entry_name(&relative_path);
|
||||
let kind = DirEntryKind::from(&file_type);
|
||||
dir_entries.push((
|
||||
entry.path(),
|
||||
relative_path,
|
||||
kind,
|
||||
DirEntry {
|
||||
name: sort_key,
|
||||
display_name,
|
||||
depth: display_depth,
|
||||
kind,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
dir_entries.sort_unstable_by(|a, b| a.3.name.cmp(&b.3.name));
|
||||
|
||||
for (entry_path, relative_path, kind, dir_entry) in dir_entries {
|
||||
if kind == DirEntryKind::Directory && remaining_depth > 1 {
|
||||
queue.push_back((entry_path, relative_path, remaining_depth - 1));
|
||||
}
|
||||
entries.push(dir_entry);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_entry_name(path: &Path) -> String {
|
||||
let normalized = path.to_string_lossy().replace("\\", "/");
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_component(name: &OsStr) -> String {
|
||||
let normalized = name.to_string_lossy();
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_line(entry: &DirEntry) -> String {
|
||||
let indent = " ".repeat(entry.depth * INDENTATION_SPACES);
|
||||
let mut name = entry.display_name.clone();
|
||||
match entry.kind {
|
||||
DirEntryKind::Directory => name.push('/'),
|
||||
DirEntryKind::Symlink => name.push('@'),
|
||||
DirEntryKind::Other => name.push('?'),
|
||||
DirEntryKind::File => {}
|
||||
}
|
||||
format!("{indent}{name}")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirEntry {
|
||||
name: String,
|
||||
display_name: String,
|
||||
depth: usize,
|
||||
kind: DirEntryKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum DirEntryKind {
|
||||
Directory,
|
||||
File,
|
||||
Symlink,
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<&FileType> for DirEntryKind {
|
||||
fn from(file_type: &FileType) -> Self {
|
||||
if file_type.is_symlink() {
|
||||
DirEntryKind::Symlink
|
||||
} else if file_type.is_dir() {
|
||||
DirEntryKind::Directory
|
||||
} else if file_type.is_file() {
|
||||
DirEntryKind::File
|
||||
} else {
|
||||
DirEntryKind::Other
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn lists_directory_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
let sub_dir = dir_path.join("nested");
|
||||
tokio::fs::create_dir(&sub_dir)
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let deeper_dir = sub_dir.join("deeper");
|
||||
tokio::fs::create_dir(&deeper_dir)
|
||||
.await
|
||||
.expect("create deeper dir");
|
||||
|
||||
tokio::fs::write(dir_path.join("entry.txt"), b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
tokio::fs::write(sub_dir.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write child");
|
||||
tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild")
|
||||
.await
|
||||
.expect("write grandchild");
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::symlink;
|
||||
let link_path = dir_path.join("link");
|
||||
symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 20, 3)
|
||||
.await
|
||||
.expect("list directory");
|
||||
|
||||
#[cfg(unix)]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"link@".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(entries, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_offset_exceeds_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::create_dir(dir_path.join("nested"))
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let err = list_dir_slice(dir_path, 10, 1, 2)
|
||||
.await
|
||||
.expect_err("offset exceeds entries");
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_depth_parameter() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await.expect("create nested");
|
||||
tokio::fs::create_dir(&deeper).await.expect("create deeper");
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root")
|
||||
.await
|
||||
.expect("write root");
|
||||
tokio::fs::write(nested.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write nested");
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep")
|
||||
.await
|
||||
.expect("write deeper");
|
||||
|
||||
let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1)
|
||||
.await
|
||||
.expect("list depth 1");
|
||||
assert_eq!(
|
||||
entries_depth_one,
|
||||
vec!["nested/".to_string(), "root.txt".to_string(),]
|
||||
);
|
||||
|
||||
let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2)
|
||||
.await
|
||||
.expect("list depth 2");
|
||||
assert_eq!(
|
||||
entries_depth_two,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3)
|
||||
.await
|
||||
.expect("list depth 3");
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handles_large_limit_without_overflow() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::write(dir_path.join("alpha.txt"), b"alpha")
|
||||
.await
|
||||
.expect("write alpha");
|
||||
tokio::fs::write(dir_path.join("beta.txt"), b"beta")
|
||||
.await
|
||||
.expect("write beta");
|
||||
tokio::fs::write(dir_path.join("gamma.txt"), b"gamma")
|
||||
.await
|
||||
.expect("write gamma");
|
||||
|
||||
let entries = list_dir_slice(dir_path, 2, usize::MAX, 1)
|
||||
.await
|
||||
.expect("list without overflow");
|
||||
assert_eq!(
|
||||
entries,
|
||||
vec!["beta.txt".to_string(), "gamma.txt".to_string(),]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indicates_truncated_results() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
for idx in 0..40 {
|
||||
let file = dir_path.join(format!("file_{idx:02}.txt"));
|
||||
tokio::fs::write(file, b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 25, 1)
|
||||
.await
|
||||
.expect("list directory");
|
||||
assert_eq!(entries.len(), 26);
|
||||
assert_eq!(
|
||||
entries.last(),
|
||||
Some(&"More than 25 entries found".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bfs_truncation() -> anyhow::Result<()> {
|
||||
let temp = tempdir()?;
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await?;
|
||||
tokio::fs::create_dir(&deeper).await?;
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root").await?;
|
||||
tokio::fs::write(nested.join("child.txt"), b"child").await?;
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?;
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?;
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
"More than 3 entries found".to_string()
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
|
||||
ToolKind::Mcp
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
|
||||
let arguments_str = raw_arguments;
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session,
|
||||
sub_id,
|
||||
session.as_ref(),
|
||||
&sub_id,
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod apply_patch;
|
||||
mod exec_stream;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
mod view_image;
|
||||
|
||||
@@ -11,9 +13,11 @@ pub use plan::PLAN_TOOL;
|
||||
|
||||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use exec_stream::ExecStreamHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
pub use unified_exec::UnifiedExecHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
|
||||
@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
|
||||
@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecParams;
|
||||
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
|
||||
})
|
||||
}
|
||||
ToolPayload::LocalShell { params } => {
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct TestSyncHandler;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
|
||||
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||
|
||||
struct BarrierState {
|
||||
barrier: Arc<Barrier>,
|
||||
participants: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarrierArgs {
|
||||
id: String,
|
||||
participants: usize,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestSyncArgs {
|
||||
#[serde(default)]
|
||||
sleep_before_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
sleep_after_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
barrier: Option<BarrierArgs>,
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
DEFAULT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for TestSyncHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"test_sync_tool handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(delay) = args.sleep_before_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
if let Some(barrier) = args.barrier {
|
||||
wait_on_barrier(barrier).await?;
|
||||
}
|
||||
|
||||
if let Some(delay) = args.sleep_after_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||
if args.participants == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier participants must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.timeout_ms == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier timeout must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let barrier_id = args.id.clone();
|
||||
let barrier = {
|
||||
let mut map = barrier_map().lock().await;
|
||||
match map.entry(barrier_id.clone()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let state = entry.get();
|
||||
if state.participants != args.participants {
|
||||
let existing = state.participants;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"barrier {barrier_id} already registered with {existing} participants"
|
||||
)));
|
||||
}
|
||||
state.barrier.clone()
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let barrier = Arc::new(Barrier::new(args.participants));
|
||||
entry.insert(BarrierState {
|
||||
barrier: barrier.clone(),
|
||||
participants: args.participants,
|
||||
});
|
||||
barrier
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(args.timeout_ms);
|
||||
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||
})?;
|
||||
|
||||
if wait_result.is_leader() {
|
||||
let mut map = barrier_map().lock().await;
|
||||
if let Some(state) = map.get(&barrier_id)
|
||||
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||
{
|
||||
map.remove(&barrier_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session, payload, ..
|
||||
} = invocation;
|
||||
|
||||
@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod context;
|
||||
pub(crate) mod handlers;
|
||||
pub mod parallel;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
pub mod spec;
|
||||
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ApplyPatchCommandContext;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
pub use router::ToolRouter;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||
pub(crate) async fn handle_container_exec_with_params(
|
||||
tool_name: &str,
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
||||
match apply_patch::apply_patch(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&sub_id,
|
||||
&call_id,
|
||||
changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => return item,
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||
Some(apply_patch_exec)
|
||||
@@ -139,12 +149,13 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
|
||||
let output_result = sess
|
||||
.run_exec_with_events(
|
||||
turn_diff_tracker,
|
||||
turn_diff_tracker.clone(),
|
||||
prepared_exec,
|
||||
turn_context.approval_policy,
|
||||
)
|
||||
.await;
|
||||
|
||||
// always make sure to truncate the output if its length isn't controlled.
|
||||
match output_result {
|
||||
Ok(output) => {
|
||||
let ExecToolCallOutput { exit_code, .. } = &output;
|
||||
@@ -155,13 +166,16 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
Err(FunctionCallError::RespondToModel(content))
|
||||
}
|
||||
}
|
||||
Err(ExecError::Function(err)) => Err(err),
|
||||
Err(ExecError::Function(err)) => Err(truncate_function_error(err)),
|
||||
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err(
|
||||
FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)),
|
||||
),
|
||||
Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!(
|
||||
"execution error: {err:?}"
|
||||
))),
|
||||
Err(ExecError::Codex(err)) => {
|
||||
let message = format!("execution error: {err:?}");
|
||||
Err(FunctionCallError::RespondToModel(format_exec_output(
|
||||
&message,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,26 +220,42 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
aggregated_output, ..
|
||||
} = exec_output;
|
||||
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
|
||||
let mut s = &aggregated_output.text;
|
||||
let prefixed_str: String;
|
||||
let content = aggregated_output.text.as_str();
|
||||
|
||||
if exec_output.timed_out {
|
||||
prefixed_str = format!(
|
||||
"command timed out after {} milliseconds\n",
|
||||
let prefixed = format!(
|
||||
"command timed out after {} milliseconds\n{content}",
|
||||
exec_output.duration.as_millis()
|
||||
) + s;
|
||||
s = &prefixed_str;
|
||||
);
|
||||
return format_exec_output(&prefixed);
|
||||
}
|
||||
|
||||
let total_lines = s.lines().count();
|
||||
if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return s.to_string();
|
||||
}
|
||||
format_exec_output(content)
|
||||
}
|
||||
|
||||
let segments: Vec<&str> = s.split_inclusive('\n').collect();
|
||||
fn truncate_function_error(err: FunctionCallError) -> FunctionCallError {
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(msg) => {
|
||||
FunctionCallError::RespondToModel(format_exec_output(&msg))
|
||||
}
|
||||
FunctionCallError::Fatal(msg) => FunctionCallError::Fatal(format_exec_output(&msg)),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_exec_output(content: &str) -> String {
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
let total_lines = content.lines().count();
|
||||
if content.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return content.to_string();
|
||||
}
|
||||
let output = truncate_formatted_exec_output(content, total_lines);
|
||||
format!("Total output lines: {total_lines}\n\n{output}")
|
||||
}
|
||||
|
||||
fn truncate_formatted_exec_output(content: &str, total_lines: usize) -> String {
|
||||
let segments: Vec<&str> = content.split_inclusive('\n').collect();
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
|
||||
let omitted = segments.len().saturating_sub(head_take + tail_take);
|
||||
@@ -236,9 +266,9 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
.map(|segment| segment.len())
|
||||
.sum();
|
||||
let tail_slice_start: usize = if tail_take == 0 {
|
||||
s.len()
|
||||
content.len()
|
||||
} else {
|
||||
s.len()
|
||||
content.len()
|
||||
- segments
|
||||
.iter()
|
||||
.rev()
|
||||
@@ -260,9 +290,9 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len());
|
||||
}
|
||||
|
||||
let head_slice = &s[..head_slice_end];
|
||||
let head_slice = &content[..head_slice_end];
|
||||
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len()));
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(content.len()));
|
||||
|
||||
result.push_str(head_part);
|
||||
result.push_str(&marker);
|
||||
@@ -272,9 +302,86 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
return result;
|
||||
}
|
||||
|
||||
let tail_slice = &s[tail_slice_start..];
|
||||
let tail_slice = &content[tail_slice_start..];
|
||||
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
|
||||
result.push_str(tail_part);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use regex_lite::Regex;
|
||||
|
||||
fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) {
|
||||
let pattern = truncated_message_pattern(line, total_lines);
|
||||
let regex = Regex::new(&pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern}: {err}");
|
||||
});
|
||||
let captures = regex
|
||||
.captures(message)
|
||||
.unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {message}"));
|
||||
let body = captures
|
||||
.name("body")
|
||||
.expect("missing body capture")
|
||||
.as_str();
|
||||
assert!(
|
||||
body.len() <= MODEL_FORMAT_MAX_BYTES,
|
||||
"body exceeds byte limit: {} bytes",
|
||||
body.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn truncated_message_pattern(line: &str, total_lines: usize) -> String {
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(total_lines);
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(total_lines.saturating_sub(head_take));
|
||||
let omitted = total_lines.saturating_sub(head_take + tail_take);
|
||||
let escaped_line = regex_lite::escape(line);
|
||||
format!(
|
||||
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_formatted_exec_output_truncates_large_error() {
|
||||
let line = "very long execution error line that should trigger truncation\n";
|
||||
let large_error = line.repeat(2_500); // way beyond both byte and line limits
|
||||
|
||||
let truncated = format_exec_output(&large_error);
|
||||
|
||||
let total_lines = large_error.lines().count();
|
||||
assert_truncated_message_matches(&truncated, line, total_lines);
|
||||
assert_ne!(truncated, large_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_respond_to_model() {
|
||||
let line = "respond-to-model error that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::RespondToModel(huge));
|
||||
match err {
|
||||
FunctionCallError::RespondToModel(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_function_error_trims_fatal() {
|
||||
let line = "fatal error output that should be truncated\n";
|
||||
let huge = line.repeat(3_000);
|
||||
let total_lines = huge.lines().count();
|
||||
|
||||
let err = truncate_function_error(FunctionCallError::Fatal(huge));
|
||||
match err {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_truncated_message_matches(&message, line, total_lines);
|
||||
}
|
||||
other => panic!("unexpected error variant: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
80
codex-rs/core/src/tools/parallel.rs
Normal file
80
codex-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
parallel_execution: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
parallel_execution: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
}));
|
||||
|
||||
async move {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
|
||||
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
|
||||
Err(err) => Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to receive: {err:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||||
-> Result<ToolOutput, FunctionCallError>;
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
@@ -57,9 +56,9 @@ impl ToolRegistry {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch<'a>(
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation<'a>,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
@@ -137,9 +136,24 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfiguredToolSpec {
|
||||
pub spec: ToolSpec,
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ConfiguredToolSpec {
|
||||
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.specs.push(spec);
|
||||
self.push_spec_with_parallel_support(spec, false);
|
||||
}
|
||||
|
||||
pub fn push_spec_with_parallel_support(
|
||||
&mut self,
|
||||
spec: ToolSpec,
|
||||
supports_parallel_tool_calls: bool,
|
||||
) {
|
||||
self.specs
|
||||
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||||
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -24,7 +26,7 @@ pub struct ToolCall {
|
||||
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRouter {
|
||||
@@ -34,11 +36,22 @@ impl ToolRouter {
|
||||
) -> Self {
|
||||
let builder = build_specs(config, mcp_tools);
|
||||
let (specs, registry) = builder.build();
|
||||
|
||||
Self { registry, specs }
|
||||
}
|
||||
|
||||
pub fn specs(&self) -> &[ToolSpec] {
|
||||
&self.specs
|
||||
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||
self.specs
|
||||
.iter()
|
||||
.map(|config| config.spec.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
@@ -118,10 +131,10 @@ impl ToolRouter {
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
|
||||
@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_test_sync_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"sleep_before_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Optional delay in milliseconds before any other action".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"sleep_after_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Optional delay in milliseconds after completing the barrier".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
let mut barrier_properties = BTreeMap::new();
|
||||
barrier_properties.insert(
|
||||
"id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier shared by concurrent calls that should rendezvous".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"participants".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Number of tool calls that must arrive before the barrier opens".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
properties.insert(
|
||||
"barrier".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: barrier_properties,
|
||||
required: Some(vec!["id".to_string(), "participants".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_sync_tool".to_string(),
|
||||
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_file_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
@@ -294,6 +356,51 @@ fn create_read_file_tool() -> ToolSpec {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_dir_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"dir_path".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Absolute path to the directory to list.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"offset".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"The entry number to start listing from. Must be 1 or greater.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"limit".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The maximum number of entries to return.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"depth".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"The maximum directory depth to traverse. Must be 1 or greater.".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_dir".to_string(),
|
||||
description:
|
||||
"Lists entries in a local directory with 1-indexed entry numbers and simple type labels."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["dir_path".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
/// TODO(dylan): deprecate once we get rid of json tool
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
@@ -503,10 +610,12 @@ pub(crate) fn build_specs(
|
||||
use crate::exec_command::create_write_stdin_tool_for_responses_api;
|
||||
use crate::tools::handlers::ApplyPatchHandler;
|
||||
use crate::tools::handlers::ExecStreamHandler;
|
||||
use crate::tools::handlers::ListDirHandler;
|
||||
use crate::tools::handlers::McpHandler;
|
||||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
use crate::tools::handlers::TestSyncHandler;
|
||||
use crate::tools::handlers::UnifiedExecHandler;
|
||||
use crate::tools::handlers::ViewImageHandler;
|
||||
use std::sync::Arc;
|
||||
@@ -573,16 +682,36 @@ pub(crate) fn build_specs(
|
||||
.any(|tool| tool == "read_file")
|
||||
{
|
||||
let read_file_handler = Arc::new(ReadFileHandler);
|
||||
builder.push_spec(create_read_file_tool());
|
||||
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
|
||||
builder.register_handler("read_file", read_file_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "list_dir")
|
||||
{
|
||||
let list_dir_handler = Arc::new(ListDirHandler);
|
||||
builder.push_spec_with_parallel_support(create_list_dir_tool(), true);
|
||||
builder.register_handler("list_dir", list_dir_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "test_sync_tool")
|
||||
{
|
||||
let test_sync_handler = Arc::new(TestSyncHandler);
|
||||
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
|
||||
builder.register_handler("test_sync_tool", test_sync_handler);
|
||||
}
|
||||
|
||||
if config.web_search_request {
|
||||
builder.push_spec(ToolSpec::WebSearch {});
|
||||
}
|
||||
|
||||
if config.include_view_image_tool {
|
||||
builder.push_spec(create_view_image_tool());
|
||||
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
|
||||
builder.register_handler("view_image", view_image_handler);
|
||||
}
|
||||
|
||||
@@ -610,20 +739,25 @@ pub(crate) fn build_specs(
|
||||
mod tests {
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use mcp_types::ToolInputSchema;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
|
||||
fn tool_name(tool: &ToolSpec) -> &str {
|
||||
match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
|
||||
let tool_names = tools
|
||||
.iter()
|
||||
.map(|tool| match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
})
|
||||
.map(|tool| tool_name(&tool.spec))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
@@ -639,6 +773,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn find_tool<'a>(
|
||||
tools: &'a [ConfiguredToolSpec],
|
||||
expected_name: &str,
|
||||
) -> &'a ConfiguredToolSpec {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool_name(&tool.spec) == expected_name)
|
||||
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs() {
|
||||
let model_family = find_family_for_model("codex-mini-latest")
|
||||
@@ -681,9 +825,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs_includes_beta_read_file_tool() {
|
||||
#[ignore]
|
||||
fn test_parallel_support_flags() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("gpt-5-codex should be a valid model family");
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
@@ -693,9 +838,39 @@ mod tests {
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert_eq_tool_names(&tools, &["unified_exec", "read_file"]);
|
||||
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_test_model_family_includes_sync_tool() {
|
||||
let model_family = find_family_for_model("test-gpt-5-codex")
|
||||
.expect("test-gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: false,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
|
||||
);
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "read_file")
|
||||
);
|
||||
assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -760,7 +935,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
tools[3].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -921,7 +1096,7 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/search",
|
||||
@@ -929,7 +1104,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/search".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -988,14 +1163,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/paginate",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/paginate".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1019,7 +1194,7 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_apply_patch_tool: true,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
@@ -1052,14 +1227,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/tags",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/tags".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1119,14 +1294,14 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"dash/value",
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/value".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1223,7 +1398,7 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"read_file",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -1231,7 +1406,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use assert_matches::assert_matches;
|
||||
use std::sync::Arc;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
@@ -178,7 +179,7 @@ async fn streams_text_without_reasoning() {
|
||||
other => panic!("expected terminal message, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -219,7 +220,7 @@ async fn streams_reasoning_from_string_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[4], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[4], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -266,7 +267,7 @@ async fn streams_reasoning_from_object_delta() {
|
||||
other => panic!("expected message item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[5], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[5], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -293,7 +294,7 @@ async fn streams_reasoning_from_final_message() {
|
||||
other => panic!("expected reasoning item, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[2], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -337,7 +338,7 @@ async fn streams_reasoning_before_tool_call() {
|
||||
other => panic!("expected function call, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(matches!(events[3], ResponseEvent::Completed { .. }));
|
||||
assert_matches!(events[3], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -10,6 +10,7 @@ path = "lib.rs"
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
|
||||
@@ -6,6 +6,7 @@ use codex_core::CodexConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use regex_lite::Regex;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
@@ -14,6 +15,16 @@ pub mod responses;
|
||||
pub mod test_codex;
|
||||
pub mod test_codex_exec;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> {
|
||||
let regex = Regex::new(pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern:?}: {err}");
|
||||
});
|
||||
regex
|
||||
.captures(actual)
|
||||
.unwrap_or_else(|| panic!("regex {pattern:?} did not match {actual:?}"))
|
||||
}
|
||||
|
||||
/// Returns a default `Config` whose on-disk state is confined to the provided
|
||||
/// temporary directory. Using a per-test directory keeps tests hermetic and
|
||||
/// avoids clobbering a developer’s real `~/.codex`.
|
||||
|
||||
@@ -1,11 +1,105 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use serde_json::Value;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockBuilder;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseMock {
|
||||
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
|
||||
}
|
||||
|
||||
impl ResponseMock {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn single_request(&self) -> ResponsesRequest {
|
||||
let requests = self.requests.lock().unwrap();
|
||||
if requests.len() != 1 {
|
||||
panic!("expected 1 request, got {}", requests.len());
|
||||
}
|
||||
requests.first().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponsesRequest(wiremock::Request);
|
||||
|
||||
impl ResponsesRequest {
|
||||
pub fn body_json(&self) -> Value {
|
||||
self.0.body_json().unwrap()
|
||||
}
|
||||
|
||||
pub fn input(&self) -> Vec<Value> {
|
||||
self.0.body_json::<Value>().unwrap()["input"]
|
||||
.as_array()
|
||||
.expect("input array not found in request")
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn function_call_output(&self, call_id: &str) -> Value {
|
||||
self.call_output(call_id, "function_call_output")
|
||||
}
|
||||
|
||||
pub fn custom_tool_call_output(&self, call_id: &str) -> Value {
|
||||
self.call_output(call_id, "custom_tool_call_output")
|
||||
}
|
||||
|
||||
pub fn call_output(&self, call_id: &str, call_type: &str) -> Value {
|
||||
self.input()
|
||||
.iter()
|
||||
.find(|item| {
|
||||
item.get("type").unwrap() == call_type && item.get("call_id").unwrap() == call_id
|
||||
})
|
||||
.cloned()
|
||||
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
||||
}
|
||||
|
||||
pub fn header(&self, name: &str) -> Option<String> {
|
||||
self.0
|
||||
.headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
pub fn path(&self) -> String {
|
||||
self.0.url.path().to_string()
|
||||
}
|
||||
|
||||
pub fn query_param(&self, name: &str) -> Option<String> {
|
||||
self.0
|
||||
.url
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == name)
|
||||
.map(|(_, v)| v.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Match for ResponseMock {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(ResponsesRequest(request.clone()));
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an SSE stream body from a list of JSON events.
|
||||
pub fn sse(events: Vec<Value>) -> String {
|
||||
@@ -34,6 +128,16 @@ pub fn ev_completed(id: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a created response with a specific id.
|
||||
pub fn ev_response_created(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": id,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ev_completed_with_tokens(id: &str, total_tokens: u64) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
@@ -135,40 +239,56 @@ pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sse_failed(id: &str, code: &str, message: &str) -> String {
|
||||
sse(vec![serde_json::json!({
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"error": {"code": code, "message": message}
|
||||
}
|
||||
})])
|
||||
}
|
||||
|
||||
pub fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String)
|
||||
fn base_mock() -> (MockBuilder, ResponseMock) {
|
||||
let response_mock = ResponseMock::new();
|
||||
let mock = Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.and(response_mock.clone());
|
||||
(mock, response_mock)
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
|
||||
where
|
||||
M: wiremock::Match + Send + Sync + 'static,
|
||||
{
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(matcher)
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.and(matcher)
|
||||
.respond_with(sse_response(body))
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once(server: &MockServer, body: String) {
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(sse_response(body))
|
||||
.expect(1)
|
||||
pub async fn mount_sse_once(server: &MockServer, body: String) -> ResponseMock {
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(sse_response(body))
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn mount_sse(server: &MockServer, body: String) {
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(sse_response(body))
|
||||
.mount(server)
|
||||
.await;
|
||||
pub async fn mount_sse(server: &MockServer, body: String) -> ResponseMock {
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(sse_response(body)).mount(server).await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
pub async fn start_mock_server() -> MockServer {
|
||||
@@ -181,7 +301,7 @@ pub async fn start_mock_server() -> MockServer {
|
||||
/// Mounts a sequence of SSE response bodies and serves them in order for each
|
||||
/// POST to `/v1/responses`. Panics if more requests are received than bodies
|
||||
/// provided. Also asserts the exact number of expected calls.
|
||||
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
|
||||
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> ResponseMock {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -208,10 +328,11 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
|
||||
responses: bodies,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder)
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(server)
|
||||
.await;
|
||||
|
||||
response_mock
|
||||
}
|
||||
|
||||
@@ -3,14 +3,14 @@ use std::time::Duration;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
|
||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||
/// function call, then interrupt the session and expect TurnAborted.
|
||||
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
||||
let body = sse(vec![
|
||||
ev_function_call("call_sleep", "shell", &args),
|
||||
ev_completed("done"),
|
||||
]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await;
|
||||
mount_sse_once(&server, body).await;
|
||||
|
||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||
|
||||
|
||||
@@ -106,16 +106,12 @@ async fn exec_cli_applies_experimental_instructions_file() {
|
||||
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"r1\"}}\n\n"
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream"),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock = core_test_support::responses::mount_sse_once_match(
|
||||
&server,
|
||||
path("/v1/responses"),
|
||||
sse.to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create a temporary instructions file with a unique marker we can assert
|
||||
// appears in the outbound request payload.
|
||||
@@ -164,8 +160,8 @@ async fn exec_cli_applies_experimental_instructions_file() {
|
||||
|
||||
// Inspect the captured request and verify our custom base instructions were
|
||||
// included in the `instructions` field.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let body = request.body_json();
|
||||
let instructions = body
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
|
||||
@@ -14,6 +14,8 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -26,8 +28,10 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
@@ -37,6 +41,7 @@ use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
@@ -218,15 +223,9 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
|
||||
// Mock server that will receive the resumed request
|
||||
let server = MockServer::start().await;
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
// Configure Codex to resume from our file
|
||||
let model_provider = ModelProviderInfo {
|
||||
@@ -272,8 +271,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
@@ -367,18 +366,9 @@ async fn includes_base_instructions_override_in_request() {
|
||||
skip_if_no_network!();
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -409,8 +399,8 @@ async fn includes_base_instructions_override_in_request() {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
|
||||
assert!(
|
||||
request_body["instructions"]
|
||||
@@ -565,16 +555,9 @@ async fn includes_user_instructions_message_in_request() {
|
||||
skip_if_no_network!();
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock =
|
||||
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -605,8 +588,8 @@ async fn includes_user_instructions_message_in_request() {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request = resp_mock.single_request();
|
||||
let request_body = request.body_json();
|
||||
|
||||
assert!(
|
||||
!request_body["instructions"]
|
||||
@@ -996,6 +979,100 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
let server = MockServer::start().await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("trigger context window"),
|
||||
responses::sse_failed(
|
||||
"resp_context_window",
|
||||
"context_length_exceeded",
|
||||
"Your input exceeds the context window of this model. Please adjust your input and try again.",
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
responses::mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("seed turn"),
|
||||
sse_completed("resp_seed"),
|
||||
)
|
||||
.await;
|
||||
|
||||
let TestCodex { codex, .. } = test_codex()
|
||||
.with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("known gpt-5 model family");
|
||||
config.model_context_window = Some(272_000);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "seed turn".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "trigger context window".into(),
|
||||
}],
|
||||
})
|
||||
.await?;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
let token_event = wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| {
|
||||
matches!(
|
||||
event,
|
||||
EventMsg::TokenCount(payload)
|
||||
if payload.info.as_ref().is_some_and(|info| {
|
||||
info.model_context_window == Some(info.total_token_usage.total_tokens)
|
||||
&& info.total_token_usage.total_tokens > 0
|
||||
})
|
||||
)
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::TokenCount(token_payload) = token_event else {
|
||||
unreachable!("wait_for_event_with_timeout returned unexpected event");
|
||||
};
|
||||
|
||||
let info = token_payload
|
||||
.info
|
||||
.expect("token usage info present when context window is exceeded");
|
||||
|
||||
assert_eq!(info.model_context_window, Some(272_000));
|
||||
assert_eq!(info.total_token_usage.total_tokens, 272_000);
|
||||
|
||||
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
|
||||
let expected_context_window_message = CodexErr::ContextWindowExceeded.to_string();
|
||||
assert!(
|
||||
matches!(
|
||||
error_event,
|
||||
EventMsg::Error(ref err) if err.message == expected_context_window_message
|
||||
),
|
||||
"expected context window error; got {error_event:?}"
|
||||
);
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
skip_if_no_network!();
|
||||
|
||||
@@ -13,12 +13,6 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
@@ -26,14 +20,10 @@ use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
pub(super) const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
@@ -295,12 +285,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -308,23 +293,13 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -455,12 +430,7 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
@@ -468,23 +438,13 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -582,35 +542,20 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, first_matcher, sse1.clone()).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, second_matcher, sse2.clone()).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
!body.contains("You have exceeded the maximum number of tokens")
|
||||
&& body.contains(SUMMARY_TEXT)
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(&server, third_matcher, sse3.clone()).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -708,49 +653,7 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
ev_completed_with_tokens("r6", 120),
|
||||
]);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SeqResponder {
|
||||
bodies: Arc<Vec<String>>,
|
||||
calls: Arc<AtomicUsize>,
|
||||
requests: Arc<Mutex<Vec<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl SeqResponder {
|
||||
fn new(bodies: Vec<String>) -> Self {
|
||||
Self {
|
||||
bodies: Arc::new(bodies),
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn recorded_requests(&self) -> Vec<Vec<u8>> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let idx = self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.requests.lock().unwrap().push(req.body.clone());
|
||||
let body = self
|
||||
.bodies
|
||||
.get(idx)
|
||||
.unwrap_or_else(|| panic!("unexpected request index {idx}"))
|
||||
.clone();
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
let responder = SeqResponder::new(vec![sse1, sse2, sse3, sse4, sse5, sse6]);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder.clone())
|
||||
.expect(6)
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_sequence(&server, vec![sse1, sse2, sse3, sse4, sse5, sse6]).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -801,10 +704,12 @@ async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_
|
||||
"auto compact should not emit task lifecycle events"
|
||||
);
|
||||
|
||||
let request_bodies: Vec<String> = responder
|
||||
.recorded_requests()
|
||||
let request_bodies: Vec<String> = server
|
||||
.received_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|body| String::from_utf8(body).unwrap_or_default())
|
||||
.map(|request| String::from_utf8(request.body).unwrap_or_default())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
request_bodies.len(),
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -131,9 +132,10 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let expected_model = OPENAI_DEFAULT_MODEL;
|
||||
let user_turn_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -182,7 +184,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let compact_1 = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -251,7 +253,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
|
||||
});
|
||||
let user_turn_2_after_compact = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -316,7 +318,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let usert_turn_3_after_resume = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
@@ -401,7 +403,7 @@ SUMMARY_ONLY_CONTEXT"
|
||||
});
|
||||
let user_turn_3_after_fork = json!(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"model": expected_model,
|
||||
"instructions": prompt,
|
||||
"input": [
|
||||
{
|
||||
|
||||
460
codex-rs/core/tests/suite/list_dir.rs
Normal file
460
codex-rs/core/tests/suite/list_dir.rs
Normal file
@@ -0,0 +1,460 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_returns_entries() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("sample_dir");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "first file")?;
|
||||
std::fs::create_dir(dir_path.join("nested"))?;
|
||||
let dir_path = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-call";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path,
|
||||
"offset": 1,
|
||||
"limit": 2,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(output_text, "E1: [file] alpha.txt\nE2: [dir] nested");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_one_omits_children() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_one");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
std::fs::create_dir(dir_path.join("nested"))?;
|
||||
std::fs::write(dir_path.join("nested").join("beta.txt"), "beta")?;
|
||||
let dir_path = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth1";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 1,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth one".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(output_text, "E1: [file] alpha.txt\nE2: [dir] nested");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_two_includes_children_only() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_two");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
let nested = dir_path.join("nested");
|
||||
std::fs::create_dir(&nested)?;
|
||||
std::fs::write(nested.join("beta.txt"), "beta")?;
|
||||
let deeper = nested.join("grand");
|
||||
std::fs::create_dir(&deeper)?;
|
||||
std::fs::write(deeper.join("gamma.txt"), "gamma")?;
|
||||
let dir_path_string = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth2";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path_string,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 2,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth two".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(
|
||||
output_text,
|
||||
"E1: [file] alpha.txt\nE2: [dir] nested\nE3: [file] nested/beta.txt\nE4: [dir] nested/grand"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable list_dir tool"]
|
||||
async fn list_dir_tool_depth_three_includes_grandchildren() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let dir_path = cwd.path().join("depth_three");
|
||||
std::fs::create_dir(&dir_path)?;
|
||||
std::fs::write(dir_path.join("alpha.txt"), "alpha")?;
|
||||
let nested = dir_path.join("nested");
|
||||
std::fs::create_dir(&nested)?;
|
||||
std::fs::write(nested.join("beta.txt"), "beta")?;
|
||||
let deeper = nested.join("grand");
|
||||
std::fs::create_dir(&deeper)?;
|
||||
std::fs::write(deeper.join("gamma.txt"), "gamma")?;
|
||||
let dir_path_string = dir_path.to_string_lossy().to_string();
|
||||
|
||||
let call_id = "list-dir-depth3";
|
||||
let arguments = serde_json::json!({
|
||||
"dir_path": dir_path_string,
|
||||
"offset": 1,
|
||||
"limit": 10,
|
||||
"depth": 3,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_function_call(call_id, "list_dir", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "list directory contents depth three".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||
_ => None,
|
||||
})
|
||||
.expect("output text present");
|
||||
assert_eq!(
|
||||
output_text,
|
||||
"E1: [file] alpha.txt\nE2: [dir] nested\nE3: [file] nested/beta.txt\nE4: [dir] nested/grand\nE5: [file] nested/grand/gamma.txt"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,6 +10,7 @@ mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod json_result;
|
||||
mod list_dir;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod model_tools;
|
||||
@@ -20,9 +21,11 @@ mod review;
|
||||
mod rmcp_client;
|
||||
mod rollout_list_find;
|
||||
mod seatbelt;
|
||||
mod shell_serialization;
|
||||
mod stream_error_allows_next_turn;
|
||||
mod stream_no_completed;
|
||||
mod tool_harness;
|
||||
mod tool_parallelism;
|
||||
mod tools;
|
||||
mod unified_exec;
|
||||
mod user_notification;
|
||||
|
||||
@@ -10,14 +10,11 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
@@ -44,16 +41,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed(model);
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let resp_mock = responses::mount_sse_once_match(&server, wiremock::matchers::any(), sse).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
@@ -93,13 +81,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for model {model}"
|
||||
);
|
||||
let body = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body = resp_mock.single_request().body_json();
|
||||
tool_identifiers(&body)
|
||||
}
|
||||
|
||||
@@ -125,7 +107,7 @@ async fn model_selects_expected_tools() {
|
||||
let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await;
|
||||
assert_eq!(
|
||||
gpt5_codex_tools,
|
||||
vec!["shell".to_string(), "read_file".to_string()],
|
||||
"gpt-5-codex should expose the beta read_file tool",
|
||||
vec!["shell".to_string(), "apply_patch".to_string(),],
|
||||
"gpt-5-codex should expose the apply_patch tool",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -18,6 +19,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use std::collections::HashMap;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
@@ -178,16 +180,16 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
config.include_apply_patch_tool = true;
|
||||
config.include_plan_tool = true;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let expected_instructions = config.model_family.base_instructions.clone();
|
||||
let base_instructions = config.model_family.base_instructions.clone();
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
@@ -219,14 +221,29 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
// our internal implementation is responsible for keeping tools in sync
|
||||
// with the OpenAI schema, so we just verify the tool presence here
|
||||
let expected_tools_names: &[&str] = &[
|
||||
"shell",
|
||||
"update_plan",
|
||||
"apply_patch",
|
||||
"read_file",
|
||||
"view_image",
|
||||
];
|
||||
let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([
|
||||
("gpt-5", vec!["shell", "update_plan", "view_image"]),
|
||||
(
|
||||
"gpt-5-codex",
|
||||
vec!["shell", "update_plan", "apply_patch", "view_image"],
|
||||
),
|
||||
]);
|
||||
let expected_tools_names = tools_by_model
|
||||
.get(OPENAI_DEFAULT_MODEL)
|
||||
.unwrap_or_else(|| panic!("expected tools to be defined for model {OPENAI_DEFAULT_MODEL}"))
|
||||
.as_slice();
|
||||
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
let expected_instructions = if expected_tools_names.contains(&"apply_patch") {
|
||||
base_instructions
|
||||
} else {
|
||||
[
|
||||
base_instructions.clone(),
|
||||
include_str!("../../../apply-patch/apply_patch_tool_instructions.md").to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
body0["instructions"],
|
||||
serde_json::json!(expected_instructions),
|
||||
|
||||
@@ -10,6 +10,7 @@ use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
@@ -21,6 +22,7 @@ use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "disabled until we enable read_file tool"]
|
||||
async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -46,10 +48,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "read_file", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -59,7 +58,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -80,36 +79,12 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert!(
|
||||
!request_bodies.is_empty(),
|
||||
"expected at least one request body"
|
||||
);
|
||||
|
||||
let tool_output_item = request_bodies
|
||||
.iter()
|
||||
.find_map(|body| {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||
});
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let tool_output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
let output_text = tool_output_item
|
||||
.get("output")
|
||||
.and_then(|value| match value {
|
||||
|
||||
@@ -24,6 +24,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id_from_str;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -260,25 +261,28 @@ async fn review_does_not_emit_agent_message_on_structured_output() {
|
||||
.unwrap();
|
||||
|
||||
// Drain events until TaskComplete; ensure none are AgentMessage.
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
let mut saw_entered = false;
|
||||
let mut saw_exited = false;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(5), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
match ev.msg {
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| match event {
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
EventMsg::AgentMessage(_) => {
|
||||
panic!("unexpected AgentMessage during review with structured output")
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => saw_entered = true,
|
||||
EventMsg::ExitedReviewMode(_) => saw_exited = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
EventMsg::EnteredReviewMode(_) => {
|
||||
saw_entered = true;
|
||||
false
|
||||
}
|
||||
EventMsg::ExitedReviewMode(_) => {
|
||||
saw_exited = true;
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
tokio::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
assert!(saw_entered && saw_exited, "missing review lifecycle events");
|
||||
|
||||
server.verify().await;
|
||||
@@ -441,7 +445,7 @@ async fn review_input_isolated_from_parent_history() {
|
||||
.await;
|
||||
let _complete = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Assert the request `input` contains the environment context followed by the review prompt.
|
||||
// Assert the request `input` contains the environment context followed by the user review prompt.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let input = body["input"].as_array().expect("input array");
|
||||
@@ -469,9 +473,14 @@ async fn review_input_isolated_from_parent_history() {
|
||||
assert_eq!(review_msg["role"].as_str().unwrap(), "user");
|
||||
assert_eq!(
|
||||
review_msg["content"][0]["text"].as_str().unwrap(),
|
||||
format!("{REVIEW_PROMPT}\n\n---\n\nNow, here's your task: Please review only this",)
|
||||
review_prompt,
|
||||
"user message should only contain the raw review prompt"
|
||||
);
|
||||
|
||||
// Ensure the REVIEW_PROMPT rubric is sent via instructions.
|
||||
let instructions = body["instructions"].as_str().expect("instructions string");
|
||||
assert_eq!(instructions, REVIEW_PROMPT);
|
||||
|
||||
// Also verify that a user interruption note was recorded in the rollout.
|
||||
codex.submit(Op::GetPath).await.unwrap();
|
||||
let history_event =
|
||||
|
||||
@@ -47,10 +47,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -184,10 +181,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -238,7 +232,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
@@ -352,10 +346,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
@@ -421,7 +412,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
|
||||
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
277
codex-rs/core/tests/suite/shell_serialization.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
async fn submit_turn(test: &TestCodex, prompt: &str, sandbox_policy: SandboxPolicy) -> Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[wiremock::Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn find_function_call_output<'a>(bodies: &'a [Value], call_id: &str) -> Option<&'a Value> {
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
{
|
||||
return Some(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_stays_json_without_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = false;
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-json";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "shell json"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the json shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item = find_function_call_output(&bodies, call_id).expect("shell output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("shell output string");
|
||||
|
||||
let parsed: Value = serde_json::from_str(output)?;
|
||||
assert_eq!(
|
||||
parsed
|
||||
.get("metadata")
|
||||
.and_then(|metadata| metadata.get("exit_code"))
|
||||
.and_then(Value::as_i64),
|
||||
Some(0),
|
||||
"expected zero exit code in unformatted JSON output",
|
||||
);
|
||||
let stdout = parsed
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert_regex_match(r"(?s)^shell json\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_is_structured_with_freeform_apply_patch() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-structured";
|
||||
let args = json!({
|
||||
"command": ["/bin/echo", "freeform shell"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the structured shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("structured output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("structured output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected structured shell output to be plain text",
|
||||
);
|
||||
let expected_pattern = r"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
freeform shell
|
||||
?$";
|
||||
assert_regex_match(expected_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_output_reserializes_truncated_content() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("gpt-5-codex").expect("gpt-5 is a model family");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-truncated";
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", "seq 1 400"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"run the truncation shell command",
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("recorded requests present");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let output_item =
|
||||
find_function_call_output(&bodies, call_id).expect("truncated output present");
|
||||
let output = output_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("truncated output string");
|
||||
|
||||
assert!(
|
||||
serde_json::from_str::<Value>(output).is_err(),
|
||||
"expected truncated shell output to be plain text",
|
||||
);
|
||||
let truncated_pattern = r#"(?s)^Exit code: 0
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: 400
|
||||
Output:
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
.*
|
||||
\[\.{3} omitted \d+ of 400 lines \.{3}\]
|
||||
|
||||
.*
|
||||
396
|
||||
397
|
||||
398
|
||||
399
|
||||
400
|
||||
$"#;
|
||||
assert_regex_match(truncated_pattern, output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -13,7 +13,7 @@ use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use tokio::time::timeout;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
@@ -102,13 +102,10 @@ async fn retries_on_early_close() {
|
||||
.unwrap();
|
||||
|
||||
// Wait until TaskComplete (should succeed after retry).
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(10), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|event| matches!(event, EventMsg::TaskComplete(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -7,31 +9,24 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::plan_tool::StepStatus;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
item.get("output").and_then(|value| match value {
|
||||
Value::String(text) => Some(text.as_str()),
|
||||
@@ -40,12 +35,6 @@ fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
})
|
||||
}
|
||||
|
||||
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||
requests
|
||||
.iter()
|
||||
.find(|body| function_call_output(body).is_some())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -53,7 +42,8 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.include_apply_patch_tool = true;
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
@@ -65,10 +55,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
let call_id = "shell-tool-call";
|
||||
let command = vec!["/bin/echo", "tool harness"];
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_local_shell_call(call_id, "completed", command),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -78,7 +65,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
ev_assistant_message("msg-1", "all done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -97,32 +84,15 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
let exec_output: Value = serde_json::from_str(output_text)?;
|
||||
assert_eq!(exec_output["metadata"]["exit_code"], 0);
|
||||
let stdout = exec_output["output"].as_str().expect("stdout field");
|
||||
assert!(
|
||||
stdout.contains("tool harness"),
|
||||
"expected stdout to contain command output, got {stdout:?}"
|
||||
);
|
||||
assert_regex_match(r"(?s)^tool harness\n?$", stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -154,10 +124,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &plan_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -167,7 +134,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "plan acknowledged"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -187,42 +154,31 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PlanUpdate(update) => {
|
||||
saw_plan_update = true;
|
||||
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||
assert_eq!(update.plan.len(), 2);
|
||||
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||
assert!(matches!(update.plan[0].status, StepStatus::InProgress));
|
||||
assert_eq!(update.plan[1].step, "Report results");
|
||||
assert!(matches!(update.plan[1].status, StepStatus::Pending));
|
||||
}
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(update) => {
|
||||
saw_plan_update = true;
|
||||
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||
assert_eq!(update.plan.len(), 2);
|
||||
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||
assert_matches!(update.plan[0].status, StepStatus::InProgress);
|
||||
assert_eq!(update.plan[1].step, "Report results");
|
||||
assert_matches!(update.plan[1].status, StepStatus::Pending);
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_plan_update, "expected PlanUpdate event");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
assert_eq!(output_text, "Plan updated");
|
||||
|
||||
Ok(())
|
||||
@@ -251,10 +207,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
.to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "update_plan", &invalid_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -264,7 +217,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "malformed plan payload"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -284,37 +237,28 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||
.await?;
|
||||
|
||||
let mut saw_plan_update = false;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PlanUpdate(_) => saw_plan_update = true,
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PlanUpdate(_) => {
|
||||
saw_plan_update = true;
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
!saw_plan_update,
|
||||
"did not expect PlanUpdate event for malformed payload"
|
||||
);
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
assert!(
|
||||
output_text.contains("failed to parse function arguments"),
|
||||
"expected parse error message in output text, got {output_text:?}"
|
||||
@@ -357,10 +301,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
*** End Patch"#;
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -370,7 +311,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
ev_assistant_message("msg-1", "patch complete"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -391,43 +332,33 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
|
||||
|
||||
let mut saw_patch_begin = false;
|
||||
let mut patch_end_success = None;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::PatchApplyBegin(begin) => {
|
||||
saw_patch_begin = true;
|
||||
assert_eq!(begin.call_id, call_id);
|
||||
}
|
||||
EventMsg::PatchApplyEnd(end) => {
|
||||
assert_eq!(end.call_id, call_id);
|
||||
patch_end_success = Some(end.success);
|
||||
}
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::PatchApplyBegin(begin) => {
|
||||
saw_patch_begin = true;
|
||||
assert_eq!(begin.call_id, call_id);
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(end) => {
|
||||
assert_eq!(end.call_id, call_id);
|
||||
patch_end_success = Some(end.success);
|
||||
false
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(saw_patch_begin, "expected PatchApplyBegin event");
|
||||
let patch_end_success =
|
||||
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
|
||||
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
|
||||
let exit_code = exec_output["metadata"]["exit_code"]
|
||||
@@ -487,10 +418,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
*** End Patch";
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_apply_patch_function_call(call_id, patch_content),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -500,7 +428,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "failed"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -519,29 +447,15 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let req = second_mock.single_request();
|
||||
let output_item = req.function_call_output(call_id);
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
|
||||
assert!(
|
||||
output_text.contains("apply_patch verification failed"),
|
||||
|
||||
205
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
205
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::json;
|
||||
|
||||
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
|
||||
let session_model = test.session_configured.model.clone();
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: prompt.into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
|
||||
let start = Instant::now();
|
||||
run_turn(test, prompt).await?;
|
||||
Ok(start.elapsed())
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "test-gpt-5-codex".to_string();
|
||||
config.model_family =
|
||||
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
|
||||
});
|
||||
builder.build(server).await
|
||||
}
|
||||
|
||||
fn assert_parallel_duration(actual: Duration) {
|
||||
assert!(
|
||||
actual < Duration::from_millis(500),
|
||||
"expected parallel execution to finish quickly, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_serial_duration(actual: Duration) {
|
||||
assert!(
|
||||
actual >= Duration::from_millis(500),
|
||||
"expected serial execution to take longer, got {actual:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let warmup_args = json!({
|
||||
"sleep_after_ms": 10,
|
||||
"barrier": {
|
||||
"id": "parallel-test-sync-warmup",
|
||||
"participants": 2,
|
||||
"timeout_ms": 1_000,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let parallel_args = json!({
|
||||
"sleep_after_ms": 300,
|
||||
"barrier": {
|
||||
"id": "parallel-test-sync",
|
||||
"participants": 2,
|
||||
"timeout_ms": 1_000,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let warmup_first = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-warm-1"}}),
|
||||
ev_function_call("warm-call-1", "test_sync_tool", &warmup_args),
|
||||
ev_function_call("warm-call-2", "test_sync_tool", &warmup_args),
|
||||
ev_completed("resp-warm-1"),
|
||||
]);
|
||||
let warmup_second = sse(vec![
|
||||
ev_assistant_message("warm-msg-1", "warmup complete"),
|
||||
ev_completed("resp-warm-2"),
|
||||
]);
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", ¶llel_args),
|
||||
ev_function_call("call-2", "test_sync_tool", ¶llel_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(
|
||||
&server,
|
||||
vec![warmup_first, warmup_second, first_response, second_response],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_turn(&test, "warm up parallel tool").await?;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
|
||||
assert_parallel_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = test_codex().build(&server).await?;
|
||||
|
||||
let shell_args = json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
let args_one = serde_json::to_string(&shell_args)?;
|
||||
let args_two = serde_json::to_string(&shell_args)?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "shell", &args_one),
|
||||
ev_function_call("call-2", "shell", &args_two),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "run shell twice").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let test = build_codex_with_test_tool(&server).await?;
|
||||
|
||||
let sync_args = json!({
|
||||
"sleep_after_ms": 300
|
||||
})
|
||||
.to_string();
|
||||
let shell_args = serde_json::to_string(&json!({
|
||||
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||
"timeout_ms": 1_000,
|
||||
}))?;
|
||||
|
||||
let first_response = sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_function_call("call-1", "test_sync_tool", &sync_args),
|
||||
ev_function_call("call-2", "shell", &shell_args),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let duration = run_turn_and_measure(&test, "mix tools").await?;
|
||||
assert_serial_duration(duration);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -2,25 +2,30 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_custom_tool_call;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use wiremock::Request;
|
||||
|
||||
async fn submit_turn(
|
||||
test: &TestCodex,
|
||||
@@ -45,37 +50,14 @@ async fn submit_turn(
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = test.codex.next_event().await?;
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TaskComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_bodies(requests: &[Request]) -> Result<Vec<Value>> {
|
||||
requests
|
||||
.iter()
|
||||
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> {
|
||||
let mut out = Vec::new();
|
||||
for body in bodies {
|
||||
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||
for item in items {
|
||||
if item.get("type").and_then(Value::as_str) == Some(ty) {
|
||||
out.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn tool_names(body: &Value) -> Vec<String> {
|
||||
body.get("tools")
|
||||
.and_then(Value::as_array)
|
||||
@@ -104,18 +86,23 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||
let call_id = "custom-unsupported";
|
||||
let tool_name = "unsupported_tool";
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_custom_tool_call(call_id, tool_name, "\"payload\""),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -125,13 +112,7 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let custom_items = collect_output_items(&bodies, "custom_tool_call_output");
|
||||
assert_eq!(custom_items.len(), 1, "expected single custom tool output");
|
||||
let item = custom_items[0];
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id));
|
||||
|
||||
let item = mock.single_request().custom_tool_call_output(call_id);
|
||||
let output = item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
@@ -147,7 +128,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let command = ["/bin/echo", "shell ok"];
|
||||
@@ -164,9 +148,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
call_id_blocked,
|
||||
"shell",
|
||||
@@ -174,8 +159,12 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
call_id_success,
|
||||
"shell",
|
||||
@@ -183,12 +172,16 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let third_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -198,46 +191,23 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
for item in &function_outputs {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
call_id == call_id_blocked || call_id == call_id_success,
|
||||
"unexpected call id {call_id}"
|
||||
);
|
||||
}
|
||||
|
||||
let policy = AskForApproval::Never;
|
||||
let expected_message = format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}"
|
||||
);
|
||||
|
||||
let blocked_outputs: Vec<&Value> = function_outputs
|
||||
.iter()
|
||||
.filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked))
|
||||
.copied()
|
||||
.collect();
|
||||
assert!(
|
||||
!blocked_outputs.is_empty(),
|
||||
"expected at least one rejection output for {call_id_blocked}"
|
||||
let blocked_item = second_mock
|
||||
.single_request()
|
||||
.function_call_output(call_id_blocked);
|
||||
assert_eq!(
|
||||
blocked_item.get("output").and_then(Value::as_str),
|
||||
Some(expected_message.as_str()),
|
||||
"unexpected rejection message"
|
||||
);
|
||||
for item in blocked_outputs {
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
Some(expected_message.as_str()),
|
||||
"unexpected rejection message"
|
||||
);
|
||||
}
|
||||
|
||||
let success_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success))
|
||||
.expect("success output present");
|
||||
let success_item = third_mock
|
||||
.single_request()
|
||||
.function_call_output(call_id_success);
|
||||
let output_json: Value = serde_json::from_str(
|
||||
success_item
|
||||
.get("output")
|
||||
@@ -250,10 +220,8 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
"expected exit code 0 after rerunning without escalation",
|
||||
);
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
stdout.contains("shell ok"),
|
||||
"expected stdout to include command output, got {stdout:?}"
|
||||
);
|
||||
let stdout_pattern = r"(?s)^shell ok\n?$";
|
||||
assert_regex_match(stdout_pattern, stdout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -278,18 +246,23 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
local_shell_event,
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -299,15 +272,7 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
assert_eq!(
|
||||
function_outputs.len(),
|
||||
1,
|
||||
"expected a single function output"
|
||||
);
|
||||
let item = function_outputs[0];
|
||||
let item = second_mock.single_request().function_call_output("");
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
@@ -321,11 +286,11 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let responses = vec![sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
let mock = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.use_experimental_unified_exec_tool = use_unified_exec;
|
||||
@@ -340,15 +305,8 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
1,
|
||||
"expected a single request for tools collection"
|
||||
);
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let first_body = bodies.first().expect("request body present");
|
||||
Ok(tool_names(first_body))
|
||||
let first_body = mock.single_request().body_json();
|
||||
Ok(tool_names(&first_body))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -375,7 +333,10 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = "gpt-5".to_string();
|
||||
config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a valid model");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-timeout";
|
||||
@@ -385,18 +346,23 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
"timeout_ms": timeout_ms,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
@@ -406,13 +372,7 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
let bodies = request_bodies(&requests)?;
|
||||
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||
let timeout_item = function_outputs
|
||||
.iter()
|
||||
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||
.expect("timeout output present");
|
||||
let timeout_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output_str = timeout_item
|
||||
.get("output")
|
||||
@@ -431,30 +391,165 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||
|
||||
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||
assert!(
|
||||
stdout.starts_with("command timed out after "),
|
||||
"expected timeout prefix, got {stdout:?}"
|
||||
);
|
||||
let first_line = stdout.lines().next().unwrap_or_default();
|
||||
let duration_ms = first_line
|
||||
.strip_prefix("command timed out after ")
|
||||
.and_then(|line| line.strip_suffix(" milliseconds"))
|
||||
.and_then(|value| value.parse::<u64>().ok())
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
duration_ms >= timeout_ms,
|
||||
"expected duration >= configured timeout, got {duration_ms} (timeout {timeout_ms})"
|
||||
stdout.contains("command timed out"),
|
||||
"timeout output missing `command timed out`: {stdout}"
|
||||
);
|
||||
} else {
|
||||
// Fallback: accept the signal classification path to deflake the test.
|
||||
assert!(
|
||||
output_str.contains("execution error"),
|
||||
"unexpected non-JSON output: {output_str:?}"
|
||||
);
|
||||
assert!(
|
||||
output_str.contains("Signal(") || output_str.to_lowercase().contains("signal"),
|
||||
"expected signal classification in error output, got {output_str:?}"
|
||||
);
|
||||
let signal_pattern = r"(?is)^execution error:.*signal.*$";
|
||||
assert_regex_match(signal_pattern, output_str);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_sandbox_denied_truncates_error_output() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-denied";
|
||||
let long_line = "this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
let script = format!(
|
||||
"for i in $(seq 1 500); do >&2 echo '{long_line}'; done; cat <<'EOF' > denied.txt\ncontent\nEOF",
|
||||
);
|
||||
let args = json!({
|
||||
"command": ["/bin/sh", "-c", script],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"attempt to write in read-only sandbox",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let denied_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output = denied_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("denied output string");
|
||||
|
||||
let sandbox_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
failed in sandbox: .*?(?:Operation not permitted|Permission denied|Read-only file system).*?
|
||||
\[\.{3} omitted \d+ of \d+ lines \.{3}\]
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
\n?$"#;
|
||||
let sandbox_regex = Regex::new(sandbox_pattern)?;
|
||||
if !sandbox_regex.is_match(output) {
|
||||
let fallback_pattern = r#"(?s)^Total output lines: \d+
|
||||
|
||||
failed in sandbox: this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz
|
||||
.*this is a long stderr line that should trigger truncation 0123456789abcdefghijklmnopqrstuvwxyz.*
|
||||
.*(?:Operation not permitted|Permission denied|Read-only file system).*$"#;
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|cfg| {
|
||||
cfg.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let call_id = "shell-spawn-failure";
|
||||
let bogus_component = "missing-bin-".repeat(700);
|
||||
let bogus_exe = test
|
||||
.cwd
|
||||
.path()
|
||||
.join(bogus_component)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let args = json!({
|
||||
"command": [bogus_exe],
|
||||
"timeout_ms": 1_000,
|
||||
});
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"spawn a missing binary",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let failure_item = second_mock.single_request().function_call_output(call_id);
|
||||
|
||||
let output = failure_item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("spawn failure output string");
|
||||
|
||||
let spawn_error_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Output:
|
||||
execution error: .*$"#;
|
||||
let spawn_truncated_pattern = r#"(?s)^Exit code: -?\d+
|
||||
Wall time: [0-9]+(?:\.[0-9]+)? seconds
|
||||
Total output lines: \d+
|
||||
Output:
|
||||
|
||||
execution error: .*$"#;
|
||||
let spawn_error_regex = Regex::new(spawn_error_pattern)?;
|
||||
let spawn_truncated_regex = Regex::new(spawn_truncated_pattern)?;
|
||||
if !spawn_error_regex.is_match(output) && !spawn_truncated_regex.is_match(output) {
|
||||
let fallback_pattern = r"(?s)^execution error: .*$";
|
||||
assert_regex_match(fallback_pattern, output);
|
||||
}
|
||||
assert!(output.len() <= 10 * 1024);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
@@ -19,6 +20,7 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::skip_if_sandbox;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
|
||||
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
@@ -81,7 +83,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
@@ -90,7 +92,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
@@ -122,12 +124,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
@@ -202,7 +199,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"unified_exec",
|
||||
@@ -211,7 +208,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"unified_exec",
|
||||
|
||||
@@ -12,24 +12,16 @@ use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|items| {
|
||||
items.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn find_image_message(body: &Value) -> Option<&Value> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
@@ -57,12 +49,6 @@ fn extract_output_text(item: &Value) -> Option<&str> {
|
||||
})
|
||||
}
|
||||
|
||||
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||
requests
|
||||
.iter()
|
||||
.find(|body| function_call_output(body).is_some())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
@@ -88,10 +74,7 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -101,7 +84,7 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -121,38 +104,31 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||
.await?;
|
||||
|
||||
let mut tool_event = None;
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
match event.msg {
|
||||
EventMsg::ViewImageToolCall(ev) => tool_event = Some(ev),
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
_ => {}
|
||||
wait_for_event(&codex, |event| match event {
|
||||
EventMsg::ViewImageToolCall(_) => {
|
||||
tool_event = Some(event.clone());
|
||||
false
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
.await;
|
||||
|
||||
let tool_event = tool_event.expect("view image tool event emitted");
|
||||
let tool_event = match tool_event.expect("view image tool event emitted") {
|
||||
EventMsg::ViewImageToolCall(event) => event,
|
||||
_ => unreachable!("stored event must be ViewImageToolCall"),
|
||||
};
|
||||
assert_eq!(tool_event.call_id, call_id);
|
||||
assert_eq!(tool_event.path, abs_path);
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
let body = mock.single_request().body_json();
|
||||
let output_item = mock.single_request().function_call_output(call_id);
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
assert_eq!(output_text, "attached local image path");
|
||||
|
||||
let image_message = find_image_message(body_with_tool_output)
|
||||
.expect("pending input image message not included in request");
|
||||
let image_message =
|
||||
find_image_message(&body).expect("pending input image message not included in request");
|
||||
let image_url = image_message
|
||||
.get("content")
|
||||
.and_then(Value::as_array)
|
||||
@@ -197,10 +173,7 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -210,7 +183,7 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -229,33 +202,16 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let body_with_tool_output = mock.single_request().body_json();
|
||||
let output_item = mock.single_request().function_call_output(call_id);
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
let expected_message = format!("image path `{}` is not a file", abs_path.display());
|
||||
assert_eq!(output_text, expected_message);
|
||||
|
||||
assert!(
|
||||
find_image_message(body_with_tool_output).is_none(),
|
||||
find_image_message(&body_with_tool_output).is_none(),
|
||||
"directory path should not produce an input_image message"
|
||||
);
|
||||
|
||||
@@ -282,10 +238,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||
|
||||
let first_response = sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "view_image", &arguments),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
@@ -295,7 +248,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
@@ -314,28 +267,11 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.expect("event");
|
||||
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(
|
||||
requests.len() >= 2,
|
||||
"expected at least two POST requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
let request_bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||
.expect("function_call_output item not found in requests");
|
||||
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||
let output_text = extract_output_text(output_item).expect("output text present");
|
||||
let body_with_tool_output = mock.single_request().body_json();
|
||||
let output_item = mock.single_request().function_call_output(call_id);
|
||||
let output_text = extract_output_text(&output_item).expect("output text present");
|
||||
let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display());
|
||||
assert!(
|
||||
output_text.starts_with(&expected_prefix),
|
||||
@@ -343,7 +279,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||
);
|
||||
|
||||
assert!(
|
||||
find_image_message(body_with_tool_output).is_none(),
|
||||
find_image_message(&body_with_tool_output).is_none(),
|
||||
"missing file should not produce an input_image message"
|
||||
);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Codex MCP Interface [experimental]
|
||||
# Codex MCP Server Interface [experimental]
|
||||
|
||||
This document describes Codex’s experimental MCP interface: a JSON‑RPC API that runs over the Model Context Protocol (MCP) transport to control a local Codex engine.
|
||||
This document describes Codex’s experimental MCP server interface: a JSON‑RPC API that runs over the Model Context Protocol (MCP) transport to control a local Codex engine.
|
||||
|
||||
- Status: experimental and subject to change without notice
|
||||
- Server binary: `codex mcp-server` (or `codex-mcp-server`)
|
||||
|
||||
@@ -77,7 +77,7 @@ pub struct Cli {
|
||||
|
||||
/// Initial instructions for the agent. If not provided as an argument (or
|
||||
/// if `-` is used), instructions are read from stdin.
|
||||
#[arg(value_name = "PROMPT")]
|
||||
#[arg(value_name = "PROMPT", value_hint = clap::ValueHint::Other)]
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ pub struct ResumeArgs {
|
||||
pub last: bool,
|
||||
|
||||
/// Prompt to send after resuming the session. If `-` is used, read from stdin.
|
||||
#[arg(value_name = "PROMPT")]
|
||||
#[arg(value_name = "PROMPT", value_hint = clap::ValueHint::Other)]
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ use codex_core::default_client::set_default_originator;
|
||||
use codex_core::find_conversation_path_by_id_str;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
if let Err(err) = set_default_originator("codex_exec") {
|
||||
if let Err(err) = set_default_originator("codex_exec".to_string()) {
|
||||
tracing::warn!(?err, "Failed to set codex exec originator override {err:?}");
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
include_plan_tool: Some(include_plan_tool),
|
||||
include_apply_patch_tool: Some(true),
|
||||
include_apply_patch_tool: None,
|
||||
include_view_image_tool: None,
|
||||
show_raw_agent_reasoning: oss.then_some(true),
|
||||
tools_web_search_request: None,
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use wiremock::Mock;
|
||||
use wiremock::matchers::header;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_uses_codex_api_key_env_var() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
let server = start_mock_server().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header("Authorization", "Bearer dummy"))
|
||||
.respond_with(sse_response(sse(vec![ev_completed("request_0")])))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
header("Authorization", "Bearer dummy"),
|
||||
sse(vec![ev_completed("request_0")]),
|
||||
)
|
||||
.await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
mod apply_patch;
|
||||
mod auth_env;
|
||||
mod originator;
|
||||
mod output_schema;
|
||||
mod resume;
|
||||
mod sandbox;
|
||||
|
||||
52
codex-rs/exec/tests/suite/originator.rs
Normal file
52
codex-rs/exec/tests/suite/originator.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use core_test_support::responses;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use wiremock::matchers::header;
|
||||
|
||||
/// Verify that when the server reports an error, `codex-exec` exits with a
|
||||
/// non-zero status code so automation can detect failures.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn send_codex_exec_originator() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![
|
||||
responses::ev_response_created("response_1"),
|
||||
responses::ev_assistant_message("response_1", "Hello, world!"),
|
||||
responses::ev_completed("response_1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, header("Originator", "codex_exec"), body).await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("tell me something")
|
||||
.assert()
|
||||
.code(0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn supports_originator_override() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![
|
||||
responses::ev_response_created("response_1"),
|
||||
responses::ev_assistant_message("response_1", "Hello, world!"),
|
||||
responses::ev_completed("response_1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, header("Originator", "codex_exec_override"), body)
|
||||
.await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.env("CODEX_INTERNAL_ORIGINATOR_OVERRIDE", "codex_exec_override")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("tell me something")
|
||||
.assert()
|
||||
.code(0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -24,14 +24,11 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp1"}
|
||||
}),
|
||||
responses::ev_response_created("resp1"),
|
||||
responses::ev_assistant_message("m1", "fixture hello"),
|
||||
responses::ev_completed("resp1"),
|
||||
]);
|
||||
responses::mount_sse_once_match(&server, any(), body).await;
|
||||
let response_mock = responses::mount_sse_once_match(&server, any(), body).await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
@@ -46,12 +43,8 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to capture requests");
|
||||
assert_eq!(requests.len(), 1, "expected exactly one request");
|
||||
let payload: Value = serde_json::from_slice(&requests[0].body)?;
|
||||
let request = response_mock.single_request();
|
||||
let payload: Value = request.body_json();
|
||||
let text = payload.get("text").expect("request missing text field");
|
||||
let format = text
|
||||
.get("format")
|
||||
|
||||
@@ -17,4 +17,5 @@ walkdir = "2"
|
||||
workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
pretty_assertions = "1.4.1"
|
||||
|
||||
@@ -186,6 +186,7 @@ fn default_commit_identity() -> Vec<(OsString, OsString)> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::operations::run_git_for_stdout;
|
||||
use assert_matches::assert_matches;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::process::Command;
|
||||
|
||||
@@ -348,7 +349,7 @@ mod tests {
|
||||
let options = CreateGhostCommitOptions::new(repo)
|
||||
.force_include(vec![PathBuf::from("../outside.txt")]);
|
||||
let err = create_ghost_commit(&options).unwrap_err();
|
||||
assert!(matches!(err, GitToolingError::PathEscapesRepository { .. }));
|
||||
assert_matches!(err, GitToolingError::PathEscapesRepository { .. });
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -356,7 +357,7 @@ mod tests {
|
||||
fn restore_requires_git_repository() {
|
||||
let temp = tempfile::tempdir().expect("tempdir");
|
||||
let err = restore_to_commit(temp.path(), "deadbeef").unwrap_err();
|
||||
assert!(matches!(err, GitToolingError::NotAGitRepository { .. }));
|
||||
assert_matches!(err, GitToolingError::NotAGitRepository { .. });
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -11,6 +11,10 @@ use crate::server::ServerOptions;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
|
||||
const ANSI_YELLOW: &str = "\x1b[93m";
|
||||
const ANSI_BOLD: &str = "\x1b[1m";
|
||||
const ANSI_RESET: &str = "\x1b[0m";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserCodeResp {
|
||||
device_auth_id: String,
|
||||
@@ -68,9 +72,15 @@ async fn request_user_code(
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
if status == StatusCode::NOT_FOUND {
|
||||
return Err(std::io::Error::other(
|
||||
"device code login is not enabled for this Codex server. Use the browser login or verify the server URL.",
|
||||
));
|
||||
}
|
||||
|
||||
return Err(std::io::Error::other(format!(
|
||||
"device code request failed with status {}",
|
||||
resp.status()
|
||||
"device code request failed with status {status}"
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -128,20 +138,13 @@ async fn poll_for_token(
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to print colored text if terminal supports ANSI
|
||||
fn print_colored_warning_device_code() {
|
||||
// ANSI escape code for bright yellow
|
||||
const YELLOW: &str = "\x1b[93m";
|
||||
const RESET: &str = "\x1b[0m";
|
||||
let warning = "WARN!!! device code authentication has potential risks and\n\
|
||||
should be used with caution only in cases where browser support \n\
|
||||
is missing. This is prone to attacks.\n\
|
||||
\n\
|
||||
- This code is valid for 15 minutes.\n\
|
||||
- Do not share this code with anyone.\n\
|
||||
";
|
||||
let mut stdout = io::stdout().lock();
|
||||
let _ = write!(stdout, "{YELLOW}{warning}{RESET}");
|
||||
let _ = write!(
|
||||
stdout,
|
||||
"{ANSI_YELLOW}{ANSI_BOLD}Only use device code authentication when browser login is not available.{ANSI_RESET}{ANSI_YELLOW}\n\
|
||||
{ANSI_BOLD}Keep the code secret; do not share it.{ANSI_RESET}{ANSI_RESET}\n\n"
|
||||
);
|
||||
let _ = stdout.flush();
|
||||
}
|
||||
|
||||
@@ -151,12 +154,11 @@ pub async fn run_device_code_login(opts: ServerOptions) -> std::io::Result<()> {
|
||||
let base_url = opts.issuer.trim_end_matches('/');
|
||||
let api_base_url = format!("{}/api/accounts", opts.issuer.trim_end_matches('/'));
|
||||
print_colored_warning_device_code();
|
||||
println!("⏳ Generating a new 9-digit device code for authentication...\n");
|
||||
let uc = request_user_code(&client, &api_base_url, &opts.client_id).await?;
|
||||
|
||||
println!(
|
||||
"To authenticate, visit: {}/deviceauth/authorize and enter code: {}",
|
||||
api_base_url, uc.user_code
|
||||
"To authenticate:\n 1. Open in your browser: {ANSI_BOLD}https://auth.openai.com/codex/device{ANSI_RESET}\n 2. Enter the one-time code below within 15 minutes:\n\n {ANSI_BOLD}{}{ANSI_RESET}\n",
|
||||
uc.user_code
|
||||
);
|
||||
|
||||
let code_resp = poll_for_token(
|
||||
@@ -172,7 +174,6 @@ pub async fn run_device_code_login(opts: ServerOptions) -> std::io::Result<()> {
|
||||
code_verifier: code_resp.code_verifier,
|
||||
code_challenge: code_resp.code_challenge,
|
||||
};
|
||||
println!("authorization code received");
|
||||
let redirect_uri = format!("{base_url}/deviceauth/callback");
|
||||
|
||||
let tokens = crate::server::exchange_code_for_tokens(
|
||||
|
||||
@@ -28,3 +28,4 @@ tracing = { workspace = true, features = ["log"] }
|
||||
wiremock = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
|
||||
@@ -30,19 +30,21 @@ pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_status_and_success() {
|
||||
let v: JsonValue = serde_json::json!({"status":"verifying"});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
|
||||
assert_matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying");
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"status":"success"});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 2);
|
||||
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
|
||||
assert!(matches!(events2[1], PullEvent::Success));
|
||||
assert_matches!(events2[0], PullEvent::Status(ref s) if s == "success");
|
||||
assert_matches!(events2[1], PullEvent::Success);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -50,33 +52,24 @@ mod tests {
|
||||
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:abc");
|
||||
assert_eq!(*total, Some(100));
|
||||
assert_eq!(*completed, None);
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
|
||||
} if digest == "sha256:abc" && total == &Some(100) && completed.is_none()
|
||||
);
|
||||
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 1);
|
||||
match &events2[0] {
|
||||
assert_matches!(
|
||||
&events2[0],
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:def");
|
||||
assert_eq!(*total, None);
|
||||
assert_eq!(*completed, Some(42));
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
} if digest == "sha256:def" && total.is_none() && completed == &Some(42)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,10 +548,15 @@ pub struct TaskStartedEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default, TS)]
|
||||
pub struct TokenUsage {
|
||||
#[ts(type = "number")]
|
||||
pub input_tokens: u64,
|
||||
#[ts(type = "number")]
|
||||
pub cached_input_tokens: u64,
|
||||
#[ts(type = "number")]
|
||||
pub output_tokens: u64,
|
||||
#[ts(type = "number")]
|
||||
pub reasoning_output_tokens: u64,
|
||||
#[ts(type = "number")]
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
@@ -559,6 +564,7 @@ pub struct TokenUsage {
|
||||
pub struct TokenUsageInfo {
|
||||
pub total_token_usage: TokenUsage,
|
||||
pub last_token_usage: TokenUsage,
|
||||
#[ts(type = "number | null")]
|
||||
pub model_context_window: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -590,6 +596,31 @@ impl TokenUsageInfo {
|
||||
self.total_token_usage.add_assign(last);
|
||||
self.last_token_usage = last.clone();
|
||||
}
|
||||
|
||||
pub fn fill_to_context_window(&mut self, context_window: u64) {
|
||||
let previous_total = self.total_token_usage.total_tokens;
|
||||
let delta = context_window.saturating_sub(previous_total);
|
||||
|
||||
self.model_context_window = Some(context_window);
|
||||
self.total_token_usage = TokenUsage {
|
||||
total_tokens: context_window,
|
||||
..TokenUsage::default()
|
||||
};
|
||||
self.last_token_usage = TokenUsage {
|
||||
total_tokens: delta,
|
||||
..TokenUsage::default()
|
||||
};
|
||||
}
|
||||
|
||||
pub fn full_context_window(context_window: u64) -> Self {
|
||||
let mut info = Self {
|
||||
total_token_usage: TokenUsage::default(),
|
||||
last_token_usage: TokenUsage::default(),
|
||||
model_context_window: Some(context_window),
|
||||
};
|
||||
info.fill_to_context_window(context_window);
|
||||
info
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
@@ -609,8 +640,10 @@ pub struct RateLimitWindow {
|
||||
/// Percentage (0-100) of the window that has been consumed.
|
||||
pub used_percent: f64,
|
||||
/// Rolling window duration, in minutes.
|
||||
#[ts(type = "number | null")]
|
||||
pub window_minutes: Option<u64>,
|
||||
/// Seconds until the window resets.
|
||||
#[ts(type = "number | null")]
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ A strict HTTP proxy that only forwards `POST` requests to `/v1/responses` to the
|
||||
|
||||
## Expected Usage
|
||||
|
||||
**IMPORTANT:** `codex-responses-api-proxy` is designed to be run by a privileged user with access to `OPENAI_API_KEY` so that an unprivileged user cannot inspect or tamper with the process. Though if `--http-shutdown` is specified, an unprivileged user _can_ make a `GET` request to `/shutdown` to shutdown the server, as an unprivileged could not send `SIGTERM` to kill the process.
|
||||
**IMPORTANT:** `codex-responses-api-proxy` is designed to be run by a privileged user with access to `OPENAI_API_KEY` so that an unprivileged user cannot inspect or tamper with the process. Though if `--http-shutdown` is specified, an unprivileged user _can_ make a `GET` request to `/shutdown` to shutdown the server, as an unprivileged user could not send `SIGTERM` to kill the process.
|
||||
|
||||
A privileged user (i.e., `root` or a user with `sudo`) who has access to `OPENAI_API_KEY` would run the following to start the server, as `codex-responses-api-proxy` reads the auth token from `stdin`:
|
||||
|
||||
```shell
|
||||
printenv OPENAI_API_KEY | codex-responses-api-proxy --http-shutdown --server-info /tmp/server-info.json
|
||||
printenv OPENAI_API_KEY | env -u OPENAI_API_KEY codex-responses-api-proxy --http-shutdown --server-info /tmp/server-info.json
|
||||
```
|
||||
|
||||
A non-privileged user would then run Codex as follows, specifying the `model_provider` dynamically:
|
||||
@@ -35,7 +35,7 @@ curl --fail --silent --show-error "${PROXY_BASE_URL}/shutdown"
|
||||
- Listens on the provided port or an ephemeral port if `--port` is not specified.
|
||||
- Accepts exactly `POST /v1/responses` (no query string). The request body is forwarded to `https://api.openai.com/v1/responses` with `Authorization: Bearer <key>` set. All original request headers (except any incoming `Authorization`) are forwarded upstream. For other requests, it responds with `403`.
|
||||
- Optionally writes a single-line JSON file with server info, currently `{ "port": <u16> }`.
|
||||
- Optional `--http-shutdown` enables `GET /shutdown` to terminate the process with exit code 0. This allows one user (e.g., `root`) to start the proxy and another unprivileged user on the host to shut it down.
|
||||
- Optional `--http-shutdown` enables `GET /shutdown` to terminate the process with exit code `0`. This allows one user (e.g., `root`) to start the proxy and another unprivileged user on the host to shut it down.
|
||||
|
||||
## CLI
|
||||
|
||||
@@ -44,7 +44,7 @@ codex-responses-api-proxy [--port <PORT>] [--server-info <FILE>] [--http-shutdow
|
||||
```
|
||||
|
||||
- `--port <PORT>`: Port to bind on `127.0.0.1`. If omitted, an ephemeral port is chosen.
|
||||
- `--server-info <FILE>`: If set, the proxy writes a single line of JSON with `{ "port": <PORT> }` once listening.
|
||||
- `--server-info <FILE>`: If set, the proxy writes a single line of JSON with `{ "port": <PORT>, "pid": <PID> }` once listening.
|
||||
- `--http-shutdown`: If set, enables `GET /shutdown` to exit the process with code `0`.
|
||||
|
||||
## Notes
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use std::io::Read;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
/// Use a generous buffer size to avoid truncation and to allow for longer API
|
||||
@@ -13,13 +12,66 @@ const AUTH_HEADER_PREFIX: &[u8] = b"Bearer ";
|
||||
/// value with the auth token used with `Bearer`. The header value is returned
|
||||
/// as a `&'static str` whose bytes are locked in memory to avoid accidental
|
||||
/// exposure.
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> {
|
||||
read_auth_header_with(read_from_unix_stdin)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> {
|
||||
use std::io::Read;
|
||||
|
||||
// Use of `stdio::io::stdin()` has the problem mentioned in the docstring on
|
||||
// the UNIX version of `read_from_unix_stdin()`, so this should ultimately
|
||||
// be replaced the low-level Windows equivalent. Because we do not have an
|
||||
// equivalent of mlock() on Windows right now, it is not pressing until we
|
||||
// address that issue.
|
||||
read_auth_header_with(|buffer| std::io::stdin().read(buffer))
|
||||
}
|
||||
|
||||
fn read_auth_header_with<F>(read_fn: F) -> Result<&'static str>
|
||||
/// We perform a low-level read with `read(2)` because `stdio::io::stdin()` has
|
||||
/// an internal BufReader:
|
||||
///
|
||||
/// https://github.com/rust-lang/rust/blob/bcbbdcb8522fd3cb4a8dde62313b251ab107694d/library/std/src/io/stdio.rs#L250-L252
|
||||
///
|
||||
/// that can end up retaining a copy of stdin data in memory with no way to zero
|
||||
/// it out, whereas we aim to guarantee there is exactly one copy of the API key
|
||||
/// in memory, protected by mlock(2).
|
||||
#[cfg(unix)]
|
||||
fn read_from_unix_stdin(buffer: &mut [u8]) -> std::io::Result<usize> {
|
||||
use libc::c_void;
|
||||
use libc::read;
|
||||
|
||||
// Perform a single read(2) call into the provided buffer slice.
|
||||
// Looping and newline/EOF handling are managed by the caller.
|
||||
loop {
|
||||
let result = unsafe {
|
||||
read(
|
||||
libc::STDIN_FILENO,
|
||||
buffer.as_mut_ptr().cast::<c_void>(),
|
||||
buffer.len(),
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if result < 0 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
if err.kind() == std::io::ErrorKind::Interrupted {
|
||||
continue;
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
return Ok(result as usize);
|
||||
}
|
||||
}
|
||||
|
||||
fn read_auth_header_with<F>(mut read_fn: F) -> Result<&'static str>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> std::io::Result<usize>,
|
||||
F: FnMut(&mut [u8]) -> std::io::Result<usize>,
|
||||
{
|
||||
// TAKE CARE WHEN MODIFYING THIS CODE!!!
|
||||
//
|
||||
@@ -31,19 +83,50 @@ where
|
||||
let mut buf = [0u8; BUFFER_SIZE];
|
||||
buf[..AUTH_HEADER_PREFIX.len()].copy_from_slice(AUTH_HEADER_PREFIX);
|
||||
|
||||
let read = read_fn(&mut buf[AUTH_HEADER_PREFIX.len()..]).inspect_err(|_err| {
|
||||
buf.zeroize();
|
||||
})?;
|
||||
let prefix_len = AUTH_HEADER_PREFIX.len();
|
||||
let capacity = buf.len() - prefix_len;
|
||||
let mut total_read = 0usize; // number of bytes read into the token region
|
||||
let mut saw_newline = false;
|
||||
let mut saw_eof = false;
|
||||
|
||||
if read == buf.len() - AUTH_HEADER_PREFIX.len() {
|
||||
while total_read < capacity {
|
||||
let slice = &mut buf[prefix_len + total_read..];
|
||||
let read = match read_fn(slice) {
|
||||
Ok(n) => n,
|
||||
Err(err) => {
|
||||
buf.zeroize();
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
if read == 0 {
|
||||
saw_eof = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Search only the newly written region for a newline.
|
||||
let newly_written = &slice[..read];
|
||||
if let Some(pos) = newly_written.iter().position(|&b| b == b'\n') {
|
||||
total_read += pos + 1; // include the newline for trimming below
|
||||
saw_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
total_read += read;
|
||||
|
||||
// Continue loop; if buffer fills without newline/EOF we'll error below.
|
||||
}
|
||||
|
||||
// If buffer filled and we did not see newline or EOF, error out.
|
||||
if total_read == capacity && !saw_newline && !saw_eof {
|
||||
buf.zeroize();
|
||||
return Err(anyhow!(
|
||||
"OPENAI_API_KEY is too large to fit in the 512-byte buffer"
|
||||
));
|
||||
}
|
||||
|
||||
let mut total = AUTH_HEADER_PREFIX.len() + read;
|
||||
while total > AUTH_HEADER_PREFIX.len() && (buf[total - 1] == b'\n' || buf[total - 1] == b'\r') {
|
||||
let mut total = prefix_len + total_read;
|
||||
while total > prefix_len && (buf[total - 1] == b'\n' || buf[total - 1] == b'\r') {
|
||||
total -= 1;
|
||||
}
|
||||
|
||||
@@ -138,13 +221,19 @@ fn validate_auth_header_bytes(key_bytes: &[u8]) -> Result<()> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
|
||||
#[test]
|
||||
fn reads_key_with_no_newlines() {
|
||||
let mut sent = false;
|
||||
let result = read_auth_header_with(|buf| {
|
||||
if sent {
|
||||
return Ok(0);
|
||||
}
|
||||
let data = b"sk-abc123";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
sent = true;
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap();
|
||||
@@ -152,11 +241,32 @@ mod tests {
|
||||
assert_eq!(result, "Bearer sk-abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_key_with_short_reads() {
|
||||
let mut chunks: VecDeque<&[u8]> =
|
||||
VecDeque::from(vec![b"sk-".as_ref(), b"abc".as_ref(), b"123\n".as_ref()]);
|
||||
let result = read_auth_header_with(|buf| match chunks.pop_front() {
|
||||
Some(chunk) if !chunk.is_empty() => {
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
Ok(chunk.len())
|
||||
}
|
||||
_ => Ok(0),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "Bearer sk-abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_key_and_trims_newlines() {
|
||||
let mut sent = false;
|
||||
let result = read_auth_header_with(|buf| {
|
||||
if sent {
|
||||
return Ok(0);
|
||||
}
|
||||
let data = b"sk-abc123\r\n";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
sent = true;
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap();
|
||||
@@ -194,9 +304,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn errors_on_invalid_utf8() {
|
||||
let mut sent = false;
|
||||
let err = read_auth_header_with(|buf| {
|
||||
if sent {
|
||||
return Ok(0);
|
||||
}
|
||||
let data = b"sk-abc\xff";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
sent = true;
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap_err();
|
||||
@@ -209,9 +324,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn errors_on_invalid_characters() {
|
||||
let mut sent = false;
|
||||
let err = read_auth_header_with(|buf| {
|
||||
if sent {
|
||||
return Ok(0);
|
||||
}
|
||||
let data = b"sk-abc!23";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
sent = true;
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap_err();
|
||||
|
||||
@@ -5,6 +5,7 @@ mod perform_oauth_login;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use oauth::OAuthCredentialsStoreMode;
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
|
||||
@@ -58,6 +58,21 @@ pub struct StoredOAuthTokens {
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
/// Determine where Codex should store and read MCP credentials.
|
||||
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OAuthCredentialsStoreMode {
|
||||
/// `Keyring` when available; otherwise, `File`.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
#[default]
|
||||
Auto,
|
||||
/// CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
File,
|
||||
/// Keyring when available, otherwise fail.
|
||||
Keyring,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CredentialStoreError(anyhow::Error);
|
||||
|
||||
@@ -83,15 +98,15 @@ impl fmt::Display for CredentialStoreError {
|
||||
|
||||
impl std::error::Error for CredentialStoreError {}
|
||||
|
||||
trait CredentialStore {
|
||||
trait KeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
|
||||
}
|
||||
|
||||
struct KeyringCredentialStore;
|
||||
struct DefaultKeyringStore;
|
||||
|
||||
impl CredentialStore for KeyringCredentialStore {
|
||||
impl KeyringStore for DefaultKeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.get_password() {
|
||||
@@ -129,47 +144,85 @@ impl PartialEq for WrappedOAuthTokenResponse {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let store = KeyringCredentialStore;
|
||||
load_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub(crate) fn load_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => Ok(Some(tokens)),
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match store.load(KEYRING_SERVICE, &key) {
|
||||
match keyring_store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to read OAuth tokens from keyring: {message}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {message}"))
|
||||
Ok(None) => Ok(None),
|
||||
Err(error) => Err(error.into_error()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
tokens,
|
||||
),
|
||||
OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let store = KeyringCredentialStore;
|
||||
save_oauth_tokens_with_store(&store, server_name, tokens)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn save_oauth_tokens_with_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
@@ -177,31 +230,61 @@ fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to write OAuth tokens to keyring: {message}");
|
||||
let message = format!(
|
||||
"failed to write OAuth tokens to keyring: {}",
|
||||
error.message()
|
||||
);
|
||||
warn!("{message}");
|
||||
Err(error.into_error().context(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) => {
|
||||
let message = error.to_string();
|
||||
warn!("falling back to file storage for OAuth tokens: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result<bool> {
|
||||
let store = KeyringCredentialStore;
|
||||
delete_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub fn delete_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<bool> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_removed = match store.delete(KEYRING_SERVICE, &key) {
|
||||
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
|
||||
let keyring_removed = match keyring_result {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
return Err(error.into_error()).context("failed to delete OAuth tokens from keyring");
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
|
||||
return Err(error.into_error())
|
||||
.context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -218,6 +301,7 @@ struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
@@ -225,14 +309,16 @@ impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
manager: Arc<Mutex<AuthorizationManager>>,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager: manager,
|
||||
authorization_manager,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
@@ -257,15 +343,18 @@ impl OAuthPersistor {
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored)?;
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) =
|
||||
delete_oauth_tokens(&self.inner.server_name, &self.inner.url)
|
||||
&& let Err(error) = delete_oauth_tokens(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
@@ -542,7 +631,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialStore for MockCredentialStore {
|
||||
impl KeyringStore for MockCredentialStore {
|
||||
fn load(
|
||||
&self,
|
||||
_service: &str,
|
||||
@@ -643,7 +732,8 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let loaded =
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
@@ -657,8 +747,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -674,8 +768,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -689,7 +787,11 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
@@ -706,7 +808,11 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
@@ -734,8 +840,34 @@ mod tests {
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
assert!(store.contains(&key));
|
||||
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
@@ -751,8 +883,12 @@ mod tests {
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url);
|
||||
let result = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
|
||||
@@ -12,6 +12,7 @@ use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
@@ -26,7 +27,11 @@ impl Drop for CallbackServerGuard {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<()> {
|
||||
pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
@@ -47,7 +52,9 @@ pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, None).await?;
|
||||
oauth_state.start_authorization(&[], &redirect_uri).await?;
|
||||
oauth_state
|
||||
.start_authorization(&[], &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
@@ -79,7 +86,7 @@ pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored)?;
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
|
||||
@@ -35,6 +35,7 @@ use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
@@ -119,17 +120,22 @@ impl RmcpClient {
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let initial_tokens = match load_oauth_tokens(server_name, url) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
}
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
let transport = if let Some(initial_tokens) = initial_tokens.clone() {
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
@@ -137,7 +143,7 @@ impl RmcpClient {
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token {
|
||||
http_config = http_config.auth_header(format!("Bearer {bearer_token}"));
|
||||
http_config = http_config.auth_header(bearer_token);
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(http_config);
|
||||
@@ -283,6 +289,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
@@ -317,6 +324,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
|
||||
@@ -68,6 +68,8 @@ strum_macros = { workspace = true }
|
||||
supports-color = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
textwrap = { workspace = true }
|
||||
tree-sitter-highlight = { workspace = true }
|
||||
tree-sitter-bash = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -94,6 +96,7 @@ arboard = { workspace = true }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
insta = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -134,8 +134,9 @@ impl App {
|
||||
/// Useful when switching sessions to ensure prior history remains visible.
|
||||
pub(crate) fn render_transcript_once(&mut self, tui: &mut tui::Tui) {
|
||||
if !self.transcript_cells.is_empty() {
|
||||
let width = tui.terminal.last_known_screen_size.width;
|
||||
for cell in &self.transcript_cells {
|
||||
tui.insert_history_lines(cell.transcript_lines());
|
||||
tui.insert_history_lines(cell.display_lines(width));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,15 +315,16 @@ impl From<ApprovalRequest> for ApprovalRequestState {
|
||||
changes,
|
||||
} => {
|
||||
let mut header: Vec<Box<dyn Renderable>> = Vec::new();
|
||||
header.push(DiffSummary::new(changes, cwd).into());
|
||||
if let Some(reason) = reason
|
||||
&& !reason.is_empty()
|
||||
{
|
||||
header.push(Box::new(Line::from("")));
|
||||
header.push(Box::new(
|
||||
Paragraph::new(reason.italic()).wrap(Wrap { trim: false }),
|
||||
Paragraph::new(Line::from_iter(["Reason: ".into(), reason.italic()]))
|
||||
.wrap(Wrap { trim: false }),
|
||||
));
|
||||
header.push(Box::new(Line::from("")));
|
||||
}
|
||||
header.push(DiffSummary::new(changes, cwd).into());
|
||||
Self {
|
||||
variant: ApprovalVariant::ApplyPatch { id },
|
||||
header: Box::new(ColumnRenderable::new(header)),
|
||||
|
||||
@@ -38,7 +38,6 @@ use crate::bottom_pane::prompt_args::prompt_has_numeric_placeholders;
|
||||
use crate::slash_command::SlashCommand;
|
||||
use crate::slash_command::built_in_slash_commands;
|
||||
use crate::style::user_message_style;
|
||||
use crate::terminal_palette;
|
||||
use codex_protocol::custom_prompts::CustomPrompt;
|
||||
use codex_protocol::custom_prompts::PROMPTS_CMD_PREFIX;
|
||||
|
||||
@@ -1533,7 +1532,7 @@ impl WidgetRef for ChatComposer {
|
||||
}
|
||||
}
|
||||
}
|
||||
let style = user_message_style(terminal_palette::default_bg());
|
||||
let style = user_message_style();
|
||||
let mut block_rect = composer_rect;
|
||||
block_rect.y = composer_rect.y.saturating_sub(1);
|
||||
block_rect.height = composer_rect.height.saturating_add(1);
|
||||
|
||||
@@ -20,7 +20,6 @@ use crate::render::RectExt as _;
|
||||
use crate::render::renderable::ColumnRenderable;
|
||||
use crate::render::renderable::Renderable;
|
||||
use crate::style::user_message_style;
|
||||
use crate::terminal_palette;
|
||||
|
||||
use super::CancellationEvent;
|
||||
use super::bottom_pane_view::BottomPaneView;
|
||||
@@ -350,7 +349,7 @@ impl Renderable for ListSelectionView {
|
||||
.areas(area);
|
||||
|
||||
Block::default()
|
||||
.style(user_message_style(terminal_palette::default_bg()))
|
||||
.style(user_message_style())
|
||||
.render(content_area, buf);
|
||||
|
||||
let header_height = self
|
||||
|
||||
@@ -81,7 +81,7 @@ pub(crate) struct BottomPaneParams {
|
||||
}
|
||||
|
||||
impl BottomPane {
|
||||
const BOTTOM_PAD_LINES: u16 = 1;
|
||||
const BOTTOM_PAD_LINES: u16 = 0;
|
||||
pub fn new(params: BottomPaneParams) -> Self {
|
||||
let enhanced_keys_supported = params.enhanced_keys_supported;
|
||||
Self {
|
||||
@@ -522,10 +522,29 @@ impl WidgetRef for &BottomPane {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::app_event::AppEvent;
|
||||
use insta::assert_snapshot;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
use tokio::sync::mpsc::unbounded_channel;
|
||||
|
||||
fn snapshot_buffer(buf: &Buffer) -> String {
|
||||
let mut lines = Vec::new();
|
||||
for y in 0..buf.area().height {
|
||||
let mut row = String::new();
|
||||
for x in 0..buf.area().width {
|
||||
row.push(buf[(x, y)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
lines.push(row);
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn render_snapshot(pane: &BottomPane, area: Rect) -> String {
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
snapshot_buffer(&buf)
|
||||
}
|
||||
|
||||
fn exec_request() -> ApprovalRequest {
|
||||
ApprovalRequest::Exec {
|
||||
id: "1".to_string(),
|
||||
@@ -685,7 +704,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bottom_padding_present_with_status_above_composer() {
|
||||
fn status_and_composer_fill_height_without_bottom_padding() {
|
||||
let (tx_raw, _rx) = unbounded_channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
@@ -700,43 +719,21 @@ mod tests {
|
||||
// Activate spinner (status view replaces composer) with no live ring.
|
||||
pane.set_task_running(true);
|
||||
|
||||
// Use height == desired_height; expect 1 status row at top and 2 bottom padding rows.
|
||||
// Use height == desired_height; expect spacer + status + composer rows without trailing padding.
|
||||
let height = pane.desired_height(30);
|
||||
assert!(
|
||||
height >= 3,
|
||||
"expected at least 3 rows with bottom padding; got {height}"
|
||||
"expected at least 3 rows to render spacer, status, and composer; got {height}"
|
||||
);
|
||||
let area = Rect::new(0, 0, 30, height);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
// Row 1 contains the status header (row 0 is the spacer)
|
||||
let mut top = String::new();
|
||||
for x in 0..area.width {
|
||||
top.push(buf[(x, 1)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
top.trim_start().starts_with("• Working"),
|
||||
"expected top row to start with '• Working': {top:?}"
|
||||
);
|
||||
assert!(
|
||||
top.contains("Working"),
|
||||
"expected Working header on top row: {top:?}"
|
||||
);
|
||||
|
||||
// Last row should be blank padding; the row above should generally contain composer content.
|
||||
let mut r_last = String::new();
|
||||
for x in 0..area.width {
|
||||
r_last.push(buf[(x, height - 1)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
r_last.trim().is_empty(),
|
||||
"expected last row blank: {r_last:?}"
|
||||
assert_snapshot!(
|
||||
"status_and_composer_fill_height_without_bottom_padding",
|
||||
render_snapshot(&pane, area)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bottom_padding_shrinks_when_tiny() {
|
||||
fn status_hidden_when_height_too_small() {
|
||||
let (tx_raw, _rx) = unbounded_channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
@@ -750,37 +747,18 @@ mod tests {
|
||||
|
||||
pane.set_task_running(true);
|
||||
|
||||
// Height=2 → status on one row, composer on the other.
|
||||
// Height=2 → composer takes the full space; status collapses when there is no room.
|
||||
let area2 = Rect::new(0, 0, 20, 2);
|
||||
let mut buf2 = Buffer::empty(area2);
|
||||
(&pane).render_ref(area2, &mut buf2);
|
||||
let mut row0 = String::new();
|
||||
let mut row1 = String::new();
|
||||
for x in 0..area2.width {
|
||||
row0.push(buf2[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
row1.push(buf2[(x, 1)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
let has_composer = row0.contains("Ask Codex") || row1.contains("Ask Codex");
|
||||
assert!(
|
||||
has_composer,
|
||||
"expected composer to be visible on one of the rows: row0={row0:?}, row1={row1:?}"
|
||||
);
|
||||
assert!(
|
||||
row0.contains("Working") || row1.contains("Working"),
|
||||
"expected status header to be visible at height=2: row0={row0:?}, row1={row1:?}"
|
||||
assert_snapshot!(
|
||||
"status_hidden_when_height_too_small_height_2",
|
||||
render_snapshot(&pane, area2)
|
||||
);
|
||||
|
||||
// Height=1 → no padding; single row is the composer (status hidden).
|
||||
let area1 = Rect::new(0, 0, 20, 1);
|
||||
let mut buf1 = Buffer::empty(area1);
|
||||
(&pane).render_ref(area1, &mut buf1);
|
||||
let mut only = String::new();
|
||||
for x in 0..area1.width {
|
||||
only.push(buf1[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
only.contains("Ask Codex"),
|
||||
"expected composer with no padding: {only:?}"
|
||||
assert_snapshot!(
|
||||
"status_hidden_when_height_too_small_height_1",
|
||||
render_snapshot(&pane, area1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
---
|
||||
source: tui/src/bottom_pane/mod.rs
|
||||
expression: "render_snapshot(&pane, area)"
|
||||
---
|
||||
|
||||
• Working (0s • esc to interru
|
||||
|
||||
|
||||
› Ask Codex to do anything
|
||||
|
||||
? for shortcuts
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user