mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
79 Commits
initial-co
...
pr1889
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fde2b273d1 | ||
|
|
6cef86f05b | ||
|
|
8262ba58b2 | ||
|
|
081caa5a6b | ||
|
|
4344537742 | ||
|
|
64f2f2eca2 | ||
|
|
ae88b69b09 | ||
|
|
ffe24991b7 | ||
|
|
dc468d563f | ||
|
|
3e8bcf0247 | ||
|
|
cda39e417f | ||
|
|
d642b07fcc | ||
|
|
7b3ab968a0 | ||
|
|
02e7965228 | ||
|
|
493e4c9463 | ||
|
|
1f7003b476 | ||
|
|
eaf2fb5b4f | ||
|
|
f8d70d67b6 | ||
|
|
966d957faf | ||
|
|
b90c15abc4 | ||
|
|
31dcae67db | ||
|
|
725dd6be6a | ||
|
|
aff97ed7dd | ||
|
|
afa8f0d617 | ||
|
|
ea7d3f27bd | ||
|
|
f6c8d1117c | ||
|
|
42bd73e150 | ||
|
|
d365cae077 | ||
|
|
0c5fa271bc | ||
|
|
bd24bc320e | ||
|
|
9f91b3da24 | ||
|
|
9285350842 | ||
|
|
e0303dbac0 | ||
|
|
d31e149cb1 | ||
|
|
136b3ee5bf | ||
|
|
fcdb1c4b4d | ||
|
|
906d449760 | ||
|
|
063083af15 | ||
|
|
f58401e203 | ||
|
|
84bcadb8d9 | ||
|
|
e38ce39c51 | ||
|
|
1a33de34b0 | ||
|
|
bd171e5206 | ||
|
|
3f13ebce10 | ||
|
|
7279080edd | ||
|
|
89ab5c3f74 | ||
|
|
6db597ec0c | ||
|
|
2899817c94 | ||
|
|
64cfbbd3c8 | ||
|
|
a6139aa003 | ||
|
|
dc15a5cf0b | ||
|
|
1f3318c1c5 | ||
|
|
e3565a3f43 | ||
|
|
2576fadc74 | ||
|
|
78a1d49fac | ||
|
|
d62b703a21 | ||
|
|
4c9f7b6bcc | ||
|
|
75eecb656e | ||
|
|
81bb1c9e26 | ||
|
|
7e0f506da2 | ||
|
|
929ba50adc | ||
|
|
80555d4ff2 | ||
|
|
97ab8fb610 | ||
|
|
fe62f859a6 | ||
|
|
92f3566d78 | ||
|
|
f20de21cb6 | ||
|
|
bc7beddaa2 | ||
|
|
8360c6a3ec | ||
|
|
f918198bbb | ||
|
|
88ea215c80 | ||
|
|
b67c485d84 | ||
|
|
e2c994e32a | ||
|
|
ad0295b893 | ||
|
|
d3aa5f46b7 | ||
|
|
575590e4c2 | ||
|
|
4aca3e46c8 | ||
|
|
d787434aa8 | ||
|
|
ea69a1d72f | ||
|
|
610addbc2e |
4
.github/actions/codex/bun.lock
vendored
4
.github/actions/codex/bun.lock
vendored
@@ -11,7 +11,7 @@
|
||||
"@types/bun": "^1.2.19",
|
||||
"@types/node": "^24.1.0",
|
||||
"prettier": "^3.6.2",
|
||||
"typescript": "^5.8.3",
|
||||
"typescript": "^5.9.2",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -68,7 +68,7 @@
|
||||
|
||||
"tunnel": ["tunnel@0.0.6", "", {}, "sha512-1h/Lnq9yajKY2PEbBadPXj3VxsDDu844OnaAo52UVmIzIvwwtBPIuNvkjuzBlTWpfJyUbG3ez0KSBibQkj4ojg=="],
|
||||
|
||||
"typescript": ["typescript@5.8.3", "", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ=="],
|
||||
"typescript": ["typescript@5.9.2", "", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-CWBzXQrc/qOkhidw1OzBTQuYRbfyxDXJMVJ1XNwUHGROVmuaeiEm3OslpZ1RV96d7SKKjZKrSJu3+t/xlw3R9A=="],
|
||||
|
||||
"undici": ["undici@5.29.0", "", { "dependencies": { "@fastify/busboy": "^2.0.0" } }, "sha512-raqeBD6NQK4SkWhQzeYKd1KmIG6dllBOTt55Rmkt4HtI9mwdWtJljnrXjAFUBLTSN67HWrOIZ3EPF4kjUw80Bg=="],
|
||||
|
||||
|
||||
2
.github/actions/codex/package.json
vendored
2
.github/actions/codex/package.json
vendored
@@ -16,6 +16,6 @@
|
||||
"@types/bun": "^1.2.19",
|
||||
"@types/node": "^24.1.0",
|
||||
"prettier": "^3.6.2",
|
||||
"typescript": "^5.8.3"
|
||||
"typescript": "^5.9.2"
|
||||
}
|
||||
}
|
||||
|
||||
33
.github/actions/codex/src/process-label.ts
vendored
33
.github/actions/codex/src/process-label.ts
vendored
@@ -91,7 +91,38 @@ async function processLabel(
|
||||
labelConfig: LabelConfig,
|
||||
): Promise<void> {
|
||||
const template = labelConfig.getPromptTemplate();
|
||||
const populatedTemplate = await renderPromptTemplate(template, ctx);
|
||||
|
||||
// If this is a review label, prepend explicit PR-diff scoping guidance to
|
||||
// reduce out-of-scope feedback. Do this before rendering so placeholders in
|
||||
// the guidance (e.g., {CODEX_ACTION_GITHUB_EVENT_PATH}) are substituted.
|
||||
const isReview = label.toLowerCase().includes("review");
|
||||
const reviewScopeGuidance = `
|
||||
PR Diff Scope
|
||||
- Only review changes between the PR's merge-base and head; do not comment on commits or files outside this range.
|
||||
- Derive the base/head SHAs from the event JSON at {CODEX_ACTION_GITHUB_EVENT_PATH}, then compute and use the PR diff for all analysis and comments.
|
||||
|
||||
Commands to determine scope
|
||||
- Resolve SHAs:
|
||||
- BASE_SHA=$(jq -r '.pull_request.base.sha // .pull_request.base.ref' "{CODEX_ACTION_GITHUB_EVENT_PATH}")
|
||||
- HEAD_SHA=$(jq -r '.pull_request.head.sha // .pull_request.head.ref' "{CODEX_ACTION_GITHUB_EVENT_PATH}")
|
||||
- BASE_SHA=$(git rev-parse "$BASE_SHA")
|
||||
- HEAD_SHA=$(git rev-parse "$HEAD_SHA")
|
||||
- Prefer triple-dot (merge-base) semantics for PR diffs:
|
||||
- Changed commits: git log --oneline "$BASE_SHA...$HEAD_SHA"
|
||||
- Changed files: git diff --name-status "$BASE_SHA...$HEAD_SHA"
|
||||
- Review hunks: git diff -U0 "$BASE_SHA...$HEAD_SHA"
|
||||
|
||||
Review rules
|
||||
- Anchor every comment to a file and hunk present in git diff "$BASE_SHA...$HEAD_SHA".
|
||||
- If you mention context outside the diff, label it as "Follow-up (outside this PR scope)" and keep it brief (<=2 bullets).
|
||||
- Do not critique commits or files not reachable in the PR range (merge-base(base, head) → head).
|
||||
`.trim();
|
||||
|
||||
const effectiveTemplate = isReview
|
||||
? `${reviewScopeGuidance}\n\n${template}`
|
||||
: template;
|
||||
|
||||
const populatedTemplate = await renderPromptTemplate(effectiveTemplate, ctx);
|
||||
|
||||
// Always run Codex and post the resulting message as a comment.
|
||||
let commentBody = await runCodex(populatedTemplate, ctx);
|
||||
|
||||
128
.github/codex/labels/codex-rust-review.md
vendored
128
.github/codex/labels/codex-rust-review.md
vendored
@@ -6,18 +6,134 @@ Then provide the **review** (1-2 sentences plus bullet points, friendly tone).
|
||||
|
||||
Things to look out for when doing the review:
|
||||
|
||||
## General Principles
|
||||
|
||||
- **Make sure the pull request body explains the motivation behind the change.** If the author has failed to do this, call it out, and if you think you can deduce the motivation behind the change, propose copy.
|
||||
- Ideally, the PR body also contains a small summary of the change. For small changes, the PR title may be sufficient.
|
||||
- Each PR should ideally do one conceptual thing. For example, if a PR does a refactoring as well as introducing a new feature, push back and suggest the refactoring be done in a separate PR. This makes things easier for the reviewer, as refactoring changes can often be far-reaching, yet quick to review.
|
||||
- If the nature of the change seems to have a visual component (which is often the case for changes to `codex-rs/tui`), recommend including a screenshot or video to demonstrate the change, if appropriate.
|
||||
- Rust files should generally be organized such that the public parts of the API appear near the top of the file and helper functions go below. This is analagous to the "inverted pyramid" structure that is favored in journalism.
|
||||
- Encourage the use of small enums or the newtype pattern in Rust if it helps readability without adding significant cognitive load or lines of code.
|
||||
- Be wary of large files and offer suggestions for how to break things into more reasonably-sized files.
|
||||
- When modifying a `Cargo.toml` file, make sure that dependency lists stay alphabetically sorted. Also consider whether a new dependency is added to the appropriate place (e.g., `[dependencies]` versus `[dev-dependencies]`)
|
||||
- If you see opportunities for the changes in a diff to use more idiomatic Rust, please make specific recommendations. For example, favor the use of expressions over `return`.
|
||||
- When introducing new code, be on the lookout for code that duplicates existing code. When found, propose a way to refactor the existing code such that it should be reused.
|
||||
|
||||
## Code Organization
|
||||
|
||||
- Each create in the Cargo workspace in `codex-rs` has a specific purpose: make a note if you believe new code is not introduced in the correct crate.
|
||||
- When possible, try to keep the `core` crate as small as possible. Non-core but shared logic is often a good candidate for `codex-rs/common`.
|
||||
- Be wary of large files and offer suggestions for how to break things into more reasonably-sized files.
|
||||
- Rust files should generally be organized such that the public parts of the API appear near the top of the file and helper functions go below. This is analagous to the "inverted pyramid" structure that is favored in journalism.
|
||||
|
||||
## Assertions in Tests
|
||||
|
||||
Assert the equality of the entire objects instead of doing "piecemeal comparisons," performing `assert_eq!()` on individual fields.
|
||||
|
||||
Note that unit tests also function as "executable documentation." As shown in the following example, "piecemeal comparisons" are often more verbose, provide less coverage, and are not as useful as executable documentation.
|
||||
|
||||
For example, suppose you have the following enum:
|
||||
|
||||
```rust
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Message {
|
||||
Request {
|
||||
id: String,
|
||||
method: String,
|
||||
params: Option<serde_json::Value>,
|
||||
},
|
||||
Notification {
|
||||
method: String,
|
||||
params: Option<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
This is an example of a _piecemeal_ comparison:
|
||||
|
||||
```rust
|
||||
// BAD: Piecemeal Comparison
|
||||
|
||||
#[test]
|
||||
fn test_get_latest_messages() {
|
||||
let messages = get_latest_messages();
|
||||
assert_eq!(messages.len(), 2);
|
||||
|
||||
let m0 = &messages[0];
|
||||
match m0 {
|
||||
Message::Request { id, method, params } => {
|
||||
assert_eq!(id, "123");
|
||||
assert_eq!(method, "subscribe");
|
||||
assert_eq!(
|
||||
*params,
|
||||
Some(json!({
|
||||
"conversation_id": "x42z86"
|
||||
}))
|
||||
)
|
||||
}
|
||||
Message::Notification { .. } => {
|
||||
panic!("expected Request");
|
||||
}
|
||||
}
|
||||
|
||||
let m1 = &messages[1];
|
||||
match m1 {
|
||||
Message::Request { .. } => {
|
||||
panic!("expected Notification");
|
||||
}
|
||||
Message::Notification { method, params } => {
|
||||
assert_eq!(method, "log");
|
||||
assert_eq!(
|
||||
*params,
|
||||
Some(json!({
|
||||
"level": "info",
|
||||
"message": "subscribed"
|
||||
}))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This is a _deep_ comparison:
|
||||
|
||||
```rust
|
||||
// GOOD: Verify the entire structure with a single assert_eq!().
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_get_latest_messages() {
|
||||
let messages = get_latest_messages();
|
||||
|
||||
assert_eq!(
|
||||
vec![
|
||||
Message::Request {
|
||||
id: "123".to_string(),
|
||||
method: "subscribe".to_string(),
|
||||
params: Some(json!({
|
||||
"conversation_id": "x42z86"
|
||||
})),
|
||||
},
|
||||
Message::Notification {
|
||||
method: "log".to_string(),
|
||||
params: Some(json!({
|
||||
"level": "info",
|
||||
"message": "subscribed"
|
||||
})),
|
||||
},
|
||||
],
|
||||
messages,
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## More Tactical Rust Things To Look Out For
|
||||
|
||||
- Do not use `unsafe` (unless you have a really, really good reason like using an operating system API directly and no safe wrapper exists). For example, there are cases where it is tempting to use `unsafe` in order to use `std::env::set_var()`, but this indeed `unsafe` and has led to race conditions on multiple occasions. (When this happens, find a mechanism other than environment variables to use for configuration.)
|
||||
- Encourage the use of small enums or the newtype pattern in Rust if it helps readability without adding significant cognitive load or lines of code.
|
||||
- If you see opportunities for the changes in a diff to use more idiomatic Rust, please make specific recommendations. For example, favor the use of expressions over `return`.
|
||||
- When modifying a `Cargo.toml` file, make sure that dependency lists stay alphabetically sorted. Also consider whether a new dependency is added to the appropriate place (e.g., `[dependencies]` versus `[dev-dependencies]`)
|
||||
|
||||
## Pull Request Body
|
||||
|
||||
- If the nature of the change seems to have a visual component (which is often the case for changes to `codex-rs/tui`), recommend including a screenshot or video to demonstrate the change, if appropriate.
|
||||
- References to existing GitHub issues and PRs are encouraged, where appropriate, though you likely do not have network access, so may not be able to help here.
|
||||
|
||||
# PR Information
|
||||
|
||||
{CODEX_ACTION_GITHUB_EVENT_PATH} contains the JSON that triggered this GitHub workflow. It contains the `base` and `head` refs that define this PR. Both refs are available locally.
|
||||
|
||||
6
.github/workflows/rust-release.yml
vendored
6
.github/workflows/rust-release.yml
vendored
@@ -181,9 +181,9 @@ jobs:
|
||||
name: ${{ steps.release_name.outputs.name }}
|
||||
tag_name: ${{ github.ref_name }}
|
||||
files: dist/**
|
||||
# For now, tag releases as "prerelease" because we are not claiming
|
||||
# the Rust CLI is stable yet.
|
||||
prerelease: true
|
||||
# Mark as prerelease only when the version has a suffix after x.y.z
|
||||
# (e.g. -alpha, -beta). Otherwise publish a normal release.
|
||||
prerelease: ${{ contains(steps.release_name.outputs.name, '-') }}
|
||||
|
||||
- uses: facebook/dotslash-publish-release@v2
|
||||
env:
|
||||
|
||||
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@@ -11,6 +11,8 @@
|
||||
"editor.defaultFormatter": "tamasfe.even-better-toml",
|
||||
"editor.formatOnSave": true,
|
||||
},
|
||||
"evenBetterToml.formatter.reorderArrays": true,
|
||||
// Array order for options in ~/.codex/config.toml such as `notify` and the
|
||||
// `args` for an MCP server is significant, so we disable reordering.
|
||||
"evenBetterToml.formatter.reorderArrays": false,
|
||||
"evenBetterToml.formatter.reorderKeys": true,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
In the codex-rs folder where the rust code lives:
|
||||
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`. You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` or `CODEX_SANDBOX_ENV_VAR`.
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
|
||||
Before creating a pull request with changes to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix` (in `codex-rs` directory) to fix any linter issues in the code, ensure the test suite passes by running `cargo test --all-features` in the `codex-rs` directory.
|
||||
|
||||
|
||||
36
README.md
36
README.md
@@ -18,6 +18,7 @@ This is the home of the **Codex CLI**, which is a coding agent from OpenAI that
|
||||
- [Quickstart](#quickstart)
|
||||
- [OpenAI API Users](#openai-api-users)
|
||||
- [OpenAI Plus/Pro Users](#openai-pluspro-users)
|
||||
- [Using Open Source Models](#using-open-source-models)
|
||||
- [Why Codex?](#why-codex)
|
||||
- [Security model & permissions](#security-model--permissions)
|
||||
- [Platform sandboxing details](#platform-sandboxing-details)
|
||||
@@ -186,6 +187,41 @@ they'll be committed to your working directory.
|
||||
|
||||
---
|
||||
|
||||
## Using Open Source Models
|
||||
|
||||
Codex can run fully locally against an OpenAI-compatible OSS host (like Ollama) using the `--oss` flag:
|
||||
|
||||
- Interactive UI:
|
||||
- codex --oss
|
||||
- Non-interactive (programmatic) mode:
|
||||
- echo "Refactor utils" | codex exec --oss
|
||||
|
||||
Model selection when using `--oss`:
|
||||
|
||||
- If you omit `-m/--model`, Codex defaults to -m gpt-oss:20b and will verify it exists locally (downloading if needed).
|
||||
- To pick a different size, pass one of:
|
||||
- -m "gpt-oss:20b"
|
||||
- -m "gpt-oss:120b"
|
||||
|
||||
Point Codex at your own OSS host:
|
||||
|
||||
- By default, `--oss` talks to http://localhost:11434/v1.
|
||||
- To use a different host, set one of these environment variables before running Codex:
|
||||
- CODEX_OSS_BASE_URL, for example:
|
||||
- CODEX_OSS_BASE_URL="http://my-ollama.example.com:11434/v1" codex --oss -m gpt-oss:20b
|
||||
- or CODEX_OSS_PORT (when the host is localhost):
|
||||
- CODEX_OSS_PORT=11434 codex --oss
|
||||
|
||||
Advanced: you can persist this in your config instead of environment variables by overriding the built-in `oss` provider in `~/.codex/config.toml`:
|
||||
|
||||
```toml
|
||||
[model_providers.oss]
|
||||
name = "Open Source"
|
||||
base_url = "http://my-ollama.example.com:11434/v1"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Codex?
|
||||
|
||||
Codex CLI is built for developers who already **live in the terminal** and want
|
||||
|
||||
21
SUMMARY.md
Normal file
21
SUMMARY.md
Normal file
@@ -0,0 +1,21 @@
|
||||
You are a summarization assistant. A conversation follows between a user and a coding-focused AI (Codex). Your task is to generate a clear summary capturing:
|
||||
|
||||
• High-level objective or problem being solved
|
||||
• Key instructions or design decisions given by the user
|
||||
• Main code actions or behaviors from the AI
|
||||
• Important variables, functions, modules, or outputs discussed
|
||||
• Any unresolved questions or next steps
|
||||
|
||||
Produce the summary in a structured format like:
|
||||
|
||||
**Objective:** …
|
||||
|
||||
**User instructions:** … (bulleted)
|
||||
|
||||
**AI actions / code behavior:** … (bulleted)
|
||||
|
||||
**Important entities:** … (e.g. function names, variables, files)
|
||||
|
||||
**Open issues / next steps:** … (if any)
|
||||
|
||||
**Summary (concise):** (one or two sentences)
|
||||
@@ -83,6 +83,7 @@ if (wantsNative && process.platform !== 'win32') {
|
||||
|
||||
const child = spawn(binaryPath, process.argv.slice(2), {
|
||||
stdio: "inherit",
|
||||
env: { ...process.env, CODEX_MANAGED_BY_NPM: "1" },
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
|
||||
@@ -147,4 +147,8 @@ const READ_ONLY_SEATBELT_POLICY = `
|
||||
(sysctl-name "kern.version")
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
)`.trim();
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)`.trim();
|
||||
|
||||
166
codex-rs/Cargo.lock
generated
166
codex-rs/Cargo.lock
generated
@@ -661,7 +661,7 @@ dependencies = [
|
||||
"clap",
|
||||
"codex-core",
|
||||
"serde",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -695,9 +695,11 @@ dependencies = [
|
||||
"reqwest",
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"shlex",
|
||||
"similar",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
@@ -705,7 +707,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.4",
|
||||
"tracing",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
@@ -727,6 +729,7 @@ dependencies = [
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-ollama",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"serde_json",
|
||||
@@ -829,19 +832,37 @@ dependencies = [
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.4",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-ollama"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"codex-core",
|
||||
"futures",
|
||||
"reqwest",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"toml 0.9.4",
|
||||
"tracing",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-tui"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-ansi-escape",
|
||||
"codex-arg0",
|
||||
@@ -849,6 +870,7 @@ dependencies = [
|
||||
"codex-core",
|
||||
"codex-file-search",
|
||||
"codex-login",
|
||||
"codex-ollama",
|
||||
"color-eyre",
|
||||
"crossterm",
|
||||
"image",
|
||||
@@ -857,23 +879,28 @@ dependencies = [
|
||||
"mcp-types",
|
||||
"path-clean",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
"ratatui",
|
||||
"ratatui-image",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"strum 0.27.2",
|
||||
"strum_macros 0.27.2",
|
||||
"supports-color",
|
||||
"textwrap 0.16.2",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tui-input",
|
||||
"tui-markdown",
|
||||
"tui-textarea",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
"uuid",
|
||||
"vt100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1466,7 +1493,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1546,7 +1573,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rustix 1.0.8",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1749,7 +1776,7 @@ version = "0.2.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cba6ae63eb948698e300f645f87c70f76630d505f23b8907cf1e193ee85048c1"
|
||||
dependencies = [
|
||||
"unicode-width 0.2.0",
|
||||
"unicode-width 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2010,7 +2037,7 @@ dependencies = [
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.0",
|
||||
"socket2",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
@@ -2329,9 +2356,15 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_ci"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7655c9839580ee829dfacba1d1278c2b7883e50a277ff7541299489d6bdfdc45"
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
@@ -2643,6 +2676,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"codex-mcp-server",
|
||||
"mcp-types",
|
||||
"pretty_assertions",
|
||||
@@ -2650,6 +2684,7 @@ dependencies = [
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
@@ -3377,7 +3412,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.29.0"
|
||||
source = "git+https://github.com/nornagon/ratatui?branch=nornagon-v0.29.0-patch#bca287ddc5d38fe088c79e2eda22422b96226f2e"
|
||||
source = "git+https://github.com/nornagon/ratatui?branch=nornagon-v0.29.0-patch#9b2ad1298408c45918ee9f8241a6f95498cdbed2"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"cassowary",
|
||||
@@ -3391,7 +3426,7 @@ dependencies = [
|
||||
"strum 0.26.3",
|
||||
"unicode-segmentation",
|
||||
"unicode-truncate",
|
||||
"unicode-width 0.2.0",
|
||||
"unicode-width 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3705,7 +3740,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3718,7 +3753,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.9.4",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3949,6 +3984,15 @@ dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8437fd221bde2d4ca316d61b90e337e9e702b3820b87d63caa9ba6c02bd06d96"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
@@ -3973,9 +4017,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.141"
|
||||
version = "1.0.142"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3"
|
||||
checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"itoa",
|
||||
@@ -4159,14 +4203,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.10"
|
||||
name = "smawk"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
@@ -4220,7 +4260,7 @@ dependencies = [
|
||||
"starlark_syntax",
|
||||
"static_assertions",
|
||||
"strsim 0.10.0",
|
||||
"textwrap",
|
||||
"textwrap 0.11.0",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
@@ -4356,6 +4396,15 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||
|
||||
[[package]]
|
||||
name = "supports-color"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c64fc7232dd8d2e4ac5ce4ef302b1d81e0b80d055b9d77c7c4f51f6aa4c867d6"
|
||||
dependencies = [
|
||||
"is_ci",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
@@ -4470,7 +4519,7 @@ dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"once_cell",
|
||||
"rustix 1.0.8",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4509,6 +4558,17 @@ dependencies = [
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
|
||||
dependencies = [
|
||||
"smawk",
|
||||
"unicode-linebreak",
|
||||
"unicode-width 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@@ -4623,9 +4683,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.46.1"
|
||||
version = "1.47.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17"
|
||||
checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytes",
|
||||
@@ -4636,9 +4696,9 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"slab",
|
||||
"socket2 0.5.10",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4723,9 +4783,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.9.2"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed0aee96c12fa71097902e0bb061a5e1ebd766a6636bb605ba401c45c1650eac"
|
||||
checksum = "41ae868b5a0f67631c14589f7e250c1ea2c574ee5ba21c6c8dd4b1485705a5a1"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"serde",
|
||||
@@ -4954,7 +5014,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "911e93158bf80bbc94bad533b2b16e3d711e1132d69a6a6980c3920a63422c19"
|
||||
dependencies = [
|
||||
"ratatui",
|
||||
"unicode-width 0.2.0",
|
||||
"unicode-width 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4973,17 +5033,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tui-textarea"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a5318dd619ed73c52a9417ad19046724effc1287fb75cdcc4eca1d6ac1acbae"
|
||||
dependencies = [
|
||||
"crossterm",
|
||||
"ratatui",
|
||||
"unicode-width 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
@@ -5002,6 +5051,12 @@ version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-linebreak"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
@@ -5027,9 +5082,9 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
||||
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
@@ -5114,6 +5169,27 @@ version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "vt100"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "054ff75fb8fa83e609e685106df4faeffdf3a735d3c74ebce97ec557d5d36fd9"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"unicode-width 0.2.1",
|
||||
"vte",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vte"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5924018406ce0063cd67f8e008104968b74b563ee1b85dde3ed1f7cb87d3dbd"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wait-timeout"
|
||||
version = "0.2.1"
|
||||
@@ -5302,7 +5378,7 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -14,6 +14,7 @@ members = [
|
||||
"mcp-client",
|
||||
"mcp-server",
|
||||
"mcp-types",
|
||||
"ollama",
|
||||
"tui",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -18,6 +18,9 @@ pub enum ApprovalModeCliArg {
|
||||
/// will escalate to the user to ask for un-sandboxed execution.
|
||||
OnFailure,
|
||||
|
||||
/// The model decides when to ask the user for approval.
|
||||
OnRequest,
|
||||
|
||||
/// Never ask for user approval
|
||||
/// Execution failures are immediately returned to the model.
|
||||
Never,
|
||||
@@ -28,6 +31,7 @@ impl From<ApprovalModeCliArg> for AskForApproval {
|
||||
match value {
|
||||
ApprovalModeCliArg::Untrusted => AskForApproval::UnlessTrusted,
|
||||
ApprovalModeCliArg::OnFailure => AskForApproval::OnFailure,
|
||||
ApprovalModeCliArg::OnRequest => AskForApproval::OnRequest,
|
||||
ApprovalModeCliArg::Never => AskForApproval::Never,
|
||||
}
|
||||
}
|
||||
|
||||
29
codex-rs/common/src/config_summary.rs
Normal file
29
codex-rs/common/src/config_summary.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
|
||||
use crate::sandbox_summary::summarize_sandbox_policy;
|
||||
|
||||
/// Build a list of key/value pairs summarizing the effective configuration.
|
||||
pub fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", config.approval_policy.to_string()),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& config.model_family.supports_reasoning_summaries
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
@@ -23,3 +23,7 @@ mod sandbox_summary;
|
||||
|
||||
#[cfg(feature = "sandbox_summary")]
|
||||
pub use sandbox_summary::summarize_sandbox_policy;
|
||||
|
||||
mod config_summary;
|
||||
|
||||
pub use config_summary::create_config_summary_entries;
|
||||
|
||||
@@ -7,6 +7,7 @@ pub fn summarize_sandbox_policy(sandbox_policy: &SandboxPolicy) -> String {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots,
|
||||
network_access,
|
||||
include_default_writable_roots,
|
||||
} => {
|
||||
let mut summary = "workspace-write".to_string();
|
||||
if !writable_roots.is_empty() {
|
||||
@@ -19,6 +20,9 @@ pub fn summarize_sandbox_policy(sandbox_policy: &SandboxPolicy) -> String {
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
if !*include_default_writable_roots {
|
||||
summary.push_str(" (exact writable roots)");
|
||||
}
|
||||
if *network_access {
|
||||
summary.push_str(" (network access enabled)");
|
||||
}
|
||||
|
||||
@@ -148,12 +148,20 @@ Determines when the user should be prompted to approve whether Codex can execute
|
||||
approval_policy = "untrusted"
|
||||
```
|
||||
|
||||
If you want to be notified whenever a command fails, use "on-failure":
|
||||
```toml
|
||||
# If the command fails when run in the sandbox, Codex asks for permission to
|
||||
# retry the command outside the sandbox.
|
||||
approval_policy = "on-failure"
|
||||
```
|
||||
|
||||
If you want the model to run until it decides that it needs to ask you for escalated permissions, use "on-request":
|
||||
```toml
|
||||
# The model decides when to escalate
|
||||
approval_policy = "on-request"
|
||||
```
|
||||
|
||||
Alternatively, you can have the model run until it is done, and never ask to run a command with escalated permissions:
|
||||
```toml
|
||||
# User is never prompted: if the command fails, Codex will automatically try
|
||||
# something out. Note the `exec` subcommand always uses this mode.
|
||||
@@ -259,6 +267,8 @@ disk, but attempts to write a file or access the network will be blocked.
|
||||
|
||||
A more relaxed policy is `workspace-write`. When specified, the current working directory for the Codex task will be writable (as well as `$TMPDIR` on macOS). Note that the CLI defaults to using the directory where it was spawned as `cwd`, though this can be overridden using `--cwd/-C`.
|
||||
|
||||
On macOS (and soon Linux), all writable roots (including `cwd`) that contain a `.git/` folder _as an immediate child_ will configure the `.git/` folder to be read-only while the rest of the Git repository will be writable. This means that commands like `git commit` will fail, by default (as it entails writing to `.git/`), and will require Codex to ask for permission.
|
||||
|
||||
```toml
|
||||
# same as `--sandbox workspace-write`
|
||||
sandbox_mode = "workspace-write"
|
||||
@@ -481,6 +491,19 @@ Setting `hide_agent_reasoning` to `true` suppresses these events in **both** the
|
||||
hide_agent_reasoning = true # defaults to false
|
||||
```
|
||||
|
||||
## show_raw_agent_reasoning
|
||||
|
||||
Surfaces the model’s raw chain-of-thought ("raw reasoning content") when available.
|
||||
|
||||
Notes:
|
||||
- Only takes effect if the selected model/provider actually emits raw reasoning content. Many models do not. When unsupported, this option has no visible effect.
|
||||
- Raw reasoning may include intermediate thoughts or sensitive context. Enable only if acceptable for your workflow.
|
||||
|
||||
Example:
|
||||
```toml
|
||||
show_raw_agent_reasoning = true # defaults to false
|
||||
```
|
||||
|
||||
## model_context_window
|
||||
|
||||
The size of the context window for the model, in tokens.
|
||||
|
||||
@@ -31,8 +31,10 @@ rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_bytes = "0.11"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
strum_macros = "0.27.2"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
@@ -44,7 +46,7 @@ tokio = { version = "1", features = [
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = "0.7.14"
|
||||
toml = "0.9.2"
|
||||
toml = "0.9.4"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
Please resolve the user's task by editing and testing the code files in your current code execution session.
|
||||
You are a deployed coding agent.
|
||||
Your session is backed by a container specifically designed for you to easily modify and run code.
|
||||
The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct.
|
||||
You are operating as and within the Codex CLI, an open-source, terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful.
|
||||
|
||||
Your capabilities:
|
||||
- Receive user prompts, project context, and files.
|
||||
- Stream responses and emit function calls (e.g., shell commands, code edits).
|
||||
- Run commands, like apply_patch, and manage user approvals based on policy.
|
||||
- Work inside a workspace with sandboxing instructions specified by the policy described in (## Sandbox environment and approval instructions)
|
||||
|
||||
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
|
||||
|
||||
## General guidelines
|
||||
As a deployed coding agent, please continue working on the user's task until their query is resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the task is solved. If you are not sure about file content or codebase structure pertaining to the user's request, use your tools to read files and gather the relevant information. Do NOT guess or make up an answer.
|
||||
|
||||
After a user sends their first message, you should immediately provide a brief message acknowledging their request to set the tone and expectation of future work to be done (no more than 8-10 words). This should be done before performing work like exploring the codebase, writing or reading files, or other tool calls needed to complete the task. Use a natural, collaborative tone similar to how a teammate would receive a task during a pair programming session.
|
||||
|
||||
Please resolve the user's task by editing the code files in your current code execution session. Your session allows for you to modify and run code. The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct.
|
||||
|
||||
### Task execution
|
||||
You MUST adhere to the following criteria when executing the task:
|
||||
|
||||
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
|
||||
- Analyzing code for vulnerabilities is allowed.
|
||||
- Showing user code and tool call details is allowed.
|
||||
- User instructions may overwrite the _CODING GUIDELINES_ section in this developer message.
|
||||
- `user_instructions` are not part of the user's request, but guidance for how to complete the task.
|
||||
- Do not cite `user_instructions` back to the user unless a specific piece is relevant.
|
||||
- Do not use \`ls -R\`, \`find\`, or \`grep\` - these are slow in large repos. Use \`rg\` and \`rg --files\`.
|
||||
- Use \`apply_patch\` to edit files: {"cmd":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]}
|
||||
- Use the \`apply_patch\` shell command to edit files: {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]}
|
||||
- If completing the user's task requires writing or modifying files:
|
||||
- Your code and final answer should follow these _CODING GUIDELINES_:
|
||||
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
|
||||
@@ -33,23 +48,22 @@ You MUST adhere to the following criteria when executing the task:
|
||||
- If completing the user's task DOES NOT require writing or modifying files (e.g., the user asks a question about the code base):
|
||||
- Respond in a friendly tune as a remote teammate, who is knowledgeable, capable and eager to help with coding.
|
||||
- When your task involves writing or modifying files:
|
||||
- Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using \`apply_patch\`. Instead, reference the file as already saved.
|
||||
- Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using the `apply_patch` shell command. Instead, reference the file as already saved.
|
||||
- Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them.
|
||||
|
||||
§ `apply-patch` Specification
|
||||
## Using the shell command `apply_patch` to edit files
|
||||
`apply_patch` is a shell command for editing files. Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
||||
|
||||
Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
||||
|
||||
**_ Begin Patch
|
||||
*** Begin Patch
|
||||
[ one or more file sections ]
|
||||
_** End Patch
|
||||
*** End Patch
|
||||
|
||||
Within that envelope, you get a sequence of file operations.
|
||||
You MUST include a header to specify the action you are taking.
|
||||
Each operation starts with one of three headers:
|
||||
|
||||
**_ Add File: <path> - create a new file. Every following line is a + line (the initial contents).
|
||||
_** Delete File: <path> - remove an existing file. Nothing follows.
|
||||
*** Add File: <path> - create a new file. Every following line is a + line (the initial contents).
|
||||
*** Delete File: <path> - remove an existing file. Nothing follows.
|
||||
\*\*\* Update File: <path> - patch an existing file in place (optionally with a rename).
|
||||
|
||||
May be immediately followed by \*\*\* Move to: <new path> if you want to rename the file.
|
||||
@@ -63,45 +77,60 @@ Within a hunk each line starts with:
|
||||
At the end of a truncated hunk you can emit \*\*\* End of File.
|
||||
|
||||
Patch := Begin { FileOp } End
|
||||
Begin := "**_ Begin Patch" NEWLINE
|
||||
End := "_** End Patch" NEWLINE
|
||||
Begin := "*** Begin Patch" NEWLINE
|
||||
End := "*** End Patch" NEWLINE
|
||||
FileOp := AddFile | DeleteFile | UpdateFile
|
||||
AddFile := "**_ Add File: " path NEWLINE { "+" line NEWLINE }
|
||||
DeleteFile := "_** Delete File: " path NEWLINE
|
||||
UpdateFile := "**_ Update File: " path NEWLINE [ MoveTo ] { Hunk }
|
||||
MoveTo := "_** Move to: " newPath NEWLINE
|
||||
AddFile := "*** Add File: " path NEWLINE { "+" line NEWLINE }
|
||||
DeleteFile := "*** Delete File: " path NEWLINE
|
||||
UpdateFile := "*** Update File: " path NEWLINE [ MoveTo ] { Hunk }
|
||||
MoveTo := "*** Move to: " newPath NEWLINE
|
||||
Hunk := "@@" [ header ] NEWLINE { HunkLine } [ "*** End of File" NEWLINE ]
|
||||
HunkLine := (" " | "-" | "+") text NEWLINE
|
||||
|
||||
A full patch can combine several operations:
|
||||
|
||||
**_ Begin Patch
|
||||
_** Add File: hello.txt
|
||||
*** Begin Patch
|
||||
*** Add File: hello.txt
|
||||
+Hello world
|
||||
**_ Update File: src/app.py
|
||||
_** Move to: src/main.py
|
||||
*** Update File: src/app.py
|
||||
*** Move to: src/main.py
|
||||
@@ def greet():
|
||||
-print("Hi")
|
||||
+print("Hello, world!")
|
||||
**_ Delete File: obsolete.txt
|
||||
_** End Patch
|
||||
*** Delete File: obsolete.txt
|
||||
*** End Patch
|
||||
|
||||
It is important to remember:
|
||||
|
||||
- You must include a header with your intended action (Add/Delete/Update)
|
||||
- You must prefix new lines with `+` even when creating a new file
|
||||
- You must follow this schema exactly when providing a patch
|
||||
|
||||
You can invoke apply_patch like:
|
||||
You can invoke apply_patch with the following shell command:
|
||||
|
||||
```
|
||||
shell {"command":["apply_patch","*** Begin Patch\n*** Add File: hello.txt\n+Hello, world!\n*** End Patch\n"]}
|
||||
```
|
||||
|
||||
Plan updates
|
||||
## Sandbox environment and approval instructions
|
||||
|
||||
You are running in a sandboxed workspace backed by version control. The sandbox might be configured by the user to restrict certain behaviors, like accessing the internet or writing to files outside the current directory.
|
||||
|
||||
Commands that are blocked by sandbox settings will be automatically sent to the user for approval. The result of the request will be returned (i.e. the command result, or the request denial).
|
||||
The user also has an opportunity to approve the same command for the rest of the session.
|
||||
|
||||
Guidance on running within the sandbox:
|
||||
- When running commands that will likely require approval, attempt to use simple, precise commands, to reduce frequency of approval requests.
|
||||
- When approval is denied or a command fails due to a permission error, do not retry the exact command in a different way. Move on and continue trying to address the user's request.
|
||||
|
||||
|
||||
## Tools available
|
||||
### Plan updates
|
||||
|
||||
A tool named `update_plan` is available. Use it to keep an up‑to‑date, step‑by‑step plan for the task so you can follow your progress. When making your plans, keep in mind that you are a deployed coding agent - `update_plan` calls should not involve doing anything that you aren't capable of doing. For example, `update_plan` calls should NEVER contain tasks to merge your own pull requests. Only stop to ask the user if you genuinely need their feedback on a change.
|
||||
|
||||
- At the start of the task, call `update_plan` with an initial plan: a short list of 1‑sentence steps with a `status` for each step (`pending`, `in_progress`, or `completed`). There should always be exactly one `in_progress` step until everything is done.
|
||||
- At the start of any nontrivial task, call `update_plan` with an initial plan: a short list of 1‑sentence steps with a `status` for each step (`pending`, `in_progress`, or `completed`). There should always be exactly one `in_progress` step until everything is done.
|
||||
- Whenever you finish a step, call `update_plan` again, marking the finished step as `completed` and the next step as `in_progress`.
|
||||
- If your plan needs to change, call `update_plan` with the revised steps and include an `explanation` describing the change.
|
||||
- When all steps are complete, make a final `update_plan` call with all steps marked `completed`.
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ReasoningItemContent;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
@@ -29,22 +31,19 @@ use crate::util::backoff;
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
include_plan_tool: bool,
|
||||
model_family: &ModelFamily,
|
||||
client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> Result<ResponseStream> {
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model);
|
||||
let full_instructions = prompt.get_full_instructions(model_family);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
if let Some(instr) = &prompt.user_instructions {
|
||||
messages.push(json!({"role": "user", "content": instr}));
|
||||
}
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
for item in &prompt.input {
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
@@ -110,9 +109,9 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(prompt, model, include_plan_tool)?;
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": model,
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
@@ -120,7 +119,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
debug!(
|
||||
"POST to {}: {}",
|
||||
provider.get_full_url(),
|
||||
provider.get_full_url(&None),
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
@@ -129,7 +128,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let req_builder = provider.create_request_builder(client)?;
|
||||
let req_builder = provider.create_request_builder(client, &None).await?;
|
||||
|
||||
let res = req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
@@ -207,6 +206,8 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
let mut assistant_text = String::new();
|
||||
let mut reasoning_text = String::new();
|
||||
|
||||
loop {
|
||||
let sse = match timeout(idle_timeout, stream.next()).await {
|
||||
@@ -235,6 +236,31 @@ async fn process_chat_sse<S>(
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
// Emit any finalized items before closing so downstream consumers receive
|
||||
// terminal events for both assistant content and raw reasoning.
|
||||
if !assistant_text.is_empty() {
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut assistant_text),
|
||||
}],
|
||||
id: None,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
if !reasoning_text.is_empty() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut reasoning_text),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
@@ -254,21 +280,47 @@ async fn process_chat_sse<S>(
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens.
|
||||
// Handle assistant content tokens as streaming deltas.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
id: None,
|
||||
};
|
||||
if !content.is_empty() {
|
||||
assistant_text.push_str(content);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(content.to_string())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
// Forward any reasoning/thinking deltas if present.
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
@@ -305,7 +357,21 @@ async fn process_chat_sse<S>(
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// Build the FunctionCall response item.
|
||||
// First, flush the terminal raw reasoning so UIs can finalize
|
||||
// the reasoning stream before any exec/tool events begin.
|
||||
if !reasoning_text.is_empty() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut reasoning_text),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Then emit the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
@@ -313,11 +379,33 @@ async fn process_chat_sse<S>(
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
// Emit it downstream.
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call.
|
||||
// Regular turn without tool-call. Emit the final assistant message
|
||||
// as a single OutputItemDone so non-delta consumers see the result.
|
||||
if !assistant_text.is_empty() {
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut assistant_text),
|
||||
}],
|
||||
id: None,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
|
||||
if !reasoning_text.is_empty() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut reasoning_text),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -355,10 +443,17 @@ async fn process_chat_sse<S>(
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
pending_completed: Option<ResponseEvent>,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
@@ -370,8 +465,8 @@ where
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered Completed event from the previous call.
|
||||
if let Some(ev) = this.pending_completed.take() {
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
@@ -388,16 +483,21 @@ where
|
||||
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
|
||||
|
||||
if is_assistant_delta {
|
||||
if let crate::models::ResponseItem::Message { content, .. } = &item {
|
||||
if let Some(text) = content.iter().find_map(|c| match c {
|
||||
crate::models::ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
}) {
|
||||
this.cumulative.push_str(text);
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty() {
|
||||
if let crate::models::ResponseItem::Message { content, .. } = &item {
|
||||
if let Some(text) = content.iter().find_map(|c| match c {
|
||||
crate::models::ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
}) {
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Swallow partial assistant chunk; keep polling.
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -408,24 +508,50 @@ where
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning = crate::models::ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![
|
||||
crate::models::ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_item = crate::models::ResponseItem::Message {
|
||||
let aggregated_message = crate::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![crate::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Buffer Completed so it is returned *after* the aggregated message.
|
||||
this.pending_completed = Some(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
aggregated_item,
|
||||
))));
|
||||
// Return the first pending event now.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
@@ -439,10 +565,27 @@ where
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(_))))
|
||||
| Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
|
||||
// Deltas are ignored here since aggregation waits for the
|
||||
// final OutputItemDone.
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -472,12 +615,24 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream {
|
||||
inner: self,
|
||||
cumulative: String::new(),
|
||||
pending_completed: None,
|
||||
}
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,12 +30,10 @@ use crate::config::Config;
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::EnvVarError;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::TokenUsage;
|
||||
@@ -83,8 +81,7 @@ impl ModelClient {
|
||||
// Create the raw streaming connection first.
|
||||
let response_stream = stream_chat_completions(
|
||||
prompt,
|
||||
&self.config.model,
|
||||
self.config.include_plan_tool,
|
||||
&self.config.model_family,
|
||||
&self.client,
|
||||
&self.provider,
|
||||
)
|
||||
@@ -93,7 +90,11 @@ impl ModelClient {
|
||||
// Wrap it with the aggregation adapter so callers see *only*
|
||||
// the final assistant message per turn (matching the
|
||||
// behaviour of the Responses API).
|
||||
let mut aggregated = response_stream.aggregate();
|
||||
let mut aggregated = if self.config.show_raw_agent_reasoning {
|
||||
crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream)
|
||||
} else {
|
||||
response_stream.aggregate()
|
||||
};
|
||||
|
||||
// Bridge the aggregated stream back into a standard
|
||||
// `ResponseStream` by forwarding events through a channel.
|
||||
@@ -122,32 +123,19 @@ impl ModelClient {
|
||||
return stream_from_fixture(path, self.provider.clone()).await;
|
||||
}
|
||||
|
||||
let auth = self.auth.as_ref().ok_or_else(|| {
|
||||
CodexErr::EnvVar(EnvVarError {
|
||||
var: "OPENAI_API_KEY".to_string(),
|
||||
instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".to_string()),
|
||||
})
|
||||
})?;
|
||||
let auth = self.auth.clone();
|
||||
|
||||
let store = prompt.store && auth.mode != AuthMode::ChatGPT;
|
||||
let auth_mode = auth.as_ref().map(|a| a.mode);
|
||||
|
||||
let base_url = match self.provider.base_url.clone() {
|
||||
Some(url) => url,
|
||||
None => match auth.mode {
|
||||
AuthMode::ChatGPT => "https://chatgpt.com/backend-api/codex".to_string(),
|
||||
AuthMode::ApiKey => "https://api.openai.com/v1".to_string(),
|
||||
},
|
||||
};
|
||||
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
|
||||
|
||||
let token = auth.get_token().await?;
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
let tools_json = create_tools_json_for_responses_api(
|
||||
prompt,
|
||||
&self.config.model,
|
||||
self.config.include_plan_tool,
|
||||
)?;
|
||||
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
|
||||
let tools_json = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
let reasoning = create_reasoning_param_for_request(
|
||||
&self.config.model_family,
|
||||
self.effort,
|
||||
self.summary,
|
||||
);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
@@ -157,15 +145,7 @@ impl ModelClient {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let mut input_with_instructions = Vec::with_capacity(prompt.input.len() + 1);
|
||||
if let Some(ui) = &prompt.user_instructions {
|
||||
input_with_instructions.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: ui.clone() }],
|
||||
});
|
||||
}
|
||||
input_with_instructions.extend(prompt.input.clone());
|
||||
let input_with_instructions = prompt.get_formatted_input();
|
||||
|
||||
let payload = ResponsesApiRequest {
|
||||
model: &self.config.model,
|
||||
@@ -180,34 +160,42 @@ impl ModelClient {
|
||||
include,
|
||||
};
|
||||
|
||||
trace!(
|
||||
"POST to {}: {}",
|
||||
self.provider.get_full_url(),
|
||||
serde_json::to_string(&payload)?
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
|
||||
trace!(
|
||||
"POST to {}: {}",
|
||||
self.provider.get_full_url(&auth),
|
||||
serde_json::to_string(&payload)?
|
||||
);
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = self
|
||||
.client
|
||||
.post(format!("{base_url}/responses"))
|
||||
.provider
|
||||
.create_request_builder(&self.client, &auth)
|
||||
.await?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
.bearer_auth(&token)
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
if auth.mode == AuthMode::ChatGPT {
|
||||
if let Some(account_id) = auth.get_account_id().await {
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
if let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
&& let Some(account_id) = auth.get_account_id().await
|
||||
{
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let req_builder = self.provider.apply_http_headers(req_builder);
|
||||
let originator = self
|
||||
.config
|
||||
.internal_originator
|
||||
.as_deref()
|
||||
.unwrap_or("codex_cli_rs");
|
||||
req_builder = req_builder.header("originator", originator);
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
@@ -439,6 +427,14 @@ async fn process_sse<S>(
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::ReasoningContentDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if event.response.is_some() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
@@ -627,7 +623,7 @@ mod tests {
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let events = collect_events(
|
||||
@@ -687,7 +683,7 @@ mod tests {
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
||||
@@ -790,7 +786,7 @@ mod tests {
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let out = run_sse(evs, provider).await;
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
@@ -17,6 +23,47 @@ use tokio::sync::mpsc;
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>\n\n";
|
||||
const ENVIRONMENT_CONTEXT_END: &str = "\n\n</environment_context>";
|
||||
|
||||
/// wraps user instructions message in a tag for the model to parse more easily.
|
||||
const USER_INSTRUCTIONS_START: &str = "<user_instructions>\n\n";
|
||||
const USER_INSTRUCTIONS_END: &str = "\n\n</user_instructions>";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct EnvironmentContext {
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
}
|
||||
|
||||
impl Display for EnvironmentContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"Current working directory: {}",
|
||||
self.cwd.to_string_lossy()
|
||||
)?;
|
||||
writeln!(f, "Approval policy: {}", self.approval_policy)?;
|
||||
writeln!(f, "Sandbox policy: {}", self.sandbox_policy)?;
|
||||
|
||||
let network_access = match self.sandbox_policy.clone() {
|
||||
SandboxPolicy::DangerFullAccess => "enabled",
|
||||
SandboxPolicy::ReadOnly => "restricted",
|
||||
SandboxPolicy::WorkspaceWrite { network_access, .. } => {
|
||||
if network_access {
|
||||
"enabled"
|
||||
} else {
|
||||
"restricted"
|
||||
}
|
||||
}
|
||||
};
|
||||
writeln!(f, "Network access: {network_access}")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
@@ -28,27 +75,62 @@ pub struct Prompt {
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Additional tools sourced from external MCP servers. Note each key is
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
/// A list of key-value pairs that will be added as a developer message
|
||||
/// for the model to use
|
||||
pub environment_context: Option<EnvironmentContext>,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub tools: Vec<OpenAiTool>,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
|
||||
pub(crate) fn get_full_instructions(&self, model: &ModelFamily) -> Cow<'_, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(BASE_INSTRUCTIONS);
|
||||
let mut sections: Vec<&str> = vec![base];
|
||||
if model.starts_with("gpt-4.1") {
|
||||
if model.needs_special_apply_patch_instructions {
|
||||
sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS);
|
||||
}
|
||||
Cow::Owned(sections.join("\n"))
|
||||
}
|
||||
|
||||
fn get_formatted_user_instructions(&self) -> Option<String> {
|
||||
self.user_instructions
|
||||
.as_ref()
|
||||
.map(|ui| format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"))
|
||||
}
|
||||
|
||||
fn get_formatted_environment_context(&self) -> Option<String> {
|
||||
self.environment_context
|
||||
.as_ref()
|
||||
.map(|ec| format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}"))
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input_with_instructions = Vec::with_capacity(self.input.len() + 2);
|
||||
if let Some(ec) = self.get_formatted_environment_context() {
|
||||
input_with_instructions.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: ec }],
|
||||
});
|
||||
}
|
||||
if let Some(ui) = self.get_formatted_user_instructions() {
|
||||
input_with_instructions.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: ui }],
|
||||
});
|
||||
}
|
||||
input_with_instructions.extend(self.input.clone());
|
||||
input_with_instructions
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -61,6 +143,7 @@ pub enum ResponseEvent {
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -134,14 +217,12 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) include: Vec<String>,
|
||||
}
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
pub(crate) fn create_reasoning_param_for_request(
|
||||
config: &Config,
|
||||
model_family: &ModelFamily,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
) -> Option<Reasoning> {
|
||||
if model_supports_reasoning_summaries(config) {
|
||||
if model_family.supports_reasoning_summaries {
|
||||
let effort: Option<OpenAiReasoningEffort> = effort.into();
|
||||
let effort = effort?;
|
||||
Some(Reasoning {
|
||||
@@ -153,27 +234,6 @@ pub(crate) fn create_reasoning_param_for_request(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_supports_reasoning_summaries(config: &Config) -> bool {
|
||||
// Currently, we hardcode this rule to decide whether to enable reasoning.
|
||||
// We expect reasoning to apply only to OpenAI models, but we do not want
|
||||
// users to have to mess with their config to disable reasoning for models
|
||||
// that do not support it, such as `gpt-4.1`.
|
||||
//
|
||||
// Though if a user is using Codex with non-OpenAI models that, say, happen
|
||||
// to start with "o", then they can set `model_reasoning_effort = "none"` in
|
||||
// config.toml to disable reasoning.
|
||||
//
|
||||
// Converseley, if a user has a non-OpenAI provider that supports reasoning,
|
||||
// they can set the top-level `model_supports_reasoning_summaries = true`
|
||||
// config option to enable reasoning.
|
||||
if config.model_supports_reasoning_summaries {
|
||||
return true;
|
||||
}
|
||||
|
||||
let model = &config.model;
|
||||
model.starts_with("o") || model.starts_with("codex")
|
||||
}
|
||||
|
||||
pub(crate) struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
@@ -188,6 +248,9 @@ impl Stream for ResponseStream {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
use crate::model_family::find_family_for_model;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -197,7 +260,8 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}");
|
||||
let full = prompt.get_full_instructions("gpt-4.1");
|
||||
let model_family = find_family_for_model("gpt-4.1").expect("known model slug");
|
||||
let full = prompt.get_full_instructions(&model_family);
|
||||
assert_eq!(full, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::apply_patch::get_writable_roots;
|
||||
use crate::apply_patch::{self};
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::EnvironmentContext;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
@@ -48,6 +49,7 @@ use crate::error::SandboxErr;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
@@ -55,16 +57,21 @@ use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
use crate::models::LocalShellAction;
|
||||
use crate::models::ReasoningItemContent;
|
||||
use crate::models::ReasoningItemReasoningSummary;
|
||||
use crate::models::ResponseInputItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::models::ShellToolCallParams;
|
||||
use crate::openai_tools::ToolsConfig;
|
||||
use crate::openai_tools::get_openai_tools;
|
||||
use crate::plan_tool::handle_update_plan;
|
||||
use crate::project_doc::get_user_instructions;
|
||||
use crate::protocol::AgentMessageDeltaEvent;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AgentReasoningDeltaEvent;
|
||||
use crate::protocol::AgentReasoningEvent;
|
||||
use crate::protocol::AgentReasoningRawContentDeltaEvent;
|
||||
use crate::protocol::AgentReasoningRawContentEvent;
|
||||
use crate::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::BackgroundEventEvent;
|
||||
@@ -84,11 +91,13 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::protocol::Submission;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
|
||||
@@ -120,7 +129,7 @@ impl Codex {
|
||||
let resume_path = config.experimental_resume.clone();
|
||||
info!("resume_path: {resume_path:?}");
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||
let (tx_event, rx_event) = async_channel::bounded(1600);
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
|
||||
@@ -210,6 +219,7 @@ pub(crate) struct Session {
|
||||
shell_environment_policy: ShellEnvironmentPolicy,
|
||||
pub(crate) writable_roots: Mutex<Vec<PathBuf>>,
|
||||
disable_response_storage: bool,
|
||||
tools_config: ToolsConfig,
|
||||
|
||||
/// Manager for external MCP servers/tools.
|
||||
mcp_connection_manager: McpConnectionManager,
|
||||
@@ -224,6 +234,7 @@ pub(crate) struct Session {
|
||||
state: Mutex<State>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
user_shell: shell::Shell,
|
||||
show_raw_agent_reasoning: bool,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -361,7 +372,11 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_exec_command_begin(&self, exec_command_context: ExecCommandContext) {
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
let ExecCommandContext {
|
||||
sub_id,
|
||||
call_id,
|
||||
@@ -373,11 +388,15 @@ impl Session {
|
||||
Some(ApplyPatchCommandContext {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}),
|
||||
}) => {
|
||||
turn_diff_tracker.on_patch_begin(&changes);
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
})
|
||||
}
|
||||
None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command: command_for_display.clone(),
|
||||
@@ -391,15 +410,21 @@ impl Session {
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
async fn notify_exec_command_end(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn on_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
exit_code: i32,
|
||||
output: &ExecToolCallOutput,
|
||||
is_apply_patch: bool,
|
||||
) {
|
||||
let ExecToolCallOutput {
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
exit_code,
|
||||
} = output;
|
||||
// Because stdout and stderr could each be up to 100 KiB, we send
|
||||
// truncated versions.
|
||||
const MAX_STREAM_OUTPUT: usize = 5 * 1024; // 5KiB
|
||||
@@ -411,14 +436,15 @@ impl Session {
|
||||
call_id: call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
success: exit_code == 0,
|
||||
success: *exit_code == 0,
|
||||
})
|
||||
} else {
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
duration: *duration,
|
||||
exit_code: *exit_code,
|
||||
})
|
||||
};
|
||||
|
||||
@@ -427,6 +453,20 @@ impl Session {
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.into(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
||||
@@ -442,6 +482,12 @@ impl Session {
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
/// history with additional items for this turn.
|
||||
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
[self.state.lock().unwrap().history.contents(), extra].concat()
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
@@ -564,6 +610,25 @@ impl AgentTask {
|
||||
handle,
|
||||
}
|
||||
}
|
||||
fn compact(
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
) -> Self {
|
||||
let handle = tokio::spawn(run_compact_task(
|
||||
Arc::clone(&sess),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
compact_instructions,
|
||||
))
|
||||
.abort_handle();
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self) {
|
||||
if !self.handle.is_finished() {
|
||||
@@ -644,7 +709,7 @@ async fn submission_loop(
|
||||
cwd,
|
||||
resume_path,
|
||||
} => {
|
||||
info!(
|
||||
debug!(
|
||||
"Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}"
|
||||
);
|
||||
if !cwd.is_absolute() {
|
||||
@@ -695,6 +760,10 @@ async fn submission_loop(
|
||||
}
|
||||
};
|
||||
|
||||
// Determine rollout path if available before moving the recorder.
|
||||
let rollout_path: Option<std::path::PathBuf> =
|
||||
rollout_recorder.as_ref().map(|r| r.path().to_path_buf());
|
||||
|
||||
let client = ModelClient::new(
|
||||
config.clone(),
|
||||
auth.clone(),
|
||||
@@ -749,6 +818,12 @@ async fn submission_loop(
|
||||
let default_shell = shell::default_user_shell().await;
|
||||
sess = Some(Arc::new(Session {
|
||||
client,
|
||||
tools_config: ToolsConfig::new(
|
||||
&config.model_family,
|
||||
approval_policy,
|
||||
sandbox_policy.clone(),
|
||||
config.include_plan_tool,
|
||||
),
|
||||
tx_event: tx_event.clone(),
|
||||
ctrl_c: Arc::clone(&ctrl_c),
|
||||
user_instructions,
|
||||
@@ -765,6 +840,7 @@ async fn submission_loop(
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
disable_response_storage,
|
||||
user_shell: default_shell,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
}));
|
||||
|
||||
// Patch restored state into the newly created session.
|
||||
@@ -787,6 +863,7 @@ async fn submission_loop(
|
||||
model,
|
||||
history_log_id,
|
||||
history_entry_count,
|
||||
rollout_path,
|
||||
}),
|
||||
})
|
||||
.chain(mcp_connection_errors.into_iter());
|
||||
@@ -884,6 +961,31 @@ async fn submission_loop(
|
||||
}
|
||||
});
|
||||
}
|
||||
Op::Compact => {
|
||||
let sess = match sess.as_ref() {
|
||||
Some(sess) => sess,
|
||||
None => {
|
||||
send_no_session_event(sub.id).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a summarization request as user input
|
||||
const SUMMARIZATION_PROMPT: &str = include_str!("../../../SUMMARY.md");
|
||||
|
||||
// Attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(vec![InputItem::Text {
|
||||
text: "Start Summarization".to_string(),
|
||||
}]) {
|
||||
let task = AgentTask::compact(
|
||||
sess.clone(),
|
||||
sub.id,
|
||||
items,
|
||||
SUMMARIZATION_PROMPT.to_string(),
|
||||
);
|
||||
sess.set_task(task);
|
||||
}
|
||||
}
|
||||
Op::Shutdown => {
|
||||
info!("Shutting down Codex instance");
|
||||
|
||||
@@ -945,11 +1047,15 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_input_for_turn = ResponseInputItem::from(input);
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let last_agent_message: Option<String>;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
@@ -966,8 +1072,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
// conversation history on each turn. The rollout file, however, should
|
||||
// only record the new items that originated in this turn so that it
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
[sess.state.lock().unwrap().history.contents(), pending_input].concat();
|
||||
let turn_input: Vec<ResponseItem> = sess.turn_input_with_history(pending_input);
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
@@ -982,7 +1087,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
match run_turn(&sess, sub_id.clone(), turn_input).await {
|
||||
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
|
||||
Ok(turn_output) => {
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
@@ -1047,6 +1152,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content,
|
||||
},
|
||||
None,
|
||||
@@ -1054,6 +1160,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||
id: id.clone(),
|
||||
summary: summary.clone(),
|
||||
content: content.clone(),
|
||||
encrypted_content: encrypted_content.clone(),
|
||||
});
|
||||
}
|
||||
@@ -1108,21 +1215,31 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let extra_tools = sess.mcp_connection_manager.list_all_tools();
|
||||
let tools = get_openai_tools(
|
||||
&sess.tools_config,
|
||||
Some(sess.mcp_connection_manager.list_all_tools()),
|
||||
);
|
||||
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
user_instructions: sess.user_instructions.clone(),
|
||||
store: !sess.disable_response_storage,
|
||||
extra_tools,
|
||||
tools,
|
||||
base_instructions_override: sess.base_instructions.clone(),
|
||||
environment_context: Some(EnvironmentContext {
|
||||
cwd: sess.cwd.clone(),
|
||||
approval_policy: sess.approval_policy,
|
||||
sandbox_policy: sess.sandbox_policy.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(sess, &sub_id, &prompt).await {
|
||||
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
@@ -1168,6 +1285,7 @@ struct ProcessedResponseItem {
|
||||
|
||||
async fn try_run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
@@ -1255,7 +1373,8 @@ async fn try_run_turn(
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
let response =
|
||||
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
|
||||
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
@@ -1273,9 +1392,24 @@ async fn try_run_turn(
|
||||
.ok();
|
||||
}
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = sess.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
return Ok(output);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
{
|
||||
let mut st = sess.state.lock().unwrap();
|
||||
st.history.append_assistant_text(&delta);
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
@@ -1289,12 +1423,107 @@ async fn try_run_turn(
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
ResponseEvent::ReasoningContentDelta(delta) => {
|
||||
if sess.show_raw_agent_reasoning {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningRawContentDelta(
|
||||
AgentReasoningRawContentDeltaEvent { delta },
|
||||
),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
) {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted,
|
||||
};
|
||||
if sess.tx_event.send(start_event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
sess.turn_input_with_history(vec![initial_input_for_turn.clone().into()]);
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
user_instructions: None,
|
||||
store: !sess.disable_response_storage,
|
||||
environment_context: None,
|
||||
tools: Vec::new(),
|
||||
base_instructions_override: Some(compact_instructions.clone()),
|
||||
};
|
||||
|
||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
|
||||
loop {
|
||||
let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => break,
|
||||
Err(CodexErr::Interrupted) => return,
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
continue;
|
||||
} else {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sess.remove_task(&sub_id);
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.history.keep_last_messages(1);
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
@@ -1312,7 +1541,12 @@ async fn handle_response_item(
|
||||
}
|
||||
None
|
||||
}
|
||||
ResponseItem::Reasoning { summary, .. } => {
|
||||
ResponseItem::Reasoning {
|
||||
id: _,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content: _,
|
||||
} => {
|
||||
for item in summary {
|
||||
let text = match item {
|
||||
ReasoningItemReasoningSummary::SummaryText { text } => text,
|
||||
@@ -1323,6 +1557,21 @@ async fn handle_response_item(
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
if sess.show_raw_agent_reasoning && content.is_some() {
|
||||
let content = content.unwrap();
|
||||
for item in content {
|
||||
let text = match item {
|
||||
ReasoningItemContent::ReasoningText { text } => text,
|
||||
};
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent {
|
||||
text,
|
||||
}),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
@@ -1332,7 +1581,17 @@ async fn handle_response_item(
|
||||
..
|
||||
} => {
|
||||
info!("FunctionCall: {arguments}");
|
||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||
Some(
|
||||
handle_function_call(
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
@@ -1346,6 +1605,8 @@ async fn handle_response_item(
|
||||
command: action.command,
|
||||
workdir: action.working_directory,
|
||||
timeout_ms: action.timeout_ms,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
let effective_call_id = match (call_id, id) {
|
||||
(Some(call_id), _) => call_id,
|
||||
@@ -1367,6 +1628,7 @@ async fn handle_response_item(
|
||||
handle_container_exec_with_params(
|
||||
exec_params,
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
effective_call_id,
|
||||
)
|
||||
@@ -1384,6 +1646,7 @@ async fn handle_response_item(
|
||||
|
||||
async fn handle_function_call(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
@@ -1397,7 +1660,8 @@ async fn handle_function_call(
|
||||
return *output;
|
||||
}
|
||||
};
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
|
||||
.await
|
||||
}
|
||||
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
||||
_ => {
|
||||
@@ -1431,6 +1695,8 @@ fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
|
||||
cwd: sess.resolve_path(params.workdir.clone()),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: create_env(&sess.shell_environment_policy),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1471,6 +1737,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
|
||||
async fn handle_container_exec_with_params(
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> ResponseInputItem {
|
||||
@@ -1530,13 +1797,19 @@ async fn handle_container_exec_with_params(
|
||||
cwd: cwd.clone(),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification.clone(),
|
||||
};
|
||||
let safety = if *user_explicitly_approved_this_action {
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
}
|
||||
} else {
|
||||
assess_safety_for_untrusted_command(sess.approval_policy, &sess.sandbox_policy)
|
||||
assess_safety_for_untrusted_command(
|
||||
sess.approval_policy,
|
||||
&sess.sandbox_policy,
|
||||
params.with_escalated_permissions.unwrap_or(false),
|
||||
)
|
||||
};
|
||||
(
|
||||
params,
|
||||
@@ -1552,6 +1825,7 @@ async fn handle_container_exec_with_params(
|
||||
sess.approval_policy,
|
||||
&sess.sandbox_policy,
|
||||
&state.approved_commands,
|
||||
params.with_escalated_permissions.unwrap_or(false),
|
||||
)
|
||||
};
|
||||
let command_for_display = params.command.clone();
|
||||
@@ -1568,7 +1842,7 @@ async fn handle_container_exec_with_params(
|
||||
call_id.clone(),
|
||||
params.command.clone(),
|
||||
params.cwd.clone(),
|
||||
None,
|
||||
params.justification.clone(),
|
||||
)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
@@ -1618,7 +1892,7 @@ async fn handle_container_exec_with_params(
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.notify_exec_command_begin(exec_command_context.clone())
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context.clone())
|
||||
.await;
|
||||
|
||||
let params = maybe_run_with_user_profile(params, sess);
|
||||
@@ -1628,6 +1902,11 @@ async fn handle_container_exec_with_params(
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
&sess.codex_linux_sandbox_exe,
|
||||
Some(StdoutStream {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: sess.tx_event.clone(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1638,23 +1917,22 @@ async fn handle_container_exec_with_params(
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
} = output;
|
||||
} = &output;
|
||||
|
||||
sess.notify_exec_command_end(
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&stdout,
|
||||
&stderr,
|
||||
exit_code,
|
||||
&output,
|
||||
exec_command_context.apply_patch.is_some(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let is_success = exit_code == 0;
|
||||
let is_success = *exit_code == 0;
|
||||
let content = format_exec_output(
|
||||
if is_success { &stdout } else { &stderr },
|
||||
exit_code,
|
||||
duration,
|
||||
if is_success { stdout } else { stderr },
|
||||
*exit_code,
|
||||
*duration,
|
||||
);
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
@@ -1666,7 +1944,15 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
}
|
||||
Err(CodexErr::Sandbox(error)) => {
|
||||
handle_sandbox_error(params, exec_command_context, error, sandbox_type, sess).await
|
||||
handle_sandbox_error(
|
||||
turn_diff_tracker,
|
||||
params,
|
||||
exec_command_context,
|
||||
error,
|
||||
sandbox_type,
|
||||
sess,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle non-sandbox errors
|
||||
@@ -1682,6 +1968,7 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
|
||||
async fn handle_sandbox_error(
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
params: ExecParams,
|
||||
exec_command_context: ExecCommandContext,
|
||||
error: SandboxErr,
|
||||
@@ -1693,13 +1980,31 @@ async fn handle_sandbox_error(
|
||||
let cwd = exec_command_context.cwd.clone();
|
||||
let is_apply_patch = exec_command_context.apply_patch.is_some();
|
||||
|
||||
// Early out if the user never wants to be asked for approval; just return to the model immediately
|
||||
if sess.approval_policy == AskForApproval::Never {
|
||||
// Early out if either the user never wants to be asked for approval, or
|
||||
// we're letting the model manage escalation requests. Otherwise, continue
|
||||
match sess.approval_policy {
|
||||
AskForApproval::Never | AskForApproval::OnRequest => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"failed in sandbox {sandbox_type:?} with execution error: {error}"
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
AskForApproval::UnlessTrusted | AskForApproval::OnFailure => (),
|
||||
}
|
||||
|
||||
// similarly, if the command timed out, we can simply return this failure to the model
|
||||
if matches!(error, SandboxErr::Timeout) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"failed in sandbox {sandbox_type:?} with execution error: {error}"
|
||||
"command timed out after {} milliseconds",
|
||||
params.timeout_duration().as_millis()
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
@@ -1713,7 +2018,8 @@ async fn handle_sandbox_error(
|
||||
// include additional metadata on the command to indicate whether non-zero
|
||||
// exit codes merit a retry.
|
||||
|
||||
// For now, we categorically ask the user to retry without sandbox.
|
||||
// For now, we categorically ask the user to retry without sandbox and
|
||||
// emit the raw error as a background event.
|
||||
sess.notify_background_event(&sub_id, format!("Execution failed: {error}"))
|
||||
.await;
|
||||
|
||||
@@ -1738,7 +2044,8 @@ async fn handle_sandbox_error(
|
||||
sess.notify_background_event(&sub_id, "retrying command without sandbox")
|
||||
.await;
|
||||
|
||||
sess.notify_exec_command_begin(exec_command_context).await;
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context)
|
||||
.await;
|
||||
|
||||
// This is an escalated retry; the policy will not be
|
||||
// examined and the sandbox has been set to `None`.
|
||||
@@ -1748,6 +2055,11 @@ async fn handle_sandbox_error(
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
&sess.codex_linux_sandbox_exe,
|
||||
Some(StdoutStream {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: sess.tx_event.clone(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1758,23 +2070,22 @@ async fn handle_sandbox_error(
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
} = retry_output;
|
||||
} = &retry_output;
|
||||
|
||||
sess.notify_exec_command_end(
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&stdout,
|
||||
&stderr,
|
||||
exit_code,
|
||||
&retry_output,
|
||||
is_apply_patch,
|
||||
)
|
||||
.await;
|
||||
|
||||
let is_success = exit_code == 0;
|
||||
let is_success = *exit_code == 0;
|
||||
let content = format_exec_output(
|
||||
if is_success { &stdout } else { &stderr },
|
||||
exit_code,
|
||||
duration,
|
||||
if is_success { stdout } else { stderr },
|
||||
*exit_code,
|
||||
*duration,
|
||||
);
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
@@ -1858,3 +2169,45 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> CodexResult<()> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
// Record only to in-memory conversation history; avoid state snapshot.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
}
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
}) => {
|
||||
let token_usage = match token_usage {
|
||||
Some(usage) => usage,
|
||||
None => {
|
||||
return Err(CodexErr::Stream(
|
||||
"token_usage was None in ResponseEvent::Completed".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
sess.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(token_usage),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return Ok(());
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ use crate::config_types::ShellEnvironmentPolicyToml;
|
||||
use crate::config_types::Tui;
|
||||
use crate::config_types::UriBasedFileOpener;
|
||||
use crate::flags::OPENAI_DEFAULT_MODEL;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
@@ -33,6 +35,8 @@ pub struct Config {
|
||||
/// Optional override of model selection.
|
||||
pub model: String,
|
||||
|
||||
pub model_family: ModelFamily,
|
||||
|
||||
/// Size of the context window for the model, in tokens.
|
||||
pub model_context_window: Option<u64>,
|
||||
|
||||
@@ -57,12 +61,16 @@ pub struct Config {
|
||||
/// users are only interested in the final agent responses.
|
||||
pub hide_agent_reasoning: bool,
|
||||
|
||||
/// When set to `true`, `AgentReasoningRawContentEvent` events will be shown in the UI/output.
|
||||
/// Defaults to `false`.
|
||||
pub show_raw_agent_reasoning: bool,
|
||||
|
||||
/// Disable server-side response storage (sends the full conversation
|
||||
/// context with every request). Currently necessary for OpenAI customers
|
||||
/// who have opted into Zero Data Retention (ZDR).
|
||||
pub disable_response_storage: bool,
|
||||
|
||||
/// User-provided instructions from instructions.md.
|
||||
/// User-provided instructions from AGENTS.md.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
@@ -134,10 +142,6 @@ pub struct Config {
|
||||
/// request using the Responses API.
|
||||
pub model_reasoning_summary: ReasoningSummary,
|
||||
|
||||
/// When set to `true`, overrides the default heuristic and forces
|
||||
/// `model_supports_reasoning_summaries()` to return `true`.
|
||||
pub model_supports_reasoning_summaries: bool,
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: String,
|
||||
|
||||
@@ -146,6 +150,9 @@ pub struct Config {
|
||||
|
||||
/// Include an experimental plan tool that the model can use to update its current plan and status of each step.
|
||||
pub include_plan_tool: bool,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -322,6 +329,10 @@ pub struct ConfigToml {
|
||||
/// UI/output. Defaults to `false`.
|
||||
pub hide_agent_reasoning: Option<bool>,
|
||||
|
||||
/// When set to `true`, `AgentReasoningRawContentEvent` events will be shown in the UI/output.
|
||||
/// Defaults to `false`.
|
||||
pub show_raw_agent_reasoning: Option<bool>,
|
||||
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
|
||||
@@ -336,6 +347,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -350,6 +364,7 @@ impl ConfigToml {
|
||||
Some(s) => SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: s.writable_roots.clone(),
|
||||
network_access: s.network_access,
|
||||
include_default_writable_roots: true,
|
||||
},
|
||||
None => SandboxPolicy::new_workspace_write_policy(),
|
||||
},
|
||||
@@ -370,6 +385,8 @@ pub struct ConfigOverrides {
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub base_instructions: Option<String>,
|
||||
pub include_plan_tool: Option<bool>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub show_raw_agent_reasoning: Option<bool>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -393,6 +410,8 @@ impl Config {
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
include_plan_tool,
|
||||
disable_response_storage,
|
||||
show_raw_agent_reasoning,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
@@ -458,7 +477,19 @@ impl Config {
|
||||
.or(config_profile.model)
|
||||
.or(cfg.model)
|
||||
.unwrap_or_else(default_model);
|
||||
let openai_model_info = get_model_info(&model);
|
||||
let model_family = find_family_for_model(&model).unwrap_or_else(|| {
|
||||
let supports_reasoning_summaries =
|
||||
cfg.model_supports_reasoning_summaries.unwrap_or(false);
|
||||
ModelFamily {
|
||||
slug: model.clone(),
|
||||
family: model.clone(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries,
|
||||
uses_local_shell_tool: false,
|
||||
}
|
||||
});
|
||||
|
||||
let openai_model_info = get_model_info(&model_family);
|
||||
let model_context_window = cfg
|
||||
.model_context_window
|
||||
.or_else(|| openai_model_info.as_ref().map(|info| info.context_window));
|
||||
@@ -473,14 +504,17 @@ impl Config {
|
||||
// Load base instructions override from a file if specified. If the
|
||||
// path is relative, resolve it against the effective cwd so the
|
||||
// behaviour matches other path-like config values.
|
||||
let file_base_instructions = Self::get_base_instructions(
|
||||
cfg.experimental_instructions_file.as_ref(),
|
||||
&resolved_cwd,
|
||||
)?;
|
||||
let experimental_instructions_path = config_profile
|
||||
.experimental_instructions_file
|
||||
.as_ref()
|
||||
.or(cfg.experimental_instructions_file.as_ref());
|
||||
let file_base_instructions =
|
||||
Self::get_base_instructions(experimental_instructions_path, &resolved_cwd)?;
|
||||
let base_instructions = base_instructions.or(file_base_instructions);
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_family,
|
||||
model_context_window,
|
||||
model_max_output_tokens,
|
||||
model_provider_id,
|
||||
@@ -495,6 +529,7 @@ impl Config {
|
||||
disable_response_storage: config_profile
|
||||
.disable_response_storage
|
||||
.or(cfg.disable_response_storage)
|
||||
.or(disable_response_storage)
|
||||
.unwrap_or(false),
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
@@ -509,6 +544,10 @@ impl Config {
|
||||
codex_linux_sandbox_exe,
|
||||
|
||||
hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false),
|
||||
show_raw_agent_reasoning: cfg
|
||||
.show_raw_agent_reasoning
|
||||
.or(show_raw_agent_reasoning)
|
||||
.unwrap_or(false),
|
||||
model_reasoning_effort: config_profile
|
||||
.model_reasoning_effort
|
||||
.or(cfg.model_reasoning_effort)
|
||||
@@ -518,10 +557,6 @@ impl Config {
|
||||
.or(cfg.model_reasoning_summary)
|
||||
.unwrap_or_default(),
|
||||
|
||||
model_supports_reasoning_summaries: cfg
|
||||
.model_supports_reasoning_summaries
|
||||
.unwrap_or(false),
|
||||
|
||||
chatgpt_base_url: config_profile
|
||||
.chatgpt_base_url
|
||||
.or(cfg.chatgpt_base_url)
|
||||
@@ -529,6 +564,7 @@ impl Config {
|
||||
|
||||
experimental_resume,
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
internal_originator: cfg.internal_originator,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -539,7 +575,7 @@ impl Config {
|
||||
None => return None,
|
||||
};
|
||||
|
||||
p.push("instructions.md");
|
||||
p.push("AGENTS.md");
|
||||
std::fs::read_to_string(&p).ok().and_then(|s| {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
@@ -720,6 +756,7 @@ writable_roots = [
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![PathBuf::from("/tmp")],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
},
|
||||
sandbox_workspace_write_cfg.derive_sandbox_policy(sandbox_mode_override)
|
||||
);
|
||||
@@ -805,7 +842,7 @@ disable_response_storage = true
|
||||
request_max_retries: Some(4),
|
||||
stream_max_retries: Some(10),
|
||||
stream_idle_timeout_ms: Some(300_000),
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
let model_provider_map = {
|
||||
let mut model_provider_map = built_in_model_providers();
|
||||
@@ -860,6 +897,7 @@ disable_response_storage = true
|
||||
assert_eq!(
|
||||
Config {
|
||||
model: "o3".to_string(),
|
||||
model_family: find_family_for_model("o3").expect("known model slug"),
|
||||
model_context_window: Some(200_000),
|
||||
model_max_output_tokens: Some(100_000),
|
||||
model_provider_id: "openai".to_string(),
|
||||
@@ -880,13 +918,14 @@ disable_response_storage = true
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::High,
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -909,6 +948,7 @@ disable_response_storage = true
|
||||
)?;
|
||||
let expected_gpt3_profile_config = Config {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
model_family: find_family_for_model("gpt-3.5-turbo").expect("known model slug"),
|
||||
model_context_window: Some(16_385),
|
||||
model_max_output_tokens: Some(4_096),
|
||||
model_provider_id: "openai-chat-completions".to_string(),
|
||||
@@ -929,13 +969,14 @@ disable_response_storage = true
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::default(),
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -973,6 +1014,7 @@ disable_response_storage = true
|
||||
)?;
|
||||
let expected_zdr_profile_config = Config {
|
||||
model: "o3".to_string(),
|
||||
model_family: find_family_for_model("o3").expect("known model slug"),
|
||||
model_context_window: Some(200_000),
|
||||
model_max_output_tokens: Some(100_000),
|
||||
model_provider_id: "openai".to_string(),
|
||||
@@ -993,13 +1035,14 @@ disable_response_storage = true
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::default(),
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config_types::ReasoningEffort;
|
||||
use crate::config_types::ReasoningSummary;
|
||||
@@ -17,4 +18,5 @@ pub struct ConfigProfile {
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
@@ -24,12 +24,83 @@ impl ConversationHistory {
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
for item in items {
|
||||
if is_api_message(&item) {
|
||||
// Note agent-loop.ts also does filtering on some of the fields.
|
||||
self.items.push(item.clone());
|
||||
if !is_api_message(&item) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Merge adjacent assistant messages into a single history entry.
|
||||
// This prevents duplicates when a partial assistant message was
|
||||
// streamed into history earlier in the turn and the final full
|
||||
// message is recorded at turn end.
|
||||
match (&*item, self.items.last_mut()) {
|
||||
(
|
||||
ResponseItem::Message {
|
||||
role: new_role,
|
||||
content: new_content,
|
||||
..
|
||||
},
|
||||
Some(ResponseItem::Message {
|
||||
role: last_role,
|
||||
content: last_content,
|
||||
..
|
||||
}),
|
||||
) if new_role == "assistant" && last_role == "assistant" => {
|
||||
append_text_content(last_content, new_content);
|
||||
}
|
||||
_ => {
|
||||
self.items.push(item.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a text `delta` to the latest assistant message, creating a new
|
||||
/// assistant entry if none exists yet (e.g. first delta for this turn).
|
||||
pub(crate) fn append_assistant_text(&mut self, delta: &str) {
|
||||
match self.items.last_mut() {
|
||||
Some(ResponseItem::Message { role, content, .. }) if role == "assistant" => {
|
||||
append_text_delta(content, delta);
|
||||
}
|
||||
_ => {
|
||||
// Start a new assistant message with the delta.
|
||||
self.items.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![crate::models::ContentItem::OutputText {
|
||||
text: delta.to_string(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn keep_last_messages(&mut self, n: usize) {
|
||||
if n == 0 {
|
||||
self.items.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect the last N message items (assistant/user), newest to oldest.
|
||||
let mut kept: Vec<ResponseItem> = Vec::with_capacity(n);
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
kept.push(ResponseItem::Message {
|
||||
// we need to remove the id or the model will complain that messages are sent without
|
||||
// their reasonings
|
||||
id: None,
|
||||
role: role.clone(),
|
||||
content: content.clone(),
|
||||
});
|
||||
if kept.len() == n {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve chronological order (oldest to newest) within the kept slice.
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
}
|
||||
}
|
||||
|
||||
/// Anything that is not a system message or "reasoning" message is considered
|
||||
@@ -44,3 +115,140 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to append the textual content from `src` into `dst` in place.
|
||||
fn append_text_content(
|
||||
dst: &mut Vec<crate::models::ContentItem>,
|
||||
src: &Vec<crate::models::ContentItem>,
|
||||
) {
|
||||
for c in src {
|
||||
if let crate::models::ContentItem::OutputText { text } = c {
|
||||
append_text_delta(dst, text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a single text delta to the last OutputText item in `content`, or
|
||||
/// push a new OutputText item if none exists.
|
||||
fn append_text_delta(content: &mut Vec<crate::models::ContentItem>, delta: &str) {
|
||||
if let Some(crate::models::ContentItem::OutputText { text }) = content
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|c| matches!(c, crate::models::ContentItem::OutputText { .. }))
|
||||
{
|
||||
text.push_str(delta);
|
||||
} else {
|
||||
content.push(crate::models::ContentItem::OutputText {
|
||||
text: delta.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::ContentItem;
|
||||
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merges_adjacent_assistant_messages() {
|
||||
let mut h = ConversationHistory::default();
|
||||
let a1 = assistant_msg("Hello");
|
||||
let a2 = assistant_msg(", world!");
|
||||
h.record_items([&a1, &a2]);
|
||||
|
||||
let items = h.contents();
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "Hello, world!".to_string()
|
||||
}]
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_assistant_text_creates_and_appends() {
|
||||
let mut h = ConversationHistory::default();
|
||||
h.append_assistant_text("Hello");
|
||||
h.append_assistant_text(", world");
|
||||
|
||||
// Now record a final full assistant message and verify it merges.
|
||||
let final_msg = assistant_msg("!");
|
||||
h.record_items([&final_msg]);
|
||||
|
||||
let items = h.contents();
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "Hello, world!".to_string()
|
||||
}]
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filters_non_api_messages() {
|
||||
let mut h = ConversationHistory::default();
|
||||
// System message is not an API message; Other is ignored.
|
||||
let system = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "system".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "ignored".to_string(),
|
||||
}],
|
||||
};
|
||||
h.record_items([&system, &ResponseItem::Other]);
|
||||
|
||||
// User and assistant should be retained.
|
||||
let u = user_msg("hi");
|
||||
let a = assistant_msg("hello");
|
||||
h.record_items([&u, &a]);
|
||||
|
||||
let items = h.contents();
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "hi".to_string()
|
||||
}]
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "hello".to_string()
|
||||
}]
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_channel::Sender;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::BufReader;
|
||||
@@ -19,10 +20,15 @@ use tokio::sync::Notify;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use serde_bytes::ByteBuf;
|
||||
|
||||
// Maximum we send for each stream, which is either:
|
||||
// - 10KiB OR
|
||||
@@ -43,6 +49,14 @@ pub struct ExecParams {
|
||||
pub cwd: PathBuf,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
impl ExecParams {
|
||||
pub fn timeout_duration(&self) -> Duration {
|
||||
Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
@@ -56,24 +70,30 @@ pub enum SandboxType {
|
||||
LinuxSeccomp,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StdoutStream {
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tx_event: Sender<Event>,
|
||||
}
|
||||
|
||||
pub async fn process_exec_tool_call(
|
||||
params: ExecParams,
|
||||
sandbox_type: SandboxType,
|
||||
ctrl_c: Arc<Notify>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
codex_linux_sandbox_exe: &Option<PathBuf>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||
{
|
||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c).await,
|
||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c, stdout_stream.clone()).await,
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
command, cwd, env, ..
|
||||
} = params;
|
||||
let child = spawn_command_under_seatbelt(
|
||||
command,
|
||||
@@ -83,14 +103,12 @@ pub async fn process_exec_tool_call(
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout, stdout_stream.clone()).await
|
||||
}
|
||||
SandboxType::LinuxSeccomp => {
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
command, cwd, env, ..
|
||||
} = params;
|
||||
|
||||
let codex_linux_sandbox_exe = codex_linux_sandbox_exe
|
||||
@@ -106,7 +124,7 @@ pub async fn process_exec_tool_call(
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout, stdout_stream).await
|
||||
}
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
@@ -126,11 +144,7 @@ pub async fn process_exec_tool_call(
|
||||
|
||||
let exit_code = raw_output.exit_status.code().unwrap_or(-1);
|
||||
|
||||
// NOTE(ragona): This is much less restrictive than the previous check. If we exec
|
||||
// a command, and it returns anything other than success, we assume that it may have
|
||||
// been a sandboxing error and allow the user to retry. (The user of course may choose
|
||||
// not to retry, or in a non-interactive mode, would automatically reject the approval.)
|
||||
if exit_code != 0 && sandbox_type != SandboxType::None {
|
||||
if exit_code != 0 && is_likely_sandbox_denied(sandbox_type, exit_code) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied(
|
||||
exit_code, stdout, stderr,
|
||||
)));
|
||||
@@ -209,6 +223,26 @@ fn create_linux_sandbox_command_args(
|
||||
linux_cmd
|
||||
}
|
||||
|
||||
/// We don't have a fully deterministic way to tell if our command failed
|
||||
/// because of the sandbox - a command in the user's zshrc file might hit an
|
||||
/// error, but the command itself might fail or succeed for other reasons.
|
||||
/// For now, we conservatively check for 'command not found' (exit code 127),
|
||||
/// and can add additional cases as necessary.
|
||||
fn is_likely_sandbox_denied(sandbox_type: SandboxType, exit_code: i32) -> bool {
|
||||
if sandbox_type == SandboxType::None {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick rejects: well-known non-sandbox shell exit codes
|
||||
// 127: command not found, 2: misuse of shell builtins
|
||||
if exit_code == 127 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For all other cases, we assume the sandbox is the cause
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RawExecToolCallOutput {
|
||||
pub exit_status: ExitStatus,
|
||||
@@ -225,15 +259,16 @@ pub struct ExecToolCallOutput {
|
||||
}
|
||||
|
||||
async fn exec(
|
||||
ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
}: ExecParams,
|
||||
params: ExecParams,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
ctrl_c: Arc<Notify>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command, cwd, env, ..
|
||||
} = params;
|
||||
|
||||
let (program, args) = command.split_first().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
@@ -251,7 +286,7 @@ async fn exec(
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout, stdout_stream).await
|
||||
}
|
||||
|
||||
/// Consumes the output of a child process, truncating it so it is suitable for
|
||||
@@ -259,7 +294,8 @@ async fn exec(
|
||||
pub(crate) async fn consume_truncated_output(
|
||||
mut child: Child,
|
||||
ctrl_c: Arc<Notify>,
|
||||
timeout_ms: Option<u64>,
|
||||
timeout: Duration,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
// Both stdout and stderr were configured with `Stdio::piped()`
|
||||
// above, therefore `take()` should normally return `Some`. If it doesn't
|
||||
@@ -280,15 +316,18 @@ pub(crate) async fn consume_truncated_output(
|
||||
BufReader::new(stdout_reader),
|
||||
MAX_STREAM_OUTPUT,
|
||||
MAX_STREAM_OUTPUT_LINES,
|
||||
stdout_stream.clone(),
|
||||
false,
|
||||
));
|
||||
let stderr_handle = tokio::spawn(read_capped(
|
||||
BufReader::new(stderr_reader),
|
||||
MAX_STREAM_OUTPUT,
|
||||
MAX_STREAM_OUTPUT_LINES,
|
||||
stdout_stream.clone(),
|
||||
true,
|
||||
));
|
||||
|
||||
let interrupted = ctrl_c.notified();
|
||||
let timeout = Duration::from_millis(timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
|
||||
let exit_status = tokio::select! {
|
||||
result = tokio::time::timeout(timeout, child.wait()) => {
|
||||
match result {
|
||||
@@ -318,10 +357,12 @@ pub(crate) async fn consume_truncated_output(
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_capped<R: AsyncRead + Unpin>(
|
||||
async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
mut reader: R,
|
||||
max_output: usize,
|
||||
max_lines: usize,
|
||||
stream: Option<StdoutStream>,
|
||||
is_stderr: bool,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut buf = Vec::with_capacity(max_output.min(8 * 1024));
|
||||
let mut tmp = [0u8; 8192];
|
||||
@@ -335,6 +376,25 @@ async fn read_capped<R: AsyncRead + Unpin>(
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(stream) = &stream {
|
||||
let chunk = tmp[..n].to_vec();
|
||||
let msg = EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
call_id: stream.call_id.clone(),
|
||||
stream: if is_stderr {
|
||||
ExecOutputStream::Stderr
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk: ByteBuf::from(chunk),
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
#[allow(clippy::let_unit_value)]
|
||||
let _ = stream.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
// Copy into the buffer only while we still have byte and line budget.
|
||||
if remaining_bytes > 0 && remaining_lines > 0 {
|
||||
let mut copy_len = 0;
|
||||
|
||||
@@ -9,7 +9,7 @@ use tokio::time::timeout;
|
||||
/// Timeout for git commands to prevent freezing on large repositories
|
||||
const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5);
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitInfo {
|
||||
/// Current commit hash (SHA)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -28,9 +28,12 @@ mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
pub mod model_family;
|
||||
mod models;
|
||||
mod openai_model_info;
|
||||
mod openai_tools;
|
||||
@@ -38,12 +41,12 @@ pub mod plan_tool;
|
||||
mod project_doc;
|
||||
pub mod protocol;
|
||||
mod rollout;
|
||||
mod safety;
|
||||
pub(crate) mod safety;
|
||||
pub mod seatbelt;
|
||||
pub mod shell;
|
||||
pub mod spawn;
|
||||
pub mod turn_diff_tracker;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use client_common::model_supports_reasoning_summaries;
|
||||
pub use safety::get_platform_sandbox;
|
||||
|
||||
95
codex-rs/core/src/model_family.rs
Normal file
95
codex-rs/core/src/model_family.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
/// A model family is a group of models that share certain characteristics.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ModelFamily {
|
||||
/// The full model slug used to derive this model family, e.g.
|
||||
/// "gpt-4.1-2025-04-14".
|
||||
pub slug: String,
|
||||
|
||||
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
|
||||
/// with [`crate::openai_model_info::get_model_info`].
|
||||
pub family: String,
|
||||
|
||||
/// True if the model needs additional instructions on how to use the
|
||||
/// "virtual" `apply_patch` CLI.
|
||||
pub needs_special_apply_patch_instructions: bool,
|
||||
|
||||
// Whether the `reasoning` field can be set when making a request to this
|
||||
// model family. Note it has `effort` and `summary` subfields (though
|
||||
// `summary` is optional).
|
||||
pub supports_reasoning_summaries: bool,
|
||||
|
||||
// This should be set to true when the model expects a tool named
|
||||
// "local_shell" to be provided. Its contract must be understood natively by
|
||||
// the model such that its description can be omitted.
|
||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||
pub uses_local_shell_tool: bool,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
(
|
||||
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?
|
||||
) => {{
|
||||
// defaults
|
||||
let mut mf = ModelFamily {
|
||||
slug: $slug.to_string(),
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
uses_local_shell_tool: false,
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
mf.$key = $value;
|
||||
)*
|
||||
Some(mf)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! simple_model_family {
|
||||
(
|
||||
$slug:expr, $family:expr
|
||||
) => {{
|
||||
Some(ModelFamily {
|
||||
slug: $slug.to_string(),
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
uses_local_shell_tool: false,
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
if slug.starts_with("o3") {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
supports_reasoning_summaries: true,
|
||||
)
|
||||
} else if slug.starts_with("o4-mini") {
|
||||
model_family!(
|
||||
slug, "o4-mini",
|
||||
supports_reasoning_summaries: true,
|
||||
)
|
||||
} else if slug.starts_with("codex-mini-latest") {
|
||||
model_family!(
|
||||
slug, "codex-mini-latest",
|
||||
supports_reasoning_summaries: true,
|
||||
uses_local_shell_tool: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4.1") {
|
||||
model_family!(
|
||||
slug, "gpt-4.1",
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4o") {
|
||||
simple_model_family!(slug, "gpt-4o")
|
||||
} else if slug.starts_with("gpt-oss") {
|
||||
simple_model_family!(slug, "gpt-oss")
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
simple_model_family!(slug, "gpt-3.5")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,8 @@
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use codex_login::AuthMode;
|
||||
use codex_login::CodexAuth;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
@@ -12,10 +14,6 @@ use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
|
||||
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
||||
/// OpenAI.
|
||||
const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 10;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
@@ -80,7 +78,7 @@ pub struct ModelProviderInfo {
|
||||
|
||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||
#[serde(default)]
|
||||
pub requires_auth: bool,
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
@@ -88,29 +86,40 @@ impl ModelProviderInfo {
|
||||
/// reqwest Client applying:
|
||||
/// • provider-specific headers (static + env based)
|
||||
/// • Bearer auth header when an API key is available.
|
||||
/// • Auth token for OAuth.
|
||||
///
|
||||
/// When `require_api_key` is true and the provider declares an `env_key`
|
||||
/// but the variable is missing/empty, returns an [`Err`] identical to the
|
||||
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
|
||||
/// one produced by [`ModelProviderInfo::api_key`].
|
||||
pub fn create_request_builder<'a>(
|
||||
pub async fn create_request_builder<'a>(
|
||||
&'a self,
|
||||
client: &'a reqwest::Client,
|
||||
auth: &Option<CodexAuth>,
|
||||
) -> crate::error::Result<reqwest::RequestBuilder> {
|
||||
let url = self.get_full_url();
|
||||
let effective_auth = match self.api_key() {
|
||||
Ok(Some(key)) => Some(CodexAuth::from_api_key(key)),
|
||||
Ok(None) => auth.clone(),
|
||||
Err(err) => {
|
||||
if auth.is_some() {
|
||||
auth.clone()
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.get_full_url(&effective_auth);
|
||||
|
||||
let mut builder = client.post(url);
|
||||
|
||||
let api_key = self.api_key()?;
|
||||
if let Some(key) = api_key {
|
||||
builder = builder.bearer_auth(key);
|
||||
if let Some(auth) = effective_auth.as_ref() {
|
||||
builder = builder.bearer_auth(auth.get_token().await?);
|
||||
}
|
||||
|
||||
Ok(self.apply_http_headers(builder))
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_url(&self) -> String {
|
||||
let query_string = self
|
||||
.query_params
|
||||
fn get_query_string(&self) -> String {
|
||||
self.query_params
|
||||
.as_ref()
|
||||
.map_or_else(String::new, |params| {
|
||||
let full_params = params
|
||||
@@ -119,16 +128,29 @@ impl ModelProviderInfo {
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
format!("?{full_params}")
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
|
||||
let default_base_url = if matches!(
|
||||
auth,
|
||||
Some(CodexAuth {
|
||||
mode: AuthMode::ChatGPT,
|
||||
..
|
||||
})
|
||||
) {
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
} else {
|
||||
"https://api.openai.com/v1"
|
||||
};
|
||||
let query_string = self.get_query_string();
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or("https://api.openai.com/v1".to_string());
|
||||
.unwrap_or(default_base_url.to_string());
|
||||
|
||||
match self.wire_api {
|
||||
WireApi::Responses => {
|
||||
format!("{base_url}/responses{query_string}")
|
||||
}
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||
}
|
||||
}
|
||||
@@ -136,10 +158,7 @@ impl ModelProviderInfo {
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||
/// onto an existing `reqwest::RequestBuilder` and return the updated
|
||||
/// builder.
|
||||
pub fn apply_http_headers(
|
||||
&self,
|
||||
mut builder: reqwest::RequestBuilder,
|
||||
) -> reqwest::RequestBuilder {
|
||||
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
builder = builder.header(k, v);
|
||||
@@ -161,7 +180,7 @@ impl ModelProviderInfo {
|
||||
/// If `env_key` is Some, returns the API key for this provider if present
|
||||
/// (and non-empty) in the environment. If `env_key` is required but
|
||||
/// cannot be found, returns an error.
|
||||
fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => {
|
||||
let env_value = std::env::var(env_key);
|
||||
@@ -204,64 +223,103 @@ impl ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||
|
||||
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
|
||||
|
||||
/// Built-in default provider list.
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
|
||||
// We do not want to be in the business of adjucating which third-party
|
||||
// providers are bundled with Codex CLI, so we only include the OpenAI
|
||||
// provider by default. Users are encouraged to add to `model_providers`
|
||||
// in config.toml to add their own providers.
|
||||
[(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
// Allow users to override the default OpenAI endpoint by
|
||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||
// Codex at a proxy, mock server, or Azure-style deployment
|
||||
// without requiring a full TOML override for the built-in
|
||||
// OpenAI provider.
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[
|
||||
(
|
||||
"originator".to_string(),
|
||||
OPENAI_ORIGINATOR_HEADER.to_string(),
|
||||
),
|
||||
("version".to_string(), env!("CARGO_PKG_VERSION").to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
(
|
||||
"OpenAI-Organization".to_string(),
|
||||
"OPENAI_ORGANIZATION".to_string(),
|
||||
),
|
||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: true,
|
||||
},
|
||||
)]
|
||||
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
||||
// open source ("oss") providers by default. Users are encouraged to add to
|
||||
// `model_providers` in config.toml to add their own providers.
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
// Allow users to override the default OpenAI endpoint by
|
||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||
// Codex at a proxy, mock server, or Azure-style deployment
|
||||
// without requiring a full TOML override for the built-in
|
||||
// OpenAI provider.
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
(
|
||||
"OpenAI-Organization".to_string(),
|
||||
"OPENAI_ORGANIZATION".to_string(),
|
||||
),
|
||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
},
|
||||
),
|
||||
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||
// These CODEX_OSS_ environment variables are experimental: we may
|
||||
// switch to reading values from config.toml instead.
|
||||
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
{
|
||||
Some(url) => url,
|
||||
None => format!(
|
||||
"http://localhost:{port}/v1",
|
||||
port = std::env::var("CODEX_OSS_PORT")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(DEFAULT_OLLAMA_PORT)
|
||||
),
|
||||
};
|
||||
|
||||
create_oss_provider_with_base_url(&codex_oss_base_url)
|
||||
}
|
||||
|
||||
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "gpt-oss".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
@@ -286,7 +344,7 @@ base_url = "http://localhost:11434/v1"
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -315,7 +373,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -347,7 +405,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
|
||||
@@ -9,7 +9,7 @@ use serde::ser::Serializer;
|
||||
|
||||
use crate::protocol::InputItem;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseInputItem {
|
||||
Message {
|
||||
@@ -26,7 +26,7 @@ pub enum ResponseInputItem {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentItem {
|
||||
InputText { text: String },
|
||||
@@ -34,7 +34,7 @@ pub enum ContentItem {
|
||||
OutputText { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseItem {
|
||||
Message {
|
||||
@@ -45,6 +45,8 @@ pub enum ResponseItem {
|
||||
Reasoning {
|
||||
id: String,
|
||||
summary: Vec<ReasoningItemReasoningSummary>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
content: Option<Vec<ReasoningItemContent>>,
|
||||
encrypted_content: Option<String>,
|
||||
},
|
||||
LocalShellCall {
|
||||
@@ -107,7 +109,7 @@ impl From<ResponseInputItem> for ResponseItem {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LocalShellStatus {
|
||||
Completed,
|
||||
@@ -115,13 +117,13 @@ pub enum LocalShellStatus {
|
||||
Incomplete,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum LocalShellAction {
|
||||
Exec(LocalShellExecAction),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct LocalShellExecAction {
|
||||
pub command: Vec<String>,
|
||||
pub timeout_ms: Option<u64>,
|
||||
@@ -130,12 +132,18 @@ pub struct LocalShellExecAction {
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ReasoningItemReasoningSummary {
|
||||
SummaryText { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ReasoningItemContent {
|
||||
ReasoningText { text: String },
|
||||
}
|
||||
|
||||
impl From<Vec<InputItem>> for ResponseInputItem {
|
||||
fn from(items: Vec<InputItem>) -> Self {
|
||||
Self::Message {
|
||||
@@ -183,12 +191,15 @@ pub struct ShellToolCallParams {
|
||||
// The wire format uses `timeout`, which has ambiguous units, so we use
|
||||
// `timeout_ms` as the field name so it is clear in code.
|
||||
pub timeout_ms: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FunctionCallOutputPayload {
|
||||
pub content: String,
|
||||
#[expect(dead_code)]
|
||||
pub success: Option<bool>,
|
||||
}
|
||||
|
||||
@@ -295,6 +306,8 @@ mod tests {
|
||||
command: vec!["ls".to_string(), "-l".to_string()],
|
||||
workdir: Some("/tmp".to_string()),
|
||||
timeout_ms: Some(1000),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
params
|
||||
);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::model_family::ModelFamily;
|
||||
|
||||
/// Metadata about a model, particularly OpenAI models.
|
||||
/// We may want to consider including details like the pricing for
|
||||
/// input tokens, output tokens, etc., though users will need to be able to
|
||||
@@ -12,10 +14,19 @@ pub(crate) struct ModelInfo {
|
||||
pub(crate) max_output_tokens: u64,
|
||||
}
|
||||
|
||||
/// Note details such as what a model like gpt-4o is aliased to may be out of
|
||||
/// date.
|
||||
pub(crate) fn get_model_info(name: &str) -> Option<ModelInfo> {
|
||||
match name {
|
||||
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
match model_family.slug.as_str() {
|
||||
// OSS models have a 128k shared token pool.
|
||||
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
|
||||
// https://openai.com/index/gpt-oss-model-card/
|
||||
"gpt-oss-20b" => Some(ModelInfo {
|
||||
context_window: 96_000,
|
||||
max_output_tokens: 32_000,
|
||||
}),
|
||||
"gpt-oss-120b" => Some(ModelInfo {
|
||||
context_window: 96_000,
|
||||
max_output_tokens: 32_000,
|
||||
}),
|
||||
// https://platform.openai.com/docs/models/o3
|
||||
"o3" => Some(ModelInfo {
|
||||
context_window: 200_000,
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::client_common::Prompt;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(crate) struct ResponsesApiTool {
|
||||
pub(crate) name: &'static str,
|
||||
pub(crate) description: &'static str,
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
pub struct ResponsesApiTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
/// TODO: Validation. When strict is set to true, the JSON schema,
|
||||
/// `required` and `additional_properties` must be present. All fields in
|
||||
/// `properties` must be present in `required`.
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum OpenAiTool {
|
||||
#[serde(rename = "function")]
|
||||
@@ -25,78 +31,217 @@ pub(crate) enum OpenAiTool {
|
||||
LocalShell {},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
LocalShell,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolsConfig {
|
||||
pub shell_type: ConfigShellToolType,
|
||||
pub plan_tool: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
pub fn new(
|
||||
model_family: &ModelFamily,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
include_plan_tool: bool,
|
||||
) -> Self {
|
||||
let mut shell_type = if model_family.uses_local_shell_tool {
|
||||
ConfigShellToolType::LocalShell
|
||||
} else {
|
||||
ConfigShellToolType::DefaultShell
|
||||
};
|
||||
if matches!(approval_policy, AskForApproval::OnRequest) {
|
||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
shell_type,
|
||||
plan_tool: include_plan_tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic JSON‑Schema subset needed for our tool definitions
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub(crate) enum JsonSchema {
|
||||
String,
|
||||
Number,
|
||||
Boolean {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
String {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Number {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Array {
|
||||
items: Box<JsonSchema>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Object {
|
||||
properties: BTreeMap<String, JsonSchema>,
|
||||
required: &'static [&'static str],
|
||||
#[serde(rename = "additionalProperties")]
|
||||
additional_properties: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
required: Option<Vec<String>>,
|
||||
#[serde(
|
||||
rename = "additionalProperties",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
additional_properties: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Tool usage specification
|
||||
static DEFAULT_TOOLS: LazyLock<Vec<OpenAiTool>> = LazyLock::new(|| {
|
||||
fn create_shell_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String),
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: None,
|
||||
},
|
||||
);
|
||||
properties.insert("workdir".to_string(), JsonSchema::String);
|
||||
properties.insert("timeout".to_string(), JsonSchema::Number);
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
properties.insert(
|
||||
"timeout".to_string(),
|
||||
JsonSchema::Number { description: None },
|
||||
);
|
||||
|
||||
vec![OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell",
|
||||
description: "Runs a shell command, and returns its output.",
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
description: "Runs a shell command and returns its output".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: &["command"],
|
||||
additional_properties: false,
|
||||
required: Some(vec!["command".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})]
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
static DEFAULT_CODEX_MODEL_TOOLS: LazyLock<Vec<OpenAiTool>> =
|
||||
LazyLock::new(|| vec![OpenAiTool::LocalShell {}]);
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some("The command to execute".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) {
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if ask_for_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files outside the current directory, and protected folders like .git or .env{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
description,
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["command".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
pub(crate) fn create_tools_json_for_responses_api(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
include_plan_tool: bool,
|
||||
tools: &Vec<OpenAiTool>,
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// Assemble tool list: built-in tools + any extra tools from the prompt.
|
||||
let default_tools = if model.starts_with("codex") {
|
||||
&DEFAULT_CODEX_MODEL_TOOLS
|
||||
} else {
|
||||
&DEFAULT_TOOLS
|
||||
};
|
||||
let mut tools_json = Vec::with_capacity(default_tools.len() + prompt.extra_tools.len());
|
||||
for t in default_tools.iter() {
|
||||
tools_json.push(serde_json::to_value(t)?);
|
||||
}
|
||||
tools_json.extend(
|
||||
prompt
|
||||
.extra_tools
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(name, tool)| mcp_tool_to_openai_tool(name, tool)),
|
||||
);
|
||||
let mut tools_json = Vec::new();
|
||||
|
||||
if include_plan_tool {
|
||||
tools_json.push(serde_json::to_value(PLAN_TOOL.clone())?);
|
||||
for tool in tools {
|
||||
tools_json.push(serde_json::to_value(tool)?);
|
||||
}
|
||||
|
||||
Ok(tools_json)
|
||||
@@ -106,14 +251,11 @@ pub(crate) fn create_tools_json_for_responses_api(
|
||||
/// Chat Completions API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
|
||||
pub(crate) fn create_tools_json_for_chat_completions_api(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
include_plan_tool: bool,
|
||||
tools: &Vec<OpenAiTool>,
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// We start with the JSON for the Responses API and than rewrite it to match
|
||||
// the chat completions tool call format.
|
||||
let responses_api_tools_json =
|
||||
create_tools_json_for_responses_api(prompt, model, include_plan_tool)?;
|
||||
let responses_api_tools_json = create_tools_json_for_responses_api(tools)?;
|
||||
let tools_json = responses_api_tools_json
|
||||
.into_iter()
|
||||
.filter_map(|mut tool| {
|
||||
@@ -136,10 +278,10 @@ pub(crate) fn create_tools_json_for_chat_completions_api(
|
||||
Ok(tools_json)
|
||||
}
|
||||
|
||||
fn mcp_tool_to_openai_tool(
|
||||
pub(crate) fn mcp_tool_to_openai_tool(
|
||||
fully_qualified_name: String,
|
||||
tool: mcp_types::Tool,
|
||||
) -> serde_json::Value {
|
||||
) -> Result<ResponsesApiTool, serde_json::Error> {
|
||||
let mcp_types::Tool {
|
||||
description,
|
||||
mut input_schema,
|
||||
@@ -154,12 +296,205 @@ fn mcp_tool_to_openai_tool(
|
||||
input_schema.properties = Some(serde_json::Value::Object(serde_json::Map::new()));
|
||||
}
|
||||
|
||||
// TODO(mbolin): Change the contract of this function to return
|
||||
// ResponsesApiTool.
|
||||
json!({
|
||||
"name": fully_qualified_name,
|
||||
"description": description,
|
||||
"parameters": input_schema,
|
||||
"type": "function",
|
||||
let serialized_input_schema = serde_json::to_value(input_schema)?;
|
||||
let input_schema = serde_json::from_value::<JsonSchema>(serialized_input_schema)?;
|
||||
|
||||
Ok(ResponsesApiTool {
|
||||
name: fully_qualified_name,
|
||||
description: description.unwrap_or_default(),
|
||||
strict: false,
|
||||
parameters: input_schema,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of OpenAiTools based on the provided config and MCP tools.
|
||||
/// Note that the keys of mcp_tools should be fully qualified names. See
|
||||
/// [`McpConnectionManager`] for more details.
|
||||
pub(crate) fn get_openai_tools(
|
||||
config: &ToolsConfig,
|
||||
mcp_tools: Option<HashMap<String, mcp_types::Tool>>,
|
||||
) -> Vec<OpenAiTool> {
|
||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
}
|
||||
|
||||
if config.plan_tool {
|
||||
tools.push(PLAN_TOOL.clone());
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
for (name, tool) in mcp_tools {
|
||||
match mcp_tool_to_openai_tool(name.clone(), tool.clone()) {
|
||||
Ok(converted_tool) => tools.push(OpenAiTool::Function(converted_tool)),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to convert {name:?} MCP tool to OpenAI tool: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tools
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::expect_used)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_eq_tool_names(tools: &[OpenAiTool], expected_names: &[&str]) {
|
||||
let tool_names = tools
|
||||
.iter()
|
||||
.map(|tool| match tool {
|
||||
OpenAiTool::Function(ResponsesApiTool { name, .. }) => name,
|
||||
OpenAiTool::LocalShell {} => "local_shell",
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
tool_names.len(),
|
||||
expected_names.len(),
|
||||
"tool_name mismatch, {tool_names:?}, {expected_names:?}",
|
||||
);
|
||||
for (name, expected_name) in tool_names.iter().zip(expected_names.iter()) {
|
||||
assert_eq!(
|
||||
name, expected_name,
|
||||
"tool_name mismatch, {name:?}, {expected_name:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_openai_tools() {
|
||||
let model_family = find_family_for_model("codex-mini-latest")
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let config = ToolsConfig::new(
|
||||
&model_family,
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
true,
|
||||
);
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(&tools, &["local_shell", "update_plan"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_openai_tools_default_shell() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(
|
||||
&model_family,
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
true,
|
||||
);
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "update_plan"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_openai_tools_mcp_tools() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
let config = ToolsConfig::new(
|
||||
&model_family,
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
);
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
Some(HashMap::from([(
|
||||
"test_server/do_something_cool".to_string(),
|
||||
mcp_types::Tool {
|
||||
name: "do_something_cool".to_string(),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: Some(serde_json::json!({
|
||||
"string_argument": {
|
||||
"type": "string",
|
||||
},
|
||||
"number_argument": {
|
||||
"type": "number",
|
||||
},
|
||||
"object_argument": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"string_property": { "type": "string" },
|
||||
"number_property": { "type": "number" },
|
||||
},
|
||||
"required": [
|
||||
"string_property",
|
||||
"number_property"
|
||||
],
|
||||
"additionalProperties": Some(false),
|
||||
},
|
||||
})),
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
output_schema: None,
|
||||
title: None,
|
||||
annotations: None,
|
||||
description: Some("Do something cool".to_string()),
|
||||
},
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "test_server/do_something_cool"]);
|
||||
|
||||
assert_eq!(
|
||||
tools[1],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
properties: BTreeMap::from([
|
||||
(
|
||||
"string_argument".to_string(),
|
||||
JsonSchema::String { description: None }
|
||||
),
|
||||
(
|
||||
"number_argument".to_string(),
|
||||
JsonSchema::Number { description: None }
|
||||
),
|
||||
(
|
||||
"object_argument".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: BTreeMap::from([
|
||||
(
|
||||
"string_property".to_string(),
|
||||
JsonSchema::String { description: None }
|
||||
),
|
||||
(
|
||||
"number_property".to_string(),
|
||||
JsonSchema::Number { description: None }
|
||||
),
|
||||
]),
|
||||
required: Some(vec![
|
||||
"string_property".to_string(),
|
||||
"number_property".to_string(),
|
||||
]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
),
|
||||
]),
|
||||
required: None,
|
||||
additional_properties: None,
|
||||
},
|
||||
description: "Do something cool".to_string(),
|
||||
strict: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,23 +39,30 @@ pub struct UpdatePlanArgs {
|
||||
|
||||
pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
||||
let mut plan_item_props = BTreeMap::new();
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String);
|
||||
plan_item_props.insert("status".to_string(), JsonSchema::String);
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
||||
plan_item_props.insert(
|
||||
"status".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
|
||||
let plan_items_schema = JsonSchema::Array {
|
||||
description: Some("The list of steps".to_string()),
|
||||
items: Box::new(JsonSchema::Object {
|
||||
properties: plan_item_props,
|
||||
required: &["step", "status"],
|
||||
additional_properties: false,
|
||||
required: Some(vec!["step".to_string(), "status".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
}),
|
||||
};
|
||||
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert("explanation".to_string(), JsonSchema::String);
|
||||
properties.insert(
|
||||
"explanation".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
properties.insert("plan".to_string(), plan_items_schema);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "update_plan",
|
||||
name: "update_plan".to_string(),
|
||||
description: r#"Use the update_plan tool to keep the user updated on the current plan for the task.
|
||||
After understanding the user's task, call the update_plan tool with an initial plan. An example of a plan:
|
||||
1. Explore the codebase to find relevant files (status: in_progress)
|
||||
@@ -66,12 +73,12 @@ Until all the steps are finished, there should always be exactly one in_progress
|
||||
Call the update_plan tool whenever you finish a step, marking the completed step as `completed` and marking the next step as `in_progress`.
|
||||
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step.
|
||||
Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
||||
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#,
|
||||
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: &["plan"],
|
||||
additional_properties: false,
|
||||
required: Some(vec!["plan".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::time::Duration;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_bytes::ByteBuf;
|
||||
use strum_macros::Display;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -121,6 +122,10 @@ pub enum Op {
|
||||
/// Request a single history entry identified by `log_id` + `offset`.
|
||||
GetHistoryEntryRequest { offset: usize, log_id: u64 },
|
||||
|
||||
/// Request the agent to summarize the current conversation context.
|
||||
/// The agent will use its existing context (either conversation history or previous response id)
|
||||
/// to generate a summary which will be returned as an AgentMessage event.
|
||||
Compact,
|
||||
/// Request to shut down codex instance.
|
||||
Shutdown,
|
||||
}
|
||||
@@ -145,13 +150,17 @@ pub enum AskForApproval {
|
||||
/// the user to approve execution without a sandbox.
|
||||
OnFailure,
|
||||
|
||||
/// The model decides when to ask the user for approval.
|
||||
OnRequest,
|
||||
|
||||
/// Never ask the user to approve commands. Failures are immediately returned
|
||||
/// to the model, and never escalated to the user for approval.
|
||||
Never,
|
||||
}
|
||||
|
||||
/// Determines execution restrictions for model shell commands.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Display)]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
#[serde(tag = "mode", rename_all = "kebab-case")]
|
||||
pub enum SandboxPolicy {
|
||||
/// No restrictions whatsoever. Use with caution.
|
||||
@@ -175,9 +184,29 @@ pub enum SandboxPolicy {
|
||||
/// default.
|
||||
#[serde(default)]
|
||||
network_access: bool,
|
||||
|
||||
/// When set to `true`, will include defaults like the current working
|
||||
/// directory and TMPDIR (on macOS). When `false`, only `writable_roots`
|
||||
/// are used. (Mainly used for testing.)
|
||||
#[serde(default = "default_true")]
|
||||
include_default_writable_roots: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// A writable root path accompanied by a list of subpaths that should remain
|
||||
/// read‑only even when the root is writable. This is primarily used to ensure
|
||||
/// top‑level VCS metadata directories (e.g. `.git`) under a writable root are
|
||||
/// not modified by the agent.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct WritableRoot {
|
||||
pub root: PathBuf,
|
||||
pub read_only_subpaths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl FromStr for SandboxPolicy {
|
||||
type Err = serde_json::Error;
|
||||
|
||||
@@ -199,6 +228,7 @@ impl SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,27 +254,51 @@ impl SandboxPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the list of writable roots that should be passed down to the
|
||||
/// Landlock rules installer, tailored to the current working directory.
|
||||
pub fn get_writable_roots_with_cwd(&self, cwd: &Path) -> Vec<PathBuf> {
|
||||
/// Returns the list of writable roots (tailored to the current working
|
||||
/// directory) together with subpaths that should remain read‑only under
|
||||
/// each writable root.
|
||||
pub fn get_writable_roots_with_cwd(&self, cwd: &Path) -> Vec<WritableRoot> {
|
||||
match self {
|
||||
SandboxPolicy::DangerFullAccess => Vec::new(),
|
||||
SandboxPolicy::ReadOnly => Vec::new(),
|
||||
SandboxPolicy::WorkspaceWrite { writable_roots, .. } => {
|
||||
let mut roots = writable_roots.clone();
|
||||
roots.push(cwd.to_path_buf());
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots,
|
||||
include_default_writable_roots,
|
||||
..
|
||||
} => {
|
||||
// Start from explicitly configured writable roots.
|
||||
let mut roots: Vec<PathBuf> = writable_roots.clone();
|
||||
|
||||
// Also include the per-user tmp dir on macOS.
|
||||
// Note this is added dynamically rather than storing it in
|
||||
// writable_roots because writable_roots contains only static
|
||||
// values deserialized from the config file.
|
||||
if cfg!(target_os = "macos") {
|
||||
if let Some(tmpdir) = std::env::var_os("TMPDIR") {
|
||||
roots.push(PathBuf::from(tmpdir));
|
||||
// Optionally include defaults (cwd and TMPDIR on macOS).
|
||||
if *include_default_writable_roots {
|
||||
roots.push(cwd.to_path_buf());
|
||||
|
||||
// Also include the per-user tmp dir on macOS.
|
||||
// Note this is added dynamically rather than storing it in
|
||||
// `writable_roots` because `writable_roots` contains only static
|
||||
// values deserialized from the config file.
|
||||
if cfg!(target_os = "macos") {
|
||||
if let Some(tmpdir) = std::env::var_os("TMPDIR") {
|
||||
roots.push(PathBuf::from(tmpdir));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each root, compute subpaths that should remain read-only.
|
||||
roots
|
||||
.into_iter()
|
||||
.map(|writable_root| {
|
||||
let mut subpaths = Vec::new();
|
||||
let top_level_git = writable_root.join(".git");
|
||||
if top_level_git.is_dir() {
|
||||
subpaths.push(top_level_git);
|
||||
}
|
||||
WritableRoot {
|
||||
root: writable_root,
|
||||
read_only_subpaths: subpaths,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,6 +363,12 @@ pub enum EventMsg {
|
||||
/// Agent reasoning delta event from agent.
|
||||
AgentReasoningDelta(AgentReasoningDeltaEvent),
|
||||
|
||||
/// Raw chain-of-thought from agent.
|
||||
AgentReasoningRawContent(AgentReasoningRawContentEvent),
|
||||
|
||||
/// Agent reasoning content delta event from agent.
|
||||
AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent),
|
||||
|
||||
/// Ack the client's configure message.
|
||||
SessionConfigured(SessionConfiguredEvent),
|
||||
|
||||
@@ -319,6 +379,9 @@ pub enum EventMsg {
|
||||
/// Notification that the server is about to execute a command.
|
||||
ExecCommandBegin(ExecCommandBeginEvent),
|
||||
|
||||
/// Incremental chunk of output from a running command.
|
||||
ExecCommandOutputDelta(ExecCommandOutputDeltaEvent),
|
||||
|
||||
ExecCommandEnd(ExecCommandEndEvent),
|
||||
|
||||
ExecApprovalRequest(ExecApprovalRequestEvent),
|
||||
@@ -334,6 +397,8 @@ pub enum EventMsg {
|
||||
/// Notification that a patch application has finished.
|
||||
PatchApplyEnd(PatchApplyEndEvent),
|
||||
|
||||
TurnDiff(TurnDiffEvent),
|
||||
|
||||
/// Response to GetHistoryEntryRequest.
|
||||
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),
|
||||
|
||||
@@ -409,6 +474,16 @@ pub struct AgentReasoningEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningRawContentEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningRawContentDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningDeltaEvent {
|
||||
pub delta: String,
|
||||
@@ -470,6 +545,26 @@ pub struct ExecCommandEndEvent {
|
||||
pub stderr: String,
|
||||
/// The command's exit code.
|
||||
pub exit_code: i32,
|
||||
/// The duration of the command execution.
|
||||
pub duration: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ExecOutputStream {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ExecCommandOutputDeltaEvent {
|
||||
/// Identifier for the ExecCommandBegin that produced this chunk.
|
||||
pub call_id: String,
|
||||
/// Which stream produced this chunk.
|
||||
pub stream: ExecOutputStream,
|
||||
/// Raw bytes from the stream (may not be valid UTF-8).
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub chunk: ByteBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -525,6 +620,11 @@ pub struct PatchApplyEndEvent {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TurnDiffEvent {
|
||||
pub unified_diff: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GetHistoryEntryResponseEvent {
|
||||
pub offset: usize,
|
||||
@@ -547,6 +647,10 @@ pub struct SessionConfiguredEvent {
|
||||
|
||||
/// Current number of entries in the history log.
|
||||
pub history_entry_count: usize,
|
||||
|
||||
/// Absolute path to the rollout file for this session, if recording is enabled.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rollout_path: Option<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
/// User's decision in response to an ExecApprovalRequest.
|
||||
@@ -609,6 +713,7 @@ mod tests {
|
||||
model: "codex-mini-latest".to_string(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
rollout_path: None,
|
||||
}),
|
||||
};
|
||||
let serialized = serde_json::to_string(&event).unwrap();
|
||||
|
||||
@@ -66,6 +66,8 @@ pub struct SavedSession {
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
/// Absolute path to the rollout file for this session.
|
||||
path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
@@ -87,6 +89,7 @@ impl RolloutRecorder {
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
path,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
@@ -118,7 +121,7 @@ impl RolloutRecorder {
|
||||
cwd,
|
||||
));
|
||||
|
||||
Ok(Self { tx })
|
||||
Ok(Self { tx, path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
@@ -223,7 +226,13 @@ impl RolloutRecorder {
|
||||
cwd,
|
||||
));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
Ok((
|
||||
Self {
|
||||
tx,
|
||||
path: path.to_path_buf(),
|
||||
},
|
||||
saved,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
@@ -240,6 +249,11 @@ impl RolloutRecorder {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the absolute path to the rollout file recorded by this session.
|
||||
pub fn path(&self) -> &std::path::Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
@@ -251,6 +265,9 @@ struct LogFileInfo {
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
|
||||
/// Absolute path to the rollout file on disk.
|
||||
path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
@@ -284,6 +301,7 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
file,
|
||||
session_id,
|
||||
timestamp,
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::is_safe_command::is_known_safe_command;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum SafetyCheck {
|
||||
AutoApprove { sandbox_type: SandboxType },
|
||||
AskUser,
|
||||
@@ -31,7 +31,7 @@ pub fn assess_patch_safety(
|
||||
}
|
||||
|
||||
match policy {
|
||||
AskForApproval::OnFailure | AskForApproval::Never => {
|
||||
AskForApproval::OnFailure | AskForApproval::Never | AskForApproval::OnRequest => {
|
||||
// Continue to see if this can be auto-approved.
|
||||
}
|
||||
// TODO(ragona): I'm not sure this is actually correct? I believe in this case
|
||||
@@ -76,6 +76,7 @@ pub fn assess_command_safety(
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
approved: &HashSet<Vec<String>>,
|
||||
with_escalated_permissions: bool,
|
||||
) -> SafetyCheck {
|
||||
// A command is "trusted" because either:
|
||||
// - it belongs to a set of commands we consider "safe" by default, or
|
||||
@@ -96,12 +97,13 @@ pub fn assess_command_safety(
|
||||
};
|
||||
}
|
||||
|
||||
assess_safety_for_untrusted_command(approval_policy, sandbox_policy)
|
||||
assess_safety_for_untrusted_command(approval_policy, sandbox_policy, with_escalated_permissions)
|
||||
}
|
||||
|
||||
pub(crate) fn assess_safety_for_untrusted_command(
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
with_escalated_permissions: bool,
|
||||
) -> SafetyCheck {
|
||||
use AskForApproval::*;
|
||||
use SandboxPolicy::*;
|
||||
@@ -113,9 +115,23 @@ pub(crate) fn assess_safety_for_untrusted_command(
|
||||
// commands.
|
||||
SafetyCheck::AskUser
|
||||
}
|
||||
(OnFailure, DangerFullAccess) | (Never, DangerFullAccess) => SafetyCheck::AutoApprove {
|
||||
(OnFailure, DangerFullAccess)
|
||||
| (Never, DangerFullAccess)
|
||||
| (OnRequest, DangerFullAccess) => SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
},
|
||||
(OnRequest, ReadOnly) | (OnRequest, WorkspaceWrite { .. }) => {
|
||||
if with_escalated_permissions {
|
||||
SafetyCheck::AskUser
|
||||
} else {
|
||||
match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
// Fall back to asking since the command is untrusted and
|
||||
// we do not have a sandbox available
|
||||
None => SafetyCheck::AskUser,
|
||||
}
|
||||
}
|
||||
}
|
||||
(Never, ReadOnly)
|
||||
| (Never, WorkspaceWrite { .. })
|
||||
| (OnFailure, ReadOnly)
|
||||
@@ -264,4 +280,47 @@ mod tests {
|
||||
&cwd,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_escalated_privileges() {
|
||||
// Should not be a trusted command
|
||||
let command = vec!["git commit".to_string()];
|
||||
let approval_policy = AskForApproval::OnRequest;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let approved: HashSet<Vec<String>> = HashSet::new();
|
||||
let request_escalated_privileges = true;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
assert_eq!(safety_check, SafetyCheck::AskUser);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_escalated_privileges_no_sandbox_fallback() {
|
||||
let command = vec!["git".to_string(), "commit".to_string()];
|
||||
let approval_policy = AskForApproval::OnRequest;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let approved: HashSet<Vec<String>> = HashSet::new();
|
||||
let request_escalated_privileges = false;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
let expected = match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
None => SafetyCheck::AskUser,
|
||||
};
|
||||
assert_eq!(safety_check, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::path::PathBuf;
|
||||
use tokio::process::Child;
|
||||
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
|
||||
@@ -20,10 +21,11 @@ pub async fn spawn_command_under_seatbelt(
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: PathBuf,
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
mut env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
let args = create_seatbelt_command_args(command, sandbox_policy, &cwd);
|
||||
let arg0 = None;
|
||||
env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
spawn_child_async(
|
||||
PathBuf::from(MACOS_PATH_TO_SEATBELT_EXECUTABLE),
|
||||
args,
|
||||
@@ -50,16 +52,38 @@ fn create_seatbelt_command_args(
|
||||
)
|
||||
} else {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(cwd);
|
||||
let (writable_folder_policies, cli_args): (Vec<String>, Vec<String>) = writable_roots
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, root)| {
|
||||
let param_name = format!("WRITABLE_ROOT_{index}");
|
||||
let policy: String = format!("(subpath (param \"{param_name}\"))");
|
||||
let cli_arg = format!("-D{param_name}={}", root.to_string_lossy());
|
||||
(policy, cli_arg)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let mut writable_folder_policies: Vec<String> = Vec::new();
|
||||
let mut cli_args: Vec<String> = Vec::new();
|
||||
|
||||
for (index, wr) in writable_roots.iter().enumerate() {
|
||||
// Canonicalize to avoid mismatches like /var vs /private/var on macOS.
|
||||
let canonical_root = wr.root.canonicalize().unwrap_or_else(|_| wr.root.clone());
|
||||
let root_param = format!("WRITABLE_ROOT_{index}");
|
||||
cli_args.push(format!(
|
||||
"-D{root_param}={}",
|
||||
canonical_root.to_string_lossy()
|
||||
));
|
||||
|
||||
if wr.read_only_subpaths.is_empty() {
|
||||
writable_folder_policies.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
} else {
|
||||
// Add parameters for each read-only subpath and generate
|
||||
// the `(require-not ...)` clauses.
|
||||
let mut require_parts: Vec<String> = Vec::new();
|
||||
require_parts.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
for (subpath_index, ro) in wr.read_only_subpaths.iter().enumerate() {
|
||||
let canonical_ro = ro.canonicalize().unwrap_or_else(|_| ro.clone());
|
||||
let ro_param = format!("WRITABLE_ROOT_{index}_RO_{subpath_index}");
|
||||
cli_args.push(format!("-D{ro_param}={}", canonical_ro.to_string_lossy()));
|
||||
require_parts
|
||||
.push(format!("(require-not (subpath (param \"{ro_param}\")))"));
|
||||
}
|
||||
let policy_component = format!("(require-all {} )", require_parts.join(" "));
|
||||
writable_folder_policies.push(policy_component);
|
||||
}
|
||||
}
|
||||
|
||||
if writable_folder_policies.is_empty() {
|
||||
("".to_string(), Vec::<String>::new())
|
||||
} else {
|
||||
@@ -88,9 +112,201 @@ fn create_seatbelt_command_args(
|
||||
let full_policy = format!(
|
||||
"{MACOS_SEATBELT_BASE_POLICY}\n{file_read_policy}\n{file_write_policy}\n{network_policy}"
|
||||
);
|
||||
|
||||
let mut seatbelt_args: Vec<String> = vec!["-p".to_string(), full_policy];
|
||||
seatbelt_args.extend(extra_cli_args);
|
||||
seatbelt_args.push("--".to_string());
|
||||
seatbelt_args.extend(command);
|
||||
seatbelt_args
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![expect(clippy::expect_used)]
|
||||
use super::MACOS_SEATBELT_BASE_POLICY;
|
||||
use super::create_seatbelt_command_args;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_with_read_only_git_subpath() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults like cwd or TMPDIR.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
tmp.path(),
|
||||
);
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1"))
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
root_without_git_canon.to_string_lossy()
|
||||
),
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(args, expected_args);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_for_cwd_as_git_repo() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
..
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that does not specify any writable_roots, but does
|
||||
// use the default ones (cwd and TMPDIR) and verifies the `.git` check
|
||||
// is done properly for cwd.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
root_with_git.as_path(),
|
||||
);
|
||||
|
||||
let tmpdir_env_var = if cfg!(target_os = "macos") {
|
||||
std::env::var("TMPDIR")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.and_then(|p| p.canonicalize().ok())
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tempdir_policy_entry = if tmpdir_env_var.is_some() {
|
||||
" (subpath (param \"WRITABLE_ROOT_1\"))"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ){tempdir_policy_entry}
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(p) = tmpdir_env_var {
|
||||
expected_args.push(format!("-DWRITABLE_ROOT_1={p}"));
|
||||
}
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
|
||||
assert_eq!(args, expected_args);
|
||||
}
|
||||
|
||||
struct PopulatedTmp {
|
||||
root_with_git: PathBuf,
|
||||
root_without_git: PathBuf,
|
||||
root_with_git_canon: PathBuf,
|
||||
root_with_git_git_canon: PathBuf,
|
||||
root_without_git_canon: PathBuf,
|
||||
}
|
||||
|
||||
fn populate_tmpdir(tmp: &Path) -> PopulatedTmp {
|
||||
let root_with_git = tmp.join("with_git");
|
||||
let root_without_git = tmp.join("no_git");
|
||||
fs::create_dir_all(&root_with_git).expect("create with_git");
|
||||
fs::create_dir_all(&root_without_git).expect("create no_git");
|
||||
fs::create_dir_all(root_with_git.join(".git")).expect("create .git");
|
||||
|
||||
// Ensure we have canonical paths for -D parameter matching.
|
||||
let root_with_git_canon = root_with_git.canonicalize().expect("canonicalize with_git");
|
||||
let root_with_git_git_canon = root_with_git_canon.join(".git");
|
||||
let root_without_git_canon = root_without_git
|
||||
.canonicalize()
|
||||
.expect("canonicalize no_git");
|
||||
PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,3 +65,7 @@
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
@@ -215,11 +215,14 @@ mod tests {
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
Arc::new(Notify::new()),
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -17,6 +17,11 @@ use crate::protocol::SandboxPolicy;
|
||||
/// attributes, so this may change in the future.
|
||||
pub const CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR: &str = "CODEX_SANDBOX_NETWORK_DISABLED";
|
||||
|
||||
/// Should be set when the process is spawned under a sandbox. Currently, the
|
||||
/// value is "seatbelt" for macOS, but it may change in the future to
|
||||
/// accommodate sandboxing configuration and other sandboxing mechanisms.
|
||||
pub const CODEX_SANDBOX_ENV_VAR: &str = "CODEX_SANDBOX";
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StdioPolicy {
|
||||
RedirectForShellTool,
|
||||
|
||||
887
codex-rs/core/src/turn_diff_tracker.rs
Normal file
887
codex-rs/core/src/turn_diff_tracker.rs
Normal file
@@ -0,0 +1,887 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use sha1::digest::Output;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::protocol::FileChange;
|
||||
|
||||
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
|
||||
const DEV_NULL: &str = "/dev/null";
|
||||
|
||||
struct BaselineFileInfo {
|
||||
path: PathBuf,
|
||||
content: Vec<u8>,
|
||||
mode: FileMode,
|
||||
oid: String,
|
||||
}
|
||||
|
||||
/// Tracks sets of changes to files and exposes the overall unified diff.
|
||||
/// Internally, the way this works is now:
|
||||
/// 1. Maintain an in-memory baseline snapshot of files when they are first seen.
|
||||
/// For new additions, do not create a baseline so that diffs are shown as proper additions (using /dev/null).
|
||||
/// 2. Keep a stable internal filename (uuid) per external path for rename tracking.
|
||||
/// 3. To compute the aggregated unified diff, compare each baseline snapshot to the current file on disk entirely in-memory
|
||||
/// using the `similar` crate and emit unified diffs with rewritten external paths.
|
||||
#[derive(Default)]
|
||||
pub struct TurnDiffTracker {
|
||||
/// Map external path -> internal filename (uuid).
|
||||
external_to_temp_name: HashMap<PathBuf, String>,
|
||||
/// Internal filename -> baseline file info.
|
||||
baseline_file_info: HashMap<String, BaselineFileInfo>,
|
||||
/// Internal filename -> external path as of current accumulated state (after applying all changes).
|
||||
/// This is where renames are tracked.
|
||||
temp_name_to_current_path: HashMap<String, PathBuf>,
|
||||
/// Cache of known git worktree roots to avoid repeated filesystem walks.
|
||||
git_root_cache: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl TurnDiffTracker {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Front-run apply patch calls to track the starting contents of any modified files.
|
||||
/// - Creates an in-memory baseline snapshot for files that already exist on disk when first seen.
|
||||
/// - For additions, we intentionally do not create a baseline snapshot so that diffs are proper additions.
|
||||
/// - Also updates internal mappings for move/rename events.
|
||||
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
|
||||
for (path, change) in changes.iter() {
|
||||
// Ensure a stable internal filename exists for this external path.
|
||||
if !self.external_to_temp_name.contains_key(path) {
|
||||
let internal = Uuid::new_v4().to_string();
|
||||
self.external_to_temp_name
|
||||
.insert(path.clone(), internal.clone());
|
||||
self.temp_name_to_current_path
|
||||
.insert(internal.clone(), path.clone());
|
||||
|
||||
// If the file exists on disk now, snapshot as baseline; else leave missing to represent /dev/null.
|
||||
let baseline_file_info = if path.exists() {
|
||||
let mode = file_mode_for_path(path);
|
||||
let mode_val = mode.unwrap_or(FileMode::Regular);
|
||||
let content = blob_bytes(path, &mode_val).unwrap_or_default();
|
||||
let oid = if mode == Some(FileMode::Symlink) {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(&content))
|
||||
} else {
|
||||
self.git_blob_oid_for_path(path)
|
||||
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(&content)))
|
||||
};
|
||||
Some(BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content,
|
||||
mode: mode_val,
|
||||
oid,
|
||||
})
|
||||
} else {
|
||||
Some(BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content: vec![],
|
||||
mode: FileMode::Regular,
|
||||
oid: ZERO_OID.to_string(),
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(baseline_file_info) = baseline_file_info {
|
||||
self.baseline_file_info
|
||||
.insert(internal.clone(), baseline_file_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Track rename/move in current mapping if provided in an Update.
|
||||
if let FileChange::Update {
|
||||
move_path: Some(dest),
|
||||
..
|
||||
} = change
|
||||
{
|
||||
let uuid_filename = match self.external_to_temp_name.get(path) {
|
||||
Some(i) => i.clone(),
|
||||
None => {
|
||||
// This should be rare, but if we haven't mapped the source, create it with no baseline.
|
||||
let i = Uuid::new_v4().to_string();
|
||||
self.baseline_file_info.insert(
|
||||
i.clone(),
|
||||
BaselineFileInfo {
|
||||
path: path.clone(),
|
||||
content: vec![],
|
||||
mode: FileMode::Regular,
|
||||
oid: ZERO_OID.to_string(),
|
||||
},
|
||||
);
|
||||
i
|
||||
}
|
||||
};
|
||||
// Update current external mapping for temp file name.
|
||||
self.temp_name_to_current_path
|
||||
.insert(uuid_filename.clone(), dest.clone());
|
||||
// Update forward file_mapping: external current -> internal name.
|
||||
self.external_to_temp_name.remove(path);
|
||||
self.external_to_temp_name
|
||||
.insert(dest.clone(), uuid_filename);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn get_path_for_internal(&self, internal: &str) -> Option<PathBuf> {
|
||||
self.temp_name_to_current_path
|
||||
.get(internal)
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
self.baseline_file_info
|
||||
.get(internal)
|
||||
.map(|info| info.path.clone())
|
||||
})
|
||||
}
|
||||
|
||||
/// Find the git worktree root for a file/directory by walking up to the first ancestor containing a `.git` entry.
|
||||
/// Uses a simple cache of known roots and avoids negative-result caching for simplicity.
|
||||
fn find_git_root_cached(&mut self, start: &Path) -> Option<PathBuf> {
|
||||
let dir = if start.is_dir() {
|
||||
start
|
||||
} else {
|
||||
start.parent()?
|
||||
};
|
||||
|
||||
// Fast path: if any cached root is an ancestor of this path, use it.
|
||||
if let Some(root) = self
|
||||
.git_root_cache
|
||||
.iter()
|
||||
.find(|r| dir.starts_with(r))
|
||||
.cloned()
|
||||
{
|
||||
return Some(root);
|
||||
}
|
||||
|
||||
// Walk up to find a `.git` marker.
|
||||
let mut cur = dir.to_path_buf();
|
||||
loop {
|
||||
let git_marker = cur.join(".git");
|
||||
if git_marker.is_dir() || git_marker.is_file() {
|
||||
if !self.git_root_cache.iter().any(|r| r == &cur) {
|
||||
self.git_root_cache.push(cur.clone());
|
||||
}
|
||||
return Some(cur);
|
||||
}
|
||||
|
||||
// On Windows, avoid walking above the drive or UNC share root.
|
||||
#[cfg(windows)]
|
||||
{
|
||||
if is_windows_drive_or_unc_root(&cur) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(parent) = cur.parent() {
|
||||
cur = parent.to_path_buf();
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a display string for `path` relative to its git root if found, else absolute.
|
||||
fn relative_to_git_root_str(&mut self, path: &Path) -> String {
|
||||
let s = if let Some(root) = self.find_git_root_cached(path) {
|
||||
if let Ok(rel) = path.strip_prefix(&root) {
|
||||
rel.display().to_string()
|
||||
} else {
|
||||
path.display().to_string()
|
||||
}
|
||||
} else {
|
||||
path.display().to_string()
|
||||
};
|
||||
s.replace('\\', "/")
|
||||
}
|
||||
|
||||
/// Ask git to compute the blob SHA-1 for the file at `path` within its repository.
|
||||
/// Returns None if no repository is found or git invocation fails.
|
||||
fn git_blob_oid_for_path(&mut self, path: &Path) -> Option<String> {
|
||||
let root = self.find_git_root_cached(path)?;
|
||||
// Compute a path relative to the repo root for better portability across platforms.
|
||||
let rel = path.strip_prefix(&root).unwrap_or(path);
|
||||
let output = Command::new("git")
|
||||
.arg("-C")
|
||||
.arg(&root)
|
||||
.arg("hash-object")
|
||||
.arg("--")
|
||||
.arg(rel)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let s = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if s.len() == 40 { Some(s) } else { None }
|
||||
}
|
||||
|
||||
/// Recompute the aggregated unified diff by comparing all of the in-memory snapshots that were
|
||||
/// collected before the first time they were touched by apply_patch during this turn with
|
||||
/// the current repo state.
|
||||
pub fn get_unified_diff(&mut self) -> Result<Option<String>> {
|
||||
let mut aggregated = String::new();
|
||||
|
||||
// Compute diffs per tracked internal file in a stable order by external path.
|
||||
let mut baseline_file_names: Vec<String> =
|
||||
self.baseline_file_info.keys().cloned().collect();
|
||||
// Sort lexicographically by full repo-relative path to match git behavior.
|
||||
baseline_file_names.sort_by_key(|internal| {
|
||||
self.get_path_for_internal(internal)
|
||||
.map(|p| self.relative_to_git_root_str(&p))
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
for internal in baseline_file_names {
|
||||
aggregated.push_str(self.get_file_diff(&internal).as_str());
|
||||
if !aggregated.ends_with('\n') {
|
||||
aggregated.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
if aggregated.trim().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(aggregated))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
|
||||
let mut aggregated = String::new();
|
||||
|
||||
// Snapshot lightweight fields only.
|
||||
let (baseline_external_path, baseline_mode, left_oid) = {
|
||||
if let Some(info) = self.baseline_file_info.get(internal_file_name) {
|
||||
(info.path.clone(), info.mode, info.oid.clone())
|
||||
} else {
|
||||
(PathBuf::new(), FileMode::Regular, ZERO_OID.to_string())
|
||||
}
|
||||
};
|
||||
let current_external_path = match self.get_path_for_internal(internal_file_name) {
|
||||
Some(p) => p,
|
||||
None => return aggregated,
|
||||
};
|
||||
|
||||
let current_mode = file_mode_for_path(¤t_external_path).unwrap_or(FileMode::Regular);
|
||||
let right_bytes = blob_bytes(¤t_external_path, ¤t_mode);
|
||||
|
||||
// Compute displays with &mut self before borrowing any baseline content.
|
||||
let left_display = self.relative_to_git_root_str(&baseline_external_path);
|
||||
let right_display = self.relative_to_git_root_str(¤t_external_path);
|
||||
|
||||
// Compute right oid before borrowing baseline content.
|
||||
let right_oid = if let Some(b) = right_bytes.as_ref() {
|
||||
if current_mode == FileMode::Symlink {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(b))
|
||||
} else {
|
||||
self.git_blob_oid_for_path(¤t_external_path)
|
||||
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(b)))
|
||||
}
|
||||
} else {
|
||||
ZERO_OID.to_string()
|
||||
};
|
||||
|
||||
// Borrow baseline content only after all &mut self uses are done.
|
||||
let left_present = left_oid.as_str() != ZERO_OID;
|
||||
let left_bytes: Option<&[u8]> = if left_present {
|
||||
self.baseline_file_info
|
||||
.get(internal_file_name)
|
||||
.map(|i| i.content.as_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Fast path: identical bytes or both missing.
|
||||
if left_bytes == right_bytes.as_deref() {
|
||||
return aggregated;
|
||||
}
|
||||
|
||||
aggregated.push_str(&format!("diff --git a/{left_display} b/{right_display}\n"));
|
||||
|
||||
let is_add = !left_present && right_bytes.is_some();
|
||||
let is_delete = left_present && right_bytes.is_none();
|
||||
|
||||
if is_add {
|
||||
aggregated.push_str(&format!("new file mode {current_mode}\n"));
|
||||
} else if is_delete {
|
||||
aggregated.push_str(&format!("deleted file mode {baseline_mode}\n"));
|
||||
} else if baseline_mode != current_mode {
|
||||
aggregated.push_str(&format!("old mode {baseline_mode}\n"));
|
||||
aggregated.push_str(&format!("new mode {current_mode}\n"));
|
||||
}
|
||||
|
||||
let left_text = left_bytes.and_then(|b| std::str::from_utf8(b).ok());
|
||||
let right_text = right_bytes
|
||||
.as_deref()
|
||||
.and_then(|b| std::str::from_utf8(b).ok());
|
||||
|
||||
let can_text_diff = matches!(
|
||||
(left_text, right_text, is_add, is_delete),
|
||||
(Some(_), Some(_), _, _) | (_, Some(_), true, _) | (Some(_), _, _, true)
|
||||
);
|
||||
|
||||
if can_text_diff {
|
||||
let l = left_text.unwrap_or("");
|
||||
let r = right_text.unwrap_or("");
|
||||
|
||||
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
|
||||
|
||||
let old_header = if left_present {
|
||||
format!("a/{left_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
let new_header = if right_bytes.is_some() {
|
||||
format!("b/{right_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
|
||||
let diff = similar::TextDiff::from_lines(l, r);
|
||||
let unified = diff
|
||||
.unified_diff()
|
||||
.context_radius(3)
|
||||
.header(&old_header, &new_header)
|
||||
.to_string();
|
||||
|
||||
aggregated.push_str(&unified);
|
||||
} else {
|
||||
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
|
||||
let old_header = if left_present {
|
||||
format!("a/{left_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
let new_header = if right_bytes.is_some() {
|
||||
format!("b/{right_display}")
|
||||
} else {
|
||||
DEV_NULL.to_string()
|
||||
};
|
||||
aggregated.push_str(&format!("--- {old_header}\n"));
|
||||
aggregated.push_str(&format!("+++ {new_header}\n"));
|
||||
aggregated.push_str("Binary files differ\n");
|
||||
}
|
||||
aggregated
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the Git SHA-1 blob object ID for the given content (bytes).
|
||||
fn git_blob_sha1_hex_bytes(data: &[u8]) -> Output<sha1::Sha1> {
|
||||
// Git blob hash is sha1 of: "blob <len>\0<data>"
|
||||
let header = format!("blob {}\0", data.len());
|
||||
use sha1::Digest;
|
||||
let mut hasher = sha1::Sha1::new();
|
||||
hasher.update(header.as_bytes());
|
||||
hasher.update(data);
|
||||
hasher.finalize()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum FileMode {
|
||||
Regular,
|
||||
#[cfg(unix)]
|
||||
Executable,
|
||||
Symlink,
|
||||
}
|
||||
|
||||
impl FileMode {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
FileMode::Regular => "100644",
|
||||
#[cfg(unix)]
|
||||
FileMode::Executable => "100755",
|
||||
FileMode::Symlink => "120000",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FileMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn file_mode_for_path(path: &Path) -> Option<FileMode> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let meta = fs::symlink_metadata(path).ok()?;
|
||||
let ft = meta.file_type();
|
||||
if ft.is_symlink() {
|
||||
return Some(FileMode::Symlink);
|
||||
}
|
||||
let mode = meta.permissions().mode();
|
||||
let is_exec = (mode & 0o111) != 0;
|
||||
Some(if is_exec {
|
||||
FileMode::Executable
|
||||
} else {
|
||||
FileMode::Regular
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn file_mode_for_path(_path: &Path) -> Option<FileMode> {
|
||||
// Default to non-executable on non-unix.
|
||||
Some(FileMode::Regular)
|
||||
}
|
||||
|
||||
fn blob_bytes(path: &Path, mode: &FileMode) -> Option<Vec<u8>> {
|
||||
if path.exists() {
|
||||
let contents = if *mode == FileMode::Symlink {
|
||||
symlink_blob_bytes(path)
|
||||
.ok_or_else(|| anyhow!("failed to read symlink target for {}", path.display()))
|
||||
} else {
|
||||
fs::read(path)
|
||||
.with_context(|| format!("failed to read current file for diff {}", path.display()))
|
||||
};
|
||||
contents.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn symlink_blob_bytes(path: &Path) -> Option<Vec<u8>> {
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
let target = std::fs::read_link(path).ok()?;
|
||||
Some(target.as_os_str().as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn symlink_blob_bytes(_path: &Path) -> Option<Vec<u8>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn is_windows_drive_or_unc_root(p: &std::path::Path) -> bool {
|
||||
use std::path::Component;
|
||||
let mut comps = p.components();
|
||||
matches!(
|
||||
(comps.next(), comps.next(), comps.next()),
|
||||
(Some(Component::Prefix(_)), Some(Component::RootDir), None)
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Compute the Git SHA-1 blob object ID for the given content (string).
|
||||
/// This delegates to the bytes version to avoid UTF-8 lossy conversions here.
|
||||
fn git_blob_sha1_hex(data: &str) -> String {
|
||||
format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes()))
|
||||
}
|
||||
|
||||
fn normalize_diff_for_test(input: &str, root: &Path) -> String {
|
||||
let root_str = root.display().to_string().replace('\\', "/");
|
||||
let replaced = input.replace(&root_str, "<TMP>");
|
||||
// Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin
|
||||
let mut blocks: Vec<String> = Vec::new();
|
||||
let mut current = String::new();
|
||||
for line in replaced.lines() {
|
||||
if line.starts_with("diff --git ") && !current.is_empty() {
|
||||
blocks.push(current);
|
||||
current = String::new();
|
||||
}
|
||||
if !current.is_empty() {
|
||||
current.push('\n');
|
||||
}
|
||||
current.push_str(line);
|
||||
}
|
||||
if !current.is_empty() {
|
||||
blocks.push(current);
|
||||
}
|
||||
blocks.sort();
|
||||
let mut out = blocks.join("\n");
|
||||
if !out.ends_with('\n') {
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_add_and_update() {
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("a.txt");
|
||||
|
||||
// First patch: add file (baseline should be /dev/null).
|
||||
let add_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Add {
|
||||
content: "foo\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&add_changes);
|
||||
|
||||
// Simulate apply: create the file on disk.
|
||||
fs::write(&file, "foo\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -0,0 +1 @@
|
||||
+foo
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Second patch: update the file on disk.
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Simulate apply: append a new line.
|
||||
fs::write(&file, "foo\nbar\n").unwrap();
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected_combined = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -0,0 +1,2 @@
|
||||
+foo
|
||||
+bar
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected_combined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_delete() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("b.txt");
|
||||
fs::write(&file, "x\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let del_changes = HashMap::from([(file.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_changes);
|
||||
|
||||
// Simulate apply: delete the file from disk.
|
||||
let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
fs::remove_file(&file).unwrap();
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let left_oid = git_blob_sha1_hex("x\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/b.txt b/<TMP>/b.txt
|
||||
deleted file mode {baseline_mode}
|
||||
index {left_oid}..{ZERO_OID}
|
||||
--- a/<TMP>/b.txt
|
||||
+++ {DEV_NULL}
|
||||
@@ -1 +0,0 @@
|
||||
-x
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_move_and_update() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("src.txt");
|
||||
let dest = dir.path().join("dst.txt");
|
||||
fs::write(&src, "line\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv_changes = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv_changes);
|
||||
|
||||
// Simulate apply: move and update content.
|
||||
fs::rename(&src, &dest).unwrap();
|
||||
fs::write(&dest, "line2\n").unwrap();
|
||||
|
||||
let out = acc.get_unified_diff().unwrap().unwrap();
|
||||
let out = normalize_diff_for_test(&out, dir.path());
|
||||
let expected = {
|
||||
let left_oid = git_blob_sha1_hex("line\n");
|
||||
let right_oid = git_blob_sha1_hex("line2\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/src.txt b/<TMP>/dst.txt
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/src.txt
|
||||
+++ b/<TMP>/dst.txt
|
||||
@@ -1 +1 @@
|
||||
-line
|
||||
+line2
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(out, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_without_1change_yields_no_diff() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("moved.txt");
|
||||
let dest = dir.path().join("renamed.txt");
|
||||
fs::write(&src, "same\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv_changes = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv_changes);
|
||||
|
||||
// Simulate apply: move only, no content change.
|
||||
fs::rename(&src, &dest).unwrap();
|
||||
|
||||
let diff = acc.get_unified_diff().unwrap();
|
||||
assert_eq!(diff, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_declared_but_file_only_appears_at_dest_is_add() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("src.txt");
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv);
|
||||
// No file existed initially; create only dest
|
||||
fs::write(&dest, "hello\n").unwrap();
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("hello\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/src.txt b/<TMP>/dest.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/dest.txt
|
||||
@@ -0,0 +1 @@
|
||||
+hello
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_persists_across_new_baseline_for_new_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let a = dir.path().join("a.txt");
|
||||
let b = dir.path().join("b.txt");
|
||||
fs::write(&a, "foo\n").unwrap();
|
||||
fs::write(&b, "z\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
// First: update existing a.txt (baseline snapshot is created for a).
|
||||
let update_a = HashMap::from([(
|
||||
a.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_a);
|
||||
// Simulate apply: modify a.txt on disk.
|
||||
fs::write(&a, "foo\nbar\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let left_oid = git_blob_sha1_hex("foo\n");
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/a.txt
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -1 +1,2 @@
|
||||
foo
|
||||
+bar
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
|
||||
let del_b = HashMap::from([(b.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_b);
|
||||
// Simulate apply: delete b.txt.
|
||||
let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular);
|
||||
fs::remove_file(&b).unwrap();
|
||||
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected = {
|
||||
let left_oid_a = git_blob_sha1_hex("foo\n");
|
||||
let right_oid_a = git_blob_sha1_hex("foo\nbar\n");
|
||||
let left_oid_b = git_blob_sha1_hex("z\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
|
||||
index {left_oid_a}..{right_oid_a}
|
||||
--- a/<TMP>/a.txt
|
||||
+++ b/<TMP>/a.txt
|
||||
@@ -1 +1,2 @@
|
||||
foo
|
||||
+bar
|
||||
diff --git a/<TMP>/b.txt b/<TMP>/b.txt
|
||||
deleted file mode {baseline_mode}
|
||||
index {left_oid_b}..{ZERO_OID}
|
||||
--- a/<TMP>/b.txt
|
||||
+++ {DEV_NULL}
|
||||
@@ -1 +0,0 @@
|
||||
-z
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_files_differ_update() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("bin.dat");
|
||||
|
||||
// Initial non-UTF8 bytes
|
||||
let left_bytes: Vec<u8> = vec![0xff, 0xfe, 0xfd, 0x00];
|
||||
// Updated non-UTF8 bytes
|
||||
let right_bytes: Vec<u8> = vec![0x01, 0x02, 0x03, 0x00];
|
||||
|
||||
fs::write(&file, &left_bytes).unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Apply update on disk
|
||||
fs::write(&file, &right_bytes).unwrap();
|
||||
|
||||
let diff = acc.get_unified_diff().unwrap().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = {
|
||||
let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes));
|
||||
let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes));
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/bin.dat b/<TMP>/bin.dat
|
||||
index {left_oid}..{right_oid}
|
||||
--- a/<TMP>/bin.dat
|
||||
+++ b/<TMP>/bin.dat
|
||||
Binary files differ
|
||||
"#
|
||||
)
|
||||
};
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filenames_with_spaces_add_and_update() {
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("name with spaces.txt");
|
||||
|
||||
// First patch: add file (baseline should be /dev/null).
|
||||
let add_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Add {
|
||||
content: "foo\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&add_changes);
|
||||
|
||||
// Simulate apply: create the file on disk.
|
||||
fs::write(&file, "foo\n").unwrap();
|
||||
let first = acc.get_unified_diff().unwrap().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/name with spaces.txt
|
||||
@@ -0,0 +1 @@
|
||||
+foo
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Second patch: update the file on disk.
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes);
|
||||
|
||||
// Simulate apply: append a new line with a space.
|
||||
fs::write(&file, "foo\nbar baz\n").unwrap();
|
||||
let combined = acc.get_unified_diff().unwrap().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected_combined = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
|
||||
let right_oid = git_blob_sha1_hex("foo\nbar baz\n");
|
||||
format!(
|
||||
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
|
||||
new file mode {mode}
|
||||
index {ZERO_OID}..{right_oid}
|
||||
--- {DEV_NULL}
|
||||
+++ b/<TMP>/name with spaces.txt
|
||||
@@ -0,0 +1,2 @@
|
||||
+foo
|
||||
+bar baz
|
||||
"#,
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected_combined);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::Utc;
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -21,14 +24,42 @@ use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
use wiremock::matchers::query_param;
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
fn assert_message_role(request_body: &serde_json::Value, role: &str) {
|
||||
assert_eq!(request_body["role"].as_str().unwrap(), role);
|
||||
}
|
||||
|
||||
fn assert_message_starts_with(request_body: &serde_json::Value, text: &str) {
|
||||
let content = request_body["content"][0]["text"]
|
||||
.as_str()
|
||||
.expect("invalid message content");
|
||||
|
||||
assert!(
|
||||
content.starts_with(text),
|
||||
"expected message content '{content}' to start with '{text}'"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_message_ends_with(request_body: &serde_json::Value, text: &str) {
|
||||
let content = request_body["content"][0]["text"]
|
||||
.as_str()
|
||||
.expect("invalid message content");
|
||||
|
||||
assert!(
|
||||
content.ends_with(text),
|
||||
"expected message content '{content}' to end with '{text}'"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
@@ -95,8 +126,8 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(
|
||||
@@ -170,6 +201,59 @@ async fn includes_base_instructions_override_in_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn originator_config_override_is_used() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.internal_originator = Some("my_override".to_string());
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
Some(CodexAuth::from_api_key("Test API Key".to_string())),
|
||||
ctrl_c.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
assert_eq!(request_originator.to_str().unwrap(), "my_override");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chatgpt_auth_sends_correct_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
@@ -235,8 +319,8 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
@@ -315,14 +399,168 @@ async fn includes_user_instructions_message_in_request() {
|
||||
.unwrap()
|
||||
.contains("be nice")
|
||||
);
|
||||
assert_eq!(request_body["input"][0]["role"], "user");
|
||||
assert!(
|
||||
request_body["input"][0]["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("be nice")
|
||||
);
|
||||
assert_message_role(&request_body["input"][0], "user");
|
||||
assert_message_starts_with(&request_body["input"][0], "<environment_context>\n\n");
|
||||
assert_message_ends_with(&request_body["input"][0], "</environment_context>");
|
||||
assert_message_role(&request_body["input"][1], "user");
|
||||
assert_message_starts_with(&request_body["input"][1], "<user_instructions>\n\n");
|
||||
assert_message_ends_with(&request_body["input"][1], "</user_instructions>");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
// Expect POST to /openai/responses with api-version query param
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/openai/responses"))
|
||||
.and(query_param("api-version", "2025-04-01-preview"))
|
||||
.and(header_regex("Custom-Header", "Value"))
|
||||
.and(header_regex(
|
||||
"Authorization",
|
||||
format!(
|
||||
"Bearer {}",
|
||||
std::env::var(existing_env_var_with_random_value).unwrap()
|
||||
)
|
||||
.as_str(),
|
||||
))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some(format!("{}/openai", server.uri())),
|
||||
// Reuse the existing environment variable to avoid using unsafe code
|
||||
env_key: Some(existing_env_var_with_random_value.to_string()),
|
||||
query_params: Some(std::collections::HashMap::from([(
|
||||
"api-version".to_string(),
|
||||
"2025-04-01-preview".to_string(),
|
||||
)])),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
http_headers: Some(std::collections::HashMap::from([(
|
||||
"Custom-Header".to_string(),
|
||||
"Value".to_string(),
|
||||
)])),
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = provider;
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, None, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn env_var_overrides_loaded_auth() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
// Expect POST to /openai/responses with api-version query param
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/openai/responses"))
|
||||
.and(query_param("api-version", "2025-04-01-preview"))
|
||||
.and(header_regex("Custom-Header", "Value"))
|
||||
.and(header_regex(
|
||||
"Authorization",
|
||||
format!(
|
||||
"Bearer {}",
|
||||
std::env::var(existing_env_var_with_random_value).unwrap()
|
||||
)
|
||||
.as_str(),
|
||||
))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some(format!("{}/openai", server.uri())),
|
||||
// Reuse the existing environment variable to avoid using unsafe code
|
||||
env_key: Some(existing_env_var_with_random_value.to_string()),
|
||||
query_params: Some(std::collections::HashMap::from([(
|
||||
"api-version".to_string(),
|
||||
"2025-04-01-preview".to_string(),
|
||||
)])),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
http_headers: Some(std::collections::HashMap::from([(
|
||||
"Custom-Header".to_string(),
|
||||
"Value".to_string(),
|
||||
)])),
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = provider;
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
Some(auth_from_token("Default Access Token".to_string())),
|
||||
ctrl_c.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
fn auth_from_token(id_token: String) -> CodexAuth {
|
||||
CodexAuth::new(
|
||||
None,
|
||||
|
||||
254
codex-rs/core/tests/compact.rs
Normal file
254
codex-rs/core/tests/compact.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
#![expect(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
/// Build an SSE stream body from a list of JSON events.
|
||||
fn sse(events: Vec<Value>) -> String {
|
||||
use std::fmt::Write as _;
|
||||
let mut out = String::new();
|
||||
for ev in events {
|
||||
let kind = ev.get("type").and_then(|v| v.as_str()).unwrap();
|
||||
writeln!(&mut out, "event: {kind}").unwrap();
|
||||
if !ev.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
write!(&mut out, "data: {ev}\n\n").unwrap();
|
||||
} else {
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a completed response with a specific id.
|
||||
fn ev_completed(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a single assistant message output item.
|
||||
fn ev_assistant_message(id: &str, text: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"id": id,
|
||||
"content": [{"type": "output_text", "text": text}]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
|
||||
async fn mount_sse_once<M>(server: &MockServer, matcher: M, body: String)
|
||||
where
|
||||
M: wiremock::Match + Send + Sync + 'static,
|
||||
{
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(matcher)
|
||||
.respond_with(sse_response(body))
|
||||
.expect(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
}
|
||||
|
||||
const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
const SUMMARY_TEXT: &str = "SUMMARY_ONLY_CONTEXT";
|
||||
const SUMMARIZE_TRIGGER: &str = "Start Summarization";
|
||||
const THIRD_USER_MSG: &str = "next turn";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up a mock server that we can inspect after the run.
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// SSE 1: assistant replies normally so it is recorded in history.
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed("r1"),
|
||||
]);
|
||||
|
||||
// SSE 2: summarizer returns a summary message.
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", SUMMARY_TEXT),
|
||||
ev_completed("r2"),
|
||||
]);
|
||||
|
||||
// SSE 3: minimal completed; we only need to capture the request body.
|
||||
let sse3 = sse(vec![ev_completed("r3")]);
|
||||
|
||||
// Mount three expectations, one per request, matched by body content.
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("\"text\":\"hello world\"")
|
||||
&& !body.contains(&format!("\"text\":\"{SUMMARIZE_TRIGGER}\""))
|
||||
};
|
||||
mount_sse_once(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(&format!("\"text\":\"{SUMMARIZE_TRIGGER}\""))
|
||||
};
|
||||
mount_sse_once(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(&format!("\"text\":\"{THIRD_USER_MSG}\""))
|
||||
};
|
||||
mount_sse_once(&server, third_matcher, sse3).await;
|
||||
|
||||
// Build config pointing to the mock server and spawn Codex.
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
Some(CodexAuth::from_api_key("dummy".to_string())),
|
||||
ctrl_c.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 1) Normal user input – should hit server once.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello world".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// 2) Summarize – second hit with summarization instructions.
|
||||
codex.submit(Op::Compact).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// 3) Next user input – third hit; history should include only the summary.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: THIRD_USER_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Inspect the three captured requests.
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "expected exactly three requests");
|
||||
|
||||
let req1 = &requests[0];
|
||||
let req2 = &requests[1];
|
||||
let req3 = &requests[2];
|
||||
|
||||
let body1 = req1.body_json::<serde_json::Value>().unwrap();
|
||||
let body2 = req2.body_json::<serde_json::Value>().unwrap();
|
||||
let body3 = req3.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
// System instructions should change for the summarization turn.
|
||||
let instr1 = body1.get("instructions").and_then(|v| v.as_str()).unwrap();
|
||||
let instr2 = body2.get("instructions").and_then(|v| v.as_str()).unwrap();
|
||||
assert_ne!(
|
||||
instr1, instr2,
|
||||
"summarization should override base instructions"
|
||||
);
|
||||
assert!(
|
||||
instr2.contains("You are a summarization assistant"),
|
||||
"summarization instructions not applied"
|
||||
);
|
||||
|
||||
// The summarization request should include the injected user input marker.
|
||||
let input2 = body2.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
// The last item is the user message created from the injected input.
|
||||
let last2 = input2.last().unwrap();
|
||||
assert_eq!(last2.get("type").unwrap().as_str().unwrap(), "message");
|
||||
assert_eq!(last2.get("role").unwrap().as_str().unwrap(), "user");
|
||||
let text2 = last2["content"][0]["text"].as_str().unwrap();
|
||||
assert!(text2.contains(SUMMARIZE_TRIGGER));
|
||||
|
||||
// Third request must contain only the summary from step 2 as prior history plus new user msg.
|
||||
let input3 = body3.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
println!("third request body: {body3}");
|
||||
assert!(
|
||||
input3.len() >= 2,
|
||||
"expected summary + new user message in third request"
|
||||
);
|
||||
|
||||
// Collect all (role, text) message tuples.
|
||||
let mut messages: Vec<(String, String)> = Vec::new();
|
||||
for item in input3 {
|
||||
if item["type"].as_str() == Some("message") {
|
||||
let role = item["role"].as_str().unwrap_or_default().to_string();
|
||||
let text = item["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
messages.push((role, text));
|
||||
}
|
||||
}
|
||||
|
||||
// Exactly one assistant message should remain after compaction and the new user message is present.
|
||||
let assistant_count = messages.iter().filter(|(r, _)| r == "assistant").count();
|
||||
assert_eq!(
|
||||
assistant_count, 1,
|
||||
"exactly one assistant message should remain after compaction"
|
||||
);
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.any(|(r, t)| r == "user" && t == THIRD_USER_MSG),
|
||||
"third request should include the new user message"
|
||||
);
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains("hello world")),
|
||||
"third request should not include the original user input"
|
||||
);
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
}
|
||||
71
codex-rs/core/tests/exec.rs
Normal file
71
codex-rs/core/tests/exec.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
#![cfg(target_os = "macos")]
|
||||
#![expect(clippy::expect_used)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use codex_core::get_platform_sandbox;
|
||||
|
||||
async fn run_test_cmd(tmp: TempDir, cmd: Vec<&str>, should_be_ok: bool) {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
let sandbox_type = get_platform_sandbox().expect("should be able to get sandbox type");
|
||||
assert_eq!(sandbox_type, SandboxType::MacosSeatbelt);
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd.iter().map(|s| s.to_string()).collect(),
|
||||
cwd: tmp.path().to_path_buf(),
|
||||
timeout_ms: Some(1000),
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(params, sandbox_type, ctrl_c, &policy, &None, None).await;
|
||||
|
||||
assert!(result.is_ok() == should_be_ok);
|
||||
}
|
||||
|
||||
/// Command succeeds with exit code 0 normally
|
||||
#[tokio::test]
|
||||
async fn exit_code_0_succeeds() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let cmd = vec!["echo", "hello"];
|
||||
|
||||
run_test_cmd(tmp, cmd, true).await
|
||||
}
|
||||
|
||||
/// Command not found returns exit code 127, this is not considered a sandbox error
|
||||
#[tokio::test]
|
||||
async fn exit_command_not_found_is_ok() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let cmd = vec!["/bin/bash", "-c", "nonexistent_command_12345"];
|
||||
run_test_cmd(tmp, cmd, true).await
|
||||
}
|
||||
|
||||
/// Writing a file fails and should be considered a sandbox error
|
||||
#[tokio::test]
|
||||
async fn write_file_fails_as_sandbox_error() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let path = tmp.path().join("test.txt");
|
||||
let cmd = vec![
|
||||
"/user/bin/touch",
|
||||
path.to_str().expect("should be able to get path"),
|
||||
];
|
||||
|
||||
run_test_cmd(tmp, cmd, false).await;
|
||||
}
|
||||
147
codex-rs/core/tests/exec_stream_events.rs
Normal file
147
codex-rs/core/tests/exec_stream_events.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
#![cfg(unix)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_channel::Receiver;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::StdoutStream;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandOutputDeltaEvent;
|
||||
use codex_core::protocol::ExecOutputStream;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
fn collect_stdout_events(rx: Receiver<Event>) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
if let EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
stream: ExecOutputStream::Stdout,
|
||||
chunk,
|
||||
..
|
||||
}) = ev.msg
|
||||
{
|
||||
out.extend_from_slice(&chunk);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exec_stdout_stream_events_echo() {
|
||||
let (tx, rx) = async_channel::unbounded::<Event>();
|
||||
|
||||
let stdout_stream = StdoutStream {
|
||||
sub_id: "test-sub".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
tx_event: tx,
|
||||
};
|
||||
|
||||
let cmd = vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
// Use printf for predictable behavior across shells
|
||||
"printf 'hello-world\n'".to_string(),
|
||||
];
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd,
|
||||
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||
timeout_ms: Some(5_000),
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(
|
||||
params,
|
||||
SandboxType::None,
|
||||
ctrl_c,
|
||||
&policy,
|
||||
&None,
|
||||
Some(stdout_stream),
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => panic!("process_exec_tool_call failed: {e}"),
|
||||
};
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert_eq!(result.stdout, "hello-world\n");
|
||||
|
||||
let streamed = collect_stdout_events(rx);
|
||||
// We should have received at least the same contents (possibly in one chunk)
|
||||
assert_eq!(String::from_utf8_lossy(&streamed), "hello-world\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exec_stderr_stream_events_echo() {
|
||||
let (tx, rx) = async_channel::unbounded::<Event>();
|
||||
|
||||
let stdout_stream = StdoutStream {
|
||||
sub_id: "test-sub".to_string(),
|
||||
call_id: "call-2".to_string(),
|
||||
tx_event: tx,
|
||||
};
|
||||
|
||||
let cmd = vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
// Write to stderr explicitly
|
||||
"printf 'oops\n' 1>&2".to_string(),
|
||||
];
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd,
|
||||
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||
timeout_ms: Some(5_000),
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(
|
||||
params,
|
||||
SandboxType::None,
|
||||
ctrl_c,
|
||||
&policy,
|
||||
&None,
|
||||
Some(stdout_stream),
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => panic!("process_exec_tool_call failed: {e}"),
|
||||
};
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert_eq!(result.stdout, "");
|
||||
assert_eq!(result.stderr, "oops\n");
|
||||
|
||||
// Collect only stderr delta events
|
||||
let mut err = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
if let EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
stream: ExecOutputStream::Stderr,
|
||||
chunk,
|
||||
..
|
||||
}) = ev.msg
|
||||
{
|
||||
err.extend_from_slice(&chunk);
|
||||
}
|
||||
}
|
||||
assert_eq!(String::from_utf8_lossy(&err), "oops\n");
|
||||
}
|
||||
@@ -177,8 +177,7 @@ async fn live_shell_function_call() {
|
||||
match ev.msg {
|
||||
EventMsg::ExecCommandBegin(codex_core::protocol::ExecCommandBeginEvent {
|
||||
command,
|
||||
call_id: _,
|
||||
cwd: _,
|
||||
..
|
||||
}) => {
|
||||
assert_eq!(command, vec!["echo", MARKER]);
|
||||
saw_begin = true;
|
||||
@@ -186,8 +185,7 @@ async fn live_shell_function_call() {
|
||||
EventMsg::ExecCommandEnd(codex_core::protocol::ExecCommandEndEvent {
|
||||
stdout,
|
||||
exit_code,
|
||||
call_id: _,
|
||||
stderr: _,
|
||||
..
|
||||
}) => {
|
||||
assert_eq!(exit_code, 0, "echo returned non‑zero exit code");
|
||||
assert!(stdout.contains(MARKER));
|
||||
|
||||
195
codex-rs/core/tests/sandbox.rs
Normal file
195
codex-rs/core/tests/sandbox.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
#![cfg(target_os = "macos")]
|
||||
#![expect(clippy::expect_used)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::seatbelt::spawn_command_under_seatbelt;
|
||||
use codex_core::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use codex_core::spawn::StdioPolicy;
|
||||
use tempfile::TempDir;
|
||||
|
||||
struct TestScenario {
|
||||
repo_parent: PathBuf,
|
||||
file_outside_repo: PathBuf,
|
||||
repo_root: PathBuf,
|
||||
file_in_repo_root: PathBuf,
|
||||
file_in_dot_git_dir: PathBuf,
|
||||
}
|
||||
|
||||
struct TestExpectations {
|
||||
file_outside_repo_is_writable: bool,
|
||||
file_in_repo_root_is_writable: bool,
|
||||
file_in_dot_git_dir_is_writable: bool,
|
||||
}
|
||||
|
||||
impl TestScenario {
|
||||
async fn run_test(&self, policy: &SandboxPolicy, expectations: TestExpectations) {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_outside_repo, policy).await,
|
||||
expectations.file_outside_repo_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_outside_repo.exists(),
|
||||
expectations.file_outside_repo_is_writable
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_in_repo_root, policy).await,
|
||||
expectations.file_in_repo_root_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_in_repo_root.exists(),
|
||||
expectations.file_in_repo_root_is_writable
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_in_dot_git_dir, policy).await,
|
||||
expectations.file_in_dot_git_dir_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_in_dot_git_dir.exists(),
|
||||
expectations.file_in_dot_git_dir_is_writable
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// If the user has added a workspace root that is not a Git repo root, then
|
||||
/// the user has to specify `--skip-git-repo-check` or go through some
|
||||
/// interstitial that indicates they are taking on some risk because Git
|
||||
/// cannot be used to backup their work before the agent begins.
|
||||
///
|
||||
/// Because the user has agreed to this risk, we do not try find all .git
|
||||
/// folders in the workspace and block them (though we could change our
|
||||
/// position on this in the future).
|
||||
#[tokio::test]
|
||||
async fn if_parent_of_repo_is_writable_then_dot_git_folder_is_writable() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![test_scenario.repo_parent.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: true,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// When the writable root is the root of a Git repository (as evidenced by the
|
||||
/// presence of a .git folder), then the .git folder should be read-only if
|
||||
/// the policy is `WorkspaceWrite`.
|
||||
#[tokio::test]
|
||||
async fn if_git_repo_is_writable_root_then_dot_git_folder_is_read_only() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![test_scenario.repo_root.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: false,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Under DangerFullAccess, all writes should be permitted anywhere on disk,
|
||||
/// including inside the .git folder.
|
||||
#[tokio::test]
|
||||
async fn danger_full_access_allows_all_writes() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: true,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Under ReadOnly, writes should not be permitted anywhere on disk.
|
||||
#[tokio::test]
|
||||
async fn read_only_forbids_all_writes() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: false,
|
||||
file_in_repo_root_is_writable: false,
|
||||
file_in_dot_git_dir_is_writable: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn create_test_scenario(tmp: &TempDir) -> TestScenario {
|
||||
let repo_parent = tmp.path().to_path_buf();
|
||||
let repo_root = repo_parent.join("repo");
|
||||
let dot_git_dir = repo_root.join(".git");
|
||||
|
||||
std::fs::create_dir(&repo_root).expect("should be able to create repo root");
|
||||
std::fs::create_dir(&dot_git_dir).expect("should be able to create .git dir");
|
||||
|
||||
TestScenario {
|
||||
file_outside_repo: repo_parent.join("outside.txt"),
|
||||
repo_parent,
|
||||
file_in_repo_root: repo_root.join("repo_file.txt"),
|
||||
repo_root,
|
||||
file_in_dot_git_dir: dot_git_dir.join("dot_git_file.txt"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Note that `path` must be absolute.
|
||||
async fn touch(path: &Path, policy: &SandboxPolicy) -> bool {
|
||||
assert!(path.is_absolute(), "Path must be absolute: {path:?}");
|
||||
let mut child = spawn_command_under_seatbelt(
|
||||
vec![
|
||||
"/usr/bin/touch".to_string(),
|
||||
path.to_string_lossy().to_string(),
|
||||
],
|
||||
policy,
|
||||
std::env::current_dir().expect("should be able to get current dir"),
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
HashMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("should be able to spawn command under seatbelt");
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.expect("should be able to wait for child process")
|
||||
.success()
|
||||
}
|
||||
@@ -90,7 +90,7 @@ async fn retries_on_early_close() {
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(1),
|
||||
stream_idle_timeout_ms: Some(2000),
|
||||
requires_auth: false,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
@@ -25,6 +25,7 @@ codex-common = { path = "../common", features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-ollama = { path = "../ollama" }
|
||||
owo-colors = "4.2.0"
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
|
||||
@@ -14,6 +14,9 @@ pub struct Cli {
|
||||
#[arg(long, short = 'm')]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[arg(long = "oss", default_value_t = false)]
|
||||
pub oss: bool,
|
||||
|
||||
/// Select the sandbox policy to use when executing model-generated shell
|
||||
/// commands.
|
||||
#[arg(long = "sandbox", short = 's')]
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_common::summarize_sandbox_policy;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::model_supports_reasoning_summaries;
|
||||
use codex_core::protocol::Event;
|
||||
|
||||
pub(crate) enum CodexStatus {
|
||||
@@ -20,44 +17,14 @@ pub(crate) trait EventProcessor {
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus;
|
||||
}
|
||||
|
||||
pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", config.approval_policy.to_string()),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
pub(crate) fn handle_last_message(
|
||||
last_agent_message: Option<&str>,
|
||||
last_message_path: Option<&Path>,
|
||||
) {
|
||||
match (last_message_path, last_agent_message) {
|
||||
(Some(path), Some(msg)) => write_last_message_file(msg, Some(path)),
|
||||
(Some(path), None) => {
|
||||
write_last_message_file("", Some(path));
|
||||
eprintln!(
|
||||
"Warning: no last agent message; wrote empty content to {}",
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
(None, _) => eprintln!("Warning: no file to write last message to."),
|
||||
pub(crate) fn handle_last_message(last_agent_message: Option<&str>, output_file: &Path) {
|
||||
let message = last_agent_message.unwrap_or_default();
|
||||
write_last_message_file(message, Some(output_file));
|
||||
if last_agent_message.is_none() {
|
||||
eprintln!(
|
||||
"Warning: no last agent message; wrote empty content to {}",
|
||||
output_file.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ use codex_core::plan_tool::UpdatePlanArgs;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||
use codex_core::protocol::AgentReasoningRawContentDeltaEvent;
|
||||
use codex_core::protocol::AgentReasoningRawContentEvent;
|
||||
use codex_core::protocol::BackgroundEventEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::Event;
|
||||
@@ -20,6 +22,7 @@ use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
@@ -30,8 +33,8 @@ use std::time::Instant;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
use crate::event_processor::handle_last_message;
|
||||
use codex_common::create_config_summary_entries;
|
||||
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
@@ -54,8 +57,10 @@ pub(crate) struct EventProcessorWithHumanOutput {
|
||||
|
||||
/// Whether to include `AgentReasoning` events in the output.
|
||||
show_agent_reasoning: bool,
|
||||
show_raw_agent_reasoning: bool,
|
||||
answer_started: bool,
|
||||
reasoning_started: bool,
|
||||
raw_reasoning_started: bool,
|
||||
last_message_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
@@ -80,8 +85,10 @@ impl EventProcessorWithHumanOutput {
|
||||
green: Style::new().green(),
|
||||
cyan: Style::new().cyan(),
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
raw_reasoning_started: false,
|
||||
last_message_path,
|
||||
}
|
||||
} else {
|
||||
@@ -96,8 +103,10 @@ impl EventProcessorWithHumanOutput {
|
||||
green: Style::new(),
|
||||
cyan: Style::new(),
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
raw_reasoning_started: false,
|
||||
last_message_path,
|
||||
}
|
||||
}
|
||||
@@ -106,7 +115,6 @@ impl EventProcessorWithHumanOutput {
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
@@ -170,10 +178,9 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
// Ignore.
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
handle_last_message(
|
||||
last_agent_message.as_deref(),
|
||||
self.last_message_path.as_deref(),
|
||||
);
|
||||
if let Some(output_file) = self.last_message_path.as_deref() {
|
||||
handle_last_message(last_agent_message.as_deref(), output_file);
|
||||
}
|
||||
return CodexStatus::InitiateShutdown;
|
||||
}
|
||||
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
|
||||
@@ -204,6 +211,32 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent { text }) => {
|
||||
if !self.show_raw_agent_reasoning {
|
||||
return CodexStatus::Running;
|
||||
}
|
||||
if !self.raw_reasoning_started {
|
||||
print!("{text}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
} else {
|
||||
println!();
|
||||
self.raw_reasoning_started = false;
|
||||
}
|
||||
}
|
||||
EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent {
|
||||
delta,
|
||||
}) => {
|
||||
if !self.show_raw_agent_reasoning {
|
||||
return CodexStatus::Running;
|
||||
}
|
||||
if !self.raw_reasoning_started {
|
||||
self.raw_reasoning_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// if answer_started is false, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
@@ -228,7 +261,6 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
@@ -239,20 +271,19 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandOutputDelta(_) => {}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
let (duration, call) = if let Some(ExecCommandBegin { command, .. }) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!(" in {}", format_duration(duration)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
@@ -402,6 +433,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
..
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
@@ -431,6 +463,10 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::TurnDiff(TurnDiffEvent { unified_diff }) => {
|
||||
ts_println!(self, "{}", "turn diff:".style(self.magenta));
|
||||
println!("{unified_diff}");
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
@@ -454,10 +490,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
session_id, model, ..
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
|
||||
@@ -9,8 +9,8 @@ use serde_json::json;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
use crate::event_processor::handle_last_message;
|
||||
use codex_common::create_config_summary_entries;
|
||||
|
||||
pub(crate) struct EventProcessorWithJsonOutput {
|
||||
last_message_path: Option<PathBuf>,
|
||||
@@ -46,10 +46,9 @@ impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
CodexStatus::Running
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
handle_last_message(
|
||||
last_agent_message.as_deref(),
|
||||
self.last_message_path.as_deref(),
|
||||
);
|
||||
if let Some(output_file) = self.last_message_path.as_deref() {
|
||||
handle_last_message(last_agent_message.as_deref(), output_file);
|
||||
}
|
||||
CodexStatus::InitiateShutdown
|
||||
}
|
||||
EventMsg::ShutdownComplete => CodexStatus::Shutdown,
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use cli::Cli;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::codex_wrapper::CodexConversation;
|
||||
use codex_core::codex_wrapper::{self};
|
||||
use codex_core::config::Config;
|
||||
@@ -21,6 +22,7 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_ollama::DEFAULT_OSS_MODEL;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use tracing::debug;
|
||||
@@ -34,7 +36,8 @@ use crate::event_processor::EventProcessor;
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
model,
|
||||
model: model_cli_arg,
|
||||
oss,
|
||||
config_profile,
|
||||
full_auto,
|
||||
dangerously_bypass_approvals_and_sandbox,
|
||||
@@ -114,6 +117,23 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
sandbox_mode_cli_arg.map(Into::<SandboxMode>::into)
|
||||
};
|
||||
|
||||
// When using `--oss`, let the bootstrapper pick the model (defaulting to
|
||||
// gpt-oss:20b) and ensure it is present locally. Also, force the built‑in
|
||||
// `oss` model provider.
|
||||
let model = if let Some(model) = model_cli_arg {
|
||||
Some(model)
|
||||
} else if oss {
|
||||
Some(DEFAULT_OSS_MODEL.to_owned())
|
||||
} else {
|
||||
None // No model specified, will use the default.
|
||||
};
|
||||
|
||||
let model_provider = if oss {
|
||||
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string())
|
||||
} else {
|
||||
None // No specific model provider override.
|
||||
};
|
||||
|
||||
// Load configuration and determine approval policy
|
||||
let overrides = ConfigOverrides {
|
||||
model,
|
||||
@@ -123,10 +143,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox_mode,
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
model_provider,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
include_plan_tool: None,
|
||||
disable_response_storage: oss.then_some(true),
|
||||
show_raw_agent_reasoning: oss.then_some(true),
|
||||
};
|
||||
// Parse `-c` overrides.
|
||||
let cli_kv_overrides = match config_overrides.parse_overrides() {
|
||||
@@ -148,6 +170,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
))
|
||||
};
|
||||
|
||||
if oss {
|
||||
codex_ollama::ensure_oss_ready(&config)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?;
|
||||
}
|
||||
|
||||
// Print the effective configuration and prompt so users can see what Codex
|
||||
// is using.
|
||||
event_processor.print_config_summary(&config, &prompt);
|
||||
@@ -188,10 +216,16 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
res = codex.next_event() => match res {
|
||||
Ok(event) => {
|
||||
debug!("Received event: {event:?}");
|
||||
|
||||
let is_shutdown_complete = matches!(event.msg, EventMsg::ShutdownComplete);
|
||||
if let Err(e) = tx.send(event) {
|
||||
error!("Error sending event: {e:?}");
|
||||
break;
|
||||
}
|
||||
if is_shutdown_complete {
|
||||
info!("Received shutdown event, exiting event loop.");
|
||||
break;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Error receiving event: {e:?}");
|
||||
|
||||
@@ -26,7 +26,7 @@ multimap = "0.10.0"
|
||||
path-absolutize = "3.1.1"
|
||||
regex-lite = "0.1"
|
||||
serde = { version = "1.0.194", features = ["derive"] }
|
||||
serde_json = "1.0.110"
|
||||
serde_json = "1.0.142"
|
||||
serde_with = { version = "3", features = ["macros"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -17,5 +17,5 @@ clap = { version = "4", features = ["derive"] }
|
||||
ignore = "0.4.23"
|
||||
nucleo-matcher = "0.3.1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1.0.110"
|
||||
serde_json = "1.0.142"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -36,7 +36,11 @@ pub(crate) fn apply_sandbox_policy_to_current_thread(
|
||||
}
|
||||
|
||||
if !sandbox_policy.has_full_disk_write_access() {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(cwd);
|
||||
let writable_roots = sandbox_policy
|
||||
.get_writable_roots_with_cwd(cwd)
|
||||
.into_iter()
|
||||
.map(|writable_root| writable_root.root)
|
||||
.collect();
|
||||
install_filesystem_landlock_rules_on_current_thread(writable_roots)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,11 +44,14 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) {
|
||||
cwd: std::env::current_dir().expect("cwd should exist"),
|
||||
timeout_ms: Some(timeout_ms),
|
||||
env: create_env_from_core_vars(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.to_vec(),
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
};
|
||||
let sandbox_program = env!("CARGO_BIN_EXE_codex-linux-sandbox");
|
||||
let codex_linux_sandbox_exe = Some(PathBuf::from(sandbox_program));
|
||||
@@ -59,6 +62,7 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) {
|
||||
ctrl_c,
|
||||
&sandbox_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -137,6 +141,8 @@ async fn assert_network_blocked(cmd: &[&str]) {
|
||||
// do not stall the suite.
|
||||
timeout_ms: Some(NETWORK_TIMEOUT_MS),
|
||||
env: create_env_from_core_vars(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
@@ -149,6 +155,7 @@ async fn assert_network_blocked(cmd: &[&str]) {
|
||||
ctrl_c,
|
||||
&sandbox_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py");
|
||||
const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq, Copy)]
|
||||
pub enum AuthMode {
|
||||
ApiKey,
|
||||
ChatGPT,
|
||||
|
||||
@@ -458,6 +458,7 @@ class _ApiKeyHTTPServer(http.server.HTTPServer):
|
||||
"code_challenge": self.pkce.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
"state": self.state,
|
||||
}
|
||||
return f"{self.issuer}/oauth/authorize?" + urllib.parse.urlencode(params)
|
||||
@@ -686,6 +687,7 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
background: white;
|
||||
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
|
||||
}
|
||||
.inner-container {
|
||||
@@ -703,6 +705,7 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
display: flex;
|
||||
margin-top: 15vh;
|
||||
}
|
||||
.svg-wrapper {
|
||||
position: relative;
|
||||
@@ -710,9 +713,9 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
.title {
|
||||
text-align: center;
|
||||
color: var(--text-primary, #0D0D0D);
|
||||
font-size: 28px;
|
||||
font-size: 32px;
|
||||
font-weight: 400;
|
||||
line-height: 36.40px;
|
||||
line-height: 40px;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.setup-box {
|
||||
@@ -785,16 +788,26 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
word-wrap: break-word;
|
||||
text-decoration: none;
|
||||
}
|
||||
.logo {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 4rem;
|
||||
height: 4rem;
|
||||
border-radius: 16px;
|
||||
border: .5px solid rgba(0, 0, 0, 0.1);
|
||||
box-shadow: rgba(0, 0, 0, 0.1) 0px 4px 16px 0px;
|
||||
box-sizing: border-box;
|
||||
background-color: rgb(255, 255, 255);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="inner-container">
|
||||
<div class="content">
|
||||
<div data-svg-wrapper class="svg-wrapper">
|
||||
<svg width="56" height="56" viewBox="0 0 56 56" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.6665 28.0003C4.6665 15.1137 15.1132 4.66699 27.9998 4.66699C40.8865 4.66699 51.3332 15.1137 51.3332 28.0003C51.3332 40.887 40.8865 51.3337 27.9998 51.3337C15.1132 51.3337 4.6665 40.887 4.6665 28.0003ZM37.5093 18.5088C36.4554 17.7672 34.9999 18.0203 34.2583 19.0742L24.8508 32.4427L20.9764 28.1808C20.1095 27.2272 18.6338 27.1569 17.6803 28.0238C16.7267 28.8906 16.6565 30.3664 17.5233 31.3199L23.3566 37.7366C23.833 38.2606 24.5216 38.5399 25.2284 38.4958C25.9353 38.4517 26.5838 38.089 26.9914 37.5098L38.0747 21.7598C38.8163 20.7059 38.5632 19.2504 37.5093 18.5088Z" fill="var(--green-400, #04B84C)"/>
|
||||
</svg>
|
||||
<div class="logo">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="none" viewBox="0 0 32 32"><path stroke="#000" stroke-linecap="round" stroke-width="2.484" d="M22.356 19.797H17.17M9.662 12.29l1.979 3.576a.511.511 0 0 1-.005.504l-1.974 3.409M30.758 16c0 8.15-6.607 14.758-14.758 14.758-8.15 0-14.758-6.607-14.758-14.758C1.242 7.85 7.85 1.242 16 1.242c8.15 0 14.758 6.608 14.758 14.758Z"></path></svg>
|
||||
</div>
|
||||
<div class="title">Signed in to Codex CLI</div>
|
||||
</div>
|
||||
|
||||
@@ -158,6 +158,8 @@ impl CodexToolCallParam {
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
include_plan_tool,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
};
|
||||
|
||||
let cli_overrides = cli_overrides
|
||||
|
||||
@@ -252,16 +252,20 @@ async fn run_codex_tool_session_inner(
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::TaskStarted
|
||||
EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
|
||||
124
codex-rs/mcp-server/src/conversation_loop.rs
Normal file
124
codex-rs/mcp-server/src/conversation_loop.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tracing::error;
|
||||
|
||||
pub async fn run_conversation_loop(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
.await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::Error(_) => {
|
||||
error!("Codex runtime error");
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
tracing::error!("unexpected SessionConfigured event");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
// send(codex_event_to_notification(&event)) above has
|
||||
// already dispatched these events as notifications,
|
||||
// though we may want to do give different treatment to
|
||||
// individual events in the future.
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Codex runtime error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`ElicitRequest`].
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
|
||||
@@ -17,12 +17,14 @@ use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod conversation_loop;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
mod mcp_protocol;
|
||||
mod message_processor;
|
||||
pub mod mcp_protocol;
|
||||
pub(crate) mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
pub(crate) mod tool_handlers;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
|
||||
@@ -7,7 +7,10 @@ use serde::Serialize;
|
||||
use strum_macros::Display;
|
||||
use uuid::Uuid;
|
||||
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
@@ -118,10 +121,47 @@ pub struct ToolCallResponse {
|
||||
pub request_id: RequestId,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none", flatten)]
|
||||
pub result: Option<ToolCallResponseResult>,
|
||||
}
|
||||
|
||||
impl From<ToolCallResponse> for CallToolResult {
|
||||
fn from(val: ToolCallResponse) -> Self {
|
||||
let ToolCallResponse {
|
||||
request_id: _request_id,
|
||||
is_error,
|
||||
result,
|
||||
} = val;
|
||||
match result {
|
||||
Some(res) => match serde_json::to_value(&res) {
|
||||
Ok(v) => CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: v.to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error,
|
||||
structured_content: Some(v),
|
||||
},
|
||||
Err(e) => CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Failed to serialize tool result: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
},
|
||||
},
|
||||
None => CallToolResult {
|
||||
content: vec![],
|
||||
is_error,
|
||||
structured_content: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolCallResponseResult {
|
||||
@@ -132,17 +172,26 @@ pub enum ToolCallResponseResult {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationCreateResult {
|
||||
pub conversation_id: ConversationId,
|
||||
pub model: String,
|
||||
#[serde(untagged)]
|
||||
pub enum ConversationCreateResult {
|
||||
Ok {
|
||||
conversation_id: ConversationId,
|
||||
model: String,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationStreamResult {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationSendMessageResult {
|
||||
pub success: bool,
|
||||
// TODO: remove this status because we have is_error field in the response.
|
||||
#[serde(tag = "status", rename_all = "camelCase")]
|
||||
pub enum ConversationSendMessageResult {
|
||||
Ok,
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
@@ -449,16 +498,19 @@ mod tests {
|
||||
request_id: RequestId::Integer(1),
|
||||
is_error: None,
|
||||
result: Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult {
|
||||
ConversationCreateResult::Ok {
|
||||
conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")),
|
||||
model: "o3".into(),
|
||||
},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 1,
|
||||
"result": {
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"conversation_id\":\"d0f6ecbe-84a2-41c1-b23d-b20473b25eab\",\"model\":\"o3\"}" }
|
||||
],
|
||||
"structuredContent": {
|
||||
"conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab",
|
||||
"model": "o3"
|
||||
}
|
||||
@@ -467,6 +519,36 @@ mod tests {
|
||||
observed, expected,
|
||||
"response (ConversationCreate) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_error_conversation_create_full_schema() {
|
||||
let env = ToolCallResponse {
|
||||
request_id: RequestId::Integer(2),
|
||||
is_error: Some(true),
|
||||
result: Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: "Failed to initialize session".into(),
|
||||
},
|
||||
)),
|
||||
};
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"message\":\"Failed to initialize session\"}" }
|
||||
],
|
||||
"isError": true,
|
||||
"structuredContent": {
|
||||
"message": "Failed to initialize session"
|
||||
}
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"error response (ConversationCreate) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -478,15 +560,17 @@ mod tests {
|
||||
ConversationStreamResult {},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 2,
|
||||
"result": {}
|
||||
"content": [ { "type": "text", "text": "{}" } ],
|
||||
"structuredContent": {}
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"response (ConversationStream) must have empty object result"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -495,18 +579,20 @@ mod tests {
|
||||
request_id: RequestId::Integer(3),
|
||||
is_error: None,
|
||||
result: Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult { success: true },
|
||||
ConversationSendMessageResult::Ok,
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 3,
|
||||
"result": { "success": true }
|
||||
"content": [ { "type": "text", "text": "{\"status\":\"ok\"}" } ],
|
||||
"structuredContent": { "status": "ok" }
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"response (ConversationSendMessageAccepted) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -526,10 +612,13 @@ mod tests {
|
||||
},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 4,
|
||||
"result": {
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"conversations\":[{\"conversation_id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\",\"title\":\"Refactor config loader\"}],\"next_cursor\":\"next123\"}" }
|
||||
],
|
||||
"structuredContent": {
|
||||
"conversations": [
|
||||
{
|
||||
"conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
@@ -543,6 +632,7 @@ mod tests {
|
||||
observed, expected,
|
||||
"response (ConversationsList with cursor) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -552,15 +642,17 @@ mod tests {
|
||||
is_error: Some(true),
|
||||
result: None,
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 4,
|
||||
"content": [],
|
||||
"isError": true
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"error response must omit `result` and include `isError`"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(4));
|
||||
}
|
||||
|
||||
// ----- Notifications -----
|
||||
@@ -816,6 +908,7 @@ mod tests {
|
||||
model: "codex-mini-latest".into(),
|
||||
history_log_id: 42,
|
||||
history_entry_count: 3,
|
||||
rollout_path: None,
|
||||
}),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -6,11 +7,17 @@ use crate::codex_tool_config::CodexToolCallParam;
|
||||
use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::mcp_protocol::ToolCallRequestParams;
|
||||
use crate::mcp_protocol::ToolCallResponse;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::tool_handlers::create_conversation::handle_create_conversation;
|
||||
use crate::tool_handlers::send_message::handle_send_message;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequest;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ClientRequest;
|
||||
@@ -37,6 +44,7 @@ pub(crate) struct MessageProcessor {
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_session_ids: Arc<Mutex<HashSet<Uuid>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -52,9 +60,22 @@ impl MessageProcessor {
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_session_ids: Arc::new(Mutex::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn session_map(&self) -> Arc<Mutex<HashMap<Uuid, Arc<Codex>>>> {
|
||||
self.session_map.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn outgoing(&self) -> Arc<OutgoingMessageSender> {
|
||||
self.outgoing.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn running_session_ids(&self) -> Arc<Mutex<HashSet<Uuid>>> {
|
||||
self.running_session_ids.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
@@ -300,6 +321,14 @@ impl MessageProcessor {
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("tools/call -> params: {:?}", params);
|
||||
// Serialize params into JSON and try to parse as new type
|
||||
if let Ok(new_params) =
|
||||
serde_json::to_value(¶ms).and_then(serde_json::from_value::<ToolCallRequestParams>)
|
||||
{
|
||||
// New tool call matched → forward
|
||||
self.handle_new_tool_calls(id, new_params).await;
|
||||
return;
|
||||
}
|
||||
let CallToolRequestParams { name, arguments } = params;
|
||||
|
||||
match name.as_str() {
|
||||
@@ -323,6 +352,29 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) {
|
||||
match params {
|
||||
ToolCallRequestParams::ConversationCreate(args) => {
|
||||
handle_create_conversation(self, request_id, args).await;
|
||||
}
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
handle_send_message(self, request_id, args).await;
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "Unknown tool".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option<serde_json::Value>) {
|
||||
let (initial_prompt, config): (String, CodexConfig) = match arguments {
|
||||
@@ -631,4 +683,20 @@ impl MessageProcessor {
|
||||
) {
|
||||
tracing::info!("notifications/message -> params: {:?}", params);
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response_with_optional_error(
|
||||
&self,
|
||||
id: RequestId,
|
||||
message: Option<ToolCallResponseResult>,
|
||||
error: Option<bool>,
|
||||
) {
|
||||
let response = ToolCallResponse {
|
||||
request_id: id.clone(),
|
||||
is_error: error,
|
||||
result: message,
|
||||
};
|
||||
let result: CallToolResult = response.into();
|
||||
self.send_response::<mcp_types::CallToolRequest>(id.clone(), result)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,6 +244,7 @@ mod tests {
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
rollout_path: None,
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -284,6 +285,7 @@ mod tests {
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
rollout_path: None,
|
||||
};
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
|
||||
162
codex-rs/mcp-server/src/tool_handlers/create_conversation.rs
Normal file
162
codex-rs/mcp-server/src/tool_handlers/create_conversation.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::conversation_loop::run_conversation_loop;
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
use crate::mcp_protocol::ConversationCreateArgs;
|
||||
use crate::mcp_protocol::ConversationCreateResult;
|
||||
use crate::mcp_protocol::ConversationId;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
|
||||
pub(crate) async fn handle_create_conversation(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
args: ConversationCreateArgs,
|
||||
) {
|
||||
// Build ConfigOverrides from args
|
||||
let ConversationCreateArgs {
|
||||
prompt: _, // not used here; creation only establishes the session
|
||||
model,
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox,
|
||||
config,
|
||||
profile,
|
||||
base_instructions,
|
||||
} = args;
|
||||
|
||||
// Convert config overrides JSON into CLI-style TOML overrides
|
||||
let cli_overrides: Vec<(String, toml::Value)> = match config {
|
||||
Some(v) => match v.as_object() {
|
||||
Some(map) => map
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.clone(), json_to_toml(v.clone())))
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
},
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
let overrides = ConfigOverrides {
|
||||
model: Some(model.clone()),
|
||||
cwd: Some(PathBuf::from(cwd)),
|
||||
approval_policy,
|
||||
sandbox_mode: sandbox,
|
||||
model_provider: None,
|
||||
config_profile: profile,
|
||||
codex_linux_sandbox_exe: None,
|
||||
base_instructions,
|
||||
include_plan_tool: None,
|
||||
disable_response_storage: None,
|
||||
show_raw_agent_reasoning: None,
|
||||
};
|
||||
|
||||
let cfg: CodexConfig = match CodexConfig::load_with_cli_overrides(cli_overrides, overrides) {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: format!("Failed to load config: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize Codex session
|
||||
let codex_conversation = match init_codex(cfg).await {
|
||||
Ok(conv) => conv,
|
||||
Err(e) => {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: format!("Failed to initialize session: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Expect SessionConfigured; if not, return error.
|
||||
let EventMsg::SessionConfigured(SessionConfiguredEvent { model, .. }) =
|
||||
&codex_conversation.session_configured.msg
|
||||
else {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: "Expected SessionConfigured event".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let effective_model = model.clone();
|
||||
|
||||
let session_id = codex_conversation.session_id;
|
||||
let codex_arc = Arc::new(codex_conversation.codex);
|
||||
|
||||
// Store session for future calls
|
||||
insert_session(
|
||||
session_id,
|
||||
codex_arc.clone(),
|
||||
message_processor.session_map(),
|
||||
)
|
||||
.await;
|
||||
// Run the conversation loop in the background so this request can return immediately.
|
||||
let outgoing = message_processor.outgoing();
|
||||
let spawn_id = id.clone();
|
||||
tokio::spawn(async move {
|
||||
run_conversation_loop(codex_arc.clone(), outgoing, spawn_id).await;
|
||||
});
|
||||
|
||||
// Reply with the new conversation id and effective model
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Ok {
|
||||
conversation_id: ConversationId(session_id),
|
||||
model: effective_model,
|
||||
},
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
session_id: Uuid,
|
||||
codex: Arc<Codex>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) {
|
||||
let mut guard = session_map.lock().await;
|
||||
guard.insert(session_id, codex);
|
||||
}
|
||||
2
codex-rs/mcp-server/src/tool_handlers/mod.rs
Normal file
2
codex-rs/mcp-server/src/tool_handlers/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub(crate) mod create_conversation;
|
||||
pub(crate) mod send_message;
|
||||
124
codex-rs/mcp-server/src/tool_handlers/send_message.rs
Normal file
124
codex-rs/mcp-server/src/tool_handlers/send_message.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::mcp_protocol::ConversationSendMessageArgs;
|
||||
use crate::mcp_protocol::ConversationSendMessageResult;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
|
||||
pub(crate) async fn handle_send_message(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
arguments: ConversationSendMessageArgs,
|
||||
) {
|
||||
let ConversationSendMessageArgs {
|
||||
conversation_id,
|
||||
content: items,
|
||||
parent_message_id: _,
|
||||
conversation_overrides: _,
|
||||
} = arguments;
|
||||
|
||||
if items.is_empty() {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "No content items provided".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let session_id = conversation_id.0;
|
||||
let Some(codex) = get_session(session_id, message_processor.session_map()).await else {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "Session does not exist".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let running = {
|
||||
let running_sessions = message_processor.running_session_ids();
|
||||
let mut running_sessions = running_sessions.lock().await;
|
||||
!running_sessions.insert(session_id)
|
||||
};
|
||||
|
||||
if running {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "Session is already running".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let request_id_string = match &id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
let submit_res = codex
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: Op::UserInput { items },
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Err(e) = submit_res {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: format!("Failed to submit user input: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Ok,
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn get_session(
|
||||
session_id: Uuid,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) -> Option<Arc<Codex>> {
|
||||
let guard = session_map.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
}
|
||||
@@ -89,14 +89,18 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
// This is the first request from the server, so the id should be 0 given
|
||||
// how things are currently implemented.
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
let params = serde_json::from_value::<ExecApprovalElicitRequestParams>(
|
||||
elicitation_request
|
||||
.params
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("elicitation_request.params must be set"))?,
|
||||
)?;
|
||||
let expected_elicitation_request = create_expected_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
shell_command.clone(),
|
||||
workdir_for_shell_function_call.path(),
|
||||
codex_request_id.to_string(),
|
||||
// Internal Codex id: empirically it is 1, but this is
|
||||
// admittedly an internal detail that could change.
|
||||
"1".to_string(),
|
||||
params.codex_event_id.clone(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ path = "lib.rs"
|
||||
anyhow = "1"
|
||||
assert_cmd = "2"
|
||||
codex-mcp-server = { path = "../.." }
|
||||
codex-core = { path = "../../../core" }
|
||||
mcp-types = { path = "../../../mcp-types" }
|
||||
pretty_assertions = "1.4.1"
|
||||
serde_json = "1"
|
||||
@@ -22,3 +23,4 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
wiremock = "0.6"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
@@ -11,8 +11,14 @@ use tokio::process::ChildStdout;
|
||||
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationId;
|
||||
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
|
||||
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -29,6 +35,7 @@ use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::process::Command as StdCommand;
|
||||
use tokio::process::Command;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct McpProcess {
|
||||
next_request_id: AtomicI64,
|
||||
@@ -174,6 +181,61 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_user_message_tool_call(
|
||||
&mut self,
|
||||
message: &str,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationSendMessage(ConversationSendMessageArgs {
|
||||
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
|
||||
content: vec![InputItem::Text {
|
||||
text: message.to_string(),
|
||||
}],
|
||||
parent_message_id: None,
|
||||
conversation_overrides: None,
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_tool_call(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
cwd: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationCreate(ConversationCreateArgs {
|
||||
prompt: prompt.to_string(),
|
||||
model: model.to_string(),
|
||||
cwd: cwd.to_string(),
|
||||
approval_policy: None,
|
||||
sandbox: None,
|
||||
config: None,
|
||||
profile: None,
|
||||
base_instructions: None,
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_with_args(
|
||||
&mut self,
|
||||
args: ConversationCreateArgs,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationCreate(args);
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
|
||||
128
codex-rs/mcp-server/tests/create_conversation.rs
Normal file
128
codex-rs/mcp-server/tests/create_conversation.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_conversation_create_and_send_message_ok() {
|
||||
// Mock server – we won't strictly rely on it, but provide one to satisfy any model wiring.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
// Temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
// Start MCP server process and initialize.
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create a conversation via the new tool.
|
||||
let req_id = mcp
|
||||
.send_conversation_create_tool_call("", "o3", "/repo")
|
||||
.await
|
||||
.expect("send conversationCreate");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("create response timeout")
|
||||
.expect("create response error");
|
||||
|
||||
// Structured content must include status=ok, a UUID conversation_id and the model we passed.
|
||||
let sc = &resp.result["structuredContent"];
|
||||
let conv_id = sc["conversation_id"].as_str().expect("uuid string");
|
||||
assert!(!conv_id.is_empty());
|
||||
assert_eq!(sc["model"], json!("o3"));
|
||||
|
||||
// Now send a message to the created conversation and expect an OK result.
|
||||
let send_id = mcp
|
||||
.send_user_message_tool_call("Hello", conv_id)
|
||||
.await
|
||||
.expect("send message");
|
||||
|
||||
let send_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(send_id)),
|
||||
)
|
||||
.await
|
||||
.expect("send response timeout")
|
||||
.expect("send response error");
|
||||
assert_eq!(
|
||||
send_resp.result["structuredContent"],
|
||||
json!({ "status": "ok" })
|
||||
);
|
||||
|
||||
// avoid race condition by waiting for the mock server to receive the chat.completions request
|
||||
let deadline = std::time::Instant::now() + DEFAULT_READ_TIMEOUT;
|
||||
loop {
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
if !requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
if std::time::Instant::now() >= deadline {
|
||||
panic!("mock server did not receive the chat.completions request in time");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
|
||||
// Verify the outbound request body matches expectations for Chat Completions.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request
|
||||
.body_json::<serde_json::Value>()
|
||||
.expect("parse request body as JSON");
|
||||
assert_eq!(body["model"], json!("o3"));
|
||||
assert!(body["stream"].as_bool().unwrap_or(false));
|
||||
let messages = body["messages"]
|
||||
.as_array()
|
||||
.expect("messages should be array");
|
||||
let last = messages.last().expect("at least one message");
|
||||
assert_eq!(last["role"], json!("user"));
|
||||
assert_eq!(last["content"], json!("Hello"));
|
||||
|
||||
drop(server);
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
163
codex-rs/mcp-server/tests/send_message.rs
Normal file
163
codex-rs/mcp-server/tests/send_message.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use std::path::Path;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_message_success() {
|
||||
// Spin up a mock completions server that immediately ends the Codex turn.
|
||||
// Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
// Create a temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
// Start MCP server process and initialize.
|
||||
let mut mcp_process = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize())
|
||||
.await
|
||||
.expect("init timed out")
|
||||
.expect("init failed");
|
||||
|
||||
// Kick off a Codex session so we have a valid session_id.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "Start a session".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send codex tool call");
|
||||
|
||||
// Wait for the session_configured event to get the session_id.
|
||||
let session_id = mcp_process
|
||||
.read_stream_until_configured_response_message()
|
||||
.await
|
||||
.expect("read session_configured");
|
||||
|
||||
// The original codex call will finish quickly given our mock; consume its response.
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("codex response timeout")
|
||||
.expect("codex response error");
|
||||
|
||||
// Now exercise the send-user-message tool.
|
||||
let send_msg_request_id = mcp_process
|
||||
.send_user_message_tool_call("Hello again", &session_id)
|
||||
.await
|
||||
.expect("send send-message tool call");
|
||||
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(send_msg_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("send-user-message response timeout")
|
||||
.expect("send-user-message response error");
|
||||
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(send_msg_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "{\"status\":\"ok\"}",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"isError": false,
|
||||
"structuredContent": {
|
||||
"status": "ok"
|
||||
}
|
||||
}),
|
||||
},
|
||||
response
|
||||
);
|
||||
// wait for the server to hear the user message
|
||||
sleep(Duration::from_secs(5));
|
||||
|
||||
// Ensure the server and tempdir live until end of test
|
||||
drop(server);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_message_session_not_found() {
|
||||
// Start MCP without creating a Codex session
|
||||
let codex_home = TempDir::new().expect("tempdir");
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await.expect("spawn");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("timeout")
|
||||
.expect("init");
|
||||
|
||||
let unknown = uuid::Uuid::new_v4().to_string();
|
||||
let req_id = mcp
|
||||
.send_user_message_tool_call("ping", &unknown)
|
||||
.await
|
||||
.expect("send tool");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("timeout")
|
||||
.expect("resp");
|
||||
|
||||
let result = resp.result.clone();
|
||||
let content = result["content"][0]["text"].as_str().unwrap_or("");
|
||||
assert!(content.contains("Session does not exist"));
|
||||
assert_eq!(result["isError"], json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
32
codex-rs/ollama/Cargo.toml
Normal file
32
codex-rs/ollama/Cargo.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-ollama"
|
||||
version = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "codex_ollama"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-stream = "0.3"
|
||||
bytes = "1.10.1"
|
||||
codex-core = { path = "../core" }
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
toml = "0.9.2"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
wiremock = "0.6"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
351
codex-rs/ollama/src/client.rs
Normal file
351
codex-rs/ollama/src/client.rs
Normal file
@@ -0,0 +1,351 @@
|
||||
use bytes::BytesMut;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::BoxStream;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
|
||||
use crate::parser::pull_events_from_value;
|
||||
use crate::pull::PullEvent;
|
||||
use crate::pull::PullProgressReporter;
|
||||
use crate::url::base_url_to_host_root;
|
||||
use crate::url::is_openai_compatible_base_url;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
|
||||
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
|
||||
|
||||
/// Client for interacting with a local Ollama instance.
|
||||
pub struct OllamaClient {
|
||||
client: reqwest::Client,
|
||||
host_root: String,
|
||||
uses_openai_compat: bool,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Construct a client for the built‑in open‑source ("oss") model provider
|
||||
/// and verify that a local Ollama server is reachable. If no server is
|
||||
/// detected, returns an error with helpful installation/run instructions.
|
||||
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
|
||||
// Note that we must look up the provider from the Config to ensure that
|
||||
// any overrides the user has in their config.toml are taken into
|
||||
// account.
|
||||
let provider = config
|
||||
.model_providers
|
||||
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
|
||||
)
|
||||
})?;
|
||||
|
||||
Self::try_from_provider(provider).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
|
||||
let provider = codex_core::create_oss_provider_with_base_url(base_url);
|
||||
Self::try_from_provider(&provider).await
|
||||
}
|
||||
|
||||
/// Build a client from a provider definition and verify the server is reachable.
|
||||
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
|
||||
#![allow(clippy::expect_used)]
|
||||
let base_url = provider
|
||||
.base_url
|
||||
.as_ref()
|
||||
.expect("oss provider must have a base_url");
|
||||
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|
||||
|| matches!(provider.wire_api, WireApi::Chat)
|
||||
&& is_openai_compatible_base_url(base_url);
|
||||
let host_root = base_url_to_host_root(base_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
let client = Self {
|
||||
client,
|
||||
host_root,
|
||||
uses_openai_compat,
|
||||
};
|
||||
client.probe_server().await?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
||||
async fn probe_server(&self) -> io::Result<()> {
|
||||
let url = if self.uses_openai_compat {
|
||||
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
||||
};
|
||||
let resp = self.client.get(url).send().await.map_err(|err| {
|
||||
tracing::warn!("Failed to connect to Ollama server: {err:?}");
|
||||
io::Error::other(OLLAMA_CONNECTION_ERROR)
|
||||
})?;
|
||||
if resp.status().is_success() {
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Failed to probe server at {}: HTTP {}",
|
||||
self.host_root,
|
||||
resp.status()
|
||||
);
|
||||
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the list of model names known to the local Ollama instance.
|
||||
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
||||
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.get(tags_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
|
||||
let names = val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Ok(names)
|
||||
}
|
||||
|
||||
/// Start a model pull and emit streaming events. The returned stream ends when
|
||||
/// a Success event is observed or the server closes the connection.
|
||||
pub async fn pull_model_stream(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> io::Result<BoxStream<'static, PullEvent>> {
|
||||
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(io::Error::other(format!(
|
||||
"failed to start pull: HTTP {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut buf = BytesMut::new();
|
||||
let _pending: VecDeque<PullEvent> = VecDeque::new();
|
||||
|
||||
// Using an async stream adaptor backed by unfold-like manual loop.
|
||||
let s = async_stream::stream! {
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
buf.extend_from_slice(&bytes);
|
||||
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||||
let line = buf.split_to(pos + 1);
|
||||
if let Ok(text) = std::str::from_utf8(&line) {
|
||||
let text = text.trim();
|
||||
if text.is_empty() { continue; }
|
||||
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
||||
for ev in pull_events_from_value(&value) { yield ev; }
|
||||
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
||||
yield PullEvent::Error(err_msg.to_string());
|
||||
return;
|
||||
}
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
if status == "success" { yield PullEvent::Success; return; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Connection error: end the stream.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(s))
|
||||
}
|
||||
|
||||
/// High-level helper to pull a model and drive a progress reporter.
|
||||
pub async fn pull_with_reporter(
|
||||
&self,
|
||||
model: &str,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> io::Result<()> {
|
||||
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
|
||||
let mut stream = self.pull_model_stream(model).await?;
|
||||
while let Some(event) = stream.next().await {
|
||||
reporter.on_event(&event)?;
|
||||
match event {
|
||||
PullEvent::Success => {
|
||||
return Ok(());
|
||||
}
|
||||
PullEvent::Error(err) => {
|
||||
// Empirically, ollama returns a 200 OK response even when
|
||||
// the output stream includes an error message. Verify with:
|
||||
//
|
||||
// `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
|
||||
//
|
||||
// As such, we have to check the event stream, not the
|
||||
// HTTP response status, to determine whether to return Err.
|
||||
return Err(io::Error::other(format!("Pull failed: {err}")));
|
||||
}
|
||||
PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(io::Error::other(
|
||||
"Pull stream ended unexpectedly without success.",
|
||||
))
|
||||
}
|
||||
|
||||
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
||||
#[cfg(test)]
|
||||
fn from_host_root(host_root: impl Into<String>) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
client,
|
||||
host_root: host_root.into(),
|
||||
uses_openai_compat: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
||||
#[tokio::test]
|
||||
async fn test_fetch_models_happy_path() {
|
||||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} is set; skipping test_fetch_models_happy_path",
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(
|
||||
wiremock::ResponseTemplate::new(200).set_body_raw(
|
||||
serde_json::json!({
|
||||
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
|
||||
})
|
||||
.to_string(),
|
||||
"application/json",
|
||||
),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = OllamaClient::from_host_root(server.uri());
|
||||
let models = client.fetch_models().await.expect("fetch models");
|
||||
assert!(models.contains(&"llama3.2:3b".to_string()));
|
||||
assert!(models.contains(&"mistral".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_probe_server_happy_path_openai_compat_and_native() {
|
||||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
|
||||
// Native endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let native = OllamaClient::from_host_root(server.uri());
|
||||
native.probe_server().await.expect("probe native");
|
||||
|
||||
// OpenAI compatibility endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/v1/models"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let ollama_client =
|
||||
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||
.await
|
||||
.expect("probe OpenAI compat");
|
||||
ollama_client
|
||||
.probe_server()
|
||||
.await
|
||||
.expect("probe OpenAI compat");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_from_oss_provider_ok_when_server_running() {
|
||||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} set; skipping test_try_from_oss_provider_ok_when_server_running",
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
|
||||
// OpenAI‑compat models endpoint responds OK.
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/v1/models"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||
.await
|
||||
.expect("client should be created when probe succeeds");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_from_oss_provider_err_when_server_missing() {
|
||||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} set; skipping test_try_from_oss_provider_err_when_server_missing",
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||
.await
|
||||
.err()
|
||||
.expect("expected error");
|
||||
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
|
||||
}
|
||||
}
|
||||
44
codex-rs/ollama/src/lib.rs
Normal file
44
codex-rs/ollama/src/lib.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
mod client;
|
||||
mod parser;
|
||||
mod pull;
|
||||
mod url;
|
||||
|
||||
pub use client::OllamaClient;
|
||||
use codex_core::config::Config;
|
||||
pub use pull::CliProgressReporter;
|
||||
pub use pull::PullEvent;
|
||||
pub use pull::PullProgressReporter;
|
||||
pub use pull::TuiProgressReporter;
|
||||
|
||||
/// Default OSS model to use when `--oss` is passed without an explicit `-m`.
|
||||
pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b";
|
||||
|
||||
/// Prepare the local OSS environment when `--oss` is selected.
|
||||
///
|
||||
/// - Ensures a local Ollama server is reachable.
|
||||
/// - Checks if the model exists locally and pulls it if missing.
|
||||
pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> {
|
||||
// Only download when the requested model is the default OSS model (or when -m is not provided).
|
||||
let model = config.model.as_ref();
|
||||
|
||||
// Verify local Ollama is reachable.
|
||||
let ollama_client = crate::OllamaClient::try_from_oss_provider(config).await?;
|
||||
|
||||
// If the model is not present locally, pull it.
|
||||
match ollama_client.fetch_models().await {
|
||||
Ok(models) => {
|
||||
if !models.iter().any(|m| m == model) {
|
||||
let mut reporter = crate::CliProgressReporter::new();
|
||||
ollama_client
|
||||
.pull_with_reporter(model, &mut reporter)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
// Not fatal; higher layers may still proceed and surface errors later.
|
||||
tracing::warn!("Failed to query local models from Ollama: {}.", err);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
82
codex-rs/ollama/src/parser.rs
Normal file
82
codex-rs/ollama/src/parser.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use crate::pull::PullEvent;
|
||||
|
||||
// Convert a single JSON object representing a pull update into one or more events.
|
||||
pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
events.push(PullEvent::Status(status.to_string()));
|
||||
if status == "success" {
|
||||
events.push(PullEvent::Success);
|
||||
}
|
||||
}
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if total.is_some() || completed.is_some() {
|
||||
events.push(PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
});
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_status_and_success() {
|
||||
let v: JsonValue = serde_json::json!({"status":"verifying"});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"status":"success"});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 2);
|
||||
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
|
||||
assert!(matches!(events2[1], PullEvent::Success));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_progress() {
|
||||
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:abc");
|
||||
assert_eq!(*total, Some(100));
|
||||
assert_eq!(*completed, None);
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 1);
|
||||
match &events2[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:def");
|
||||
assert_eq!(*total, None);
|
||||
assert_eq!(*completed, Some(42));
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
}
|
||||
}
|
||||
147
codex-rs/ollama/src/pull.rs
Normal file
147
codex-rs/ollama/src/pull.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
/// Events emitted while pulling a model from Ollama.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PullEvent {
|
||||
/// A human-readable status message (e.g., "verifying", "writing").
|
||||
Status(String),
|
||||
/// Byte-level progress update for a specific layer digest.
|
||||
ChunkProgress {
|
||||
digest: String,
|
||||
total: Option<u64>,
|
||||
completed: Option<u64>,
|
||||
},
|
||||
/// The pull finished successfully.
|
||||
Success,
|
||||
|
||||
/// Error event with a message.
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// A simple observer for pull progress events. Implementations decide how to
|
||||
/// render progress (CLI, TUI, logs, ...).
|
||||
pub trait PullProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A minimal CLI reporter that writes inline progress to stderr.
|
||||
pub struct CliProgressReporter {
|
||||
printed_header: bool,
|
||||
last_line_len: usize,
|
||||
last_completed_sum: u64,
|
||||
last_instant: std::time::Instant,
|
||||
totals_by_digest: HashMap<String, (u64, u64)>,
|
||||
}
|
||||
|
||||
impl Default for CliProgressReporter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CliProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
printed_header: false,
|
||||
last_line_len: 0,
|
||||
last_completed_sum: 0,
|
||||
last_instant: std::time::Instant::now(),
|
||||
totals_by_digest: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PullProgressReporter for CliProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
let mut out = std::io::stderr();
|
||||
match event {
|
||||
PullEvent::Status(status) => {
|
||||
// Avoid noisy manifest messages; otherwise show status inline.
|
||||
if status.eq_ignore_ascii_case("pulling manifest") {
|
||||
return Ok(());
|
||||
}
|
||||
let pad = self.last_line_len.saturating_sub(status.len());
|
||||
let line = format!("\r{status}{}", " ".repeat(pad));
|
||||
self.last_line_len = status.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
}
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
if let Some(t) = *total {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.0 = t;
|
||||
}
|
||||
if let Some(c) = *completed {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.1 = c;
|
||||
}
|
||||
|
||||
let (sum_total, sum_completed) = self
|
||||
.totals_by_digest
|
||||
.values()
|
||||
.fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
|
||||
if sum_total > 0 {
|
||||
if !self.printed_header {
|
||||
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let header = format!("Downloading model: total {gb:.2} GB\n");
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(header.as_bytes())?;
|
||||
self.printed_header = true;
|
||||
}
|
||||
let now = std::time::Instant::now();
|
||||
let dt = now
|
||||
.duration_since(self.last_instant)
|
||||
.as_secs_f64()
|
||||
.max(0.001);
|
||||
let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
|
||||
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
|
||||
self.last_completed_sum = sum_completed;
|
||||
self.last_instant = now;
|
||||
|
||||
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
|
||||
let text =
|
||||
format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
|
||||
let pad = self.last_line_len.saturating_sub(text.len());
|
||||
let line = format!("\r{text}{}", " ".repeat(pad));
|
||||
self.last_line_len = text.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
PullEvent::Error(_) => {
|
||||
// This will be handled by the caller, so we don't do anything
|
||||
// here or the error will be printed twice.
|
||||
Ok(())
|
||||
}
|
||||
PullEvent::Success => {
|
||||
out.write_all(b"\n")?;
|
||||
out.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
|
||||
/// CLI behavior aligned until a dedicated TUI integration is implemented.
|
||||
#[derive(Default)]
|
||||
pub struct TuiProgressReporter(CliProgressReporter);
|
||||
|
||||
impl PullProgressReporter for TuiProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
self.0.on_event(event)
|
||||
}
|
||||
}
|
||||
39
codex-rs/ollama/src/url.rs
Normal file
39
codex-rs/ollama/src/url.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
|
||||
pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool {
|
||||
base_url.trim_end_matches('/').ends_with("/v1")
|
||||
}
|
||||
|
||||
/// Convert a provider base_url into the native Ollama host root.
|
||||
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
|
||||
pub fn base_url_to_host_root(base_url: &str) -> String {
|
||||
let trimmed = base_url.trim_end_matches('/');
|
||||
if trimmed.ends_with("/v1") {
|
||||
trimmed
|
||||
.trim_end_matches("/v1")
|
||||
.trim_end_matches('/')
|
||||
.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_base_url_to_host_root() {
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/v1"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,33 @@ if ! git diff --quiet || ! git diff --cached --quiet || [ -n "$(git ls-files --o
|
||||
fi
|
||||
|
||||
# Fail if in a detached HEAD state.
|
||||
CURRENT_BRANCH=$(git symbolic-ref --short -q HEAD)
|
||||
CURRENT_BRANCH=$(git symbolic-ref --short -q HEAD 2>/dev/null || true)
|
||||
if [ -z "${CURRENT_BRANCH:-}" ]; then
|
||||
echo "ERROR: Could not determine the current branch (detached HEAD?)." >&2
|
||||
echo " Please run this script from a checked-out branch." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure we are on the 'main' branch before proceeding.
|
||||
if [ "${CURRENT_BRANCH}" != "main" ]; then
|
||||
echo "ERROR: Releases must be created from the 'main' branch (current: '${CURRENT_BRANCH}')." >&2
|
||||
echo " Please switch to 'main' and try again." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure the current local commit on 'main' is present on 'origin/main'.
|
||||
# This guarantees we only create releases from commits that are already on
|
||||
# the canonical repository (https://github.com/openai/codex).
|
||||
if ! git fetch --quiet origin main; then
|
||||
echo "ERROR: Failed to fetch 'origin/main'. Ensure the 'origin' remote is configured and reachable." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! git merge-base --is-ancestor HEAD origin/main; then
|
||||
echo "ERROR: Your local 'main' HEAD commit is not present on 'origin/main'." >&2
|
||||
echo " Please push your commits first (git push origin main) or check out a commit on 'origin/main'." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create a new branch for the release and make a commit with the new version.
|
||||
if [ $# -ge 1 ]; then
|
||||
|
||||
@@ -11,12 +11,17 @@ path = "src/main.rs"
|
||||
name = "codex_tui"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[features]
|
||||
# Enable vt100-based tests (emulator) when running with `--features vt100-tests`.
|
||||
vt100-tests = []
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
base64 = "0.22.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-ansi-escape = { path = "../ansi-escape" }
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
@@ -28,6 +33,7 @@ codex-common = { path = "../common", features = [
|
||||
codex-core = { path = "../core" }
|
||||
codex-file-search = { path = "../file-search" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-ollama = { path = "../ollama" }
|
||||
color-eyre = "0.6.3"
|
||||
crossterm = { version = "0.28.1", features = ["bracketed-paste"] }
|
||||
image = { version = "^0.25.6", default-features = false, features = ["jpeg"] }
|
||||
@@ -41,10 +47,14 @@ ratatui = { version = "0.29.0", features = [
|
||||
] }
|
||||
ratatui-image = "8.0.0"
|
||||
regex-lite = "0.1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1", features = ["preserve_order"] }
|
||||
shlex = "1.3.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
supports-color = "3.0.2"
|
||||
textwrap = "0.16.2"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -57,11 +67,15 @@ tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
tui-input = "0.14.0"
|
||||
tui-markdown = "0.3.3"
|
||||
tui-textarea = "0.7.0"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
uuid = "1"
|
||||
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "1.43.1"
|
||||
pretty_assertions = "1"
|
||||
rand = "0.8"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
vt100 = "0.16.2"
|
||||
|
||||
40
codex-rs/tui/prompt_for_init_command.md
Normal file
40
codex-rs/tui/prompt_for_init_command.md
Normal file
@@ -0,0 +1,40 @@
|
||||
Generate a file named AGENTS.md that serves as a contributor guide for this repository.
|
||||
Your goal is to produce a clear, concise, and well-structured document with descriptive headings and actionable explanations for each section.
|
||||
Follow the outline below, but adapt as needed — add sections if relevant, and omit those that do not apply to this project.
|
||||
|
||||
Document Requirements
|
||||
|
||||
- Title the document "Repository Guidelines".
|
||||
- Use Markdown headings (#, ##, etc.) for structure.
|
||||
- Keep the document concise. 200-400 words is optimal.
|
||||
- Keep explanations short, direct, and specific to this repository.
|
||||
- Provide examples where helpful (commands, directory paths, naming patterns).
|
||||
- Maintain a professional, instructional tone.
|
||||
|
||||
Recommended Sections
|
||||
|
||||
Project Structure & Module Organization
|
||||
|
||||
- Outline the project structure, including where the source code, tests, and assets are located.
|
||||
|
||||
Build, Test, and Development Commands
|
||||
|
||||
- List key commands for building, testing, and running locally (e.g., npm test, make build).
|
||||
- Briefly explain what each command does.
|
||||
|
||||
Coding Style & Naming Conventions
|
||||
|
||||
- Specify indentation rules, language-specific style preferences, and naming patterns.
|
||||
- Include any formatting or linting tools used.
|
||||
|
||||
Testing Guidelines
|
||||
|
||||
- Identify testing frameworks and coverage requirements.
|
||||
- State test naming conventions and how to run tests.
|
||||
|
||||
Commit & Pull Request Guidelines
|
||||
|
||||
- Summarize commit message conventions found in the project’s Git history.
|
||||
- Outline pull request requirements (descriptions, linked issues, screenshots, etc.).
|
||||
|
||||
(Optional) Add other sections if relevant, such as Security & Configuration Tips, Architecture Overview, or Agent-Specific Instructions.
|
||||
@@ -10,13 +10,16 @@ use crate::tui;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::Op;
|
||||
use color_eyre::eyre::Result;
|
||||
use crossterm::SynchronizedUpdate;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyEventKind;
|
||||
use crossterm::terminal::supports_keyboard_enhancement;
|
||||
use ratatui::layout::Offset;
|
||||
use ratatui::prelude::Backend;
|
||||
use ratatui::text::Line;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@@ -55,9 +58,13 @@ pub(crate) struct App<'a> {
|
||||
/// True when a redraw has been scheduled but not yet executed.
|
||||
pending_redraw: Arc<AtomicBool>,
|
||||
|
||||
pending_history_lines: Vec<Line<'static>>,
|
||||
|
||||
/// Stored parameters needed to instantiate the ChatWidget later, e.g.,
|
||||
/// after dismissing the Git-repo warning.
|
||||
chat_args: Option<ChatWidgetArgs>,
|
||||
|
||||
enhanced_keys_supported: bool,
|
||||
}
|
||||
|
||||
/// Aggregate parameters needed to create a `ChatWidget`, as creation may be
|
||||
@@ -67,6 +74,7 @@ struct ChatWidgetArgs {
|
||||
config: Config,
|
||||
initial_prompt: Option<String>,
|
||||
initial_images: Vec<PathBuf>,
|
||||
enhanced_keys_supported: bool,
|
||||
}
|
||||
|
||||
impl App<'_> {
|
||||
@@ -80,6 +88,8 @@ impl App<'_> {
|
||||
let app_event_tx = AppEventSender::new(app_event_tx);
|
||||
let pending_redraw = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let enhanced_keys_supported = supports_keyboard_enhancement().unwrap_or(false);
|
||||
|
||||
// Spawn a dedicated thread for reading the crossterm event loop and
|
||||
// re-publishing the events as AppEvents, as appropriate.
|
||||
{
|
||||
@@ -96,9 +106,7 @@ impl App<'_> {
|
||||
if let Ok(event) = crossterm::event::read() {
|
||||
match event {
|
||||
crossterm::event::Event::Key(key_event) => {
|
||||
if key_event.kind == crossterm::event::KeyEventKind::Press {
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
crossterm::event::Event::Resize(_, _) => {
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
@@ -134,6 +142,7 @@ impl App<'_> {
|
||||
config: config.clone(),
|
||||
initial_prompt,
|
||||
initial_images,
|
||||
enhanced_keys_supported,
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
@@ -142,6 +151,7 @@ impl App<'_> {
|
||||
app_event_tx.clone(),
|
||||
initial_prompt,
|
||||
initial_images,
|
||||
enhanced_keys_supported,
|
||||
);
|
||||
(
|
||||
AppState::Chat {
|
||||
@@ -154,12 +164,14 @@ impl App<'_> {
|
||||
let file_search = FileSearchManager::new(config.cwd.clone(), app_event_tx.clone());
|
||||
Self {
|
||||
app_event_tx,
|
||||
pending_history_lines: Vec::new(),
|
||||
app_event_rx,
|
||||
app_state,
|
||||
config,
|
||||
file_search,
|
||||
pending_redraw,
|
||||
chat_args,
|
||||
enhanced_keys_supported,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,7 +211,7 @@ impl App<'_> {
|
||||
while let Ok(event) = self.app_event_rx.recv() {
|
||||
match event {
|
||||
AppEvent::InsertHistory(lines) => {
|
||||
crate::insert_history::insert_history_lines(terminal, lines);
|
||||
self.pending_history_lines.extend(lines);
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
AppEvent::RequestRedraw => {
|
||||
@@ -213,6 +225,7 @@ impl App<'_> {
|
||||
KeyEvent {
|
||||
code: KeyCode::Char('c'),
|
||||
modifiers: crossterm::event::KeyModifiers::CONTROL,
|
||||
kind: KeyEventKind::Press,
|
||||
..
|
||||
} => {
|
||||
match &mut self.app_state {
|
||||
@@ -220,13 +233,15 @@ impl App<'_> {
|
||||
widget.on_ctrl_c();
|
||||
}
|
||||
AppState::GitWarning { .. } => {
|
||||
// No-op.
|
||||
// Allow exiting the app with Ctrl+C from the warning screen.
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
}
|
||||
}
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Char('d'),
|
||||
modifiers: crossterm::event::KeyModifiers::CONTROL,
|
||||
kind: KeyEventKind::Press,
|
||||
..
|
||||
} => {
|
||||
match &mut self.app_state {
|
||||
@@ -245,9 +260,15 @@ impl App<'_> {
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
KeyEvent {
|
||||
kind: KeyEventKind::Press | KeyEventKind::Repeat,
|
||||
..
|
||||
} => {
|
||||
self.dispatch_key_event(key_event);
|
||||
}
|
||||
_ => {
|
||||
// Ignore Release key events for now.
|
||||
}
|
||||
};
|
||||
}
|
||||
AppEvent::Paste(text) => {
|
||||
@@ -274,10 +295,24 @@ impl App<'_> {
|
||||
self.app_event_tx.clone(),
|
||||
None,
|
||||
Vec::new(),
|
||||
self.enhanced_keys_supported,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget: new_widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
SlashCommand::Init => {
|
||||
// Guard: do not run if a task is active.
|
||||
if let AppState::Chat { widget } = &mut self.app_state {
|
||||
const INIT_PROMPT: &str = include_str!("../prompt_for_init_command.md");
|
||||
widget.submit_text_message(INIT_PROMPT.to_string());
|
||||
}
|
||||
}
|
||||
SlashCommand::Compact => {
|
||||
if let AppState::Chat { widget } = &mut self.app_state {
|
||||
widget.clear_token_usage();
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Compact));
|
||||
}
|
||||
}
|
||||
SlashCommand::Quit => {
|
||||
break;
|
||||
}
|
||||
@@ -302,16 +337,48 @@ impl App<'_> {
|
||||
widget.add_diff_output(text);
|
||||
}
|
||||
}
|
||||
SlashCommand::Status => {
|
||||
if let AppState::Chat { widget } = &mut self.app_state {
|
||||
widget.add_status_output();
|
||||
}
|
||||
}
|
||||
#[cfg(debug_assertions)]
|
||||
SlashCommand::TestApproval => {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
|
||||
self.app_event_tx.send(AppEvent::CodexEvent(Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id: "1".to_string(),
|
||||
command: vec!["git".into(), "apply".into()],
|
||||
cwd: self.config.cwd.clone(),
|
||||
reason: Some("test".to_string()),
|
||||
}),
|
||||
// msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
// call_id: "1".to_string(),
|
||||
// command: vec!["git".into(), "apply".into()],
|
||||
// cwd: self.config.cwd.clone(),
|
||||
// reason: Some("test".to_string()),
|
||||
// }),
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(
|
||||
ApplyPatchApprovalRequestEvent {
|
||||
call_id: "1".to_string(),
|
||||
changes: HashMap::from([
|
||||
(
|
||||
PathBuf::from("/tmp/test.txt"),
|
||||
FileChange::Add {
|
||||
content: "test".to_string(),
|
||||
},
|
||||
),
|
||||
(
|
||||
PathBuf::from("/tmp/test2.txt"),
|
||||
FileChange::Update {
|
||||
unified_diff: "+test\n-test2".to_string(),
|
||||
move_path: None,
|
||||
},
|
||||
),
|
||||
]),
|
||||
reason: None,
|
||||
grant_root: Some(PathBuf::from("/tmp")),
|
||||
},
|
||||
),
|
||||
}));
|
||||
}
|
||||
},
|
||||
@@ -361,8 +428,9 @@ impl App<'_> {
|
||||
let size = terminal.size()?;
|
||||
let desired_height = match &self.app_state {
|
||||
AppState::Chat { widget } => widget.desired_height(size.width),
|
||||
AppState::GitWarning { .. } => 10,
|
||||
AppState::GitWarning { .. } => size.height,
|
||||
};
|
||||
|
||||
let mut area = terminal.viewport_area;
|
||||
area.height = desired_height.min(size.height);
|
||||
area.width = size.width;
|
||||
@@ -376,14 +444,22 @@ impl App<'_> {
|
||||
terminal.clear()?;
|
||||
terminal.set_viewport_area(area);
|
||||
}
|
||||
match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
terminal.draw(|frame| frame.render_widget_ref(&**widget, frame.area()))?;
|
||||
}
|
||||
AppState::GitWarning { screen } => {
|
||||
terminal.draw(|frame| frame.render_widget_ref(&*screen, frame.area()))?;
|
||||
}
|
||||
if !self.pending_history_lines.is_empty() {
|
||||
crate::insert_history::insert_history_lines(
|
||||
terminal,
|
||||
self.pending_history_lines.clone(),
|
||||
);
|
||||
self.pending_history_lines.clear();
|
||||
}
|
||||
terminal.draw(|frame| match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
if let Some((x, y)) = widget.cursor_pos(frame.area()) {
|
||||
frame.set_cursor_position((x, y));
|
||||
}
|
||||
frame.render_widget_ref(&**widget, frame.area())
|
||||
}
|
||||
AppState::GitWarning { screen } => frame.render_widget_ref(&*screen, frame.area()),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -407,6 +483,7 @@ impl App<'_> {
|
||||
self.app_event_tx.clone(),
|
||||
args.initial_prompt,
|
||||
args.initial_images,
|
||||
args.enhanced_keys_supported,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
|
||||
@@ -99,6 +99,7 @@ mod tests {
|
||||
let mut pane = BottomPane::new(super::super::BottomPaneParams {
|
||||
app_event_tx: AppEventSender::new(tx_raw2),
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
assert_eq!(CancellationEvent::Handled, view.on_ctrl_c(&mut pane));
|
||||
assert!(view.queue.is_empty());
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tui_textarea::CursorMove;
|
||||
use tui_textarea::TextArea;
|
||||
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -67,59 +64,52 @@ impl ChatComposerHistory {
|
||||
|
||||
/// Should Up/Down key presses be interpreted as history navigation given
|
||||
/// the current content and cursor position of `textarea`?
|
||||
pub fn should_handle_navigation(&self, textarea: &TextArea) -> bool {
|
||||
pub fn should_handle_navigation(&self, text: &str, cursor: usize) -> bool {
|
||||
if self.history_entry_count == 0 && self.local_history.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if textarea.is_empty() {
|
||||
if text.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Textarea is not empty – only navigate when cursor is at start and
|
||||
// text matches last recalled history entry so regular editing is not
|
||||
// hijacked.
|
||||
let (row, col) = textarea.cursor();
|
||||
if row != 0 || col != 0 {
|
||||
if cursor != 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lines = textarea.lines();
|
||||
matches!(&self.last_history_text, Some(prev) if prev == &lines.join("\n"))
|
||||
matches!(&self.last_history_text, Some(prev) if prev == text)
|
||||
}
|
||||
|
||||
/// Handle <Up>. Returns true when the key was consumed and the caller
|
||||
/// should request a redraw.
|
||||
pub fn navigate_up(&mut self, textarea: &mut TextArea, app_event_tx: &AppEventSender) -> bool {
|
||||
pub fn navigate_up(&mut self, app_event_tx: &AppEventSender) -> Option<String> {
|
||||
let total_entries = self.history_entry_count + self.local_history.len();
|
||||
if total_entries == 0 {
|
||||
return false;
|
||||
return None;
|
||||
}
|
||||
|
||||
let next_idx = match self.history_cursor {
|
||||
None => (total_entries as isize) - 1,
|
||||
Some(0) => return true, // already at oldest
|
||||
Some(0) => return None, // already at oldest
|
||||
Some(idx) => idx - 1,
|
||||
};
|
||||
|
||||
self.history_cursor = Some(next_idx);
|
||||
self.populate_history_at_index(next_idx as usize, textarea, app_event_tx);
|
||||
true
|
||||
self.populate_history_at_index(next_idx as usize, app_event_tx)
|
||||
}
|
||||
|
||||
/// Handle <Down>.
|
||||
pub fn navigate_down(
|
||||
&mut self,
|
||||
textarea: &mut TextArea,
|
||||
app_event_tx: &AppEventSender,
|
||||
) -> bool {
|
||||
pub fn navigate_down(&mut self, app_event_tx: &AppEventSender) -> Option<String> {
|
||||
let total_entries = self.history_entry_count + self.local_history.len();
|
||||
if total_entries == 0 {
|
||||
return false;
|
||||
return None;
|
||||
}
|
||||
|
||||
let next_idx_opt = match self.history_cursor {
|
||||
None => return false, // not browsing
|
||||
None => return None, // not browsing
|
||||
Some(idx) if (idx as usize) + 1 >= total_entries => None,
|
||||
Some(idx) => Some(idx + 1),
|
||||
};
|
||||
@@ -127,16 +117,15 @@ impl ChatComposerHistory {
|
||||
match next_idx_opt {
|
||||
Some(idx) => {
|
||||
self.history_cursor = Some(idx);
|
||||
self.populate_history_at_index(idx as usize, textarea, app_event_tx);
|
||||
self.populate_history_at_index(idx as usize, app_event_tx)
|
||||
}
|
||||
None => {
|
||||
// Past newest – clear and exit browsing mode.
|
||||
self.history_cursor = None;
|
||||
self.last_history_text = None;
|
||||
self.replace_textarea_content(textarea, "");
|
||||
Some(String::new())
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Integrate a GetHistoryEntryResponse event.
|
||||
@@ -145,19 +134,18 @@ impl ChatComposerHistory {
|
||||
log_id: u64,
|
||||
offset: usize,
|
||||
entry: Option<String>,
|
||||
textarea: &mut TextArea,
|
||||
) -> bool {
|
||||
) -> Option<String> {
|
||||
if self.history_log_id != Some(log_id) {
|
||||
return false;
|
||||
return None;
|
||||
}
|
||||
let Some(text) = entry else { return false };
|
||||
let text = entry?;
|
||||
self.fetched_history.insert(offset, text.clone());
|
||||
|
||||
if self.history_cursor == Some(offset as isize) {
|
||||
self.replace_textarea_content(textarea, &text);
|
||||
return true;
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text);
|
||||
}
|
||||
false
|
||||
None
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
@@ -167,21 +155,20 @@ impl ChatComposerHistory {
|
||||
fn populate_history_at_index(
|
||||
&mut self,
|
||||
global_idx: usize,
|
||||
textarea: &mut TextArea,
|
||||
app_event_tx: &AppEventSender,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
if global_idx >= self.history_entry_count {
|
||||
// Local entry.
|
||||
if let Some(text) = self
|
||||
.local_history
|
||||
.get(global_idx - self.history_entry_count)
|
||||
{
|
||||
let t = text.clone();
|
||||
self.replace_textarea_content(textarea, &t);
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text.clone());
|
||||
}
|
||||
} else if let Some(text) = self.fetched_history.get(&global_idx) {
|
||||
let t = text.clone();
|
||||
self.replace_textarea_content(textarea, &t);
|
||||
self.last_history_text = Some(text.clone());
|
||||
return Some(text.clone());
|
||||
} else if let Some(log_id) = self.history_log_id {
|
||||
let op = Op::GetHistoryEntryRequest {
|
||||
offset: global_idx,
|
||||
@@ -189,14 +176,7 @@ impl ChatComposerHistory {
|
||||
};
|
||||
app_event_tx.send(AppEvent::CodexOp(op));
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_textarea_content(&mut self, textarea: &mut TextArea, text: &str) {
|
||||
textarea.select_all();
|
||||
textarea.cut();
|
||||
let _ = textarea.insert_str(text);
|
||||
textarea.move_cursor(CursorMove::Jump(0, 0));
|
||||
self.last_history_text = Some(text.to_string());
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,11 +197,9 @@ mod tests {
|
||||
// Pretend there are 3 persistent entries.
|
||||
history.set_metadata(1, 3);
|
||||
|
||||
let mut textarea = TextArea::default();
|
||||
|
||||
// First Up should request offset 2 (latest) and await async data.
|
||||
assert!(history.should_handle_navigation(&textarea));
|
||||
assert!(history.navigate_up(&mut textarea, &tx));
|
||||
assert!(history.should_handle_navigation("", 0));
|
||||
assert!(history.navigate_up(&tx).is_none()); // don't replace the text yet
|
||||
|
||||
// Verify that an AppEvent::CodexOp with the correct GetHistoryEntryRequest was sent.
|
||||
let event = rx.try_recv().expect("expected AppEvent to be sent");
|
||||
@@ -235,14 +213,15 @@ mod tests {
|
||||
},
|
||||
history_request1
|
||||
);
|
||||
assert_eq!(textarea.lines().join("\n"), ""); // still empty
|
||||
|
||||
// Inject the async response.
|
||||
assert!(history.on_entry_response(1, 2, Some("latest".into()), &mut textarea));
|
||||
assert_eq!(textarea.lines().join("\n"), "latest");
|
||||
assert_eq!(
|
||||
Some("latest".into()),
|
||||
history.on_entry_response(1, 2, Some("latest".into()))
|
||||
);
|
||||
|
||||
// Next Up should move to offset 1.
|
||||
assert!(history.navigate_up(&mut textarea, &tx));
|
||||
assert!(history.navigate_up(&tx).is_none()); // don't replace the text yet
|
||||
|
||||
// Verify second CodexOp event for offset 1.
|
||||
let event2 = rx.try_recv().expect("expected second event");
|
||||
@@ -257,7 +236,9 @@ mod tests {
|
||||
history_request_2
|
||||
);
|
||||
|
||||
history.on_entry_response(1, 1, Some("older".into()), &mut textarea);
|
||||
assert_eq!(textarea.lines().join("\n"), "older");
|
||||
assert_eq!(
|
||||
Some("older".into()),
|
||||
history.on_entry_response(1, 1, Some("older".into()))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,3 +188,38 @@ impl WidgetRef for CommandPopup {
|
||||
table.render(area, buf);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn filter_includes_init_when_typing_prefix() {
|
||||
let mut popup = CommandPopup::new();
|
||||
// Simulate the composer line starting with '/in' so the popup filters
|
||||
// matching commands by prefix.
|
||||
popup.on_composer_text_change("/in".to_string());
|
||||
|
||||
// Access the filtered list via the selected command and ensure that
|
||||
// one of the matches is the new "init" command.
|
||||
let matches = popup.filtered_commands();
|
||||
assert!(
|
||||
matches.iter().any(|cmd| cmd.command() == "init"),
|
||||
"expected '/init' to appear among filtered commands"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selecting_init_by_exact_match() {
|
||||
let mut popup = CommandPopup::new();
|
||||
popup.on_composer_text_change("/init".to_string());
|
||||
|
||||
// When an exact match exists, the selected command should be that
|
||||
// command by default.
|
||||
let selected = popup.selected_command();
|
||||
match selected {
|
||||
Some(cmd) => assert_eq!(cmd.command(), "init"),
|
||||
None => panic!("expected a selected command for exact match"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
45
codex-rs/tui/src/bottom_pane/live_ring_widget.rs
Normal file
45
codex-rs/tui/src/bottom_pane/live_ring_widget.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::widgets::Paragraph;
|
||||
use ratatui::widgets::WidgetRef;
|
||||
|
||||
/// Minimal rendering-only widget for the transient ring rows.
|
||||
pub(crate) struct LiveRingWidget {
|
||||
max_rows: u16,
|
||||
rows: Vec<Line<'static>>, // newest at the end
|
||||
}
|
||||
|
||||
impl LiveRingWidget {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_rows: 3,
|
||||
rows: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_max_rows(&mut self, n: u16) {
|
||||
self.max_rows = n.max(1);
|
||||
}
|
||||
|
||||
pub fn set_rows(&mut self, rows: Vec<Line<'static>>) {
|
||||
self.rows = rows;
|
||||
}
|
||||
|
||||
pub fn desired_height(&self, _width: u16) -> u16 {
|
||||
let len = self.rows.len() as u16;
|
||||
len.min(self.max_rows)
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for LiveRingWidget {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
if area.height == 0 {
|
||||
return;
|
||||
}
|
||||
let visible = self.rows.len().saturating_sub(self.max_rows as usize);
|
||||
let slice = &self.rows[visible..];
|
||||
let para = Paragraph::new(slice.to_vec());
|
||||
para.render_ref(area, buf);
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,12 @@ use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use crate::user_approval_widget::ApprovalRequest;
|
||||
use bottom_pane_view::BottomPaneView;
|
||||
use bottom_pane_view::ConditionalUpdate;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_file_search::FileMatch;
|
||||
use crossterm::event::KeyEvent;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::widgets::WidgetRef;
|
||||
|
||||
mod approval_modal_view;
|
||||
@@ -18,7 +18,9 @@ mod chat_composer;
|
||||
mod chat_composer_history;
|
||||
mod command_popup;
|
||||
mod file_search_popup;
|
||||
mod live_ring_widget;
|
||||
mod status_indicator_view;
|
||||
mod textarea;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancellationEvent {
|
||||
@@ -29,6 +31,7 @@ pub(crate) enum CancellationEvent {
|
||||
pub(crate) use chat_composer::ChatComposer;
|
||||
pub(crate) use chat_composer::InputResult;
|
||||
|
||||
use crate::status_indicator_widget::StatusIndicatorWidget;
|
||||
use approval_modal_view::ApprovalModalView;
|
||||
use status_indicator_view::StatusIndicatorView;
|
||||
|
||||
@@ -36,7 +39,7 @@ use status_indicator_view::StatusIndicatorView;
|
||||
pub(crate) struct BottomPane<'a> {
|
||||
/// Composer is retained even when a BottomPaneView is displayed so the
|
||||
/// input state is retained when the view is closed.
|
||||
composer: ChatComposer<'a>,
|
||||
composer: ChatComposer,
|
||||
|
||||
/// If present, this is displayed instead of the `composer`.
|
||||
active_view: Option<Box<dyn BottomPaneView<'a> + 'a>>,
|
||||
@@ -45,30 +48,88 @@ pub(crate) struct BottomPane<'a> {
|
||||
has_input_focus: bool,
|
||||
is_task_running: bool,
|
||||
ctrl_c_quit_hint: bool,
|
||||
|
||||
/// Optional live, multi‑line status/"live cell" rendered directly above
|
||||
/// the composer while a task is running. Unlike `active_view`, this does
|
||||
/// not replace the composer; it augments it.
|
||||
live_status: Option<StatusIndicatorWidget>,
|
||||
|
||||
/// Optional transient ring shown above the composer. This is a rendering-only
|
||||
/// container used during development before we wire it to ChatWidget events.
|
||||
live_ring: Option<live_ring_widget::LiveRingWidget>,
|
||||
|
||||
/// True if the active view is the StatusIndicatorView that replaces the
|
||||
/// composer during a running task.
|
||||
status_view_active: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct BottomPaneParams {
|
||||
pub(crate) app_event_tx: AppEventSender,
|
||||
pub(crate) has_input_focus: bool,
|
||||
pub(crate) enhanced_keys_supported: bool,
|
||||
}
|
||||
|
||||
impl BottomPane<'_> {
|
||||
const BOTTOM_PAD_LINES: u16 = 2;
|
||||
pub fn new(params: BottomPaneParams) -> Self {
|
||||
let enhanced_keys_supported = params.enhanced_keys_supported;
|
||||
Self {
|
||||
composer: ChatComposer::new(params.has_input_focus, params.app_event_tx.clone()),
|
||||
composer: ChatComposer::new(
|
||||
params.has_input_focus,
|
||||
params.app_event_tx.clone(),
|
||||
enhanced_keys_supported,
|
||||
),
|
||||
active_view: None,
|
||||
app_event_tx: params.app_event_tx,
|
||||
has_input_focus: params.has_input_focus,
|
||||
is_task_running: false,
|
||||
ctrl_c_quit_hint: false,
|
||||
live_status: None,
|
||||
live_ring: None,
|
||||
status_view_active: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn desired_height(&self, width: u16) -> u16 {
|
||||
self.active_view
|
||||
let overlay_status_h = self
|
||||
.live_status
|
||||
.as_ref()
|
||||
.map(|v| v.desired_height(width))
|
||||
.unwrap_or(self.composer.desired_height())
|
||||
.map(|s| s.desired_height(width))
|
||||
.unwrap_or(0);
|
||||
let ring_h = self
|
||||
.live_ring
|
||||
.as_ref()
|
||||
.map(|r| r.desired_height(width))
|
||||
.unwrap_or(0);
|
||||
|
||||
let view_height = if let Some(view) = self.active_view.as_ref() {
|
||||
// Add a single blank spacer line between live ring and status view when active.
|
||||
let spacer = if self.live_ring.is_some() && self.status_view_active {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
spacer + view.desired_height(width)
|
||||
} else {
|
||||
self.composer.desired_height(width)
|
||||
};
|
||||
|
||||
overlay_status_h
|
||||
.saturating_add(ring_h)
|
||||
.saturating_add(view_height)
|
||||
.saturating_add(Self::BOTTOM_PAD_LINES)
|
||||
}
|
||||
|
||||
pub fn cursor_pos(&self, area: Rect) -> Option<(u16, u16)> {
|
||||
// Hide the cursor whenever an overlay view is active (e.g. the
|
||||
// status indicator shown while a task is running, or approval modal).
|
||||
// In these states the textarea is not interactable, so we should not
|
||||
// show its caret.
|
||||
if self.active_view.is_some() {
|
||||
None
|
||||
} else {
|
||||
self.composer.cursor_pos(area)
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward a key event to the active view or the composer.
|
||||
@@ -78,9 +139,10 @@ impl BottomPane<'_> {
|
||||
if !view.is_complete() {
|
||||
self.active_view = Some(view);
|
||||
} else if self.is_task_running {
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
)));
|
||||
let mut v = StatusIndicatorView::new(self.app_event_tx.clone());
|
||||
v.update_text("waiting for model".to_string());
|
||||
self.active_view = Some(Box::new(v));
|
||||
self.status_view_active = true;
|
||||
}
|
||||
self.request_redraw();
|
||||
InputResult::None
|
||||
@@ -107,9 +169,11 @@ impl BottomPane<'_> {
|
||||
if !view.is_complete() {
|
||||
self.active_view = Some(view);
|
||||
} else if self.is_task_running {
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
)));
|
||||
// Modal aborted but task still running – restore status indicator.
|
||||
let mut v = StatusIndicatorView::new(self.app_event_tx.clone());
|
||||
v.update_text("waiting for model".to_string());
|
||||
self.active_view = Some(Box::new(v));
|
||||
self.status_view_active = true;
|
||||
}
|
||||
self.show_ctrl_c_quit_hint();
|
||||
}
|
||||
@@ -129,19 +193,42 @@ impl BottomPane<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the status indicator text (only when the `StatusIndicatorView` is
|
||||
/// active).
|
||||
/// Update the status indicator text. Prefer replacing the composer with
|
||||
/// the StatusIndicatorView so the input pane shows a single-line status
|
||||
/// like: `▌ Working waiting for model`.
|
||||
pub(crate) fn update_status_text(&mut self, text: String) {
|
||||
if let Some(view) = &mut self.active_view {
|
||||
match view.update_status_text(text) {
|
||||
ConditionalUpdate::NeedsRedraw => {
|
||||
self.request_redraw();
|
||||
}
|
||||
ConditionalUpdate::NoRedraw => {
|
||||
// No redraw needed.
|
||||
}
|
||||
let mut handled_by_view = false;
|
||||
if let Some(view) = self.active_view.as_mut() {
|
||||
if matches!(
|
||||
view.update_status_text(text.clone()),
|
||||
bottom_pane_view::ConditionalUpdate::NeedsRedraw
|
||||
) {
|
||||
handled_by_view = true;
|
||||
}
|
||||
} else {
|
||||
let mut v = StatusIndicatorView::new(self.app_event_tx.clone());
|
||||
v.update_text(text.clone());
|
||||
self.active_view = Some(Box::new(v));
|
||||
self.status_view_active = true;
|
||||
handled_by_view = true;
|
||||
}
|
||||
|
||||
// Fallback: if the current active view did not consume status updates
|
||||
// and no modal view is active, present an overlay above the composer.
|
||||
// If a modal is active, do NOT render the overlay to avoid drawing
|
||||
// over the dialog.
|
||||
if !handled_by_view && self.active_view.is_none() {
|
||||
if self.live_status.is_none() {
|
||||
self.live_status = Some(StatusIndicatorWidget::new(self.app_event_tx.clone()));
|
||||
}
|
||||
if let Some(status) = &mut self.live_status {
|
||||
status.update_text(text);
|
||||
}
|
||||
} else if !handled_by_view {
|
||||
// Ensure any previous overlay is cleared when a modal becomes active.
|
||||
self.live_status = None;
|
||||
}
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
pub(crate) fn show_ctrl_c_quit_hint(&mut self) {
|
||||
@@ -167,27 +254,23 @@ impl BottomPane<'_> {
|
||||
pub fn set_task_running(&mut self, running: bool) {
|
||||
self.is_task_running = running;
|
||||
|
||||
match (running, self.active_view.is_some()) {
|
||||
(true, false) => {
|
||||
// Show status indicator overlay.
|
||||
if running {
|
||||
if self.active_view.is_none() {
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
)));
|
||||
self.request_redraw();
|
||||
self.status_view_active = true;
|
||||
}
|
||||
(false, true) => {
|
||||
if let Some(mut view) = self.active_view.take() {
|
||||
if view.should_hide_when_task_is_done() {
|
||||
// Leave self.active_view as None.
|
||||
self.request_redraw();
|
||||
} else {
|
||||
// Preserve the view.
|
||||
self.active_view = Some(view);
|
||||
}
|
||||
self.request_redraw();
|
||||
} else {
|
||||
self.live_status = None;
|
||||
// Drop the status view when a task completes, but keep other
|
||||
// modal views (e.g. approval dialogs).
|
||||
if let Some(mut view) = self.active_view.take() {
|
||||
if !view.should_hide_when_task_is_done() {
|
||||
self.active_view = Some(view);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// No change.
|
||||
self.status_view_active = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -229,6 +312,9 @@ impl BottomPane<'_> {
|
||||
// Otherwise create a new approval modal overlay.
|
||||
let modal = ApprovalModalView::new(request, self.app_event_tx.clone());
|
||||
self.active_view = Some(Box::new(modal));
|
||||
// Hide any overlay status while a modal is visible.
|
||||
self.live_status = None;
|
||||
self.status_view_active = false;
|
||||
self.request_redraw()
|
||||
}
|
||||
|
||||
@@ -262,15 +348,82 @@ impl BottomPane<'_> {
|
||||
self.composer.on_file_search_result(query, matches);
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
/// Set the rows and cap for the transient live ring overlay.
|
||||
pub(crate) fn set_live_ring_rows(&mut self, max_rows: u16, rows: Vec<Line<'static>>) {
|
||||
let mut w = live_ring_widget::LiveRingWidget::new();
|
||||
w.set_max_rows(max_rows);
|
||||
w.set_rows(rows);
|
||||
self.live_ring = Some(w);
|
||||
}
|
||||
|
||||
pub(crate) fn clear_live_ring(&mut self) {
|
||||
self.live_ring = None;
|
||||
}
|
||||
|
||||
// Removed restart_live_status_with_text – no longer used by the current streaming UI.
|
||||
}
|
||||
|
||||
impl WidgetRef for &BottomPane<'_> {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
// Show BottomPaneView if present.
|
||||
if let Some(ov) = &self.active_view {
|
||||
ov.render(area, buf);
|
||||
} else {
|
||||
(&self.composer).render_ref(area, buf);
|
||||
let mut y_offset = 0u16;
|
||||
if let Some(ring) = &self.live_ring {
|
||||
let live_h = ring.desired_height(area.width).min(area.height);
|
||||
if live_h > 0 {
|
||||
let live_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: live_h,
|
||||
};
|
||||
ring.render_ref(live_rect, buf);
|
||||
y_offset = live_h;
|
||||
}
|
||||
}
|
||||
// Spacer between live ring and status view when active
|
||||
if self.live_ring.is_some() && self.status_view_active && y_offset < area.height {
|
||||
// Leave one empty line
|
||||
y_offset = y_offset.saturating_add(1);
|
||||
}
|
||||
if let Some(status) = &self.live_status {
|
||||
let live_h = status
|
||||
.desired_height(area.width)
|
||||
.min(area.height.saturating_sub(y_offset));
|
||||
if live_h > 0 {
|
||||
let live_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + y_offset,
|
||||
width: area.width,
|
||||
height: live_h,
|
||||
};
|
||||
status.render_ref(live_rect, buf);
|
||||
y_offset = y_offset.saturating_add(live_h);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(view) = &self.active_view {
|
||||
if y_offset < area.height {
|
||||
// Reserve bottom padding lines; keep at least 1 line for the view.
|
||||
let avail = area.height - y_offset;
|
||||
let pad = BottomPane::BOTTOM_PAD_LINES.min(avail.saturating_sub(1));
|
||||
let view_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + y_offset,
|
||||
width: area.width,
|
||||
height: avail - pad,
|
||||
};
|
||||
view.render(view_rect, buf);
|
||||
}
|
||||
} else if y_offset < area.height {
|
||||
let composer_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + y_offset,
|
||||
width: area.width,
|
||||
// Reserve bottom padding
|
||||
height: (area.height - y_offset)
|
||||
- BottomPane::BOTTOM_PAD_LINES.min((area.height - y_offset).saturating_sub(1)),
|
||||
};
|
||||
(&self.composer).render_ref(composer_rect, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -279,6 +432,9 @@ impl WidgetRef for &BottomPane<'_> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::app_event::AppEvent;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::text::Line;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
@@ -298,10 +454,323 @@ mod tests {
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
pane.push_approval_request(exec_request());
|
||||
assert_eq!(CancellationEvent::Handled, pane.on_ctrl_c());
|
||||
assert!(pane.ctrl_c_quit_hint_visible());
|
||||
assert_eq!(CancellationEvent::Ignored, pane.on_ctrl_c());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn live_ring_renders_above_composer() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Provide 4 rows with max_rows=3; only the last 3 should be visible.
|
||||
pane.set_live_ring_rows(
|
||||
3,
|
||||
vec![
|
||||
Line::from("one".to_string()),
|
||||
Line::from("two".to_string()),
|
||||
Line::from("three".to_string()),
|
||||
Line::from("four".to_string()),
|
||||
],
|
||||
);
|
||||
|
||||
let area = Rect::new(0, 0, 10, 5);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
// Extract the first 3 rows and assert they contain the last three lines.
|
||||
let mut lines: Vec<String> = Vec::new();
|
||||
for y in 0..3 {
|
||||
let mut s = String::new();
|
||||
for x in 0..area.width {
|
||||
s.push(buf[(x, y)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
lines.push(s.trim_end().to_string());
|
||||
}
|
||||
assert_eq!(lines, vec!["two", "three", "four"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_indicator_visible_with_live_ring() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Simulate task running which replaces composer with the status indicator.
|
||||
pane.set_task_running(true);
|
||||
pane.update_status_text("waiting for model".to_string());
|
||||
|
||||
// Provide 2 rows in the live ring (e.g., streaming CoT) and ensure the
|
||||
// status indicator remains visible below them.
|
||||
pane.set_live_ring_rows(
|
||||
2,
|
||||
vec![
|
||||
Line::from("cot1".to_string()),
|
||||
Line::from("cot2".to_string()),
|
||||
],
|
||||
);
|
||||
|
||||
// Allow some frames so the dot animation is present.
|
||||
std::thread::sleep(std::time::Duration::from_millis(120));
|
||||
|
||||
// Height should include both ring rows, 1 spacer, and the 1-line status.
|
||||
let area = Rect::new(0, 0, 30, 4);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
// Top two rows are the live ring.
|
||||
let mut r0 = String::new();
|
||||
let mut r1 = String::new();
|
||||
for x in 0..area.width {
|
||||
r0.push(buf[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
r1.push(buf[(x, 1)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(r0.contains("cot1"), "expected first live row: {r0:?}");
|
||||
assert!(r1.contains("cot2"), "expected second live row: {r1:?}");
|
||||
|
||||
// Row 2 is the spacer (blank)
|
||||
let mut r2 = String::new();
|
||||
for x in 0..area.width {
|
||||
r2.push(buf[(x, 2)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(r2.trim().is_empty(), "expected blank spacer line: {r2:?}");
|
||||
|
||||
// Bottom row is the status line; it should contain the left bar and "Working".
|
||||
let mut r3 = String::new();
|
||||
for x in 0..area.width {
|
||||
r3.push(buf[(x, 3)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert_eq!(buf[(0, 3)].symbol().chars().next().unwrap_or(' '), '▌');
|
||||
assert!(
|
||||
r3.contains("Working"),
|
||||
"expected Working header in status line: {r3:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overlay_not_shown_above_approval_modal() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Create an approval modal (active view).
|
||||
pane.push_approval_request(exec_request());
|
||||
// Attempt to update status; this should NOT create an overlay while modal is visible.
|
||||
pane.update_status_text("running command".to_string());
|
||||
|
||||
// Render and verify the top row does not include the Working header overlay.
|
||||
let area = Rect::new(0, 0, 60, 6);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
let mut r0 = String::new();
|
||||
for x in 0..area.width {
|
||||
r0.push(buf[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
!r0.contains("Working"),
|
||||
"overlay Working header should not render above modal"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composer_not_shown_after_denied_if_task_running() {
|
||||
let (tx_raw, rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx.clone(),
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Start a running task so the status indicator replaces the composer.
|
||||
pane.set_task_running(true);
|
||||
pane.update_status_text("waiting for model".to_string());
|
||||
|
||||
// Push an approval modal (e.g., command approval) which should hide the status view.
|
||||
pane.push_approval_request(exec_request());
|
||||
|
||||
// Simulate pressing 'n' (deny) on the modal.
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
pane.handle_key_event(KeyEvent::new(KeyCode::Char('n'), KeyModifiers::NONE));
|
||||
|
||||
// After denial, since the task is still running, the status indicator
|
||||
// should be restored as the active view; the composer should NOT be visible.
|
||||
assert!(
|
||||
pane.status_view_active,
|
||||
"status view should be active after denial"
|
||||
);
|
||||
assert!(pane.active_view.is_some(), "active view should be present");
|
||||
|
||||
// Render and ensure the top row includes the Working header instead of the composer.
|
||||
// Give the animation thread a moment to tick.
|
||||
std::thread::sleep(std::time::Duration::from_millis(120));
|
||||
let area = Rect::new(0, 0, 40, 3);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
let mut row0 = String::new();
|
||||
for x in 0..area.width {
|
||||
row0.push(buf[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
row0.contains("Working"),
|
||||
"expected Working header after denial: {row0:?}"
|
||||
);
|
||||
|
||||
// Drain the channel to avoid unused warnings.
|
||||
drop(rx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_indicator_visible_during_command_execution() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Begin a task: show initial status.
|
||||
pane.set_task_running(true);
|
||||
pane.update_status_text("waiting for model".to_string());
|
||||
|
||||
// As a long-running command begins (post-approval), ensure the status
|
||||
// indicator is visible while we wait for the command to run.
|
||||
pane.update_status_text("running command".to_string());
|
||||
|
||||
// Allow some frames so the animation thread ticks.
|
||||
std::thread::sleep(std::time::Duration::from_millis(120));
|
||||
|
||||
// Render and confirm the line contains the "Working" header.
|
||||
let area = Rect::new(0, 0, 40, 3);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
let mut row0 = String::new();
|
||||
for x in 0..area.width {
|
||||
row0.push(buf[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
row0.contains("Working"),
|
||||
"expected Working header: {row0:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bottom_padding_present_for_status_view() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
// Activate spinner (status view replaces composer) with no live ring.
|
||||
pane.set_task_running(true);
|
||||
pane.update_status_text("waiting for model".to_string());
|
||||
|
||||
// Use height == desired_height; expect 1 status row at top and 2 bottom padding rows.
|
||||
let height = pane.desired_height(30);
|
||||
assert!(
|
||||
height >= 3,
|
||||
"expected at least 3 rows with bottom padding; got {height}"
|
||||
);
|
||||
let area = Rect::new(0, 0, 30, height);
|
||||
let mut buf = Buffer::empty(area);
|
||||
(&pane).render_ref(area, &mut buf);
|
||||
|
||||
// Top row contains the status header
|
||||
let mut top = String::new();
|
||||
for x in 0..area.width {
|
||||
top.push(buf[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert_eq!(buf[(0, 0)].symbol().chars().next().unwrap_or(' '), '▌');
|
||||
assert!(
|
||||
top.contains("Working"),
|
||||
"expected Working header on top row: {top:?}"
|
||||
);
|
||||
|
||||
// Bottom two rows are blank padding
|
||||
let mut r_last = String::new();
|
||||
let mut r_last2 = String::new();
|
||||
for x in 0..area.width {
|
||||
r_last.push(buf[(x, height - 1)].symbol().chars().next().unwrap_or(' '));
|
||||
r_last2.push(buf[(x, height - 2)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
r_last.trim().is_empty(),
|
||||
"expected last row blank: {r_last:?}"
|
||||
);
|
||||
assert!(
|
||||
r_last2.trim().is_empty(),
|
||||
"expected second-to-last row blank: {r_last2:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bottom_padding_shrinks_when_tiny() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
enhanced_keys_supported: false,
|
||||
});
|
||||
|
||||
pane.set_task_running(true);
|
||||
pane.update_status_text("waiting for model".to_string());
|
||||
|
||||
// Height=2 → pad shrinks to 1; bottom row is blank, top row has spinner.
|
||||
let area2 = Rect::new(0, 0, 20, 2);
|
||||
let mut buf2 = Buffer::empty(area2);
|
||||
(&pane).render_ref(area2, &mut buf2);
|
||||
let mut row0 = String::new();
|
||||
let mut row1 = String::new();
|
||||
for x in 0..area2.width {
|
||||
row0.push(buf2[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
row1.push(buf2[(x, 1)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
row0.contains("Working"),
|
||||
"expected Working header on row 0: {row0:?}"
|
||||
);
|
||||
assert!(
|
||||
row1.trim().is_empty(),
|
||||
"expected bottom padding on row 1: {row1:?}"
|
||||
);
|
||||
|
||||
// Height=1 → no padding; single row is the spinner.
|
||||
let area1 = Rect::new(0, 0, 20, 1);
|
||||
let mut buf1 = Buffer::empty(area1);
|
||||
(&pane).render_ref(area1, &mut buf1);
|
||||
let mut only = String::new();
|
||||
for x in 0..area1.width {
|
||||
only.push(buf1[(x, 0)].symbol().chars().next().unwrap_or(' '));
|
||||
}
|
||||
assert!(
|
||||
only.contains("Working"),
|
||||
"expected Working header with no padding: {only:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,4 +11,4 @@ expression: terminal.backend()
|
||||
"▌ "
|
||||
"▌ "
|
||||
"▌ "
|
||||
" ⏎ send Shift+⏎ newline Ctrl+C quit "
|
||||
" ⏎ send Ctrl+J newline Ctrl+C quit "
|
||||
|
||||
@@ -11,4 +11,4 @@ expression: terminal.backend()
|
||||
"▌ "
|
||||
"▌ "
|
||||
"▌ "
|
||||
" ⏎ send Shift+⏎ newline Ctrl+C quit "
|
||||
" ⏎ send Ctrl+J newline Ctrl+C quit "
|
||||
|
||||
@@ -11,4 +11,4 @@ expression: terminal.backend()
|
||||
"▌ "
|
||||
"▌ "
|
||||
"▌ "
|
||||
" ⏎ send Shift+⏎ newline Ctrl+C quit "
|
||||
" ⏎ send Ctrl+J newline Ctrl+C quit "
|
||||
|
||||
@@ -11,4 +11,4 @@ expression: terminal.backend()
|
||||
"▌ "
|
||||
"▌ "
|
||||
"▌ "
|
||||
" ⏎ send Shift+⏎ newline Ctrl+C quit "
|
||||
" ⏎ send Ctrl+J newline Ctrl+C quit "
|
||||
|
||||
@@ -11,4 +11,4 @@ expression: terminal.backend()
|
||||
"▌ "
|
||||
"▌ "
|
||||
"▌ "
|
||||
" ⏎ send Shift+⏎ newline Ctrl+C quit "
|
||||
" ⏎ send Ctrl+J newline Ctrl+C quit "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user